@@ -56,26 +56,23 @@ def test_output_signature_raises_error_without_calling_prepare(self):
5656 _ = processor .output_signature
5757
5858 def test_prepare_fails_with_multiple_calls (self ):
59- processor = tf_data_processor .TfDataProcessor (lambda x : x )
59+ processor = tf_data_processor .TfDataProcessor (lambda x : x , name = 'add' )
6060 processor .prepare (
61- 'add' ,
62- input_signature = (tf .TensorSpec ([None , 3 ], tf .float32 ),),
61+ (tf .TensorSpec ([None , 3 ], tf .float32 ),),
6362 )
6463 with self .assertRaisesWithLiteralMatch (
6564 RuntimeError , '`prepare()` can only be called once.'
6665 ):
6766 processor .prepare (
68- 'add' ,
69- input_signature = (tf .TensorSpec ([None , 3 ], tf .float32 ),),
67+ (tf .TensorSpec ([None , 3 ], tf .float32 ),),
7068 )
7169
7270 def test_prepare_succeeds (self ):
7371 processor = tf_data_processor .TfDataProcessor (
74- tf .function (lambda x , y : x + y )
72+ tf .function (lambda x , y : x + y ), name = 'add'
7573 )
7674 processor .prepare (
77- 'add' ,
78- input_signature = (
75+ (
7976 tf .TensorSpec ([None , 3 ], tf .float64 ),
8077 tf .TensorSpec ([None , 3 ], tf .float64 ),
8178 ),
@@ -107,10 +104,11 @@ def test_prepare_polymorphic_function_with_default_input_signature(self):
107104 def preprocessor_callable (x , y ):
108105 return x + y
109106
110- processor = tf_data_processor .TfDataProcessor (preprocessor_callable )
107+ processor = tf_data_processor .TfDataProcessor (
108+ preprocessor_callable , name = 'add'
109+ )
111110 processor .prepare (
112- 'add' ,
113- input_signature = (
111+ (
114112 tf .TensorSpec ([None , 3 ], tf .float32 ),
115113 tf .TensorSpec ([None , 3 ], tf .float32 ),
116114 ),
@@ -136,25 +134,27 @@ def test_suppress_x64_output(self):
136134 processor = tf_data_processor .TfDataProcessor (
137135 tf .function (
138136 lambda x , y : tf .cast (x , tf .float64 ) + tf .cast (y , tf .float64 )
139- )
137+ ),
138+ name = 'add_f64' ,
140139 )
141140 input_signature = (
142141 tf .TensorSpec ([None , 3 ], tf .float32 ),
143142 tf .TensorSpec ([None , 3 ], tf .float32 ),
144143 )
145144
146145 # With suppress_x64_output=True, f64 output is suppressed to f32.
147- processor .prepare ('add_f64' , input_signature , suppress_x64_output = True )
146+ processor .prepare (input_signature , suppress_x64_output = True )
148147 self .assertEqual (
149148 processor .output_signature ,
150149 obm .ShloTensorSpec (shape = (None , 3 ), dtype = obm .ShloDType .f32 ),
151150 )
152151
153152 def test_convert_to_bfloat16 (self ):
154- processor = tf_data_processor .TfDataProcessor (lambda x : 0.5 + x )
153+ processor = tf_data_processor .TfDataProcessor (
154+ lambda x : 0.5 + x , name = 'preprocessor'
155+ )
155156 processor .prepare (
156- 'preprocessor' ,
157- input_signature = (tf .TensorSpec ((), tf .float32 )),
157+ (tf .TensorSpec ((), tf .float32 )),
158158 bfloat16_options = converter_options_v2_pb2 .ConverterOptionsV2 (
159159 bfloat16_optimization_options = converter_options_v2_pb2 .BFloat16OptimizationOptions (
160160 scope = converter_options_v2_pb2 .BFloat16OptimizationOptions .ALL ,
@@ -168,15 +168,16 @@ def test_convert_to_bfloat16(self):
168168 )
169169
170170 def test_bfloat16_convert_error (self ):
171- processor = tf_data_processor .TfDataProcessor (lambda x : 0.5 + x )
171+ processor = tf_data_processor .TfDataProcessor (
172+ lambda x : 0.5 + x , name = 'preprocessor'
173+ )
172174 with self .assertRaisesRegex (
173175 google_error .StatusNotOk ,
174176 'Found bfloat16 ops in the model. The model may have been converted'
175177 ' before. It should not be converted again.' ,
176178 ):
177179 processor .prepare (
178- 'preprocessor' ,
179- input_signature = (tf .TensorSpec ((), tf .bfloat16 )),
180+ (tf .TensorSpec ((), tf .bfloat16 )),
180181 bfloat16_options = converter_options_v2_pb2 .ConverterOptionsV2 (
181182 bfloat16_optimization_options = converter_options_v2_pb2 .BFloat16OptimizationOptions (
182183 scope = converter_options_v2_pb2 .BFloat16OptimizationOptions .ALL ,
@@ -185,12 +186,9 @@ def test_bfloat16_convert_error(self):
185186 )
186187
187188 def test_prepare_with_shlo_bf16_inputs (self ):
188- processor = tf_data_processor .TfDataProcessor (lambda x : x )
189+ processor = tf_data_processor .TfDataProcessor (lambda x : x , name = 'identity' )
189190 processor .prepare (
190- 'identity' ,
191- input_signature = (
192- obm .ShloTensorSpec (shape = (1 ,), dtype = obm .ShloDType .bf16 ),
193- ),
191+ (obm .ShloTensorSpec (shape = (1 ,), dtype = obm .ShloDType .bf16 ),),
194192 )
195193 self .assertEqual (
196194 processor .concrete_function .structured_input_signature [0 ][0 ].dtype ,
0 commit comments