Skip to content

Commit 7448e45

Browse files
author
jhlu
committed
feat(vision): add eca/cbam resnet variants
1 parent 2549768 commit 7448e45

3 files changed

Lines changed: 330 additions & 1 deletion

File tree

dlhub/vision/backbones/cnn.py

Lines changed: 266 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
163163
return x * s
164164

165165

166+
class ECALayer(nn.Module):
167+
"""Efficient Channel Attention (ECA).
168+
169+
Minimal, CPU-friendly implementation:
170+
- Global average pool -> 1D conv over channels -> sigmoid gate.
171+
"""
172+
173+
def __init__(self, channels: int, *, kernel_size: int = 3) -> None:
174+
super().__init__()
175+
c = int(channels)
176+
k = int(kernel_size)
177+
if c <= 0:
178+
raise ValueError("channels must be > 0")
179+
if k <= 0 or k % 2 == 0:
180+
raise ValueError("kernel_size must be a positive odd integer")
181+
self.pool = nn.AdaptiveAvgPool2d((1, 1))
182+
self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=k // 2, bias=False)
183+
self.gate = nn.Sigmoid()
184+
185+
def forward(self, x: torch.Tensor) -> torch.Tensor:
186+
# x: (B, C, H, W)
187+
y = self.pool(x) # (B, C, 1, 1)
188+
y = y.squeeze(-1).transpose(1, 2) # (B, 1, C)
189+
y = self.conv(y)
190+
y = self.gate(y).transpose(1, 2).unsqueeze(-1) # (B, C, 1, 1)
191+
return x * y
192+
193+
194+
class _CBAMChannelAttention(nn.Module):
195+
def __init__(self, channels: int, *, reduction: int = 16) -> None:
196+
super().__init__()
197+
c = int(channels)
198+
r = max(1, int(reduction))
199+
hidden = max(8, c // r)
200+
201+
self.avg = nn.AdaptiveAvgPool2d((1, 1))
202+
self.max = nn.AdaptiveMaxPool2d((1, 1))
203+
self.mlp = nn.Sequential(
204+
nn.Conv2d(c, hidden, kernel_size=1, bias=True),
205+
nn.ReLU(inplace=True),
206+
nn.Conv2d(hidden, c, kernel_size=1, bias=True),
207+
)
208+
self.gate = nn.Sigmoid()
209+
210+
def forward(self, x: torch.Tensor) -> torch.Tensor:
211+
a = self.mlp(self.avg(x))
212+
m = self.mlp(self.max(x))
213+
return x * self.gate(a + m)
214+
215+
216+
class _CBAMSpatialAttention(nn.Module):
217+
def __init__(self, *, kernel_size: int = 7) -> None:
218+
super().__init__()
219+
k = int(kernel_size)
220+
if k <= 0 or k % 2 == 0:
221+
raise ValueError("kernel_size must be a positive odd integer")
222+
self.conv = nn.Conv2d(2, 1, kernel_size=k, padding=k // 2, bias=False)
223+
self.gate = nn.Sigmoid()
224+
225+
def forward(self, x: torch.Tensor) -> torch.Tensor:
226+
avg = x.mean(dim=1, keepdim=True)
227+
mx = x.amax(dim=1, keepdim=True)
228+
attn = self.conv(torch.cat([avg, mx], dim=1))
229+
return x * self.gate(attn)
230+
231+
232+
class CBAM(nn.Module):
233+
"""Convolutional Block Attention Module (CBAM), simplified."""
234+
235+
def __init__(self, channels: int, *, reduction: int = 16, spatial_kernel: int = 7) -> None:
236+
super().__init__()
237+
self.ca = _CBAMChannelAttention(int(channels), reduction=int(reduction))
238+
self.sa = _CBAMSpatialAttention(kernel_size=int(spatial_kernel))
239+
240+
def forward(self, x: torch.Tensor) -> torch.Tensor:
241+
x = self.ca(x)
242+
x = self.sa(x)
243+
return x
244+
245+
166246
class BasicBlock(nn.Module):
167247
expansion = 1
168248

@@ -196,6 +276,80 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
196276
return self.relu(out)
197277

198278

279+
class ECABasicBlock(nn.Module):
280+
expansion = 1
281+
282+
def __init__(
283+
self, in_ch: int, out_ch: int, stride: int, *, groups: int = 1, width_per_group: int = 64, eca_kernel: int = 3
284+
) -> None:
285+
super().__init__()
286+
_ = int(groups)
287+
_ = int(width_per_group)
288+
self.conv1 = _conv3x3(in_ch, out_ch, stride=stride, groups=1)
289+
self.bn1 = nn.BatchNorm2d(int(out_ch))
290+
self.relu = nn.ReLU(inplace=True)
291+
self.conv2 = _conv3x3(out_ch, out_ch, stride=1, groups=1)
292+
self.bn2 = nn.BatchNorm2d(int(out_ch))
293+
self.eca = ECALayer(int(out_ch), kernel_size=int(eca_kernel))
294+
295+
self.downsample: nn.Module | None = None
296+
if int(stride) != 1 or int(in_ch) != int(out_ch):
297+
self.downsample = nn.Sequential(
298+
_conv1x1(int(in_ch), int(out_ch), stride=int(stride)),
299+
nn.BatchNorm2d(int(out_ch)),
300+
)
301+
302+
def forward(self, x: torch.Tensor) -> torch.Tensor:
303+
identity = x
304+
305+
out = self.relu(self.bn1(self.conv1(x)))
306+
out = self.bn2(self.conv2(out))
307+
out = self.eca(out)
308+
309+
if self.downsample is not None:
310+
identity = self.downsample(identity)
311+
312+
out = out + identity
313+
return self.relu(out)
314+
315+
316+
class CBAMBasicBlock(nn.Module):
317+
expansion = 1
318+
319+
def __init__(
320+
self, in_ch: int, out_ch: int, stride: int, *, groups: int = 1, width_per_group: int = 64, reduction: int = 16
321+
) -> None:
322+
super().__init__()
323+
_ = int(groups)
324+
_ = int(width_per_group)
325+
self.conv1 = _conv3x3(in_ch, out_ch, stride=stride, groups=1)
326+
self.bn1 = nn.BatchNorm2d(int(out_ch))
327+
self.relu = nn.ReLU(inplace=True)
328+
self.conv2 = _conv3x3(out_ch, out_ch, stride=1, groups=1)
329+
self.bn2 = nn.BatchNorm2d(int(out_ch))
330+
self.cbam = CBAM(int(out_ch), reduction=int(reduction))
331+
332+
self.downsample: nn.Module | None = None
333+
if int(stride) != 1 or int(in_ch) != int(out_ch):
334+
self.downsample = nn.Sequential(
335+
_conv1x1(int(in_ch), int(out_ch), stride=int(stride)),
336+
nn.BatchNorm2d(int(out_ch)),
337+
)
338+
339+
def forward(self, x: torch.Tensor) -> torch.Tensor:
340+
identity = x
341+
342+
out = self.relu(self.bn1(self.conv1(x)))
343+
out = self.bn2(self.conv2(out))
344+
out = self.cbam(out)
345+
346+
if self.downsample is not None:
347+
identity = self.downsample(identity)
348+
349+
out = out + identity
350+
return self.relu(out)
351+
352+
199353
class SEBasicBlock(nn.Module):
200354
expansion = 1
201355

@@ -278,6 +432,109 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
278432
return self.relu(out)
279433

280434

435+
class ECABottleneck(nn.Module):
436+
expansion = 4
437+
438+
def __init__(
439+
self, in_ch: int, out_ch: int, stride: int, *, groups: int = 1, width_per_group: int = 64, eca_kernel: int = 3
440+
) -> None:
441+
super().__init__()
442+
g = int(groups)
443+
wpg = int(width_per_group)
444+
if g <= 0:
445+
raise ValueError("groups must be >= 1")
446+
if wpg <= 0:
447+
raise ValueError("width_per_group must be >= 1")
448+
449+
width = int(out_ch) * wpg // 64 * g
450+
width = max(g, width)
451+
452+
self.conv1 = _conv1x1(in_ch, width, stride=1)
453+
self.bn1 = nn.BatchNorm2d(width)
454+
self.conv2 = _conv3x3(width, width, stride=stride, groups=g)
455+
self.bn2 = nn.BatchNorm2d(width)
456+
self.conv3 = _conv1x1(width, int(out_ch) * self.expansion, stride=1)
457+
self.bn3 = nn.BatchNorm2d(int(out_ch) * self.expansion)
458+
self.eca = ECALayer(int(out_ch) * self.expansion, kernel_size=int(eca_kernel))
459+
self.relu = nn.ReLU(inplace=True)
460+
461+
self.downsample: nn.Module | None = None
462+
if int(stride) != 1 or int(in_ch) != int(out_ch) * self.expansion:
463+
self.downsample = nn.Sequential(
464+
_conv1x1(int(in_ch), int(out_ch) * self.expansion, stride=int(stride)),
465+
nn.BatchNorm2d(int(out_ch) * self.expansion),
466+
)
467+
468+
def forward(self, x: torch.Tensor) -> torch.Tensor:
469+
identity = x
470+
471+
out = self.relu(self.bn1(self.conv1(x)))
472+
out = self.relu(self.bn2(self.conv2(out)))
473+
out = self.bn3(self.conv3(out))
474+
out = self.eca(out)
475+
476+
if self.downsample is not None:
477+
identity = self.downsample(identity)
478+
479+
out = out + identity
480+
return self.relu(out)
481+
482+
483+
class CBAMBottleneck(nn.Module):
484+
expansion = 4
485+
486+
def __init__(
487+
self,
488+
in_ch: int,
489+
out_ch: int,
490+
stride: int,
491+
*,
492+
groups: int = 1,
493+
width_per_group: int = 64,
494+
reduction: int = 16,
495+
) -> None:
496+
super().__init__()
497+
g = int(groups)
498+
wpg = int(width_per_group)
499+
if g <= 0:
500+
raise ValueError("groups must be >= 1")
501+
if wpg <= 0:
502+
raise ValueError("width_per_group must be >= 1")
503+
504+
width = int(out_ch) * wpg // 64 * g
505+
width = max(g, width)
506+
507+
self.conv1 = _conv1x1(in_ch, width, stride=1)
508+
self.bn1 = nn.BatchNorm2d(width)
509+
self.conv2 = _conv3x3(width, width, stride=stride, groups=g)
510+
self.bn2 = nn.BatchNorm2d(width)
511+
self.conv3 = _conv1x1(width, int(out_ch) * self.expansion, stride=1)
512+
self.bn3 = nn.BatchNorm2d(int(out_ch) * self.expansion)
513+
self.cbam = CBAM(int(out_ch) * self.expansion, reduction=int(reduction))
514+
self.relu = nn.ReLU(inplace=True)
515+
516+
self.downsample: nn.Module | None = None
517+
if int(stride) != 1 or int(in_ch) != int(out_ch) * self.expansion:
518+
self.downsample = nn.Sequential(
519+
_conv1x1(int(in_ch), int(out_ch) * self.expansion, stride=int(stride)),
520+
nn.BatchNorm2d(int(out_ch) * self.expansion),
521+
)
522+
523+
def forward(self, x: torch.Tensor) -> torch.Tensor:
524+
identity = x
525+
526+
out = self.relu(self.bn1(self.conv1(x)))
527+
out = self.relu(self.bn2(self.conv2(out)))
528+
out = self.bn3(self.conv3(out))
529+
out = self.cbam(out)
530+
531+
if self.downsample is not None:
532+
identity = self.downsample(identity)
533+
534+
out = out + identity
535+
return self.relu(out)
536+
537+
281538
class SEBottleneck(nn.Module):
282539
expansion = 4
283540

@@ -486,6 +743,14 @@ def build_resnet_classifier(
486743
block: type[nn.Module] = BasicBlock
487744
elif name == "bottleneck":
488745
block = Bottleneck
746+
elif name == "eca_basic":
747+
block = ECABasicBlock
748+
elif name == "eca_bottleneck":
749+
block = ECABottleneck
750+
elif name == "cbam_basic":
751+
block = CBAMBasicBlock
752+
elif name == "cbam_bottleneck":
753+
block = CBAMBottleneck
489754
elif name == "se_basic":
490755
block = SEBasicBlock
491756
elif name == "se_bottleneck":
@@ -496,7 +761,7 @@ def build_resnet_classifier(
496761
block = PreActBottleneck
497762
else:
498763
raise ValueError(
499-
"Unknown ResNet variant. Supported: basic, bottleneck, se_basic, se_bottleneck, preact_basic, preact_bottleneck"
764+
"Unknown ResNet variant. Supported: basic, bottleneck, eca_basic, eca_bottleneck, cbam_basic, cbam_bottleneck, se_basic, se_bottleneck, preact_basic, preact_bottleneck"
500765
)
501766

502767
return ResNetClassifier(

dlhub/vision/local_zoo.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,44 @@ def _registry() -> dict[str, Builder]:
157157
width_per_group=64,
158158
)
159159

160+
# ECA/CBAM ResNet
161+
r["eca_resnet18"] = lambda cfg: build_resnet_classifier(
162+
in_channels=cfg.in_channels,
163+
num_classes=cfg.num_classes,
164+
layers=(2, 2, 2, 2),
165+
variant="eca_basic",
166+
width_mult=cfg.width_mult,
167+
dropout=cfg.dropout,
168+
)
169+
r["cbam_resnet18"] = lambda cfg: build_resnet_classifier(
170+
in_channels=cfg.in_channels,
171+
num_classes=cfg.num_classes,
172+
layers=(2, 2, 2, 2),
173+
variant="cbam_basic",
174+
width_mult=cfg.width_mult,
175+
dropout=cfg.dropout,
176+
)
177+
r["eca_resnet50"] = lambda cfg: build_resnet_classifier(
178+
in_channels=cfg.in_channels,
179+
num_classes=cfg.num_classes,
180+
layers=(3, 4, 6, 3),
181+
variant="eca_bottleneck",
182+
width_mult=cfg.width_mult,
183+
dropout=cfg.dropout,
184+
groups=1,
185+
width_per_group=64,
186+
)
187+
r["cbam_resnet50"] = lambda cfg: build_resnet_classifier(
188+
in_channels=cfg.in_channels,
189+
num_classes=cfg.num_classes,
190+
layers=(3, 4, 6, 3),
191+
variant="cbam_bottleneck",
192+
width_mult=cfg.width_mult,
193+
dropout=cfg.dropout,
194+
groups=1,
195+
width_per_group=64,
196+
)
197+
160198
# SE-ResNet
161199
for name, layers, variant in [
162200
("se_resnet50", (3, 4, 6, 3), "se_bottleneck"),
@@ -201,6 +239,26 @@ def _registry() -> dict[str, Builder]:
201239
groups=32,
202240
width_per_group=4,
203241
)
242+
r["eca_resnext50_32x4d"] = lambda cfg: build_resnet_classifier(
243+
in_channels=cfg.in_channels,
244+
num_classes=cfg.num_classes,
245+
layers=(3, 4, 6, 3),
246+
variant="eca_bottleneck",
247+
width_mult=cfg.width_mult,
248+
dropout=cfg.dropout,
249+
groups=32,
250+
width_per_group=4,
251+
)
252+
r["cbam_resnext50_32x4d"] = lambda cfg: build_resnet_classifier(
253+
in_channels=cfg.in_channels,
254+
num_classes=cfg.num_classes,
255+
layers=(3, 4, 6, 3),
256+
variant="cbam_bottleneck",
257+
width_mult=cfg.width_mult,
258+
dropout=cfg.dropout,
259+
groups=32,
260+
width_per_group=4,
261+
)
204262
r["resnext101_32x8d"] = lambda cfg: build_resnet_classifier(
205263
in_channels=cfg.in_channels,
206264
num_classes=cfg.num_classes,
@@ -563,6 +621,8 @@ def _registry() -> dict[str, Builder]:
563621
r["poolformer"] = r["poolformer_tiny"]
564622
r["gmlp"] = r["gmlp_tiny"]
565623
r["resmlp"] = r["resmlp_tiny"]
624+
r["eca_resnet"] = r["eca_resnet18"]
625+
r["cbam_resnet"] = r["cbam_resnet18"]
566626

567627
return r
568628

0 commit comments

Comments
 (0)