diff --git a/unet/unet_parts.py b/unet/unet_parts.py index 986ba251f4..6d0fc0e66e 100644 --- a/unet/unet_parts.py +++ b/unet/unet_parts.py @@ -8,17 +8,17 @@ class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - if not mid_channels: - mid_channels = out_channels + def __init__(self, in_channels, out_channels, dropout_prob=0.1): + super(DoubleConv, self).__init__() self.double_conv = nn.Sequential( - nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(mid_channels), + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), # <- اضافه شدن BatchNorm بعد از کانولوشن nn.ReLU(inplace=True), - nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.Dropout2d(dropout_prob), # <- اضافه شدن Dropout برای جلوگیری از بیش‌برازش + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), + nn.Dropout2d(dropout_prob) ) def forward(self, x):