-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathesnet.py
132 lines (106 loc) · 4.86 KB
/
esnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Paper: ESNet: An Efficient Symmetric Network for Real-time Semantic Segmentation
Url: https://arxiv.org/abs/1906.09826
Create by: zh320
Date: 2023/09/24
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .modules import ConvBNAct, DeConvBNAct, Activation
from .enet import InitialBlock as DownsamplingUnit
from .model_registry import register_model
@register_model()
class ESNet(nn.Module):
def __init__(self, num_class=1, n_channel=3, act_type='relu'):
super().__init__()
self.block1_down = DownsamplingUnit(n_channel, 16, act_type)
self.block1 = build_blocks('fcu', 16, 3, K=3, act_type=act_type)
self.block2_down = DownsamplingUnit(16, 64, act_type)
self.block2 = build_blocks('fcu', 64, 2, K=5, act_type=act_type)
self.block3_down = DownsamplingUnit(64, 128, act_type)
self.block3 = build_blocks('pfcu', 128, 3, r1=2, r2=5, r3=9, act_type=act_type)
self.block4_up = DeConvBNAct(128, 64, act_type=act_type)
self.block4 = build_blocks('fcu', 64, 2, K=5, act_type=act_type)
self.block5_up = DeConvBNAct(64, 16, act_type=act_type)
self.block5 = build_blocks('fcu', 16, 2, K=3, act_type=act_type)
self.full_conv = DeConvBNAct(16, num_class, act_type=act_type)
def forward(self, x):
x = self.block1_down(x)
x = self.block1(x)
x = self.block2_down(x)
x = self.block2(x)
x = self.block3_down(x)
x = self.block3(x)
x = self.block4_up(x)
x = self.block4(x)
x = self.block5_up(x)
x = self.block5(x)
x = self.full_conv(x)
return x
def build_blocks(block_type, channels, num_block, K=None, r1=None, r2=None, r3=None,
act_type='relu'):
layers = []
for _ in range(num_block):
if block_type == 'fcu':
layers.append(FCU(channels, K, act_type))
elif block_type == 'pfcu':
layers.append(PFCU(channels, r1, r2, r3, act_type))
else:
raise NotImplementedError(f'Unsupported block type: {block_type}.\n')
return nn.Sequential(*layers)
class FCU(nn.Module):
def __init__(self, channels, K, act_type):
super().__init__()
assert K is not None, 'K should not be None.\n'
padding = (K - 1) // 2
self.conv = nn.Sequential(
nn.Conv2d(channels, channels, (K, 1), padding=(padding, 0), bias=False),
Activation(act_type, inplace=True),
ConvBNAct(channels, channels, (1, K), act_type=act_type, inplace=True),
nn.Conv2d(channels, channels, (K, 1), padding=(padding, 0), bias=False),
Activation(act_type, inplace=True),
ConvBNAct(channels, channels, (1, K), act_type='none')
)
self.act = Activation(act_type)
def forward(self, x):
residual = x
x = self.conv(x)
x += residual
return self.act(x)
class PFCU(nn.Module):
def __init__(self, channels, r1, r2, r3, act_type):
super().__init__()
assert (r1 is not None) and (r2 is not None) and (r3 is not None)
self.conv0 = nn.Sequential(
nn.Conv2d(channels, channels, (3, 1), padding=(1, 0), bias=False),
Activation(act_type, inplace=True),
ConvBNAct(channels, channels, (1, 3), act_type=act_type, inplace=True)
)
self.conv_left = nn.Sequential(
nn.Conv2d(channels, channels, (3, 1), padding=(r1, 0),
dilation=r1, bias=False),
Activation(act_type, inplace=True),
ConvBNAct(channels, channels, (1, 3), dilation=r1, act_type='none')
)
self.conv_mid = nn.Sequential(
nn.Conv2d(channels, channels, (3, 1), padding=(r2, 0),
dilation=r2, bias=False),
Activation(act_type, inplace=True),
ConvBNAct(channels, channels, (1, 3), dilation=r2, act_type='none')
)
self.conv_right = nn.Sequential(
nn.Conv2d(channels, channels, (3, 1), padding=(r3, 0),
dilation=r3, bias=False),
Activation(act_type, inplace=True),
ConvBNAct(channels, channels, (1, 3), dilation=r3, act_type='none')
)
self.act = Activation(act_type)
def forward(self, x):
residual = x
x = self.conv0(x)
x_left = self.conv_left(x)
x_mid = self.conv_mid(x)
x_right = self.conv_right(x)
x = x_left + x_mid + x_right + residual
return self.act(x)