Skip to content

Commit 8ac19bf

Browse files
committed
add model/dataset registry
1 parent 0e89ca6 commit 8ac19bf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+524
-412
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ PyTorch implementation of realtime semantic segmentation models, support multi-g
77
# Requirements
88

99
torch == 1.8.1
10-
segmentation-models-pytorch
10+
segmentation-models-pytorch (optional)
1111
torchmetrics
1212
albumentations
1313
loguru

configs/my_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
class MyConfig(BaseConfig):
55
def __init__(self,):
6-
super(MyConfig, self).__init__()
6+
super().__init__()
77
# Dataset
88
self.dataset = 'cityscapes'
99
self.data_root = '/path/to/your/dataset'

configs/optuna_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class OptunaConfig(BaseConfig):
99
def __init__(self,):
10-
super(OptunaConfig, self).__init__()
10+
super().__init__()
1111
# Dataset
1212
self.dataset = 'cityscapes'
1313
self.data_root = '/path/to/your/dataset'

configs/parser.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import argparse
22

3+
from datasets import list_available_datasets
4+
from models import list_available_models
5+
36

47
def load_parser(config):
58
args = get_parser()
@@ -16,7 +19,8 @@ def load_parser(config):
1619
def get_parser():
1720
parser = argparse.ArgumentParser()
1821
# Dataset
19-
parser.add_argument('--dataset', type=str, default=None, choices=['cityscapes'],
22+
dataset_list = list_available_datasets()
23+
parser.add_argument('--dataset', type=str, default=None, choices=dataset_list,
2024
help='choose which dataset you want to use')
2125
parser.add_argument('--dataroot', type=str, default=None,
2226
help='path to your dataset')
@@ -26,14 +30,8 @@ def get_parser():
2630
help='ignore index used for cross_entropy/ohem loss')
2731

2832
# Model
29-
parser.add_argument('--model', type=str, default=None,
30-
choices=['adscnet', 'aglnet', 'bisenetv1', 'bisenetv2', 'canet', 'cfpnet',
31-
'cgnet', 'contextnet', 'dabnet', 'ddrnet', 'dfanet', 'edanet',
32-
'enet', 'erfnet', 'esnet', 'espnet', 'espnetv2', 'fanet', 'farseenet',
33-
'fastscnn', 'fddwnet', 'fpenet', 'fssnet', 'icnet', 'lednet',
34-
'linknet', 'lite_hrnet', 'liteseg', 'mininet', 'mininetv2', 'ppliteseg',
35-
'regseg', 'segnet', 'shelfnet', 'sqnet', 'stdc', 'swiftnet',
36-
'smp'],
33+
model_list = list_available_models()
34+
parser.add_argument('--model', type=str, default=None, choices=model_list,
3735
help='choose which model you want to use')
3836
parser.add_argument('--encoder', type=str, default=None,
3937
help='choose which encoder of SMP model you want to use (please refer to SMP repo)')
@@ -179,4 +177,4 @@ def get_parser():
179177
help='temperature used for KL divergence loss')
180178

181179
args = parser.parse_args()
182-
return args
180+
return args

core/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
from .base_trainer import BaseTrainer
21
from .seg_trainer import SegTrainer
32
from .loss import get_loss_fn, kd_loss_fn, get_detail_loss_fn

core/base_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
class BaseTrainer:
1414
def __init__(self, config):
15-
super(BaseTrainer, self).__init__()
15+
super().__init__()
1616
# DDP parameters, DO NOT CHANGE
1717
self.rank = int(os.getenv('RANK', -1))
1818
self.local_rank = int(os.getenv('LOCAL_RANK', -1))

core/loss.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class OhemCELoss(nn.Module):
77
def __init__(self, thresh, ignore_index=255):
8-
super(OhemCELoss, self).__init__()
8+
super().__init__()
99
self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda()
1010
self.ignore_index = ignore_index
1111
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='none')
@@ -22,7 +22,7 @@ def forward(self, logits, labels):
2222

2323
class DiceLoss(nn.Module):
2424
def __init__(self, smooth=1):
25-
super(DiceLoss, self).__init__()
25+
super().__init__()
2626
self.smooth = smooth
2727

2828
def forward(self, logits, labels):
@@ -39,7 +39,7 @@ class DetailLoss(nn.Module):
3939
'''Implement detail loss used in paper
4040
`Rethinking BiSeNet For Real-time Semantic Segmentation`'''
4141
def __init__(self, dice_loss_coef=1., bce_loss_coef=1., smooth=1):
42-
super(DetailLoss, self).__init__()
42+
super().__init__()
4343
self.dice_loss_coef = dice_loss_coef
4444
self.bce_loss_coef = bce_loss_coef
4545
self.dice_loss_fn = DiceLoss(smooth)

datasets/__init__.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from torch.utils.data import DataLoader
2-
from .cityscapes import Cityscapes
32

4-
dataset_hub = {'cityscapes':Cityscapes,}
3+
from .cityscapes import Cityscapes
4+
from .dataset_registry import dataset_hub
55

66

77
def get_dataset(config):
@@ -58,4 +58,10 @@ def get_test_loader(config):
5858
test_loader = DataLoader(dataset, batch_size=config.test_bs,
5959
shuffle=False, num_workers=config.num_workers)
6060

61-
return test_loader
61+
return test_loader
62+
63+
64+
def list_available_datasets():
65+
dataset_list = list(dataset_hub.keys())
66+
67+
return dataset_list

datasets/cityscapes.py

+3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
from torch.utils.data import Dataset
66
import albumentations as AT
77
from albumentations.pytorch import ToTensorV2
8+
89
from utils import transforms
10+
from .dataset_registry import register_dataset
911

1012

13+
@register_dataset
1114
class Cityscapes(Dataset):
1215
# Codes are based on https://github.com/mcordts/cityscapesScripts
1316

datasets/dataset_registry.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
dataset_hub = {}
2+
3+
4+
def register_dataset(dataset_class):
5+
dataset_hub[dataset_class.__name__.lower()] = dataset_class
6+
return dataset_class

models/__init__.py

+28-41
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os, torch
2-
import segmentation_models_pytorch as smp
32

43
from .adscnet import ADSCNet
54
from .aglnet import AGLNet
@@ -38,51 +37,25 @@
3837
from .sqnet import SQNet
3938
from .stdc import STDC, LaplacianConv
4039
from .swiftnet import SwiftNet
41-
42-
43-
decoder_hub = {'deeplabv3':smp.DeepLabV3, 'deeplabv3p':smp.DeepLabV3Plus, 'fpn':smp.FPN,
44-
'linknet':smp.Linknet, 'manet':smp.MAnet, 'pan':smp.PAN, 'pspnet':smp.PSPNet,
45-
'unet':smp.Unet, 'unetpp':smp.UnetPlusPlus}
40+
from .model_registry import model_hub, aux_models, detail_head_models
4641

4742

4843
def get_model(config):
49-
model_hub = {'adscnet':ADSCNet, 'aglnet':AGLNet, 'bisenetv1':BiSeNetv1,
50-
'bisenetv2':BiSeNetv2, 'canet':CANet, 'cfpnet':CFPNet,
51-
'cgnet':CGNet, 'contextnet':ContextNet, 'dabnet':DABNet,
52-
'ddrnet':DDRNet, 'dfanet':DFANet, 'edanet':EDANet,
53-
'enet':ENet, 'erfnet':ERFNet, 'esnet':ESNet,
54-
'espnet':ESPNet, 'espnetv2':ESPNetv2, 'fanet':FANet, 'farseenet':FarSeeNet,
55-
'fastscnn':FastSCNN, 'fddwnet':FDDWNet, 'fpenet':FPENet,
56-
'fssnet':FSSNet, 'icnet':ICNet, 'lednet':LEDNet,
57-
'linknet':LinkNet, 'lite_hrnet':LiteHRNet, 'liteseg':LiteSeg, 'mininet':MiniNet,
58-
'mininetv2':MiniNetv2, 'ppliteseg':PPLiteSeg, 'regseg':RegSeg,
59-
'segnet':SegNet, 'shelfnet':ShelfNet, 'sqnet':SQNet,
60-
'stdc':STDC, 'swiftnet':SwiftNet,}
61-
62-
# The following models currently support auxiliary heads
63-
aux_models = ['bisenetv2', 'ddrnet', 'icnet']
64-
65-
# The following models currently support detail heads
66-
detail_head_models = ['stdc']
67-
6844
if config.model == 'smp': # Use segmentation models pytorch
69-
if config.decoder not in decoder_hub:
70-
raise ValueError(f"Unsupported decoder type: {config.decoder}")
45+
from .smp_wrapper import get_smp_model
7146

72-
model = decoder_hub[config.decoder](encoder_name=config.encoder,
73-
encoder_weights=config.encoder_weights,
74-
in_channels=3, classes=config.num_class)
47+
model = get_smp_model(config.encoder, config.decoder, config.encoder_weights, config.num_class)
7548

7649
elif config.model in model_hub.keys():
77-
if config.model in aux_models:
78-
model = model_hub[config.model](num_class=config.num_class, use_aux=config.use_aux)
79-
elif config.model in detail_head_models:
80-
model = model_hub[config.model](num_class=config.num_class, use_detail_head=config.use_detail_head, use_aux=config.use_aux)
50+
if config.model in aux_models: # models support auxiliary heads
51+
if config.model in detail_head_models: # models support detail heads
52+
model = model_hub[config.model](num_class=config.num_class, use_detail_head=config.use_detail_head, use_aux=config.use_aux)
53+
else:
54+
model = model_hub[config.model](num_class=config.num_class, use_aux=config.use_aux)
55+
8156
else:
8257
if config.use_aux:
8358
raise ValueError(f'Model {config.model} does not support auxiliary heads.\n')
84-
if config.use_detail_head:
85-
raise ValueError(f'Model {config.model} does not support detail heads.\n')
8659

8760
model = model_hub[config.model](num_class=config.num_class)
8861

@@ -92,16 +65,30 @@ def get_model(config):
9265
return model
9366

9467

68+
def list_available_models():
69+
model_list = list(model_hub.keys())
70+
71+
try:
72+
import segmentation_models_pytorch as smp
73+
model_list.append('smp')
74+
except:
75+
pass
76+
77+
return model_list
78+
79+
9580
def get_teacher_model(config, device):
9681
if config.kd_training:
9782
if not os.path.isfile(config.teacher_ckpt):
98-
raise ValueError(f'Could not find teacher checkpoint at path {config.teacher_ckpt}.')
83+
raise ValueError(f'Could not find teacher checkpoint at path {config.teacher_ckpt}.')
9984

100-
if config.teacher_decoder not in decoder_hub.keys():
101-
raise ValueError(f"Unsupported teacher decoder type: {config.teacher_decoder}")
85+
if config.teacher_model == 'smp':
86+
from .smp_wrapper import get_smp_model
10287

103-
model = decoder_hub[config.teacher_decoder](encoder_name=config.teacher_encoder,
104-
encoder_weights=None, in_channels=3, classes=config.num_class)
88+
model = get_smp_model(config.teacher_encoder, config.teacher_decoder, None, config.num_class)
89+
90+
else:
91+
raise NotImplementedError()
10592

10693
teacher_ckpt = torch.load(config.teacher_ckpt, map_location=torch.device('cpu'))
10794
model.load_state_dict(teacher_ckpt['state_dict'])

models/adscnet.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import torch.nn as nn
1111

1212
from .modules import conv1x1, ConvBNAct, DWConvBNAct, DeConvBNAct, Activation
13+
from .model_registry import register_model
1314

1415

16+
@register_model()
1517
class ADSCNet(nn.Module):
1618
def __init__(self, num_class=1, n_channel=3, act_type='relu6'):
17-
super(ADSCNet, self).__init__()
19+
super().__init__()
1820
self.conv0 = ConvBNAct(n_channel, 32, 3, 2, act_type=act_type, inplace=True)
1921
self.conv1 = ADSCModule(32, 1, act_type=act_type)
2022
self.conv2_4 = nn.Sequential(
@@ -54,7 +56,7 @@ def forward(self, x):
5456

5557
class ADSCModule(nn.Module):
5658
def __init__(self, channels, stride, dilation=1, act_type='relu'):
57-
super(ADSCModule, self).__init__()
59+
super().__init__()
5860
assert stride in [1, 2], 'Unsupported stride type.\n'
5961
self.use_skip = stride == 1
6062
self.conv = nn.Sequential(
@@ -80,7 +82,7 @@ def forward(self, x):
8082

8183
class DDCC(nn.Module):
8284
def __init__(self, channels, dilations, act_type):
83-
super(DDCC, self).__init__()
85+
super().__init__()
8486
assert len(dilations)==4, 'Length of dilations should be 4.\n'
8587
self.block1 = nn.Sequential(
8688
nn.AvgPool2d(dilations[0], 1, dilations[0]//2),
@@ -109,17 +111,17 @@ def __init__(self, channels, dilations, act_type):
109111

110112
def forward(self, x):
111113
x1 = self.block1(x)
112-
114+
113115
x2 = torch.cat([x, x1], dim=1)
114116
x2 = self.block2(x2)
115-
117+
116118
x3 = torch.cat([x, x1, x2], dim=1)
117119
x3 = self.block3(x3)
118-
120+
119121
x4 = torch.cat([x, x1, x2, x3], dim=1)
120122
x4 = self.block4(x4)
121-
123+
122124
x = torch.cat([x, x1, x2, x3, x4], dim=1)
123125
x = self.conv_last(x)
124126

125-
return x
127+
return x

0 commit comments

Comments
 (0)