-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
81 lines (72 loc) · 2.58 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V
import cv2
import os
import numpy as np
from time import time
from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool
from framework import MyFrame
from loss import dice_bce_loss
from data import ImageFolder
SHAPE = (1024, 1024)
ROOT = 'dataset/train/'
imagelist = list(filter(lambda x: x.find('sat') != -1, os.listdir(ROOT)))
trainlist = list(map(lambda x: x[:-8], imagelist))
NAME = 'log01_dink34'
BATCHSIZE_PER_CARD = 2
if __name__ == '__main__':
solver = MyFrame(DinkNet34, dice_bce_loss, 2e-4)
batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
dataset = ImageFolder(trainlist, ROOT)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batchsize,
shuffle=True,
num_workers=4)
mylog = open('logs/'+NAME+'.log', 'w')
tic = time()
no_optim = 0
total_epoch = 20
train_epoch_best_loss = 100.
for epoch in range(1, total_epoch + 1):
data_loader_iter = iter(data_loader)
train_epoch_loss = 0
print(len(data_loader_iter))
for index, (img, mask) in enumerate(data_loader_iter):
if index % 100 == 0:
print("%{} of one epoch.".format(index / len(data_loader_iter)))
solver.set_input(img, mask)
train_loss = solver.optimize()
train_epoch_loss += train_loss
train_epoch_loss /= len(data_loader_iter)
print('********', file=mylog)
print('epoch:',epoch,' time:',int(time()-tic), file=mylog)
print('train_loss:',train_epoch_loss, file=mylog)
print('SHAPE:',SHAPE, file=mylog)
print('********')
print('epoch:',epoch,' time:',int(time()-tic))
print('train_loss:',train_epoch_loss)
print('SHAPE:', SHAPE)
if train_epoch_loss >= train_epoch_best_loss:
no_optim += 1
else:
no_optim = 0
train_epoch_best_loss = train_epoch_loss
solver.save('weights/'+NAME+'.th')
if no_optim > 6:
print('early stop at %d epoch' % epoch, file=mylog)
print('early stop at %d epoch' % epoch)
break
if no_optim > 3:
if solver.old_lr < 5e-7:
break
solver.load('weights/'+NAME+'.th')
solver.update_lr(5.0, factor = True, mylog = mylog)
mylog.flush()
print('Finish!', file=mylog)
print('Finish!')
mylog.close()