diff --git a/mmgen/apis/train.py b/mmgen/apis/train.py index 56c738984..cde663403 100644 --- a/mmgen/apis/train.py +++ b/mmgen/apis/train.py @@ -62,7 +62,7 @@ def train_model(model, k: v for k, v in cfg.data.items() if k not in [ 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', - 'test_dataloader' + 'test_dataloader', 'val_samples_per_gpu', 'val_workers_per_gpu' ] }) @@ -179,8 +179,13 @@ def train_model(model, **loader_cfg, 'shuffle': False, **cfg.data.get('val_data_loader', {}) } - val_dataloader = build_dataloader( - val_dataset, dist=distributed, **val_loader_cfg) + val_loader_cfg.update({ + 'samples_per_gpu': + cfg.data.get('val_samples_per_gpu', cfg.data.samples_per_gpu), + 'workers_per_gpu': + cfg.data.get('val_workers_per_gpu', cfg.data.workers_per_gpu) + }) + val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) eval_cfg = deepcopy(cfg.get('evaluation')) priority = eval_cfg.pop('priority', 'LOW') eval_cfg.update(dict(dist=distributed, dataloader=val_dataloader))