|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn |
| 5 | +import torch.nn.functional as F |
| 6 | + |
| 7 | +from ._utils import pad_to_multiple, unpad |
| 8 | + |
| 9 | + |
| 10 | +def _act() -> nn.Module: |
| 11 | + return nn.ReLU(inplace=True) |
| 12 | + |
| 13 | + |
| 14 | +class ConvBlock(nn.Module): |
| 15 | + def __init__(self, in_ch: int, out_ch: int, *, kernel_size: int = 3, act: bool = True) -> None: |
| 16 | + super().__init__() |
| 17 | + k = int(kernel_size) |
| 18 | + p = k // 2 |
| 19 | + self.conv = nn.Conv2d(int(in_ch), int(out_ch), kernel_size=k, padding=p, bias=True) |
| 20 | + self.act = _act() if bool(act) else nn.Identity() |
| 21 | + |
| 22 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 23 | + return self.act(self.conv(x)) |
| 24 | + |
| 25 | + |
| 26 | +class ChannelAttention(nn.Module): |
| 27 | + def __init__(self, channels: int, *, reduction: int = 8) -> None: |
| 28 | + super().__init__() |
| 29 | + c = int(channels) |
| 30 | + r = int(reduction) |
| 31 | + hidden = max(8, c // max(1, r)) |
| 32 | + self.pool = nn.AdaptiveAvgPool2d(1) |
| 33 | + self.mlp = nn.Sequential( |
| 34 | + nn.Conv2d(c, hidden, kernel_size=1, bias=True), |
| 35 | + _act(), |
| 36 | + nn.Conv2d(hidden, c, kernel_size=1, bias=True), |
| 37 | + nn.Sigmoid(), |
| 38 | + ) |
| 39 | + |
| 40 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 41 | + w = self.mlp(self.pool(x)) |
| 42 | + return x * w |
| 43 | + |
| 44 | + |
| 45 | +class DualAttentionUnit(nn.Module): |
| 46 | + """A small residual block with channel attention (toy DU/DAU).""" |
| 47 | + |
| 48 | + def __init__(self, channels: int) -> None: |
| 49 | + super().__init__() |
| 50 | + c = int(channels) |
| 51 | + self.conv1 = ConvBlock(c, c, kernel_size=3, act=True) |
| 52 | + self.conv2 = ConvBlock(c, c, kernel_size=3, act=False) |
| 53 | + self.ca = ChannelAttention(c, reduction=8) |
| 54 | + |
| 55 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 56 | + y = self.conv2(self.conv1(x)) |
| 57 | + y = self.ca(y) |
| 58 | + return x + y |
| 59 | + |
| 60 | + |
| 61 | +class SKFF(nn.Module): |
| 62 | + """Selective kernel feature fusion (toy, channel-wise softmax).""" |
| 63 | + |
| 64 | + def __init__(self, channels: int, num_branches: int, *, reduction: int = 8) -> None: |
| 65 | + super().__init__() |
| 66 | + c = int(channels) |
| 67 | + b = int(num_branches) |
| 68 | + if b <= 1: |
| 69 | + raise ValueError("num_branches must be > 1") |
| 70 | + hidden = max(8, c // max(1, int(reduction))) |
| 71 | + self.pool = nn.AdaptiveAvgPool2d(1) |
| 72 | + self.fc = nn.Sequential(nn.Conv2d(c, hidden, 1, bias=True), _act()) |
| 73 | + self.fcs = nn.ModuleList([nn.Conv2d(hidden, c, 1, bias=True) for _ in range(b)]) |
| 74 | + |
| 75 | + def forward(self, feats: list[torch.Tensor]) -> torch.Tensor: |
| 76 | + if len(feats) != len(self.fcs): |
| 77 | + raise ValueError(f"Expected {len(self.fcs)} feature maps, got {len(feats)}") |
| 78 | + base = feats[0] |
| 79 | + for f in feats[1:]: |
| 80 | + if f.shape != base.shape: |
| 81 | + raise ValueError("All SKFF branches must have the same shape") |
| 82 | + |
| 83 | + s = torch.zeros_like(base) |
| 84 | + for f in feats: |
| 85 | + s = s + f |
| 86 | + |
| 87 | + z = self.fc(self.pool(s)) # (B, hidden, 1, 1) |
| 88 | + logits = torch.stack([fc(z) for fc in self.fcs], dim=1) # (B, Bn, C, 1, 1) |
| 89 | + w = torch.softmax(logits, dim=1) |
| 90 | + out = torch.zeros_like(base) |
| 91 | + for i, f in enumerate(feats): |
| 92 | + out = out + f * w[:, i] |
| 93 | + return out |
| 94 | + |
| 95 | + |
| 96 | +class MultiScaleResidualBlock(nn.Module): |
| 97 | + """Toy MIRNet multi-scale residual block: 3-scale processing + SKFF + residual.""" |
| 98 | + |
| 99 | + def __init__(self, channels: int) -> None: |
| 100 | + super().__init__() |
| 101 | + c = int(channels) |
| 102 | + self.dau1 = DualAttentionUnit(c) |
| 103 | + self.dau2 = DualAttentionUnit(c) |
| 104 | + self.dau3 = DualAttentionUnit(c) |
| 105 | + self.fuse = SKFF(c, 3, reduction=8) |
| 106 | + self.out = ConvBlock(c, c, kernel_size=3, act=False) |
| 107 | + |
| 108 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 109 | + b, c, h, w = x.shape |
| 110 | + |
| 111 | + x1 = self.dau1(x) |
| 112 | + |
| 113 | + x2 = F.avg_pool2d(x, kernel_size=2, stride=2, ceil_mode=False) |
| 114 | + x2 = self.dau2(x2) |
| 115 | + x2 = F.interpolate(x2, size=(h, w), mode="nearest") |
| 116 | + |
| 117 | + x3 = F.avg_pool2d(x, kernel_size=4, stride=4, ceil_mode=False) |
| 118 | + x3 = self.dau3(x3) |
| 119 | + x3 = F.interpolate(x3, size=(h, w), mode="nearest") |
| 120 | + |
| 121 | + y = self.fuse([x1, x2, x3]) |
| 122 | + y = self.out(y) |
| 123 | + return x + y |
| 124 | + |
| 125 | + |
| 126 | +class MIRNet(nn.Module): |
| 127 | + """MIRNet-style denoiser (toy-first, pure torch). |
| 128 | +
|
| 129 | + Notes: |
| 130 | + - This is a simplified MIRNet-inspired architecture with multi-scale residual blocks and SKFF-like fusion. |
| 131 | + - It performs residual learning: output = input + predicted_residual. |
| 132 | + """ |
| 133 | + |
| 134 | + def __init__( |
| 135 | + self, |
| 136 | + *, |
| 137 | + in_channels: int, |
| 138 | + width: int = 32, |
| 139 | + depth: int = 5, |
| 140 | + ) -> None: |
| 141 | + super().__init__() |
| 142 | + c_in = int(in_channels) |
| 143 | + w0 = int(width) |
| 144 | + d = int(depth) |
| 145 | + if c_in <= 0: |
| 146 | + raise ValueError("in_channels must be > 0") |
| 147 | + if w0 < 8: |
| 148 | + raise ValueError("width must be >= 8") |
| 149 | + if d <= 0: |
| 150 | + raise ValueError("depth must be > 0") |
| 151 | + |
| 152 | + self.intro = nn.Conv2d(c_in, w0, kernel_size=3, padding=1, bias=True) |
| 153 | + self.body = nn.Sequential(*[MultiScaleResidualBlock(w0) for _ in range(d)]) |
| 154 | + self.outro = nn.Conv2d(w0, c_in, kernel_size=3, padding=1, bias=True) |
| 155 | + |
| 156 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 157 | + x = x.to(torch.float32) |
| 158 | + if x.ndim != 4: |
| 159 | + raise ValueError(f"Expected input shape (B, C, H, W), got {tuple(x.shape)}") |
| 160 | + |
| 161 | + x_pad, pad_hw = pad_to_multiple(x, 4, mode="reflect") |
| 162 | + inp = x_pad |
| 163 | + h = self.body(self.intro(x_pad)) |
| 164 | + y = inp + self.outro(h) |
| 165 | + return unpad(y, pad_hw) |
| 166 | + |
| 167 | + |
| 168 | +_VARIANTS: dict[str, dict] = { |
| 169 | + "mirnet_tiny": {"width": 24, "depth": 3}, |
| 170 | + "mirnet_small": {"width": 32, "depth": 5}, |
| 171 | + "mirnet_base": {"width": 48, "depth": 7}, |
| 172 | +} |
| 173 | + |
| 174 | + |
| 175 | +def build_mirnet_denoiser( |
| 176 | + *, |
| 177 | + in_channels: int, |
| 178 | + variant: str = "mirnet_small", |
| 179 | +) -> nn.Module: |
| 180 | + name = str(variant).lower().strip() |
| 181 | + if name not in _VARIANTS: |
| 182 | + raise ValueError(f"Unknown MIRNet variant: {variant!r}. Supported: {sorted(_VARIANTS)}") |
| 183 | + spec = _VARIANTS[name] |
| 184 | + return MIRNet(in_channels=int(in_channels), width=int(spec["width"]), depth=int(spec["depth"])) |
| 185 | + |
| 186 | + |
| 187 | +if __name__ == "__main__": |
| 188 | + torch.manual_seed(0) |
| 189 | + x = torch.rand(2, 1, 64, 64) |
| 190 | + noisy = (x + torch.randn_like(x) * 0.1).clamp(0.0, 1.0) |
| 191 | + m = build_mirnet_denoiser(in_channels=1, variant="mirnet_tiny") |
| 192 | + y = m(noisy) |
| 193 | + print("mirnet_tiny", tuple(y.shape)) |
| 194 | + loss = (y - x).pow(2).mean() |
| 195 | + loss.backward() |
| 196 | + print("ok") |
| 197 | + |
0 commit comments