Skip to content

Commit 64c313f

Browse files
hejiang0116Orbax Authors
authored andcommitted
Update sidv2, yt_prod_v3 test data and re-enable the test.
PiperOrigin-RevId: 802622705
1 parent ce363b7 commit 64c313f

File tree

8 files changed

+68
-22
lines changed

8 files changed

+68
-22
lines changed

export/orbax/export/protos/oex_orchestration.proto

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,24 @@ message Pipeline {
2222
// nth pre-processor serves as the input to the (n+1)th pre-processor. while
2323
// multi-stage tf preprocessing is doable, we recommend consolidating the
2424
// steps into a single module to optimize for latency.
25-
repeated string pre_processor_names = 20;
25+
repeated string pre_processor_names = 3;
2626

2727
// The model functions to be executed in the pipeline. The order of the
2828
// model functions determines the order of execution. The output of the nth
2929
// model function serves as the input to the (n+1)th model function.
3030
// Currently, we ONLY support AT MOST ONE model function (thus, having no
3131
// model functions is allowed). See b/421976850.
32-
repeated BoundModelFunction model_functions = 30;
32+
repeated BoundModelFunction model_functions = 4;
3333

3434
// The names (in manifest) of the post-processor functions. For multi-stage
3535
// postprocessing, the postprocessors are exected in sequence. The output of
3636
// the nth post-processor serves as the input to the (n+1)th post-processor.
3737
// while multi-stage tf postprocessing is doable, we recommend consolidating
3838
// the steps into a single module to optimize for latency.
39-
repeated string post_processor_names = 50;
39+
repeated string post_processor_names = 5;
4040

4141
// The batch options for the model. If not set, the model will not be batched.
42-
optional BatchOptions batch_options = 60;
42+
optional BatchOptions batch_options = 6;
4343
}
4444

4545
// A bound model function.
@@ -53,7 +53,7 @@ message BoundModelFunction {
5353

5454
// The name (in manifest) of the value representing the weights PyTree.
5555
// This PyTree will be given as the 1st arg to the model function.
56-
optional string weights_name = 10;
56+
optional string weights_name = 2;
5757
}
5858

5959
message BatchOptions {

model/orbax/experimental/model/core/protos/manifest.proto

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ import "orbax/experimental/model/core/protos/type.proto";
77
message Manifest {
88
map<string, TopLevelObject> objects = 1;
99
map<string, UnstructuredData> supplemental_info = 2;
10-
optional DeviceAssignmentByCoords device_assignment_by_coords = 4;
10+
optional DeviceAssignmentByCoords device_assignment_by_coords = 3;
1111

1212
// A list of paths frequently used in this manifest. Once a path is added to
1313
// this list, use sites can refer to it by its index here instead of spelling
1414
// out the full string.
15-
repeated string frequent_paths = 5;
15+
repeated string frequent_paths = 4;
1616
}
1717

1818
// Copied from /third_party/australis/google/ifrt/ifrt_australis.proto.
@@ -32,7 +32,7 @@ message TopLevelObject {
3232
oneof case {
3333
Function function = 1;
3434
Value value = 2;
35-
PolymorphicFunction poly_fn = 30;
35+
PolymorphicFunction poly_fn = 3;
3636
}
3737
}
3838

@@ -41,10 +41,10 @@ message Function {
4141
FunctionBody body = 2;
4242
// The names of the dependent data. Each of this corresponds to a key in the
4343
// checkpoint metadata.
44-
repeated string data_names = 5;
44+
repeated string data_names = 3;
4545

46-
Visibility visibility = 3;
47-
optional string gradient_function_name = 4;
46+
Visibility visibility = 4;
47+
optional string gradient_function_name = 5;
4848
}
4949

5050
message Value {

model/orbax/experimental/model/core/python/manifest_constants.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,18 @@
1414

1515
"""Manifest model format constants."""
1616

17-
MANIFEST_FILENAME = 'manifest.pb'
17+
MANIFEST_VERSION_FILENAME = 'orbax_model_version.txt'
1818

19-
MANIFEST_VERSION_FILENAME = 'manifest_version.txt'
19+
# The file path of the manifest proto file
20+
MANIFEST_FILE_PATH_KEY = 'manifest_file_path'
21+
# TODO(b/439870164): Update the `MANIFEST_FILENAME` to be `MANIFEST_FILE_PATH`
22+
# and treat it as a configurable path
23+
MANIFEST_FILENAME = 'manifest.pb'
2024

25+
# The version of the manifest
26+
VERSION_KEY = 'version'
2127
MANIFEST_VERSION = '0.0.1'
28+
29+
# The mime type of the manifest proto file
30+
MIME_TYPE_KEY = 'mime_type'
31+
MANIFEST_MIME_TYPE = 'application/protobuf; type=orbax_model_manifest.Manifest'

model/orbax/experimental/model/core/python/manifest_util.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from collections.abc import Mapping, Sequence
1919
from absl import logging
2020
from orbax.experimental.model.core.protos import manifest_pb2
21+
from orbax.experimental.model.core.python import manifest_constants
2122
from orbax.experimental.model.core.python import unstructured_data
2223
from orbax.experimental.model.core.python.device_assignment import DeviceAssignment
2324
from orbax.experimental.model.core.python.function import Function
@@ -28,7 +29,6 @@
2829
from orbax.experimental.model.core.python.unstructured_data import UnstructuredData
2930
from orbax.experimental.model.core.python.value import ExternalValue
3031

31-
3232
def _build_function(
3333
fn: Function,
3434
path: str,
@@ -115,6 +115,28 @@ def _is_seq_of_functions(obj: Saveable) -> bool:
115115
)
116116

117117

118+
def build_manifest_version_file() -> str:
119+
"""Builds a manifest version file content."""
120+
121+
# TODO(b/365967674): Remove this check once the manifest filename is
122+
# configurable by the manifest version file. Currently, the manifest filename
123+
# is hardcoded to "manifest.pb" in OBM & JSV codebase and that needs to be
124+
# updated first.
125+
if manifest_constants.MANIFEST_FILENAME != "manifest.pb":
126+
raise ValueError(
127+
"Currently, only manifest.pb is supported as the manifest filename."
128+
)
129+
130+
return (
131+
f"{manifest_constants.MANIFEST_FILE_PATH_KEY}:"
132+
f' "{manifest_constants.MANIFEST_FILENAME}"\n'
133+
f"{manifest_constants.VERSION_KEY}:"
134+
f' "{manifest_constants.MANIFEST_VERSION}"\n'
135+
f"{manifest_constants.MIME_TYPE_KEY}:"
136+
f' "{manifest_constants.MANIFEST_MIME_TYPE}"\n'
137+
)
138+
139+
118140
def build_manifest_proto(
119141
obm_module: dict[str, Saveable],
120142
path: str,

model/orbax/experimental/model/core/python/manifest_util_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,15 @@ def test_build_device_assignment_by_coords_proto(self):
6464
self.assertEqual(device.core_on_chip, 0) # Proto default
6565

6666

67+
def test_build_manifest_version_file_content(self):
68+
content = manifest_util.build_manifest_version_file()
69+
expected_content = (
70+
'manifest_file_path: "manifest.pb"\n'
71+
'version: "0.0.1"\n'
72+
'mime_type: "application/protobuf; type=orbax_model_manifest.Manifest"\n'
73+
)
74+
self.assertEqual(content, expected_content)
75+
76+
6777
if __name__ == '__main__':
6878
absltest.main()

model/orbax/experimental/model/core/python/save_lib.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from orbax.experimental.model.core.python import unstructured_data
3030
from orbax.experimental.model.core.python.device_assignment import DeviceAssignment
3131
from orbax.experimental.model.core.python.manifest_util import build_manifest_proto
32+
from orbax.experimental.model.core.python.manifest_util import build_manifest_version_file
3233
from orbax.experimental.model.core.python.saveable import Saveable
3334
from orbax.experimental.model.core.python.unstructured_data import UnstructuredData
3435

@@ -137,8 +138,11 @@ def save(
137138
os.path.join(path, manifest_constants.MANIFEST_FILENAME), 'wb'
138139
) as f:
139140
f.write(manifest_proto.SerializeToString())
140-
# Write the manifest version.
141+
142+
# Write the main metadata to detect and parse an Orbax Model. The version file
143+
# should be THE LAST file to be written. It is used to validate the export and
144+
# identify an Orbax Model.
141145
with file_utils.open_file(
142146
os.path.join(path, manifest_constants.MANIFEST_VERSION_FILENAME), 'w'
143147
) as f:
144-
f.write(manifest_constants.MANIFEST_VERSION)
148+
f.write(build_manifest_version_file())

model/orbax/experimental/model/test_utils/simple_orchestration.proto

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ package orbax_model_simple_orchestration;
55
// An orchestration pipeline.
66
message Pipeline {
77
// The name (in manifest) of the pre-processor function.
8-
optional string pre_processor_name = 20;
8+
optional string pre_processor_name = 1;
99

1010
// The name (in manifest) of the model function.
11-
optional string model_function_name = 30;
11+
optional string model_function_name = 2;
1212

1313
// The name (in manifest) of the value representing the weights.
14-
optional string weights_name = 40;
14+
optional string weights_name = 3;
1515

1616
// The name (in manifest) of the post-processor function.
17-
optional string post_processor_name = 50;
17+
optional string post_processor_name = 4;
1818
}

model/orbax/experimental/model/tf2obm/tf_concrete_function_handle.proto

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ message TfConcreteFunctionHandle {
1414
// SavedModel) but we want to allow calling the function with positional
1515
// arguments.
1616
//
17-
repeated string input_names = 20;
17+
repeated string input_names = 2;
1818

1919
// Similar to `input_names`, but for outputs.
20-
repeated string output_names = 30;
20+
repeated string output_names = 3;
2121
}

0 commit comments

Comments
 (0)