Skip to content

Commit 9898ceb

Browse files
committed
revise inception load behavior for pt <= 1.6
1 parent 370fdea commit 9898ceb

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

mmgen/core/evaluation/metrics.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def load_inception(inception_args, metric):
5757
inceptoin_type = _inception_args.pop('type', None)
5858

5959
if torch.__version__ < '1.6.0':
60+
# reset inception_args for FID (Inception for IS do not use
61+
# inception_args)
62+
if metric == 'FID':
63+
_inception_args = dict(normalize_input=False)
64+
6065
mmcv.print_log(
6166
'Current Pytorch Version not support script module, load '
6267
'Inception Model from torch model zoo. If you want to use '

tests/test_cores/test_metrics.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,12 @@ def test_inception_download():
4040
with pytest.raises(TypeError):
4141
load_inception(args_empty, 'FID')
4242

43-
args_error_path = dict(type='StyleGAN', inception_path='error-path')
44-
with pytest.raises(RuntimeError):
45-
load_inception(args_error_path, 'FID')
43+
# pt lower than this version cannot load Tero's inception and direct use
44+
# torch ones, only test this for pt >= 1.6
45+
if torch.__version__ >= '1.6.0':
46+
args_error_path = dict(type='StyleGAN', inception_path='error-path')
47+
with pytest.raises(RuntimeError):
48+
load_inception(args_error_path, 'FID')
4649

4750
with pytest.raises(AssertionError):
4851
load_inception(dict(type='pytorch', normalize_input=False), 'PPL')

0 commit comments

Comments
 (0)