Skip to content

Commit ebba7d6

Browse files
update get_ov_output
1 parent c278dbf commit ebba7d6

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

keras/src/backend/openvino/core.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,25 +107,27 @@ def get_ov_output(x, ov_type=None):
107107
x = ov_opset.constant(x, ov_type).output(0)
108108
elif isinstance(x, (list, tuple)):
109109
if len(x) == 0:
110-
raise ValueError(f"Input {type(x)} must not be empty.")
111-
constants = []
112-
for i, element in enumerate(x):
113-
elem_output = get_ov_output(element, ov_type)
114-
elem_shape = elem_output.get_partial_shape()
110+
raise ValueError(f"Input {type(x).__name__} must not be empty.")
111+
constants = [get_ov_output(element, ov_type) for element in x]
112+
for i, c in enumerate(constants):
113+
elem_shape = c.get_partial_shape()
115114
if elem_shape.rank.get_length() == 0:
116-
elem_output = ov_opset.unsqueeze(elem_output, 0).output(0)
117-
elem_shape = elem_output.get_partial_shape()
118-
if i > 0:
119-
first_shape = constants[0].get_partial_shape()
115+
constants[i] = ov_opset.unsqueeze(c, 0).output(0)
116+
if len(constants) > 1:
117+
first_shape = constants[0].get_partial_shape()
118+
for i, c in enumerate(constants[1:], 1):
119+
elem_shape = c.get_partial_shape()
120120
if not first_shape.compatible(elem_shape):
121121
raise ValueError(
122-
"Shapes of elements must match."
122+
"Shapes of elements must match.\n"
123123
f"Got {first_shape} e(0) vs {elem_shape} e({i})."
124124
)
125-
constants[0], elem_output = align_operand_types(
126-
constants[0], elem_output, "list/tuple concatenation"
127-
)
128-
constants.append(elem_output)
125+
keras_types = [ov_to_keras_type(c.element_type) for c in constants]
126+
result_keras_type = dtypes.result_type(*keras_types)
127+
result_ov_type = OPENVINO_DTYPES[result_keras_type]
128+
for i, c in enumerate(constants):
129+
if c.element_type != result_ov_type:
130+
constants[i] = ov_opset.convert(c, result_ov_type).output(0)
129131
x = ov_opset.concat(constants, axis=0).output(0)
130132
elif isinstance(x, np.ndarray):
131133
if x.dtype == np.dtype("bfloat16"):

0 commit comments

Comments
 (0)