Skip to content

Commit 4ecb95d

Browse files
authored
Refactor torchdynamo imports to be relative (#1494)
1 parent a83c21a commit 4ecb95d

31 files changed

+167
-620
lines changed

.flake8

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[flake8]
2+
exclude = minifier_launcher.py,repro.py
23
ignore = E203,W503,C101,C4,EXE,B,Y
34
max-line-length = 120

Makefile

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
.PHONY: default develop test torchbench format lint setup clean
22

33
PY_FILES := $(wildcard *.py) $(wildcard torchdynamo/*.py) $(wildcard torchdynamo/*/*.py) \
4-
$(wildcard test/*.py) $(wildcard torchinductor/*.py) $(wildcard torchinductor/*/*.py) \
5-
$(wildcard benchmarks/*.py) $(wildcard benchmarks/*/*.py) $(wildcard .circleci/*.py)
4+
$(wildcard torchinductor/*.py) $(wildcard torchinductor/*/*.py) \
5+
$(wildcard benchmarks/*.py) $(wildcard benchmarks/*/*.py) \
6+
$(wildcard test/*.py) $(wildcard test/*/*.py) \
7+
$(wildcard .circleci/*.py)
68
C_FILES := $(wildcard torchdynamo/*.c torchdynamo/*.cpp)
79
CLANG_TIDY ?= clang-tidy-10
810
CLANG_FORMAT ?= clang-format-10
@@ -74,7 +76,6 @@ clone-deps:
7476
&& (test -e pytorch || git clone --recursive https://github.com/pytorch/pytorch pytorch) \
7577
&& (test -e torchvision || git clone --recursive https://github.com/pytorch/vision torchvision) \
7678
&& (test -e torchtext || git clone --recursive https://github.com/pytorch/text torchtext) \
77-
&& (test -e torchaudio || git clone --recursive https://github.com/pytorch/audio torchaudio) \
7879
&& (test -e detectron2 || git clone --recursive https://github.com/facebookresearch/detectron2) \
7980
&& (test -e torchbenchmark || git clone --recursive https://github.com/pytorch/benchmark torchbenchmark) \
8081
&& (test -e triton || git clone --recursive https://github.com/openai/triton.git) \
@@ -84,7 +85,6 @@ pull-deps:
8485
(cd ../pytorch && git pull && git submodule update --init --recursive)
8586
(cd ../torchvision && git pull && git submodule update --init --recursive)
8687
(cd ../torchtext && git pull && git submodule update --init --recursive)
87-
(cd ../torchaudio && git pull && git submodule update --init --recursive)
8888
(cd ../detectron2 && git pull && git submodule update --init --recursive)
8989
(cd ../torchbenchmark && git pull && git submodule update --init --recursive)
9090
(cd ../triton && git checkout master && git pull && git checkout $(TRITON_VERSION) && git submodule update --init --recursive)
@@ -102,7 +102,6 @@ build-deps: clone-deps
102102
(cd ../pytorch && python setup.py clean && python setup.py develop)
103103
(cd ../torchvision && python setup.py clean && python setup.py develop)
104104
(cd ../torchtext && python setup.py clean && python setup.py develop)
105-
(cd ../torchaudio && python setup.py clean && python setup.py develop)
106105
(cd ../detectron2 && python setup.py clean && python setup.py develop)
107106
(cd ../torchbenchmark && python install.py --continue_on_fail)
108107
(cd ../triton/python && python setup.py clean && python setup.py develop)

test/mock_modules/mock_module2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ def method2(self, x):
1414

1515

1616
def method1(x, y):
17-
z = torch.ones(1, 1)
17+
z = torch.ones(1, 1) # noqa
1818
x.append(y)
1919
return x

test/mock_modules/mock_module3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33

44
def method1(x, y):
5-
z = torch.ones(1, 1)
5+
z = torch.ones(1, 1) # noqa
66
x.append(y)
77
return x

torchdynamo/codegen.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
import torch.nn
99

10-
import torchdynamo
11-
1210
from .bytecode_transformation import Instruction
1311
from .bytecode_transformation import create_instruction
1412
from .exc import unimplemented
@@ -42,7 +40,7 @@ class PyCodegen(object):
4240

4341
def __init__(
4442
self,
45-
tx: "torchdynamo.symbolic_convert.InstructionTranslator" = None,
43+
tx=None,
4644
root: torch.nn.Module = None,
4745
graph_output_var: str = None,
4846
tempvars=None,

torchdynamo/config.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
import torch
99

10-
import torchdynamo.utils
11-
1210
try:
1311
import torch._prims
1412
import torch._refs
@@ -141,21 +139,28 @@
141139
# to allow DDP comm/compute overlap
142140
optimize_ddp = False
143141

144-
145142
# If True, raises exception if TorchDynamo is called with a context manager
146143
raise_on_ctx_manager_usage = True
147144

148145
# If True, raise when aot autograd is unsafe to use
149146
raise_on_unsafe_aot_autograd = False
150147

148+
# How to import torchdynamo, either torchdynamo or torch.dynamo
149+
dynamo_import = __name__.replace(".config", "")
150+
151+
# How to import torchinductor, either torchinductor or torch.inductor
152+
inductor_import = dynamo_import.replace("dynamo", "inductor")
153+
151154

152155
class _AccessLimitingConfig(ModuleType):
153156
def __setattr__(self, name, value):
154157
if name not in _allowed_config_names:
155158
raise AttributeError(f"{__name__}.{name} does not exist")
156159
# automatically set logger level whenever config.log_level is modified
157160
if name == "log_level":
158-
torchdynamo.utils.set_loggers_level(value)
161+
from .utils import set_loggers_level
162+
163+
set_loggers_level(value)
159164
return object.__setattr__(self, name, value)
160165

161166

torchdynamo/convert_frame.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def replay_record_msg():
199199
and record_filename is not None
200200
):
201201
return f"\nLast frame execution written to {record_filename}. To run only this frame while debugging, run\
202-
torchdynamo.replay('{record_filename}').\n"
202+
{config.dynamo_import}.replay('{record_filename}').\n"
203203
else:
204204
return ""
205205

@@ -238,7 +238,9 @@ def replay_record_msg():
238238

239239
msg += replay_record_msg()
240240

241-
msg += "\nSet torchdynamo.config.verbose=True for more information\n"
241+
msg += (
242+
f"\nSet {config.dynamo_import}.config.verbose=True for more information\n"
243+
)
242244
msg += "=" * 10
243245
return msg
244246

torchdynamo/debug_utils.py

+30-31
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
import torch
1313
import torch.fx as fx
1414

15-
import torchdynamo
16-
from torchdynamo import config
17-
from torchdynamo.optimizations.backends import register_backend
18-
from torchdynamo.utils import clone_inputs
15+
from . import config
16+
from .optimizations.backends import register_backend
17+
from .utils import clone_inputs
1918

2019
log = logging.getLogger(__name__)
2120

@@ -132,11 +131,11 @@ def _cuda_system_info_comment():
132131

133132
def generate_compiler_repro_string(gm, args):
134133
model_str = textwrap.dedent(
135-
"""
134+
f"""
136135
import torch
137136
from torch import tensor, device
138137
import torch.fx as fx
139-
from torchdynamo.testing import rand_strided
138+
from {config.dynamo_import}.testing import rand_strided
140139
from math import inf
141140
from torch.fx.experimental.proxy_tensor import make_fx
142141
@@ -180,7 +179,7 @@ def dump_compiler_graph_state(gm, args, compiler_name):
180179
print(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
181180
with open(file_name, "w") as fd:
182181
save_graph_repro(fd, gm, args, compiler_name)
183-
repro_path = os.path.join(torchdynamo.config.base_dir, "repro.py")
182+
repro_path = os.path.join(config.base_dir, "repro.py")
184183
shutil.copyfile(file_name, repro_path)
185184

186185

@@ -210,7 +209,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None):
210209
fd.write(
211210
textwrap.dedent(
212211
f"""
213-
from torchdynamo.debug_utils import {fail_fn}
212+
from {__name__} import {fail_fn}
214213
"""
215214
)
216215
)
@@ -290,7 +289,7 @@ def helper_for_dump_minify(contents):
290289
log.exception(e)
291290
raise NotImplementedError("Could not write to {minified_repro_path}")
292291

293-
local_path = os.path.join(torchdynamo.config.base_dir, "minifier_launcher.py")
292+
local_path = os.path.join(config.base_dir, "minifier_launcher.py")
294293
try:
295294
shutil.copyfile(minified_repro_path, local_path)
296295
log.warning(
@@ -308,7 +307,7 @@ def dump_to_minify(gm, args, compiler_name: str):
308307
{generate_compiler_repro_string(gm, args)}
309308
310309
from functools import partial
311-
from torchdynamo.debug_utils import (
310+
from {__name__} import (
312311
isolate_fails,
313312
dump_compiler_graph_state,
314313
)
@@ -385,9 +384,9 @@ def run_fwd_maybe_bwd(gm, args):
385384
"""
386385
Runs a forward and possibly backward iteration for a given mod and args.
387386
"""
388-
from torchdynamo.testing import collect_results
389-
from torchdynamo.testing import reduce_to_scalar_loss
390-
from torchdynamo.testing import requires_bwd_pass
387+
from .testing import collect_results
388+
from .testing import reduce_to_scalar_loss
389+
from .testing import requires_bwd_pass
391390

392391
gm = copy.deepcopy(gm)
393392
args = clone_inputs(args)
@@ -406,7 +405,7 @@ def same_two_models(gm, opt_gm, example_inputs):
406405
"""
407406
Check two models have same accuracy.
408407
"""
409-
from torchdynamo.utils import same
408+
from .utils import same
410409

411410
ref = run_fwd_maybe_bwd(gm, example_inputs)
412411

@@ -447,21 +446,21 @@ def generate_dynamo_fx_repro_string(model_str, args, compiler_name):
447446

448447
return textwrap.dedent(
449448
f"""
449+
from math import inf
450450
import torch
451-
import torchdynamo
452451
from torch import tensor, device
453452
import torch.fx as fx
454-
from torchdynamo.testing import rand_strided
455-
from math import inf
456-
from torchdynamo.debug_utils import run_fwd_maybe_bwd
453+
import {config.dynamo_import}
454+
from {config.dynamo_import}.testing import rand_strided
455+
from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd
457456
458457
args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]}
459458
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
460459
461460
{model_str}
462461
463462
mod = Repro().cuda()
464-
opt_mod = torchdynamo.optimize("{compiler_name}")(mod)
463+
opt_mod = {config.dynamo_import}.optimize("{compiler_name}")(mod)
465464
466465
with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}):
467466
ref = run_fwd_maybe_bwd(mod, args)
@@ -487,7 +486,7 @@ def dump_backend_repro_as_file(gm, args, compiler_name):
487486
log.warning(f"Copying {file_name} to {latest_repro} for convenience")
488487
shutil.copyfile(file_name, latest_repro)
489488

490-
local_path = os.path.join(torchdynamo.config.base_dir, "repro.py")
489+
local_path = os.path.join(config.base_dir, "repro.py")
491490
try:
492491
shutil.copyfile(file_name, local_path)
493492
log.warning(
@@ -542,11 +541,11 @@ def dump_backend_repro_as_file(gm, args, compiler_name):
542541
# )
543542
# )
544543

545-
# local_dir = os.path.join(torchdynamo.config.base_dir, "repro")
544+
# local_dir = os.path.join(config.base_dir, "repro")
546545
# if os.path.exists(local_dir):
547546
# shutil.rmtree(local_dir)
548547
# shutil.copytree(tmp_dir, local_dir)
549-
# local_tar_file = os.path.join(torchdynamo.config.base_dir, "repro.tar.gz")
548+
# local_tar_file = os.path.join(config.base_dir, "repro.tar.gz")
550549
# print(f"Writing checkpoint with {len(gm.graph.nodes)} locally to {local_tar_file}")
551550
# with tarfile.open(local_tar_file, "w:gz") as tar:
552551
# tar.add(local_dir, arcname=os.path.basename(local_dir))
@@ -595,18 +594,18 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name):
595594

596595
contents = textwrap.dedent(
597596
f"""
597+
import os
598+
from math import inf
598599
import torch
599-
import torchdynamo
600600
from torch import tensor, device
601601
import torch.fx as fx
602-
from torchdynamo.testing import rand_strided
603-
from math import inf
604-
from torchdynamo.debug_utils import run_fwd_maybe_bwd
605-
from torchdynamo.optimizations.backends import BACKENDS
606602
import functools
607-
import os
603+
import {config.dynamo_import}
604+
from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd
605+
from {config.dynamo_import}.optimizations.backends import BACKENDS
606+
from {config.dynamo_import}.testing import rand_strided
608607
609-
torchdynamo.config.repro_dir = \"{minifier_dir()}\"
608+
{config.dynamo_import}.config.repro_dir = \"{minifier_dir()}\"
610609
611610
args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]}
612611
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
@@ -620,7 +619,7 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name):
620619
compiler_fn,
621620
compiler_name="{compiler_name}",
622621
)
623-
opt_mod = torchdynamo.optimize(dynamo_minifier_backend)(mod)
622+
opt_mod = {config.dynamo_import}.optimize(dynamo_minifier_backend)(mod)
624623
625624
opt_mod(*args)
626625
"""
@@ -678,7 +677,7 @@ def debug_wrapper(gm, example_inputs, **kwargs):
678677
def dynamo_minifier_backend(gm, example_inputs, compiler_name):
679678
from functorch.compile import minifier
680679

681-
from torchdynamo.optimizations.backends import BACKENDS
680+
from .optimizations.backends import BACKENDS
682681

683682
if compiler_name == "inductor":
684683
from torchinductor.compile_fx import compile_fx

torchdynamo/eval_frame.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,17 @@
1515
from torch.fx.experimental.proxy_tensor import make_fx
1616
from torch.nn.parallel.distributed import DistributedDataParallel
1717

18-
import torchdynamo
19-
from torchdynamo.optimizations.distributed import DDPOptimizer
20-
from torchdynamo.utils import checkpoint_params
21-
from torchdynamo.utils import clone_inputs
22-
from torchdynamo.utils import compile_times
23-
from torchdynamo.utils import same
24-
2518
from . import config
2619
from . import convert_frame
2720
from . import skipfiles
2821
from . import utils
2922
from .exc import ResetRequired
3023
from .mutation_guard import install_generation_tagging_init
24+
from .optimizations.distributed import DDPOptimizer
25+
from .utils import checkpoint_params
26+
from .utils import clone_inputs
27+
from .utils import compile_times
28+
from .utils import same
3129

3230
log = logging.getLogger(__name__)
3331

@@ -65,7 +63,7 @@ def remove_from_cache(f):
6563
elif hasattr(getattr(f, "forward", None), "__code__"):
6664
reset_code(f.forward.__code__)
6765
else:
68-
from torchdynamo import reset
66+
from . import reset
6967

7068
reset()
7169
log.warning("could not determine __code__ for %s", f)
@@ -301,7 +299,7 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs):
301299

302300

303301
def get_compiler_fn(compiler_fn):
304-
from torchdynamo.debug_utils import wrap_backend_debug
302+
from .debug_utils import wrap_backend_debug
305303

306304
"""Expand backend strings to functions"""
307305
compiler_str = compiler_fn if isinstance(compiler_fn, str) else None
@@ -368,7 +366,9 @@ def toy_example(a, b):
368366
@patch("torchdynamo.symbolic_convert.explain", True)
369367
def explain(f, *args, **kwargs):
370368
# TODO(voz): Do we want a decorator for this?
371-
torchdynamo.reset()
369+
from . import reset
370+
371+
reset()
372372

373373
out_guards = []
374374
graphs = []
@@ -428,7 +428,7 @@ def guard_export_print(guards):
428428
explanation += compile_times()
429429

430430
# TODO(voz): Do we want a decorator for this?
431-
torchdynamo.reset()
431+
reset()
432432
return explanation, out_guards, graphs, ops_per_graph, break_reasons
433433

434434

torchdynamo/exc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import textwrap
44

5-
from torchdynamo.utils import counters
5+
from .utils import counters
66

77

88
class TorchDynamoException(RuntimeError):

torchdynamo/guards.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
import numpy as np
2121
import torch
2222

23-
from torchdynamo import convert_frame
24-
2523
from . import config
24+
from . import convert_frame
2625
from . import mutation_guard
2726
from ._guards import TensorGuards
2827
from ._guards import check_obj_id

0 commit comments

Comments
 (0)