Skip to content

Commit d91f3d5

Browse files
committed
clean up arguments
1 parent d224f03 commit d91f3d5

11 files changed

+40
-47
lines changed

Diff for: config/ade20k-mobilenetv2dilated-c1_deepsup.yaml

+3-5
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@ DATASET:
1212
MODEL:
1313
arch_encoder: "mobilenetv2dilated"
1414
arch_decoder: "c1_deepsup"
15-
weights_encoder: ""
16-
weights_decoder: ""
1715
fc_dim: 320
1816

1917
TRAIN:
20-
batch_size_per_gpu: 2
18+
batch_size_per_gpu: 3
2119
num_epoch: 20
2220
start_epoch: 0
2321
epoch_iters: 5000
@@ -35,10 +33,10 @@ TRAIN:
3533

3634
VAL:
3735
visualize: False
38-
suffix: "_epoch_20.pth"
36+
checkpoint: "epoch_20.pth"
3937

4038
TEST:
41-
suffix: "_epoch_20.pth"
39+
checkpoint: "epoch_20.pth"
4240
result: "./"
4341

4442
DIR: "ckpt/ade20k-mobilenetv2dilated-c1_deepsup"

Diff for: config/ade20k-resnet101-upernet.yaml

+2-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ DATASET:
1212
MODEL:
1313
arch_encoder: "resnet101"
1414
arch_decoder: "upernet"
15-
weights_encoder: ""
16-
weights_decoder: ""
1715
fc_dim: 2048
1816

1917
TRAIN:
@@ -35,10 +33,10 @@ TRAIN:
3533

3634
VAL:
3735
visualize: False
38-
suffix: "_epoch_40.pth"
36+
checkpoint: "epoch_20.pth"
3937

4038
TEST:
41-
suffix: "_epoch_40.pth"
39+
checkpoint: "epoch_20.pth"
4240
result: "./"
4341

4442
DIR: "ckpt/ade20k-resnet101-upernet"

Diff for: config/ade20k-resnet101dilated-ppm_deepsup.yaml

+2-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ DATASET:
1212
MODEL:
1313
arch_encoder: "resnet50dilated"
1414
arch_decoder: "ppm_deepsup"
15-
weights_encoder: ""
16-
weights_decoder: ""
1715
fc_dim: 2048
1816

1917
TRAIN:
@@ -35,10 +33,10 @@ TRAIN:
3533

3634
VAL:
3735
visualize: False
38-
suffix: "_epoch_20.pth"
36+
checkpoint: "epoch_20.pth"
3937

4038
TEST:
41-
suffix: "_epoch_20.pth"
39+
checkpoint: "epoch_20.pth"
4240
result: "./"
4341

4442
DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"

Diff for: config/ade20k-resnet18dilated-ppm_deepsup.yaml

+2-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ DATASET:
1212
MODEL:
1313
arch_encoder: "resnet18dilated"
1414
arch_decoder: "ppm_deepsup"
15-
weights_encoder: ""
16-
weights_decoder: ""
1715
fc_dim: 512
1816

1917
TRAIN:
@@ -35,10 +33,10 @@ TRAIN:
3533

3634
VAL:
3735
visualize: False
38-
suffix: "_epoch_20.pth"
36+
checkpoint: "epoch_20.pth"
3937

4038
TEST:
41-
suffix: "_epoch_20.pth"
39+
checkpoint: "epoch_20.pth"
4240
result: "./"
4341

4442
DIR: "ckpt/ade20k-resnet18dilated-ppm_deepsup"

Diff for: config/ade20k-resnet50dilated-ppm_deepsup.yaml

+2-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ DATASET:
1212
MODEL:
1313
arch_encoder: "resnet50dilated"
1414
arch_decoder: "ppm_deepsup"
15-
weights_encoder: ""
16-
weights_decoder: ""
1715
fc_dim: 2048
1816

1917
TRAIN:
@@ -35,10 +33,10 @@ TRAIN:
3533

3634
VAL:
3735
visualize: False
38-
suffix: "_epoch_20.pth"
36+
checkpoint: "epoch_20.pth"
3937

4038
TEST:
41-
suffix: "_epoch_20.pth"
39+
checkpoint: "epoch_20.pth"
4240
result: "./"
4341

4442
DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"

Diff for: config/defaults.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
# epochs to train for
5050
_C.TRAIN.num_epoch = 20
5151
# epoch to start training. useful if continue from a checkpoint
52-
_C.TRAIN.start_epoch = 1
52+
_C.TRAIN.start_epoch = 0
5353
# iterations of each epoch (irrelevant to batch size)
5454
_C.TRAIN.epoch_iters = 5000
5555

@@ -83,7 +83,7 @@
8383
# output visualization during validation
8484
_C.VAL.visualize = False
8585
# the checkpoint to evaluate on
86-
_C.VAL.suffix = "_epoch_20.pth"
86+
_C.VAL.checkpoint = "epoch_20.pth"
8787

8888
# -----------------------------------------------------------------------------
8989
# Testing
@@ -92,6 +92,6 @@
9292
# currently only supports 1
9393
_C.TEST.batch_size = 1
9494
# the checkpoint to test on
95-
_C.TEST.suffix = "_epoch_20.pth"
95+
_C.TEST.checkpoint = "epoch_20.pth"
9696
# folder to output visualization results
9797
_C.TEST.result = "./"

Diff for: dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import json
33
import torch
4-
import lib.utils.data as torchdata
54
import cv2
65
from torchvision import transforms
76
import numpy as np
@@ -23,7 +22,7 @@ def imresize(im, size, interp='bilinear'):
2322
)
2423

2524

26-
class BaseDataset(torchdata.Dataset):
25+
class BaseDataset(torch.utils.data.Dataset):
2726
def __init__(self, odgt, opt, **kwargs):
2827
# parse options
2928
self.imgSizes = opt.imgSizes
@@ -110,6 +109,7 @@ def _get_sub_batch(self):
110109
def __getitem__(self, index):
111110
# NOTE: random shuffle for the first time. shuffle in __init__ is useless
112111
if not self.if_shuffled:
112+
np.random.seed(index)
113113
np.random.shuffle(self.list_sample)
114114
self.if_shuffled = True
115115

Diff for: eval.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, setup_logger
1616
from lib.nn import user_scattered_collate, async_copy_to
1717
from lib.utils import as_numpy
18-
import lib.utils.data as torchdata
1918
import cv2
2019
from tqdm import tqdm
2120

@@ -133,7 +132,7 @@ def main(cfg, gpu):
133132
cfg.DATASET.root_dataset,
134133
cfg.DATASET.list_val,
135134
cfg.DATASET)
136-
loader_val = torchdata.DataLoader(
135+
loader_val = torch.utils.data.DataLoader(
137136
dataset_val,
138137
batch_size=cfg.VAL.batch_size,
139138
shuffle=False,
@@ -186,10 +185,9 @@ def main(cfg, gpu):
186185

187186
# absolute paths of model weights
188187
cfg.MODEL.weights_encoder = os.path.join(
189-
cfg.DIR, 'encoder' + cfg.VAL.suffix)
188+
cfg.DIR, 'encoder_' + cfg.VAL.checkpoint)
190189
cfg.MODEL.weights_decoder = os.path.join(
191-
cfg.DIR, 'decoder' + cfg.VAL.suffix)
192-
190+
cfg.DIR, 'decoder_' + cfg.VAL.checkpoint)
193191
assert os.path.exists(cfg.MODEL.weights_encoder) and \
194192
os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"
195193

Diff for: eval_multipro.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, parse_devices, setup_logger
1717
from lib.nn import user_scattered_collate, async_copy_to
1818
from lib.utils import as_numpy
19-
import lib.utils.data as torchdata
2019
import cv2
2120
from tqdm import tqdm
2221

@@ -94,7 +93,7 @@ def worker(cfg, gpu_id, start_idx, end_idx, result_queue):
9493
cfg.DATASET.list_val,
9594
cfg.DATASET,
9695
start_idx=start_idx, end_idx=end_idx)
97-
loader_val = torchdata.DataLoader(
96+
loader_val = torch.utils.data.DataLoader(
9897
dataset_val,
9998
batch_size=cfg.VAL.batch_size,
10099
shuffle=False,
@@ -211,10 +210,9 @@ def main(cfg, gpus):
211210

212211
# absolute paths of model weights
213212
cfg.MODEL.weights_encoder = os.path.join(
214-
cfg.DIR, 'encoder' + cfg.VAL.suffix)
213+
cfg.DIR, 'encoder_' + cfg.VAL.checkpoint)
215214
cfg.MODEL.weights_decoder = os.path.join(
216-
cfg.DIR, 'decoder' + cfg.VAL.suffix)
217-
215+
cfg.DIR, 'decoder_' + cfg.VAL.checkpoint)
218216
assert os.path.exists(cfg.MODEL.weights_encoder) and \
219217
os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"
220218

Diff for: test.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from utils import colorEncode, find_recursive, setup_logger
1515
from lib.nn import user_scattered_collate, async_copy_to
1616
from lib.utils import as_numpy
17-
import lib.utils.data as torchdata
1817
import cv2
1918
from tqdm import tqdm
2019
from config import cfg
@@ -116,7 +115,7 @@ def main(cfg, gpu):
116115
dataset_test = TestDataset(
117116
cfg.list_test,
118117
cfg.DATASET)
119-
loader_test = torchdata.DataLoader(
118+
loader_test = torch.utils.data.DataLoader(
120119
dataset_test,
121120
batch_size=cfg.TEST.batch_size,
122121
shuffle=False,
@@ -179,9 +178,9 @@ def main(cfg, gpu):
179178

180179
# absolute paths of model weights
181180
cfg.MODEL.weights_encoder = os.path.join(
182-
cfg.DIR, 'encoder' + cfg.TEST.suffix)
181+
cfg.DIR, 'encoder_' + cfg.TEST.checkpoint)
183182
cfg.MODEL.weights_decoder = os.path.join(
184-
cfg.DIR, 'decoder' + cfg.TEST.suffix)
183+
cfg.DIR, 'decoder_' + cfg.TEST.checkpoint)
185184

186185
assert os.path.exists(cfg.MODEL.weights_encoder) and \
187186
os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"

Diff for: train.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from models import ModelBuilder, SegmentationModule
1515
from utils import AverageMeter, parse_devices, setup_logger
1616
from lib.nn import UserScatteredDataParallel, user_scattered_collate, patch_replication_callback
17-
import lib.utils.data as torchdata
1817

1918

2019
# train one epoch
@@ -31,7 +30,6 @@ def train(segmentation_module, iterator, optimizers, history, epoch, cfg):
3130
for i in range(cfg.TRAIN.epoch_iters):
3231
batch_data = next(iterator)
3332
data_time.update(time.time() - tic)
34-
3533
segmentation_module.zero_grad()
3634

3735
# forward pass
@@ -72,7 +70,7 @@ def train(segmentation_module, iterator, optimizers, history, epoch, cfg):
7270
adjust_learning_rate(optimizers, cur_iter, cfg)
7371

7472

75-
def checkpoint(nets, history, cfg, epoch_num):
73+
def checkpoint(nets, history, cfg, epoch):
7674
print('Saving checkpoints...')
7775
(net_encoder, net_decoder, crit) = nets
7876

@@ -81,13 +79,13 @@ def checkpoint(nets, history, cfg, epoch_num):
8179

8280
torch.save(
8381
history,
84-
'{}/history_epoch_{}.pth'.format(cfg.DIR, epoch_num))
82+
'{}/history_epoch_{}.pth'.format(cfg.DIR, epoch))
8583
torch.save(
8684
dict_encoder,
87-
'{}/encoder_epoch_{}.pth'.format(cfg.DIR, epoch_num))
85+
'{}/encoder_epoch_{}.pth'.format(cfg.DIR, epoch))
8886
torch.save(
8987
dict_decoder,
90-
'{}/decoder_epoch_{}.pth'.format(cfg.DIR, epoch_num))
88+
'{}/decoder_epoch_{}.pth'.format(cfg.DIR, epoch))
9189

9290

9391
def group_weight(module):
@@ -169,7 +167,7 @@ def main(cfg, gpus):
169167
cfg.DATASET,
170168
batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)
171169

172-
loader_train = torchdata.DataLoader(
170+
loader_train = torch.utils.data.DataLoader(
173171
dataset_train,
174172
batch_size=len(gpus), # we have modified data_parallel
175173
shuffle=False, # we do not use this param
@@ -242,12 +240,22 @@ def main(cfg, gpus):
242240
logger.info("Loaded configuration file {}".format(args.cfg))
243241
logger.info("Running with config:\n{}".format(cfg))
244242

243+
# Output directory
245244
if not os.path.isdir(cfg.DIR):
246245
os.makedirs(cfg.DIR)
247246
logger.info("Outputing checkpoints to: {}".format(cfg.DIR))
248247
with open(os.path.join(cfg.DIR, 'config.yaml'), 'w') as f:
249248
f.write("{}".format(cfg))
250249

250+
# Start from checkpoint
251+
if cfg.TRAIN.start_epoch > 0:
252+
cfg.MODEL.weights_encoder = os.path.join(
253+
cfg.DIR, 'encoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch))
254+
cfg.MODEL.weights_decoder = os.path.join(
255+
cfg.DIR, 'decoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch))
256+
assert os.path.exists(cfg.MODEL.weights_encoder) and \
257+
os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"
258+
251259
# Parse gpu ids
252260
gpus = parse_devices(args.gpus)
253261
gpus = [x.replace('gpu', '') for x in gpus]

0 commit comments

Comments
 (0)