-
Notifications
You must be signed in to change notification settings - Fork 4
/
eval_omni.py
89 lines (73 loc) · 2.27 KB
/
eval_omni.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import os
import sys
import warnings
from pathlib import Path
import numpy as np
import torch
from tqdm import trange
from datasets.OmniSampler import OmniSampler
from anml import test_train
warnings.filterwarnings("ignore")
def check_path(path):
if Path(path).exists():
return path
else:
raise argparse.ArgumentTypeError(f"model:{path} is not a valid path")
def repeats(runs, path, classes, train_examples, lr, device):
omni_sampler = OmniSampler(root="../data/omni")
run = lambda: test_train(
path,
sampler=omni_sampler,
num_classes=classes,
train_examples=train_examples,
device=device,
lr=lr,
)
results = []
for _ in trange(runs):
results.append(run().mean())
print(
f"Classes {classes} Accuracy {np.mean(results):.2f} (std {np.std(results):.2f})"
)
if __name__ == "__main__":
# Training setting
parser = argparse.ArgumentParser(description="ANML training")
parser.add_argument(
"-l",
"--lr",
type=float,
help="learning rate to use (check README for suggestions)",
)
parser.add_argument(
"-c", "--classes", type=int, help="number of classes to test",
)
parser.add_argument(
"-r", "--runs", type=int, help="number of repetitions to run",
)
parser.add_argument(
"-t",
"--train-examples",
type=int,
default=15,
help="how many examples to use for training (max 20, default 15)",
)
parser.add_argument(
"-m", "--model", type=check_path, help="path to the model to use"
)
parser.add_argument("-d", "--device", choices=["cpu", "cuda"], type=str.lower, help="Device to use for PyTorch.")
args = parser.parse_args()
device = args.device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
elif device == "cuda" and not torch.cuda.is_available():
print("Torch says CUDA is not available. Remove it from your command to proceed on CPU.", file=sys.stderr)
sys.exit(os.EX_UNAVAILABLE)
repeats(
runs=args.runs,
path=args.model,
classes=args.classes,
train_examples=args.train_examples,
lr=args.lr,
device=device,
)