diff --git a/eval.py b/eval.py index 1e3c43a6..86f4f598 100644 --- a/eval.py +++ b/eval.py @@ -117,8 +117,18 @@ # Load weights. if args['weights'] is not None: model = create_model(num_classes=NUM_CLASSES, coco_model=False) + # Load checkpoint with DDP checkpoint = torch.load(args['weights'], map_location=DEVICE) - model.load_state_dict(checkpoint['model_state_dict']) + # Delete possible prefix "module." of model_state_dict + state_dict = checkpoint['model_state_dict'] + new_state_dict = {} + for key in state_dict.keys(): + if key.startswith('module.'): + new_key = key[7:] # Delete "module." + new_state_dict[new_key] = state_dict[key] + else: + new_state_dict[key] = state_dict[key] + model.load_state_dict(new_state_dict) valid_dataset = create_valid_dataset( VALID_DIR_IMAGES, VALID_DIR_LABELS, @@ -215,4 +225,4 @@ def evaluate( print('-'*num_hyphens) print(f"|{CLASSES[1]:<15} | {np.array(stats['map']):.3f}{empty_string:<15}| {np.array(stats['mar_100']):.3f}{empty_string:<15}|") print('-'*num_hyphens) - print(f"|Avg{empty_string:<12} | {np.array(stats['map']):.3f}{empty_string:<15}| {np.array(stats['mar_100']):.3f}{empty_string:<15}|") \ No newline at end of file + print(f"|Avg{empty_string:<12} | {np.array(stats['map']):.3f}{empty_string:<15}| {np.array(stats['mar_100']):.3f}{empty_string:<15}|")