Skip to content

Commit c894f7f

Browse files
committed
Add test onnx.Cast bf16 to f16
1 parent 46216e0 commit c894f7f

File tree

1 file changed

+25
-0
lines changed
  • alt_e2eshark/onnx_tests/operators

1 file changed

+25
-0
lines changed
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
from onnx import TensorProto
7+
from onnx.helper import make_node, make_tensor_value_info, make_attribute
8+
9+
from ..helper_classes import BuildAModel
10+
from e2e_testing.registry import register_test
11+
12+
class CastModel(BuildAModel):
13+
def construct_i_o_value_info(self):
14+
self.input_vi = [
15+
make_tensor_value_info("X", TensorProto.BFLOAT16, [1]),
16+
]
17+
self.output_vi = [make_tensor_value_info("Y", TensorProto.FLOAT16, [1])]
18+
19+
def construct_nodes(self):
20+
cast_node = make_node("Cast", ["X"], ["Y"], "castnode")
21+
cast_node.attribute.append(make_attribute("to", 10))
22+
self.node_list.append(cast_node)
23+
24+
register_test(CastModel, "cast_test")
25+

0 commit comments

Comments
 (0)