@@ -107,25 +107,27 @@ def get_ov_output(x, ov_type=None):
107
107
x = ov_opset .constant (x , ov_type ).output (0 )
108
108
elif isinstance (x , (list , tuple )):
109
109
if len (x ) == 0 :
110
- raise ValueError (f"Input { type (x )} must not be empty." )
111
- constants = []
112
- for i , element in enumerate (x ):
113
- elem_output = get_ov_output (element , ov_type )
114
- elem_shape = elem_output .get_partial_shape ()
110
+ raise ValueError (f"Input { type (x ).__name__ } must not be empty." )
111
+ constants = [get_ov_output (element , ov_type ) for element in x ]
112
+ for i , c in enumerate (constants ):
113
+ elem_shape = c .get_partial_shape ()
115
114
if elem_shape .rank .get_length () == 0 :
116
- elem_output = ov_opset .unsqueeze (elem_output , 0 ).output (0 )
117
- elem_shape = elem_output .get_partial_shape ()
118
- if i > 0 :
119
- first_shape = constants [0 ].get_partial_shape ()
115
+ constants [i ] = ov_opset .unsqueeze (c , 0 ).output (0 )
116
+ if len (constants ) > 1 :
117
+ first_shape = constants [0 ].get_partial_shape ()
118
+ for i , c in enumerate (constants [1 :], 1 ):
119
+ elem_shape = c .get_partial_shape ()
120
120
if not first_shape .compatible (elem_shape ):
121
121
raise ValueError (
122
- "Shapes of elements must match."
122
+ "Shapes of elements must match.\n "
123
123
f"Got { first_shape } e(0) vs { elem_shape } e({ i } )."
124
124
)
125
- constants [0 ], elem_output = align_operand_types (
126
- constants [0 ], elem_output , "list/tuple concatenation"
127
- )
128
- constants .append (elem_output )
125
+ keras_types = [ov_to_keras_type (c .element_type ) for c in constants ]
126
+ result_keras_type = dtypes .result_type (* keras_types )
127
+ result_ov_type = OPENVINO_DTYPES [result_keras_type ]
128
+ for i , c in enumerate (constants ):
129
+ if c .element_type != result_ov_type :
130
+ constants [i ] = ov_opset .convert (c , result_ov_type ).output (0 )
129
131
x = ov_opset .concat (constants , axis = 0 ).output (0 )
130
132
elif isinstance (x , np .ndarray ):
131
133
if x .dtype == np .dtype ("bfloat16" ):
0 commit comments