12
12
import torch
13
13
import torch .fx as fx
14
14
15
- import torchdynamo
16
- from torchdynamo import config
17
- from torchdynamo .optimizations .backends import register_backend
18
- from torchdynamo .utils import clone_inputs
15
+ from . import config
16
+ from .optimizations .backends import register_backend
17
+ from .utils import clone_inputs
19
18
20
19
log = logging .getLogger (__name__ )
21
20
@@ -132,11 +131,11 @@ def _cuda_system_info_comment():
132
131
133
132
def generate_compiler_repro_string (gm , args ):
134
133
model_str = textwrap .dedent (
135
- """
134
+ f """
136
135
import torch
137
136
from torch import tensor, device
138
137
import torch.fx as fx
139
- from torchdynamo .testing import rand_strided
138
+ from { config . dynamo_import } .testing import rand_strided
140
139
from math import inf
141
140
from torch.fx.experimental.proxy_tensor import make_fx
142
141
@@ -180,7 +179,7 @@ def dump_compiler_graph_state(gm, args, compiler_name):
180
179
print (f"Writing checkpoint with { len (gm .graph .nodes )} nodes to { file_name } " )
181
180
with open (file_name , "w" ) as fd :
182
181
save_graph_repro (fd , gm , args , compiler_name )
183
- repro_path = os .path .join (torchdynamo . config .base_dir , "repro.py" )
182
+ repro_path = os .path .join (config .base_dir , "repro.py" )
184
183
shutil .copyfile (file_name , repro_path )
185
184
186
185
@@ -210,7 +209,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None):
210
209
fd .write (
211
210
textwrap .dedent (
212
211
f"""
213
- from torchdynamo.debug_utils import { fail_fn }
212
+ from { __name__ } import { fail_fn }
214
213
"""
215
214
)
216
215
)
@@ -290,7 +289,7 @@ def helper_for_dump_minify(contents):
290
289
log .exception (e )
291
290
raise NotImplementedError ("Could not write to {minified_repro_path}" )
292
291
293
- local_path = os .path .join (torchdynamo . config .base_dir , "minifier_launcher.py" )
292
+ local_path = os .path .join (config .base_dir , "minifier_launcher.py" )
294
293
try :
295
294
shutil .copyfile (minified_repro_path , local_path )
296
295
log .warning (
@@ -308,7 +307,7 @@ def dump_to_minify(gm, args, compiler_name: str):
308
307
{ generate_compiler_repro_string (gm , args )}
309
308
310
309
from functools import partial
311
- from torchdynamo.debug_utils import (
310
+ from { __name__ } import (
312
311
isolate_fails,
313
312
dump_compiler_graph_state,
314
313
)
@@ -385,9 +384,9 @@ def run_fwd_maybe_bwd(gm, args):
385
384
"""
386
385
Runs a forward and possibly backward iteration for a given mod and args.
387
386
"""
388
- from torchdynamo .testing import collect_results
389
- from torchdynamo .testing import reduce_to_scalar_loss
390
- from torchdynamo .testing import requires_bwd_pass
387
+ from .testing import collect_results
388
+ from .testing import reduce_to_scalar_loss
389
+ from .testing import requires_bwd_pass
391
390
392
391
gm = copy .deepcopy (gm )
393
392
args = clone_inputs (args )
@@ -406,7 +405,7 @@ def same_two_models(gm, opt_gm, example_inputs):
406
405
"""
407
406
Check two models have same accuracy.
408
407
"""
409
- from torchdynamo .utils import same
408
+ from .utils import same
410
409
411
410
ref = run_fwd_maybe_bwd (gm , example_inputs )
412
411
@@ -447,21 +446,21 @@ def generate_dynamo_fx_repro_string(model_str, args, compiler_name):
447
446
448
447
return textwrap .dedent (
449
448
f"""
449
+ from math import inf
450
450
import torch
451
- import torchdynamo
452
451
from torch import tensor, device
453
452
import torch.fx as fx
454
- from torchdynamo.testing import rand_strided
455
- from math import inf
456
- from torchdynamo .debug_utils import run_fwd_maybe_bwd
453
+ import { config . dynamo_import }
454
+ from { config . dynamo_import } .testing import rand_strided
455
+ from { config . dynamo_import } .debug_utils import run_fwd_maybe_bwd
457
456
458
457
args = { [(tuple (a .shape ), tuple (a .stride ()), a .dtype , a .device .type , a .requires_grad ) for a in args ]}
459
458
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
460
459
461
460
{ model_str }
462
461
463
462
mod = Repro().cuda()
464
- opt_mod = torchdynamo .optimize("{ compiler_name } ")(mod)
463
+ opt_mod = { config . dynamo_import } .optimize("{ compiler_name } ")(mod)
465
464
466
465
with torch.cuda.amp.autocast(enabled={ torch .is_autocast_enabled ()} ):
467
466
ref = run_fwd_maybe_bwd(mod, args)
@@ -487,7 +486,7 @@ def dump_backend_repro_as_file(gm, args, compiler_name):
487
486
log .warning (f"Copying { file_name } to { latest_repro } for convenience" )
488
487
shutil .copyfile (file_name , latest_repro )
489
488
490
- local_path = os .path .join (torchdynamo . config .base_dir , "repro.py" )
489
+ local_path = os .path .join (config .base_dir , "repro.py" )
491
490
try :
492
491
shutil .copyfile (file_name , local_path )
493
492
log .warning (
@@ -542,11 +541,11 @@ def dump_backend_repro_as_file(gm, args, compiler_name):
542
541
# )
543
542
# )
544
543
545
- # local_dir = os.path.join(torchdynamo. config.base_dir, "repro")
544
+ # local_dir = os.path.join(config.base_dir, "repro")
546
545
# if os.path.exists(local_dir):
547
546
# shutil.rmtree(local_dir)
548
547
# shutil.copytree(tmp_dir, local_dir)
549
- # local_tar_file = os.path.join(torchdynamo. config.base_dir, "repro.tar.gz")
548
+ # local_tar_file = os.path.join(config.base_dir, "repro.tar.gz")
550
549
# print(f"Writing checkpoint with {len(gm.graph.nodes)} locally to {local_tar_file}")
551
550
# with tarfile.open(local_tar_file, "w:gz") as tar:
552
551
# tar.add(local_dir, arcname=os.path.basename(local_dir))
@@ -595,18 +594,18 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name):
595
594
596
595
contents = textwrap .dedent (
597
596
f"""
597
+ import os
598
+ from math import inf
598
599
import torch
599
- import torchdynamo
600
600
from torch import tensor, device
601
601
import torch.fx as fx
602
- from torchdynamo.testing import rand_strided
603
- from math import inf
604
- from torchdynamo.debug_utils import run_fwd_maybe_bwd
605
- from torchdynamo.optimizations.backends import BACKENDS
606
602
import functools
607
- import os
603
+ import { config .dynamo_import }
604
+ from { config .dynamo_import } .debug_utils import run_fwd_maybe_bwd
605
+ from { config .dynamo_import } .optimizations.backends import BACKENDS
606
+ from { config .dynamo_import } .testing import rand_strided
608
607
609
- torchdynamo .config.repro_dir = \" { minifier_dir ()} \"
608
+ { config . dynamo_import } .config.repro_dir = \" { minifier_dir ()} \"
610
609
611
610
args = { [(tuple (a .shape ), tuple (a .stride ()), a .dtype , a .device .type , a .requires_grad ) for a in args ]}
612
611
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
@@ -620,7 +619,7 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name):
620
619
compiler_fn,
621
620
compiler_name="{ compiler_name } ",
622
621
)
623
- opt_mod = torchdynamo .optimize(dynamo_minifier_backend)(mod)
622
+ opt_mod = { config . dynamo_import } .optimize(dynamo_minifier_backend)(mod)
624
623
625
624
opt_mod(*args)
626
625
"""
@@ -678,7 +677,7 @@ def debug_wrapper(gm, example_inputs, **kwargs):
678
677
def dynamo_minifier_backend (gm , example_inputs , compiler_name ):
679
678
from functorch .compile import minifier
680
679
681
- from torchdynamo .optimizations .backends import BACKENDS
680
+ from .optimizations .backends import BACKENDS
682
681
683
682
if compiler_name == "inductor" :
684
683
from torchinductor .compile_fx import compile_fx
0 commit comments