Skip to content

Commit dd82dbc

Browse files
author
jhlu
committed
feat(vision): add mirnet/mprnet/uformer denoisers
1 parent ce6859b commit dd82dbc

7 files changed

Lines changed: 596 additions & 1 deletion

File tree

dlhub/vision/denoising/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
from .dncnn import DnCNN, DnCNNDenoiser, build_dncnn_denoiser
1111
from .drunet import DRUNet, build_drunet_denoiser
1212
from .ffdnet import FFDNetDenoiser, build_ffdnet_denoiser
13+
from .mirnet import MIRNet, build_mirnet_denoiser
14+
from .mprnet import MPRNet, build_mprnet_denoiser
1315
from .nafnet import NAFNet, build_nafnet_denoiser
1416
from .noise2noise import Noise2NoiseUNet, build_noise2noise_denoiser
1517
from .restormer import Restormer, build_restormer_denoiser
1618
from .ridnet import RIDNetDenoiser, build_ridnet_denoiser
1719
from .swinir import SwinIR, build_swinir_denoiser
20+
from .uformer import UFormer, build_uformer_denoiser
1821

1922
__all__ = [
2023
"BM3D",
@@ -24,19 +27,25 @@
2427
"DnCNNDenoiser",
2528
"DRUNet",
2629
"FFDNetDenoiser",
30+
"MIRNet",
31+
"MPRNet",
2732
"NAFNet",
2833
"Noise2NoiseUNet",
2934
"Restormer",
3035
"RIDNetDenoiser",
3136
"SwinIR",
37+
"UFormer",
3238
"build_bm3d_denoiser",
3339
"build_ddpm_unet_denoiser",
3440
"build_dncnn_denoiser",
3541
"build_drunet_denoiser",
3642
"build_ffdnet_denoiser",
43+
"build_mirnet_denoiser",
44+
"build_mprnet_denoiser",
3745
"build_nafnet_denoiser",
3846
"build_noise2noise_denoiser",
3947
"build_restormer_denoiser",
4048
"build_ridnet_denoiser",
4149
"build_swinir_denoiser",
50+
"build_uformer_denoiser",
4251
]

dlhub/vision/denoising/mirnet.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
from torch import nn
5+
import torch.nn.functional as F
6+
7+
from ._utils import pad_to_multiple, unpad
8+
9+
10+
def _act() -> nn.Module:
11+
return nn.ReLU(inplace=True)
12+
13+
14+
class ConvBlock(nn.Module):
15+
def __init__(self, in_ch: int, out_ch: int, *, kernel_size: int = 3, act: bool = True) -> None:
16+
super().__init__()
17+
k = int(kernel_size)
18+
p = k // 2
19+
self.conv = nn.Conv2d(int(in_ch), int(out_ch), kernel_size=k, padding=p, bias=True)
20+
self.act = _act() if bool(act) else nn.Identity()
21+
22+
def forward(self, x: torch.Tensor) -> torch.Tensor:
23+
return self.act(self.conv(x))
24+
25+
26+
class ChannelAttention(nn.Module):
27+
def __init__(self, channels: int, *, reduction: int = 8) -> None:
28+
super().__init__()
29+
c = int(channels)
30+
r = int(reduction)
31+
hidden = max(8, c // max(1, r))
32+
self.pool = nn.AdaptiveAvgPool2d(1)
33+
self.mlp = nn.Sequential(
34+
nn.Conv2d(c, hidden, kernel_size=1, bias=True),
35+
_act(),
36+
nn.Conv2d(hidden, c, kernel_size=1, bias=True),
37+
nn.Sigmoid(),
38+
)
39+
40+
def forward(self, x: torch.Tensor) -> torch.Tensor:
41+
w = self.mlp(self.pool(x))
42+
return x * w
43+
44+
45+
class DualAttentionUnit(nn.Module):
46+
"""A small residual block with channel attention (toy DU/DAU)."""
47+
48+
def __init__(self, channels: int) -> None:
49+
super().__init__()
50+
c = int(channels)
51+
self.conv1 = ConvBlock(c, c, kernel_size=3, act=True)
52+
self.conv2 = ConvBlock(c, c, kernel_size=3, act=False)
53+
self.ca = ChannelAttention(c, reduction=8)
54+
55+
def forward(self, x: torch.Tensor) -> torch.Tensor:
56+
y = self.conv2(self.conv1(x))
57+
y = self.ca(y)
58+
return x + y
59+
60+
61+
class SKFF(nn.Module):
62+
"""Selective kernel feature fusion (toy, channel-wise softmax)."""
63+
64+
def __init__(self, channels: int, num_branches: int, *, reduction: int = 8) -> None:
65+
super().__init__()
66+
c = int(channels)
67+
b = int(num_branches)
68+
if b <= 1:
69+
raise ValueError("num_branches must be > 1")
70+
hidden = max(8, c // max(1, int(reduction)))
71+
self.pool = nn.AdaptiveAvgPool2d(1)
72+
self.fc = nn.Sequential(nn.Conv2d(c, hidden, 1, bias=True), _act())
73+
self.fcs = nn.ModuleList([nn.Conv2d(hidden, c, 1, bias=True) for _ in range(b)])
74+
75+
def forward(self, feats: list[torch.Tensor]) -> torch.Tensor:
76+
if len(feats) != len(self.fcs):
77+
raise ValueError(f"Expected {len(self.fcs)} feature maps, got {len(feats)}")
78+
base = feats[0]
79+
for f in feats[1:]:
80+
if f.shape != base.shape:
81+
raise ValueError("All SKFF branches must have the same shape")
82+
83+
s = torch.zeros_like(base)
84+
for f in feats:
85+
s = s + f
86+
87+
z = self.fc(self.pool(s)) # (B, hidden, 1, 1)
88+
logits = torch.stack([fc(z) for fc in self.fcs], dim=1) # (B, Bn, C, 1, 1)
89+
w = torch.softmax(logits, dim=1)
90+
out = torch.zeros_like(base)
91+
for i, f in enumerate(feats):
92+
out = out + f * w[:, i]
93+
return out
94+
95+
96+
class MultiScaleResidualBlock(nn.Module):
97+
"""Toy MIRNet multi-scale residual block: 3-scale processing + SKFF + residual."""
98+
99+
def __init__(self, channels: int) -> None:
100+
super().__init__()
101+
c = int(channels)
102+
self.dau1 = DualAttentionUnit(c)
103+
self.dau2 = DualAttentionUnit(c)
104+
self.dau3 = DualAttentionUnit(c)
105+
self.fuse = SKFF(c, 3, reduction=8)
106+
self.out = ConvBlock(c, c, kernel_size=3, act=False)
107+
108+
def forward(self, x: torch.Tensor) -> torch.Tensor:
109+
b, c, h, w = x.shape
110+
111+
x1 = self.dau1(x)
112+
113+
x2 = F.avg_pool2d(x, kernel_size=2, stride=2, ceil_mode=False)
114+
x2 = self.dau2(x2)
115+
x2 = F.interpolate(x2, size=(h, w), mode="nearest")
116+
117+
x3 = F.avg_pool2d(x, kernel_size=4, stride=4, ceil_mode=False)
118+
x3 = self.dau3(x3)
119+
x3 = F.interpolate(x3, size=(h, w), mode="nearest")
120+
121+
y = self.fuse([x1, x2, x3])
122+
y = self.out(y)
123+
return x + y
124+
125+
126+
class MIRNet(nn.Module):
127+
"""MIRNet-style denoiser (toy-first, pure torch).
128+
129+
Notes:
130+
- This is a simplified MIRNet-inspired architecture with multi-scale residual blocks and SKFF-like fusion.
131+
- It performs residual learning: output = input + predicted_residual.
132+
"""
133+
134+
def __init__(
135+
self,
136+
*,
137+
in_channels: int,
138+
width: int = 32,
139+
depth: int = 5,
140+
) -> None:
141+
super().__init__()
142+
c_in = int(in_channels)
143+
w0 = int(width)
144+
d = int(depth)
145+
if c_in <= 0:
146+
raise ValueError("in_channels must be > 0")
147+
if w0 < 8:
148+
raise ValueError("width must be >= 8")
149+
if d <= 0:
150+
raise ValueError("depth must be > 0")
151+
152+
self.intro = nn.Conv2d(c_in, w0, kernel_size=3, padding=1, bias=True)
153+
self.body = nn.Sequential(*[MultiScaleResidualBlock(w0) for _ in range(d)])
154+
self.outro = nn.Conv2d(w0, c_in, kernel_size=3, padding=1, bias=True)
155+
156+
def forward(self, x: torch.Tensor) -> torch.Tensor:
157+
x = x.to(torch.float32)
158+
if x.ndim != 4:
159+
raise ValueError(f"Expected input shape (B, C, H, W), got {tuple(x.shape)}")
160+
161+
x_pad, pad_hw = pad_to_multiple(x, 4, mode="reflect")
162+
inp = x_pad
163+
h = self.body(self.intro(x_pad))
164+
y = inp + self.outro(h)
165+
return unpad(y, pad_hw)
166+
167+
168+
_VARIANTS: dict[str, dict] = {
169+
"mirnet_tiny": {"width": 24, "depth": 3},
170+
"mirnet_small": {"width": 32, "depth": 5},
171+
"mirnet_base": {"width": 48, "depth": 7},
172+
}
173+
174+
175+
def build_mirnet_denoiser(
176+
*,
177+
in_channels: int,
178+
variant: str = "mirnet_small",
179+
) -> nn.Module:
180+
name = str(variant).lower().strip()
181+
if name not in _VARIANTS:
182+
raise ValueError(f"Unknown MIRNet variant: {variant!r}. Supported: {sorted(_VARIANTS)}")
183+
spec = _VARIANTS[name]
184+
return MIRNet(in_channels=int(in_channels), width=int(spec["width"]), depth=int(spec["depth"]))
185+
186+
187+
if __name__ == "__main__":
188+
torch.manual_seed(0)
189+
x = torch.rand(2, 1, 64, 64)
190+
noisy = (x + torch.randn_like(x) * 0.1).clamp(0.0, 1.0)
191+
m = build_mirnet_denoiser(in_channels=1, variant="mirnet_tiny")
192+
y = m(noisy)
193+
print("mirnet_tiny", tuple(y.shape))
194+
loss = (y - x).pow(2).mean()
195+
loss.backward()
196+
print("ok")
197+

0 commit comments

Comments
 (0)