Skip to content

Commit d966ddf

Browse files
olegshaldybinOrbax Authors
authored andcommitted
obm: Match TF input/output signature argument names.
During serving the argument names in TF SavedModel SignatureDef must match the argument names saved in TfConcreteFunctionHandle. Previously this was achieved by renaming the flattened input/argument names to `input_{i}` and `output_{i}`. This change aims to use the same argument names as TF SavedModel, because some tooling may expect the names in the signature definition to match the original argument names. The input names are generally derived from the original argument function argument names and/or tensor names in the input signature, potentially mangled internally by TF. We source them from the saved function definition. The output names match TF's naming scheme: - `output_{i}` for lists, tuples and scalar return values; - `{key}` for dictionaries (including registered dataclasses that can be flattened to a dictionary). In addition we support nested data structures as TF function outputs. Normally these are not supported by TF but since we flatten the function outputs we have to support nested output shapes (for backward compatibility). Nested outputs naming follows a simple dot-separated key path derived from JAX tree flattening on the original output signature. PiperOrigin-RevId: 834443321
1 parent 83207b6 commit d966ddf

File tree

2 files changed

+244
-168
lines changed

2 files changed

+244
-168
lines changed

model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm.py

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838
TF_SAVED_MODEL_SUPPLEMENTAL_NAME = 'tensorflow_saved_model'
3939

40-
_INPUT_NAME_PREFIX = 'input'
4140
_OUTPUT_NAME_PREFIX = 'output'
4241

4342

@@ -53,9 +52,6 @@ def is_args_kwargs_pattern(tree: utils.TfSignature) -> bool:
5352
)
5453

5554

56-
_NamesAndSequence = Tuple[Sequence[str], Sequence[Any]]
57-
58-
5955
def tf_concrete_function_name_to_obm_function(
6056
name: str,
6157
*,
@@ -95,10 +91,8 @@ def tf_concrete_function_name_to_obm_function(
9591
input_signature = utils.get_input_signature(fn)
9692
output_signature = utils.get_output_signature(fn)
9793

98-
input_names, _, _ = _get_flat_signature(input_signature, _INPUT_NAME_PREFIX)
99-
output_names, _, _ = _get_flat_signature(
100-
output_signature, _OUTPUT_NAME_PREFIX
101-
)
94+
input_names, _, _ = _flat_input_signature(fn)
95+
output_names = _output_names(fn)
10296
unstructured_data = obm.manifest_pb2.UnstructuredData(
10397
inlined_bytes=tf_concrete_function_handle_pb2.TfConcreteFunctionHandle(
10498
fn_name=name,
@@ -260,33 +254,62 @@ class SignatureFlat(NamedTuple):
260254
tree_def: jax_tree_util.PyTreeDef
261255

262256

263-
# We choose to rely solely on a concrete function's TF signature to
264-
# determine its argument names, not using any other information (such
265-
# as the argument names in the original Python `def`, or the `name`
266-
# field in `TensorSpec`). Currently in TF SavedModel, if a concrete
267-
# function's TF signature is a list, SavedModel may use the argument
268-
# names in the original Python `def` to generate a keyword-based
269-
# version of this function (which is needed for Servomatic which only
270-
# supports keyword-based calling conventions). We think relying on
271-
# this SavedModel behavior is a mistake and the user should make the
272-
# TF signature a dict instead if they want to serve the function on
273-
# Servomatic. If we find that there are too many users relying on this
274-
# SavedModel behavior, we can revisit the decision here.
275-
def _get_flat_signature(
276-
signature: utils.TfSignature, name_prefix: str
257+
def _flat_input_signature(
258+
fn: tf.types.experimental.ConcreteFunction,
277259
) -> SignatureFlat:
278-
"""Gets the flattened signature.
279-
280-
Args:
281-
signature: The TF signature.
282-
name_prefix: The prefix for generating names.
283-
284-
Returns:
285-
A SignatureFlat object `(names, leaves, treedef)`.
286-
"""
287-
leaves, tree_def = jax_tree_util.tree_flatten(signature)
288-
names = tuple(f'{name_prefix}_{i}' for i in range(len(leaves)))
289-
return SignatureFlat(names, leaves, tree_def)
260+
"""Returns the flattened input signature of the given function."""
261+
leaves, tree_def = jax_tree_util.tree_flatten(utils.get_input_signature(fn))
262+
# The argument names in SavedModel's SignatureDef may not match the names in
263+
# the input signature due to internal name mangling, hence we're looking
264+
# it up in the FunctionDef.
265+
input_names = [arg.name for arg in fn.function_def.signature.input_arg]
266+
if len(input_names) < len(leaves):
267+
# There could be more arguments in the FunctionDef than in the input
268+
# signature, because it also contains the captured inputs appended
269+
# to the flattened list of the input arguments.
270+
raise ValueError(
271+
f'The number of input arguments in FunctionDef ({len(input_names)}) is'
272+
' smaller than the number of leaves in the flattened input signature'
273+
f' ({len(leaves)})'
274+
)
275+
return SignatureFlat(input_names[: len(leaves)], leaves, tree_def)
276+
277+
278+
def _output_name_for_key(key: Any) -> str:
279+
if isinstance(key, jax_tree_util.SequenceKey):
280+
return f'{_OUTPUT_NAME_PREFIX}_{key.idx}'
281+
elif isinstance(key, jax_tree_util.DictKey):
282+
# The order is stable as guaranteed by `jax.tree.flatten`.
283+
return f'{key.key}'
284+
elif isinstance(key, jax_tree_util.GetAttrKey):
285+
return f'{key.name}'
286+
raise ValueError(f'Invalid output key type: {key}')
287+
288+
289+
def _output_name(path: Sequence[Any]) -> str:
290+
"""Returns the output name based on its path in the output signature."""
291+
if not path:
292+
# Scalar return value (single tensor).
293+
return f'{_OUTPUT_NAME_PREFIX}_0'
294+
295+
# Multiple levels of nesting is normally not suppported for
296+
# TF concrete function outputs. However, we already
297+
# support the case of nested sturctures in Orbax TF export,
298+
# so we will explicitly support nested structures here.
299+
return '.'.join(_output_name_for_key(k) for k in path)
300+
301+
302+
def _output_names(
303+
fn: tf.types.experimental.ConcreteFunction,
304+
) -> Sequence[str]:
305+
"""Returns the flattened output signature of the given function."""
306+
leaves_with_path = jax_tree_util.tree_leaves_with_path(
307+
utils.get_output_signature(fn)
308+
)
309+
if not leaves_with_path:
310+
return []
311+
paths, _ = zip(*leaves_with_path)
312+
return [_output_name(path) for path in paths]
290313

291314

292315
def to_keyword_only_fn(
@@ -300,12 +323,8 @@ def to_keyword_only_fn(
300323
Returns:
301324
The wrapped function (also a TF concrete function).
302325
"""
303-
input_names, input_leaves, input_def = _get_flat_signature(
304-
utils.get_input_signature(f), _INPUT_NAME_PREFIX
305-
)
306-
output_names, _, _ = _get_flat_signature(
307-
utils.get_output_signature(f), _OUTPUT_NAME_PREFIX
308-
)
326+
input_names, input_leaves, input_def = _flat_input_signature(f)
327+
output_names = _output_names(f)
309328

310329
if input_names is None and output_names is None:
311330
return f

0 commit comments

Comments
 (0)