Skip to content

Commit 3e9e8f6

Browse files
authored
Small fixes to torch.compile tutorial (pytorch#2601)
* small fixes to torch.compile tutorial * Update intermediate_source/torch_compile_tutorial.py
1 parent 5003542 commit 3e9e8f6

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

intermediate_source/torch_compile_tutorial.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -138,21 +138,18 @@ def init_model():
138138
# Note that in the call to ``torch.compile``, we have have the additional
139139
# ``mode`` argument, which we will discuss below.
140140

141-
def evaluate(mod, inp):
142-
return mod(inp)
143-
144141
model = init_model()
145142

146143
# Reset since we are using a different mode.
147144
import torch._dynamo
148145
torch._dynamo.reset()
149146

150-
evaluate_opt = torch.compile(evaluate, mode="reduce-overhead")
147+
model_opt = torch.compile(model, mode="reduce-overhead")
151148

152149
inp = generate_data(16)[0]
153150
with torch.no_grad():
154-
print("eager:", timed(lambda: evaluate(model, inp))[1])
155-
print("compile:", timed(lambda: evaluate_opt(model, inp))[1])
151+
print("eager:", timed(lambda: model(inp))[1])
152+
print("compile:", timed(lambda: model_opt(inp))[1])
156153

157154
######################################################################
158155
# Notice that ``torch.compile`` takes a lot longer to complete
@@ -166,7 +163,7 @@ def evaluate(mod, inp):
166163
for i in range(N_ITERS):
167164
inp = generate_data(16)[0]
168165
with torch.no_grad():
169-
_, eager_time = timed(lambda: evaluate(model, inp))
166+
_, eager_time = timed(lambda: model(inp))
170167
eager_times.append(eager_time)
171168
print(f"eager eval time {i}: {eager_time}")
172169

@@ -176,7 +173,7 @@ def evaluate(mod, inp):
176173
for i in range(N_ITERS):
177174
inp = generate_data(16)[0]
178175
with torch.no_grad():
179-
_, compile_time = timed(lambda: evaluate_opt(model, inp))
176+
_, compile_time = timed(lambda: model_opt(inp))
180177
compile_times.append(compile_time)
181178
print(f"compile eval time {i}: {compile_time}")
182179
print("~" * 10)
@@ -250,6 +247,10 @@ def train(mod, data):
250247
# Again, we can see that ``torch.compile`` takes longer in the first
251248
# iteration, as it must compile the model, but in subsequent iterations, we see
252249
# significant speedups compared to eager.
250+
#
251+
# We remark that the speedup numbers presented in this tutorial are for
252+
# demonstration purposes only. Official speedup values can be seen at the
253+
# `TorchInductor performance dashboard <https://hud.pytorch.org/benchmark/compilers>`__.
253254

254255
######################################################################
255256
# Comparison to TorchScript and FX Tracing

0 commit comments

Comments
 (0)