1
+ import torch
2
+ import torch .nn as nn
3
+ import torchvision
4
+ from efficientnet_pytorch import EfficientNet
5
+
6
+ class VGG19 (nn .Module ):
7
+ def __init__ (self , pretrained = True , requires_grad = True ):
8
+ super (VGG19 , self ).__init__ ()
9
+ _vgg = torchvision .models .vgg19 (pretrained = pretrained ).features
10
+ self .vgg_pool3 = torch .nn .Sequential ()
11
+ self .vgg_pool4 = torch .nn .Sequential ()
12
+ self .vgg_pool5 = torch .nn .Sequential ()
13
+
14
+ for x in range (19 ):
15
+ self .vgg_pool3 .add_module (str (x ), _vgg [x ])
16
+ for x in range (19 , 28 ):
17
+ self .vgg_pool4 .add_module (str (x ), _vgg [x ])
18
+ for x in range (28 , 37 ):
19
+ self .vgg_pool5 .add_module (str (x ), _vgg [x ])
20
+
21
+
22
+ if not requires_grad :
23
+ for param in self .parameters ():
24
+ param .requires_grad = False
25
+
26
+ def forward (self , x ):
27
+ pool_3_out = self .vgg_pool3 (x ) #torch.Size([1, 256, 128, 128])
28
+ pool_4_out = self .vgg_pool4 (pool_3_out ) #torch.Size([1, 512, 64, 64])
29
+ pool_5_out = self .vgg_pool5 (pool_4_out ) #torch.Size([1, 512, 32, 32])
30
+ return (pool_3_out , pool_4_out , pool_5_out )
31
+
32
+ class ResNet (nn .Module ):
33
+ def __init__ (self , pretrained = True , requires_grad = True ):
34
+ super (ResNet , self ).__init__ ()
35
+ resnet18 = torchvision .models .resnet34 (pretrained = True )
36
+
37
+ self .layer_1 = nn .Sequential (
38
+ resnet18 .conv1 ,
39
+ resnet18 .bn1 ,
40
+ resnet18 .relu ,
41
+ resnet18 .maxpool ,
42
+ resnet18 .layer1
43
+ )
44
+ self .layer_2 = resnet18 .layer2
45
+ self .layer_3 = resnet18 .layer3
46
+ self .layer_4 = resnet18 .layer4
47
+
48
+ if not requires_grad :
49
+ for param in self .parameters ():
50
+ param .requires_grad = False
51
+
52
+ def forward (self , x ):
53
+
54
+ out_1 = self .layer_2 (self .layer_1 (x )) #torch.Size([1, 128, 128, 128])
55
+ out_2 = self .layer_3 (out_1 ) #torch.Size([1, 256, 64, 64])
56
+ out_3 = self .layer_4 (out_2 ) #torch.Size([1, 512, 32, 32])
57
+ return out_1 , out_2 , out_3
58
+
59
+
60
+ class DenseNet (nn .Module ):
61
+ def __init__ (self , pretrained = True , requires_grad = True ):
62
+ super (DenseNet , self ).__init__ ()
63
+ denseNet = torchvision .models .densenet121 (pretrained = True ).features
64
+ self .densenet_out_1 = torch .nn .Sequential ()
65
+ self .densenet_out_2 = torch .nn .Sequential ()
66
+ self .densenet_out_3 = torch .nn .Sequential ()
67
+
68
+ for x in range (8 ):
69
+ self .densenet_out_1 .add_module (str (x ), denseNet [x ])
70
+ for x in range (8 ,10 ):
71
+ self .densenet_out_2 .add_module (str (x ), denseNet [x ])
72
+
73
+ self .densenet_out_3 .add_module (str (10 ), denseNet [10 ])
74
+
75
+ if not requires_grad :
76
+ for param in self .parameters ():
77
+ param .requires_grad = False
78
+
79
+ def forward (self , x ):
80
+
81
+ out_1 = self .densenet_out_1 (x ) #torch.Size([1, 256, 64, 64])
82
+ out_2 = self .densenet_out_2 (out_1 ) #torch.Size([1, 512, 32, 32])
83
+ out_3 = self .densenet_out_3 (out_2 ) #torch.Size([1, 1024, 32, 32])
84
+ return out_1 , out_2 , out_3
85
+
86
+ class efficientNet_B0 (nn .Module ):
87
+ def __init__ (self , pretrained = True , requires_grad = True ):
88
+ super (efficientNet_B0 , self ).__init__ ()
89
+ eNet = EfficientNet .from_pretrained ('efficientnet-b0' )
90
+
91
+ self .eNet_out_1 = torch .nn .Sequential ()
92
+ self .eNet_out_2 = torch .nn .Sequential ()
93
+ self .eNet_out_3 = torch .nn .Sequential ()
94
+
95
+ blocks = eNet ._blocks
96
+
97
+ self .eNet_out_1 .add_module ('_conv_stem' , eNet ._conv_stem )
98
+ self .eNet_out_1 .add_module ('_bn0' , eNet ._bn0 )
99
+
100
+ for x in range (14 ):
101
+ self .eNet_out_1 .add_module (str (x ), blocks [x ])
102
+
103
+ self .eNet_out_2 .add_module (str (14 ), blocks [14 ])
104
+ self .eNet_out_3 .add_module (str (15 ), blocks [15 ])
105
+
106
+
107
+ def forward (self , x ):
108
+ out_1 = self .eNet_out_1 (x ) #torch.Size([1, 192, 32, 32])
109
+ out_2 = self .eNet_out_2 (out_1 ) #torch.Size([1, 192, 32, 32])
110
+ out_3 = self .eNet_out_3 (out_2 ) #torch.Size([1, 320, 32, 32])
111
+ return out_1 , out_2 , out_3
112
+
113
+ class efficientNet (nn .Module ):
114
+ def __init__ (self , model_type = 'efficientnet-b0' , pretrained = True , requires_grad = True ):
115
+ super (efficientNet , self ).__init__ ()
116
+ eNet = EfficientNet .from_pretrained (model_type )
117
+
118
+ self .eNet_out_1 = torch .nn .Sequential ()
119
+ self .eNet_out_2 = torch .nn .Sequential ()
120
+ self .eNet_out_3 = torch .nn .Sequential ()
121
+
122
+ blocks = eNet ._blocks
123
+
124
+ self .eNet_out_1 .add_module ('_conv_stem' , eNet ._conv_stem )
125
+ self .eNet_out_1 .add_module ('_bn0' , eNet ._bn0 )
126
+
127
+ for x in range (len (blocks )- 3 ):
128
+ self .eNet_out_1 .add_module (str (x ), blocks [x ])
129
+
130
+ self .eNet_out_2 .add_module (str (len (blocks )- 2 ), blocks [len (blocks )- 2 ])
131
+ self .eNet_out_3 .add_module (str (len (blocks )- 1 ), blocks [len (blocks )- 1 ])
132
+
133
+
134
+ def forward (self , x ):
135
+ out_1 = self .eNet_out_1 (x ) #torch.Size([1, 192, 32, 32])
136
+ out_2 = self .eNet_out_2 (out_1 ) #torch.Size([1, 192, 32, 32])
137
+ out_3 = self .eNet_out_3 (out_2 ) #torch.Size([1, 320, 32, 32])
138
+
139
+
140
+ """
141
+ shapes of b1
142
+ torch.Size([1, 192, 32, 32])
143
+ torch.Size([1, 320, 32, 32])
144
+ torch.Size([1, 320, 32, 32])
145
+
146
+ shapes of b2
147
+ torch.Size([1, 208, 32, 32])
148
+ torch.Size([1, 352, 32, 32])
149
+ torch.Size([1, 352, 32, 32])
150
+ """
151
+
152
+ return out_1 , out_2 , out_3
153
+
154
+
155
+
156
+ if __name__ == '__main__' :
157
+ model = efficientNet ()
158
+ x = torch .randn (1 ,3 ,1024 ,1024 )
159
+ model (x )
0 commit comments