-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathprepcache.py
63 lines (50 loc) · 1.56 KB
/
prepcache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import sys
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import SGD
from torch.utils.data import DataLoader
from util.util import enumerateWithEstimate
from .dsets import LunaDataset
from util.logconf import logging
from .model import LunaModel
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
log.setLevel(logging.INFO)
# log.setLevel(logging.DEBUG)
class LunaPrepCacheApp:
@classmethod
def __init__(self, sys_argv=None):
if sys_argv is None:
sys_argv = sys.argv[1:]
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size',
help='Batch size to use for training',
default=1024,
type=int,
)
parser.add_argument('--num-workers',
help='Number of worker processes for background data loading',
default=8,
type=int,
)
self.cli_args = parser.parse_args(sys_argv)
def main(self):
log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
self.prep_dl = DataLoader(
LunaDataset(
sortby_str='series_uid',
),
batch_size=self.cli_args.batch_size,
num_workers=self.cli_args.num_workers,
)
batch_iter = enumerateWithEstimate(
self.prep_dl,
"Stuffing cache",
start_ndx=self.prep_dl.num_workers,
)
for _ in batch_iter:
pass
if __name__ == '__main__':
LunaPrepCacheApp().main()