-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathresnet_block.py
More file actions
113 lines (91 loc) · 4.69 KB
/
resnet_block.py
File metadata and controls
113 lines (91 loc) · 4.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch.nn as nn
# Define a ResNet block class
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=False):
"""
Initialize the ResNet block.
Args:
dim (int): Number of input/output channels.
padding_type (str): Type of padding ('reflect', 'replicate', or 'zero').
norm_layer (nn.Module): Normalization layer to use (e.g., BatchNorm2d).
use_dropout (bool): Whether to include a dropout layer in the block.
use_bias (bool): Whether the convolutional layers should use bias.
"""
super(ResnetBlock, self).__init__()
# Print out parameters during initialization
print(f"Initializing ResNet block with {dim} channels, padding type: {padding_type}, use_dropout: {use_dropout}, use_bias: {use_bias}")
# Build the convolutional block
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""
Build the sequence of layers for the convolutional block.
Args:
dim (int): Number of input/output channels.
padding_type (str): Type of padding ('reflect', 'replicate', or 'zero').
norm_layer (nn.Module): Normalization layer to use.
use_dropout (bool): Whether to include a dropout layer.
use_bias (bool): Whether the convolutional layers should use bias.
Returns:
nn.Sequential: A sequential container for the layers of the convolutional block.
"""
print(f"Building convolutional block with padding: {padding_type}")
conv_block = []
p = 0
# Choose padding based on the specified type
if padding_type == 'reflect':
# If the padding type is 'reflect', use ReflectionPad2d for padding
print("Using ReflectionPad2d for padding.")
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
# If the padding type is 'replicate', use ReplicationPad2d for padding
print("Using ReplicationPad2d for padding.")
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
# If the padding type is 'zero', use zero padding
p = 1
print("Using zero padding.")
else:
raise NotImplementedError(f'Padding type [{padding_type}] is not implemented.')
# First convolutional layer with normalization and ReLU activation
print("Adding first convolutional layer, normalization, and ReLU activation.")
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
# Add dropout layer if specified
if use_dropout:
print("Adding dropout layer with p=0.5.")
conv_block += [nn.Dropout(0.5)]
# Second convolutional layer
p = 0
if padding_type == 'reflect':
print("Using ReflectionPad2d for padding in second layer.")
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
print("Using ReplicationPad2d for padding in second layer.")
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
print("Using zero padding in second layer.")
else:
raise NotImplementedError(f'Padding type [{padding_type}] is not implemented.')
# Second convolutional layer with normalization
print("Adding second convolutional layer and normalization.")
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
# Return the sequential block
return nn.Sequential(*conv_block)
def forward(self, x):
"""
Forward pass for the ResNet block.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying the residual connection.
"""
# Print input size before passing through the block
print(f"Forward pass input shape: {x.shape}")
# Apply the convolution block and add the input tensor to the output (residual connection)
out = x + self.conv_block(x)
# Print output shape after residual connection
print(f"Forward pass output shape: {out.shape}")
return out