@@ -173,48 +173,53 @@ def testFoldBatchNorms(self):
173
173
self .assertNotEqual ("BatchNormWithGlobalNormalization" , node .op )
174
174
175
175
def testFoldFusedBatchNorms (self ):
176
- with self .test_session () as sess :
177
- inputs = [1 , 4 , 2 , 5 , 3 , 6 , - 1 , - 4 , - 2 , - 5 , - 3 , - 6 ]
178
- input_op = constant_op .constant (
179
- np .array (inputs ), shape = [1 , 1 , 6 , 2 ], dtype = dtypes .float32 )
180
- weights = [1 , 2 , 3 , 4 , 0.1 , 0.2 , 0.3 , 0.4 ]
181
- weights_op = constant_op .constant (
182
- np .array (weights ), shape = [1 , 2 , 2 , 2 ], dtype = dtypes .float32 )
183
- conv_op = nn_ops .conv2d (
184
- input_op , weights_op , [1 , 1 , 1 , 1 ], padding = "SAME" , name = "conv_op" )
185
- mean_op = constant_op .constant (
186
- np .array ([10 , 20 ]), shape = [2 ], dtype = dtypes .float32 )
187
- variance_op = constant_op .constant (
188
- np .array ([0.25 , 0.5 ]), shape = [2 ], dtype = dtypes .float32 )
189
- beta_op = constant_op .constant (
190
- np .array ([0.1 , 0.6 ]), shape = [2 ], dtype = dtypes .float32 )
191
- gamma_op = constant_op .constant (
192
- np .array ([1.0 , 2.0 ]), shape = [2 ], dtype = dtypes .float32 )
193
- ops .get_default_graph ().graph_def_versions .producer = 9
194
- gen_nn_ops ._fused_batch_norm (
195
- conv_op ,
196
- gamma_op ,
197
- beta_op ,
198
- mean_op ,
199
- variance_op ,
200
- 0.00001 ,
201
- is_training = False ,
202
- name = "output" )
203
- original_graph_def = sess .graph_def
204
- original_result = sess .run (["output:0" ])
205
- optimized_graph_def = optimize_for_inference_lib .fold_batch_norms (
206
- original_graph_def )
207
-
208
- with self .test_session () as sess :
209
- _ = importer .import_graph_def (
210
- optimized_graph_def , input_map = {}, name = "optimized" )
211
- optimized_result = sess .run (["optimized/output:0" ])
212
-
213
- self .assertAllClose (
214
- original_result , optimized_result , rtol = 1e-04 , atol = 1e-06 )
215
-
216
- for node in optimized_graph_def .node :
217
- self .assertNotEqual ("FusedBatchNorm" , node .op )
176
+ for data_format , use_gpu in [("NHWC" , False ), ("NCHW" , True )]:
177
+ with self .test_session (use_gpu = use_gpu ) as sess :
178
+ inputs = [1 , 4 , 2 , 5 , 3 , 6 , - 1 , - 4 , - 2 , - 5 , - 3 , - 6 ]
179
+ input_op = constant_op .constant (
180
+ np .array (inputs ),
181
+ shape = [1 , 1 , 6 , 2 ] if data_format == "NHWC" else [1 , 2 , 1 , 6 ],
182
+ dtype = dtypes .float32 )
183
+ weights = [1 , 2 , 3 , 4 , 0.1 , 0.2 , 0.3 , 0.4 ]
184
+ weights_op = constant_op .constant (
185
+ np .array (weights ), shape = [1 , 2 , 2 , 2 ], dtype = dtypes .float32 )
186
+ conv_op = nn_ops .conv2d (
187
+ input_op , weights_op , [1 , 1 , 1 , 1 ], padding = "SAME" ,
188
+ data_format = data_format , name = "conv_op" )
189
+ mean_op = constant_op .constant (
190
+ np .array ([10 , 20 ]), shape = [2 ], dtype = dtypes .float32 )
191
+ variance_op = constant_op .constant (
192
+ np .array ([0.25 , 0.5 ]), shape = [2 ], dtype = dtypes .float32 )
193
+ beta_op = constant_op .constant (
194
+ np .array ([0.1 , 0.6 ]), shape = [2 ], dtype = dtypes .float32 )
195
+ gamma_op = constant_op .constant (
196
+ np .array ([1.0 , 2.0 ]), shape = [2 ], dtype = dtypes .float32 )
197
+ ops .get_default_graph ().graph_def_versions .producer = 9
198
+ gen_nn_ops ._fused_batch_norm (
199
+ conv_op ,
200
+ gamma_op ,
201
+ beta_op ,
202
+ mean_op ,
203
+ variance_op ,
204
+ 0.00001 ,
205
+ is_training = False ,
206
+ data_format = data_format ,
207
+ name = "output" )
208
+ original_graph_def = sess .graph_def
209
+ original_result = sess .run (["output:0" ])
210
+ optimized_graph_def = optimize_for_inference_lib .fold_batch_norms (
211
+ original_graph_def )
212
+
213
+ with self .test_session (use_gpu = use_gpu ) as sess :
214
+ _ = importer .import_graph_def (
215
+ optimized_graph_def , input_map = {}, name = "optimized" )
216
+ optimized_result = sess .run (["optimized/output:0" ])
217
+
218
+ self .assertAllClose (
219
+ original_result , optimized_result , rtol = 1e-04 , atol = 1e-06 )
220
+
221
+ for node in optimized_graph_def .node :
222
+ self .assertNotEqual ("FusedBatchNorm" , node .op )
218
223
219
224
def testFuseResizePadAndConv (self ):
220
225
with self .test_session () as sess :
0 commit comments