This repository has been archived by the owner on Jul 20, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 303
/
evaluate.py
executable file
·123 lines (93 loc) · 3.34 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python
import argparse
import pathlib
import time
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from fvcore.common.checkpoint import Checkpointer
from pytorch_image_classification import (
apply_data_parallel_wrapper,
create_dataloader,
create_loss,
create_model,
get_default_config,
update_config,
)
from pytorch_image_classification.utils import (
AverageMeter,
create_logger,
get_rank,
)
def load_config():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True)
parser.add_argument('options', default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
config = get_default_config()
config.merge_from_file(args.config)
config.merge_from_list(args.options)
update_config(config)
config.freeze()
return config
def evaluate(config, model, test_loader, loss_func, logger):
device = torch.device(config.device)
model.eval()
loss_meter = AverageMeter()
correct_meter = AverageMeter()
start = time.time()
pred_raw_all = []
pred_prob_all = []
pred_label_all = []
with torch.no_grad():
for data, targets in tqdm.tqdm(test_loader):
data = data.to(device)
targets = targets.to(device)
outputs = model(data)
loss = loss_func(outputs, targets)
pred_raw_all.append(outputs.cpu().numpy())
pred_prob_all.append(F.softmax(outputs, dim=1).cpu().numpy())
_, preds = torch.max(outputs, dim=1)
pred_label_all.append(preds.cpu().numpy())
loss_ = loss.item()
correct_ = preds.eq(targets).sum().item()
num = data.size(0)
loss_meter.update(loss_, num)
correct_meter.update(correct_, 1)
accuracy = correct_meter.sum / len(test_loader.dataset)
elapsed = time.time() - start
logger.info(f'Elapsed {elapsed:.2f}')
logger.info(f'Loss {loss_meter.avg:.4f} Accuracy {accuracy:.4f}')
preds = np.concatenate(pred_raw_all)
probs = np.concatenate(pred_prob_all)
labels = np.concatenate(pred_label_all)
return preds, probs, labels, loss_meter.avg, accuracy
def main():
config = load_config()
if config.test.output_dir is None:
output_dir = pathlib.Path(config.test.checkpoint).parent
else:
output_dir = pathlib.Path(config.test.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
logger = create_logger(name=__name__, distributed_rank=get_rank())
model = create_model(config)
model = apply_data_parallel_wrapper(config, model)
checkpointer = Checkpointer(model,
checkpoint_dir=output_dir,
logger=logger,
distributed_rank=get_rank())
checkpointer.load(config.test.checkpoint)
test_loader = create_dataloader(config, is_train=False)
_, test_loss = create_loss(config)
preds, probs, labels, loss, acc = evaluate(config, model, test_loader,
test_loss, logger)
output_path = output_dir / f'predictions.npz'
np.savez(output_path,
preds=preds,
probs=probs,
labels=labels,
loss=loss,
acc=acc)
if __name__ == '__main__':
main()