Skip to content

Commit 6afe900

Browse files
yegordrmlarsen
authored andcommitted
optimize_for_inference_lib.fold_batch_norms() preserves data_format (tensorflow#16075)
Fixes tensorflow#15034
1 parent c24e3dd commit 6afe900

File tree

2 files changed

+48
-42
lines changed

2 files changed

+48
-42
lines changed

tensorflow/python/tools/optimize_for_inference_lib.py

+1
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def fold_batch_norms(input_graph_def):
349349
bias_add_op.op = "BiasAdd"
350350
bias_add_op.name = node.name
351351
bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"])
352+
bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"])
352353
bias_add_op.input.extend([new_conv_op.name, offset_op.name])
353354
new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op])
354355

tensorflow/python/tools/optimize_for_inference_test.py

+47-42
Original file line numberDiff line numberDiff line change
@@ -173,48 +173,53 @@ def testFoldBatchNorms(self):
173173
self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
174174

175175
def testFoldFusedBatchNorms(self):
176-
with self.test_session() as sess:
177-
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
178-
input_op = constant_op.constant(
179-
np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32)
180-
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
181-
weights_op = constant_op.constant(
182-
np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
183-
conv_op = nn_ops.conv2d(
184-
input_op, weights_op, [1, 1, 1, 1], padding="SAME", name="conv_op")
185-
mean_op = constant_op.constant(
186-
np.array([10, 20]), shape=[2], dtype=dtypes.float32)
187-
variance_op = constant_op.constant(
188-
np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
189-
beta_op = constant_op.constant(
190-
np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
191-
gamma_op = constant_op.constant(
192-
np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
193-
ops.get_default_graph().graph_def_versions.producer = 9
194-
gen_nn_ops._fused_batch_norm(
195-
conv_op,
196-
gamma_op,
197-
beta_op,
198-
mean_op,
199-
variance_op,
200-
0.00001,
201-
is_training=False,
202-
name="output")
203-
original_graph_def = sess.graph_def
204-
original_result = sess.run(["output:0"])
205-
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
206-
original_graph_def)
207-
208-
with self.test_session() as sess:
209-
_ = importer.import_graph_def(
210-
optimized_graph_def, input_map={}, name="optimized")
211-
optimized_result = sess.run(["optimized/output:0"])
212-
213-
self.assertAllClose(
214-
original_result, optimized_result, rtol=1e-04, atol=1e-06)
215-
216-
for node in optimized_graph_def.node:
217-
self.assertNotEqual("FusedBatchNorm", node.op)
176+
for data_format, use_gpu in [("NHWC", False), ("NCHW", True)]:
177+
with self.test_session(use_gpu=use_gpu) as sess:
178+
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
179+
input_op = constant_op.constant(
180+
np.array(inputs),
181+
shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
182+
dtype=dtypes.float32)
183+
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
184+
weights_op = constant_op.constant(
185+
np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
186+
conv_op = nn_ops.conv2d(
187+
input_op, weights_op, [1, 1, 1, 1], padding="SAME",
188+
data_format=data_format, name="conv_op")
189+
mean_op = constant_op.constant(
190+
np.array([10, 20]), shape=[2], dtype=dtypes.float32)
191+
variance_op = constant_op.constant(
192+
np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
193+
beta_op = constant_op.constant(
194+
np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
195+
gamma_op = constant_op.constant(
196+
np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
197+
ops.get_default_graph().graph_def_versions.producer = 9
198+
gen_nn_ops._fused_batch_norm(
199+
conv_op,
200+
gamma_op,
201+
beta_op,
202+
mean_op,
203+
variance_op,
204+
0.00001,
205+
is_training=False,
206+
data_format=data_format,
207+
name="output")
208+
original_graph_def = sess.graph_def
209+
original_result = sess.run(["output:0"])
210+
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
211+
original_graph_def)
212+
213+
with self.test_session(use_gpu=use_gpu) as sess:
214+
_ = importer.import_graph_def(
215+
optimized_graph_def, input_map={}, name="optimized")
216+
optimized_result = sess.run(["optimized/output:0"])
217+
218+
self.assertAllClose(
219+
original_result, optimized_result, rtol=1e-04, atol=1e-06)
220+
221+
for node in optimized_graph_def.node:
222+
self.assertNotEqual("FusedBatchNorm", node.op)
218223

219224
def testFuseResizePadAndConv(self):
220225
with self.test_session() as sess:

0 commit comments

Comments
 (0)