@@ -32,6 +32,7 @@ def test_inception_download():
3232 model , style = load_inception (inception_args , metric )
3333
3434 if torch .__version__ < '1.6.0' :
35+ print (inception_args , metric , tar_style )
3536 assert style == 'pytorch'
3637 else :
3738 assert style == tar_style
@@ -175,6 +176,34 @@ def test_fid(self):
175176 fid_score , mean , cov = fid .summary ()
176177 assert fid_score > 0 and mean > 0 and cov > 0
177178
179+ # test load
180+ inception_pkl = osp .expanduser ('~/.cache/openmmlab/mmgen/cifar10.pkl' )
181+ fid = FID (
182+ 3 ,
183+ inception_args = dict (
184+ normalize_input = False , load_fid_inception = False ),
185+ inception_pkl = inception_pkl )
186+ fid .prepare ()
187+ assert fid .num_real_feeded == 3
188+ for b in self .reals :
189+ fid .feed (b , 'reals' )
190+
191+ for b in self .fakes :
192+ fid .feed (b , 'fakes' )
193+
194+ fid_score , mean , cov = fid .summary ()
195+ assert fid_score > 0 and mean > 0 and cov > 0
196+
197+ # test raise load error
198+ inception_pkl = 'wrong_path'
199+ fid = FID (
200+ 3 ,
201+ inception_args = dict (
202+ normalize_input = False , load_fid_inception = False ),
203+ inception_pkl = inception_pkl )
204+ with pytest .raises (FileNotFoundError ):
205+ fid .prepare ()
206+
178207
179208class TestPR :
180209
0 commit comments