Skip to content

[OpenVINO backend] support export model from the supported backends to openvino format #21486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 24, 2025
2 changes: 2 additions & 0 deletions keras/src/backend/openvino/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,8 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
dtype = standardize_dtype(type(x))
ov_type = OPENVINO_DTYPES[dtype]
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0), x)
elif isinstance(x, ov.Output):
return OpenVINOKerasTensor(x)
if isinstance(x, Variable):
x = x.value
if dtype and dtype != x.dtype:
Expand Down
1 change: 1 addition & 0 deletions keras/src/export/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from keras.src.export.onnx import export_onnx
from keras.src.export.openvino import export_openvino
from keras.src.export.saved_model import ExportArchive
from keras.src.export.saved_model import export_saved_model
from keras.src.export.tfsm_layer import TFSMLayer
178 changes: 178 additions & 0 deletions keras/src/export/openvino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import warnings

from keras.src import backend
from keras.src import tree
from keras.src.export.export_utils import convert_spec_to_tensor
from keras.src.export.export_utils import get_input_signature
from keras.src.export.export_utils import make_tf_tensor_spec
from keras.src.export.saved_model import DEFAULT_ENDPOINT_NAME
from keras.src.export.saved_model import ExportArchive
from keras.src.utils import io_utils


def export_openvino(
model, filepath, verbose=None, input_signature=None, **kwargs
):
"""Export the model as an OpenVINO IR artifact for inference.

This method exports the model to the OpenVINO IR format,
which includes two files:
a `.xml` file containing the model structure and a `.bin` file
containing the weights.
The exported model contains only the forward pass
(i.e., the model's `call()` method), and can be deployed with the
OpenVINO Runtime for fast inference on CPU and other Intel hardware.

Args:
filepath: `str` or `pathlib.Path`. Path to the output `.xml` file.
The corresponding `.bin` file will be saved alongside it.
verbose: Optional `bool`. Whether to print a confirmation message
after export. If `None`, it uses the default verbosity configured
by the backend.
input_signature: Optional. Specifies the shape and dtype of the
model inputs. If not provided, it will be inferred.
**kwargs: Additional keyword arguments.

Example:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add line break above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


```python
import keras

# Define or load a Keras model
model = keras.models.Sequential([
keras.layers.Input(shape=(128,)),
keras.layers.Dense(64, activation="relu"),
keras.layers.Dense(10)
])

# Export to OpenVINO IR
model.export("model.xml", format="openvino")
```
"""
assert filepath.endswith(".xml"), (
"The OpenVINO export requires the filepath to end with '.xml'. "
f"Got: {filepath}"
)

import openvino as ov
from openvino.runtime import opset14 as ov_opset

from keras.src.backend.openvino.core import OPENVINO_DTYPES
from keras.src.backend.openvino.core import OpenVINOKerasTensor

actual_verbose = verbose if verbose is not None else True

if input_signature is None:
input_signature = get_input_signature(model)

if backend.backend() == "openvino":
import inspect

def parameterize_inputs(inputs, prefix=""):
if isinstance(inputs, (list, tuple)):
return [
parameterize_inputs(e, f"{prefix}{i}")
for i, e in enumerate(inputs)
]
elif isinstance(inputs, dict):
return {k: parameterize_inputs(v, k) for k, v in inputs.items()}
elif isinstance(inputs, OpenVINOKerasTensor):
ov_type = OPENVINO_DTYPES[str(inputs.dtype)]
ov_shape = list(inputs.shape)
param = ov_opset.parameter(shape=ov_shape, dtype=ov_type)
param.set_friendly_name(prefix)
return OpenVINOKerasTensor(param.output(0))
else:
raise TypeError(f"Unknown input type: {type(inputs)}")

if isinstance(input_signature, list) and len(input_signature) == 1:
input_signature = input_signature[0]

sample_inputs = tree.map_structure(
lambda x: convert_spec_to_tensor(x, replace_none_number=1),
input_signature,
)
params = parameterize_inputs(sample_inputs)
signature = inspect.signature(model.call)
if len(signature.parameters) > 1 and isinstance(params, (list, tuple)):
outputs = model(*params)
else:
outputs = model(params)
parameters = [p.output.get_node() for p in tree.flatten(params)]
results = [ov_opset.result(r.output) for r in tree.flatten(outputs)]
ov_model = ov.Model(results=results, parameters=parameters)
flat_specs = tree.flatten(input_signature)
for ov_input, spec in zip(ov_model.inputs, flat_specs):
# Respect the dynamic axes from the original input signature.
dynamic_shape_dims = [
-1 if dim is None else dim for dim in spec.shape
]
dynamic_shape = ov.PartialShape(dynamic_shape_dims)
ov_input.get_node().set_partial_shape(dynamic_shape)

elif backend.backend() in ("tensorflow", "jax"):
inputs = tree.map_structure(make_tf_tensor_spec, input_signature)
decorated_fn = get_concrete_fn(model, inputs, **kwargs)
ov_model = ov.convert_model(decorated_fn)
elif backend.backend() == "torch":
import torch

sample_inputs = tree.map_structure(
lambda x: convert_spec_to_tensor(x, replace_none_number=1),
input_signature,
)
sample_inputs = tuple(sample_inputs)
if hasattr(model, "eval"):
model.eval()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
traced = torch.jit.trace(model, sample_inputs)
ov_model = ov.convert_model(traced)
else:
raise NotImplementedError(
"`export_openvino` is only compatible with OpenVINO, "
"TensorFlow, JAX and Torch backends."
)

ov.serialize(ov_model, filepath)

if actual_verbose:
io_utils.print_msg(f"Saved OpenVINO IR at '{filepath}'.")


def _check_jax_kwargs(kwargs):
kwargs = kwargs.copy()
if "is_static" not in kwargs:
kwargs["is_static"] = True
if "jax2tf_kwargs" not in kwargs:
kwargs["jax2tf_kwargs"] = {
"enable_xla": False,
"native_serialization": False,
}
if kwargs["is_static"] is not True:
raise ValueError(
"`is_static` must be `True` in `kwargs` when using the jax backend."
)
if kwargs["jax2tf_kwargs"]["enable_xla"] is not False:
raise ValueError(
"`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` "
"when using the jax backend."
)
if kwargs["jax2tf_kwargs"]["native_serialization"] is not False:
raise ValueError(
"`native_serialization` must be `False` in "
"`kwargs['jax2tf_kwargs']` when using the jax backend."
)
return kwargs


def get_concrete_fn(model, input_signature, **kwargs):
if backend.backend() == "jax":
kwargs = _check_jax_kwargs(kwargs)
export_archive = ExportArchive()
export_archive.track_and_add_endpoint(
DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs
)
if backend.backend() == "tensorflow":
export_archive._filter_and_track_resources()
return export_archive._get_concrete_fn(DEFAULT_ENDPOINT_NAME)
Loading