Skip to content

Commit 7f71b1b

Browse files
hejiang0116Orbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 834940834
1 parent d966ddf commit 7f71b1b

File tree

9 files changed

+165
-44
lines changed

9 files changed

+165
-44
lines changed

export/orbax/export/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,18 @@ class ExportModelType(enum.Enum):
9797
# Mesh for the model.
9898
JAX_MESH = 'jax_mesh'
9999

100+
# TODO: b/459991985 - Remove this flag and use PERSIST_XLA_FLAGS instead.
100101
# Whether to strip XLA flags from the model.
101102
STRIP_XLA_FLAGS = 'strip_xla_flags'
102103

104+
# Whether to persist XLA flags in the model.
105+
PERSIST_XLA_FLAGS = 'persist_xla_flags'
106+
107+
# Whether to enable bf16 optimization for the model.
108+
# TODO_REGEX: b/422170690: (1): Apply this flag to the pre/post processors. (2):
109+
# Adding filter flags once the flag is applied to the pre/post processors.
110+
ENABLE_BF16_OPTIMIZATION = 'enable_bf16_optimization'
111+
103112
################################################################################
104113
# Proto field names
105114
################################################################################

export/orbax/export/data_processors/tf_data_processor_test.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -56,26 +56,23 @@ def test_output_signature_raises_error_without_calling_prepare(self):
5656
_ = processor.output_signature
5757

5858
def test_prepare_fails_with_multiple_calls(self):
59-
processor = tf_data_processor.TfDataProcessor(lambda x: x)
59+
processor = tf_data_processor.TfDataProcessor(lambda x: x, name='add')
6060
processor.prepare(
61-
'add',
62-
input_signature=(tf.TensorSpec([None, 3], tf.float32),),
61+
(tf.TensorSpec([None, 3], tf.float32),),
6362
)
6463
with self.assertRaisesWithLiteralMatch(
6564
RuntimeError, '`prepare()` can only be called once.'
6665
):
6766
processor.prepare(
68-
'add',
69-
input_signature=(tf.TensorSpec([None, 3], tf.float32),),
67+
(tf.TensorSpec([None, 3], tf.float32),),
7068
)
7169

7270
def test_prepare_succeeds(self):
7371
processor = tf_data_processor.TfDataProcessor(
74-
tf.function(lambda x, y: x + y)
72+
tf.function(lambda x, y: x + y), name='add'
7573
)
7674
processor.prepare(
77-
'add',
78-
input_signature=(
75+
(
7976
tf.TensorSpec([None, 3], tf.float64),
8077
tf.TensorSpec([None, 3], tf.float64),
8178
),
@@ -107,10 +104,11 @@ def test_prepare_polymorphic_function_with_default_input_signature(self):
107104
def preprocessor_callable(x, y):
108105
return x + y
109106

110-
processor = tf_data_processor.TfDataProcessor(preprocessor_callable)
107+
processor = tf_data_processor.TfDataProcessor(
108+
preprocessor_callable, name='add'
109+
)
111110
processor.prepare(
112-
'add',
113-
input_signature=(
111+
(
114112
tf.TensorSpec([None, 3], tf.float32),
115113
tf.TensorSpec([None, 3], tf.float32),
116114
),
@@ -136,25 +134,27 @@ def test_suppress_x64_output(self):
136134
processor = tf_data_processor.TfDataProcessor(
137135
tf.function(
138136
lambda x, y: tf.cast(x, tf.float64) + tf.cast(y, tf.float64)
139-
)
137+
),
138+
name='add_f64',
140139
)
141140
input_signature = (
142141
tf.TensorSpec([None, 3], tf.float32),
143142
tf.TensorSpec([None, 3], tf.float32),
144143
)
145144

146145
# With suppress_x64_output=True, f64 output is suppressed to f32.
147-
processor.prepare('add_f64', input_signature, suppress_x64_output=True)
146+
processor.prepare(input_signature, suppress_x64_output=True)
148147
self.assertEqual(
149148
processor.output_signature,
150149
obm.ShloTensorSpec(shape=(None, 3), dtype=obm.ShloDType.f32),
151150
)
152151

153152
def test_convert_to_bfloat16(self):
154-
processor = tf_data_processor.TfDataProcessor(lambda x: 0.5 + x)
153+
processor = tf_data_processor.TfDataProcessor(
154+
lambda x: 0.5 + x, name='preprocessor'
155+
)
155156
processor.prepare(
156-
'preprocessor',
157-
input_signature=(tf.TensorSpec((), tf.float32)),
157+
(tf.TensorSpec((), tf.float32)),
158158
bfloat16_options=converter_options_v2_pb2.ConverterOptionsV2(
159159
bfloat16_optimization_options=converter_options_v2_pb2.BFloat16OptimizationOptions(
160160
scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL,
@@ -168,15 +168,16 @@ def test_convert_to_bfloat16(self):
168168
)
169169

170170
def test_bfloat16_convert_error(self):
171-
processor = tf_data_processor.TfDataProcessor(lambda x: 0.5 + x)
171+
processor = tf_data_processor.TfDataProcessor(
172+
lambda x: 0.5 + x, name='preprocessor'
173+
)
172174
with self.assertRaisesRegex(
173175
google_error.StatusNotOk,
174176
'Found bfloat16 ops in the model. The model may have been converted'
175177
' before. It should not be converted again.',
176178
):
177179
processor.prepare(
178-
'preprocessor',
179-
input_signature=(tf.TensorSpec((), tf.bfloat16)),
180+
(tf.TensorSpec((), tf.bfloat16)),
180181
bfloat16_options=converter_options_v2_pb2.ConverterOptionsV2(
181182
bfloat16_optimization_options=converter_options_v2_pb2.BFloat16OptimizationOptions(
182183
scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL,
@@ -185,12 +186,9 @@ def test_bfloat16_convert_error(self):
185186
)
186187

187188
def test_prepare_with_shlo_bf16_inputs(self):
188-
processor = tf_data_processor.TfDataProcessor(lambda x: x)
189+
processor = tf_data_processor.TfDataProcessor(lambda x: x, name='identity')
189190
processor.prepare(
190-
'identity',
191-
input_signature=(
192-
obm.ShloTensorSpec(shape=(1,), dtype=obm.ShloDType.bf16),
193-
),
191+
(obm.ShloTensorSpec(shape=(1,), dtype=obm.ShloDType.bf16),),
194192
)
195193
self.assertEqual(
196194
processor.concrete_function.structured_input_signature[0][0].dtype,

export/orbax/export/jax_module.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,16 @@ def jax2tf_kwargs_map(self) -> Mapping[str, Any]:
197197
tensorflow_module.TensorFlowModule, self._export_module
198198
).jax2tf_kwargs_map
199199

200+
@property
201+
def jax2obm_kwargs(self) -> Mapping[str, Any]:
202+
"""Returns the jax2obm_kwargs."""
203+
if self._export_version == constants.ExportModelType.TF_SAVEDMODEL:
204+
raise TypeError(
205+
'jax2obm_kwargs is not implemented for export version'
206+
' ExportModelType.TF_SAVEDMODEL.'
207+
)
208+
return cast(obm_module.ObmModule, self._export_module).jax2obm_kwargs
209+
200210
@property
201211
def input_polymorphic_shape_map(self) -> Mapping[str, PyTree]:
202212
"""Returns the polymorphic shapes."""

export/orbax/export/modules/obm_module.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,34 +73,43 @@ def __init__(
7373
)
7474

7575
# It is possible for jax2obm_kwargs to be None if the key is present.
76-
if not jax2obm_kwargs:
77-
jax2obm_kwargs = {}
7876

77+
self._jax2obm_kwargs = jax2obm_kwargs if jax2obm_kwargs else {}
78+
79+
enable_bf16_optimization = self.jax2obm_kwargs.get(
80+
constants.ENABLE_BF16_OPTIMIZATION, False
81+
)
82+
83+
if enable_bf16_optimization:
84+
mapped_apply_fn = utils.to_bfloat16(apply_fn)
85+
self._params_args_spec = utils.to_bfloat16(params)
86+
else:
87+
mapped_apply_fn = apply_fn
88+
self._params_args_spec = params
7989
(
8090
self._apply_fn_map,
8191
self.input_polymorphic_shape_map,
8292
self.input_polymorphic_shape_symbol_values_map,
8393
) = self._normalize_apply_fn_map(
84-
apply_fn,
94+
mapped_apply_fn,
8595
input_polymorphic_shape,
8696
input_polymorphic_shape_symbol_values,
8797
)
8898

89-
self._jax_mesh = jax2obm_kwargs.get(constants.JAX_MESH, None)
90-
self._strip_xla_flags = jax2obm_kwargs.get(constants.STRIP_XLA_FLAGS, False)
91-
92-
self.polymorphic_constraints = self._maybe_set_polymorphic_constraints(
93-
jax2obm_kwargs
99+
self._jax_mesh = self.jax2obm_kwargs.get(constants.JAX_MESH, None)
100+
self._strip_xla_flags = self.jax2obm_kwargs.get(
101+
constants.STRIP_XLA_FLAGS, False
94102
)
103+
104+
self.polymorphic_constraints = self._maybe_set_polymorphic_constraints()
95105
self._native_serialization_platforms = utils.get_lowering_platforms(
96-
jax2obm_kwargs
106+
self.jax2obm_kwargs
97107
)
98-
self._params_args_spec = params
99108

100109
self._checkpoint_path: str = None
101110
# Set the Orbax checkpoint path if provided in the jax2obm_kwargs.
102-
self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs)
103-
self._load_all_checkpoint_weights = jax2obm_kwargs.get(
111+
self._maybe_set_orbax_checkpoint_path(self.jax2obm_kwargs)
112+
self._load_all_checkpoint_weights = self.jax2obm_kwargs.get(
104113
constants.LOAD_ALL_CHECKPOINT_WEIGHTS, False
105114
)
106115

@@ -203,15 +212,9 @@ def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs):
203212
else constants.DEFAULT_WEIGHTS_NAME
204213
)
205214

206-
def _maybe_set_polymorphic_constraints(
207-
self, jax2obm_kwargs
208-
) -> Mapping[str, Sequence[Any]]:
215+
def _maybe_set_polymorphic_constraints(self) -> Mapping[str, Sequence[Any]]:
209216
"""Sets the polymorphic constraints for the model.
210217
211-
Args:
212-
jax2obm_kwargs: A dictionary of kwargs to pass to the jax2obm conversion
213-
library.
214-
215218
Returns:
216219
A mapping of function name to polymorphic constraints.
217220
@@ -221,7 +224,7 @@ def _maybe_set_polymorphic_constraints(
221224
size of the apply_fn_map or if a key in apply_fn_map is not found in
222225
polymorphic_constraints.
223226
"""
224-
polymorphic_constraints = jax2obm_kwargs.get(
227+
polymorphic_constraints = self.jax2obm_kwargs.get(
225228
constants.POLYMORPHIC_CONSTRAINTS, None
226229
)
227230
if not isinstance(polymorphic_constraints, Mapping):
@@ -300,3 +303,8 @@ def methods(self) -> Mapping[str, Callable[..., Any]]:
300303
def jax_methods(self) -> Mapping[str, Callable[..., Any]]:
301304
"""Named methods in JAX context for validation."""
302305
raise NotImplementedError('apply_fn_map is not implemented for ObmModule.')
306+
307+
@property
308+
def jax2obm_kwargs(self) -> Mapping[str, Any]:
309+
"""Returns the jax2obm_kwargs."""
310+
return self._jax2obm_kwargs

export/orbax/export/modules/obm_module_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,32 @@ def test_obm_module_multiple_apply_fns(
357357
jax2obm_kwargs=jax2obm_kwargs,
358358
)
359359

360+
@parameterized.named_parameters(
361+
{'testcase_name': 'enable_bf16', 'enable_bf16_optimization': True},
362+
{'testcase_name': 'disable_bf16', 'enable_bf16_optimization': False},
363+
)
364+
def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization):
365+
params_spec = {
366+
'w': jax.ShapeDtypeStruct((2, 2), jnp.float32),
367+
'b': jax.ShapeDtypeStruct((2,), jnp.float32),
368+
}
369+
input_spec = {constants.DEFAULT_METHOD_KEY: 'b, ...'}
370+
371+
module = obm_module.ObmModule(
372+
params=params_spec,
373+
apply_fn=_linear,
374+
input_polymorphic_shape=input_spec,
375+
jax2obm_kwargs={
376+
constants.ENABLE_BF16_OPTIMIZATION: enable_bf16_optimization
377+
},
378+
)
379+
380+
expected_dtype = jnp.bfloat16 if enable_bf16_optimization else jnp.float32
381+
with self.subTest('test_weights_w_dtype'):
382+
self.assertEqual(module.model_params['w'].dtype, expected_dtype)
383+
with self.subTest('test_weights_b_dtype'):
384+
self.assertEqual(module.model_params['b'].dtype, expected_dtype)
385+
360386

361387
if __name__ == '__main__':
362388
absltest.main()

export/orbax/export/oex_orchestration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414

1515
"""Pipeline: pre-processor + model-function + post-processor."""
1616

17+
import dataclasses
1718
from typing import Any, Dict, List, Sequence, Tuple, TypeVar
1819

1920
from absl import logging
2021
import jax
22+
import jaxtyping
23+
from orbax.export.data_processors import data_processor_base
24+
from orbax.export.modules import obm_module

export/orbax/export/serving_config.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import jax
2222
import jaxtyping
2323
from orbax.export.data_processors import data_processor_base
24+
from orbax.export.data_processors import tf_data_processor
2425
import tensorflow as tf
2526

2627

@@ -108,6 +109,32 @@ def get_signature_keys(self) -> Sequence[str]:
108109
else:
109110
return self.signature_key
110111

112+
def get_preprocessors(self) -> Sequence[data_processor_base.DataProcessor]:
113+
"""Returns the preprocessors for this serving config."""
114+
if self.preprocessors:
115+
return self.preprocessors
116+
elif self.tf_preprocessor:
117+
return [
118+
tf_data_processor.TfDataProcessor(
119+
self.tf_preprocessor,
120+
)
121+
]
122+
else:
123+
return []
124+
125+
def get_postprocessors(self) -> Sequence[data_processor_base.DataProcessor]:
126+
"""Returns the postprocessors for this serving config."""
127+
if self.postprocessors:
128+
return self.postprocessors
129+
elif self.tf_postprocessor:
130+
return [
131+
tf_data_processor.TfDataProcessor(
132+
self.tf_postprocessor,
133+
)
134+
]
135+
else:
136+
return []
137+
111138
def get_input_signature(self, required=True) -> Any:
112139
"""Gets the input signature from the explict one or tf_preprocessor."""
113140
input_signature = self.input_signature

export/orbax/export/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import dataclasses
1919
import functools
2020
import inspect
21+
import jax.numpy as jnp
2122
import os
2223
from typing import Any, Callable, List, Optional, Tuple, Union
2324

@@ -532,3 +533,40 @@ def get_lowering_platforms(
532533
)
533534

534535
return native_serialization_platforms
536+
537+
538+
def to_bfloat16(x: Any) -> Any:
539+
"""Helper to convert leaves of a pytree to bfloat16.
540+
541+
It handles `float`, `jax.ShapeDtypeStruct`, and other array-like objects with
542+
a floating point `dtype`.
543+
544+
Args:
545+
x: The input pytree to convert.
546+
547+
Returns:
548+
The input `x` with floating point values converted to `jnp.bfloat16`.
549+
"""
550+
551+
def _to_bfloat16_leaf(x: Any) -> Any:
552+
if isinstance(x, jax.ShapeDtypeStruct) and jnp.issubdtype(
553+
x.dtype, jnp.floating
554+
):
555+
return jax.ShapeDtypeStruct(
556+
x.shape,
557+
jnp.bfloat16,
558+
sharding=x.sharding,
559+
)
560+
if isinstance(x, jax.ShapeDtypeStruct):
561+
return x
562+
if hasattr(x, 'dtype') and jnp.issubdtype(x.dtype, jnp.floating):
563+
return x.astype(jnp.bfloat16)
564+
if isinstance(x, float):
565+
return jnp.bfloat16(x)
566+
return x
567+
568+
flattened_x, treedef = jax.tree_util.tree_flatten(x)
569+
flattened_y = [
570+
jax.tree_util.tree_map(_to_bfloat16_leaf, y) for y in flattened_x
571+
]
572+
return jax.tree_util.tree_unflatten(treedef, flattened_y)

model/orbax/experimental/model/cli/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
A command-line tool for inspecting Orbax models.
44

5+
56
## Examples
67

78
To inspect the model:

0 commit comments

Comments
 (0)