Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions unet/unet_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""

def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
def __init__(self, in_channels, out_channels, dropout_prob=0.1):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels), # <- اضافه شدن BatchNorm بعد از کانولوشن
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.Dropout2d(dropout_prob), # <- اضافه شدن Dropout برای جلوگیری از بیش‌برازش
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
nn.ReLU(inplace=True),
nn.Dropout2d(dropout_prob)
)

def forward(self, x):
Expand Down