@@ -163,6 +163,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
163163 return x * s
164164
165165
166+ class ECALayer (nn .Module ):
167+ """Efficient Channel Attention (ECA).
168+
169+ Minimal, CPU-friendly implementation:
170+ - Global average pool -> 1D conv over channels -> sigmoid gate.
171+ """
172+
173+ def __init__ (self , channels : int , * , kernel_size : int = 3 ) -> None :
174+ super ().__init__ ()
175+ c = int (channels )
176+ k = int (kernel_size )
177+ if c <= 0 :
178+ raise ValueError ("channels must be > 0" )
179+ if k <= 0 or k % 2 == 0 :
180+ raise ValueError ("kernel_size must be a positive odd integer" )
181+ self .pool = nn .AdaptiveAvgPool2d ((1 , 1 ))
182+ self .conv = nn .Conv1d (1 , 1 , kernel_size = k , padding = k // 2 , bias = False )
183+ self .gate = nn .Sigmoid ()
184+
185+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
186+ # x: (B, C, H, W)
187+ y = self .pool (x ) # (B, C, 1, 1)
188+ y = y .squeeze (- 1 ).transpose (1 , 2 ) # (B, 1, C)
189+ y = self .conv (y )
190+ y = self .gate (y ).transpose (1 , 2 ).unsqueeze (- 1 ) # (B, C, 1, 1)
191+ return x * y
192+
193+
194+ class _CBAMChannelAttention (nn .Module ):
195+ def __init__ (self , channels : int , * , reduction : int = 16 ) -> None :
196+ super ().__init__ ()
197+ c = int (channels )
198+ r = max (1 , int (reduction ))
199+ hidden = max (8 , c // r )
200+
201+ self .avg = nn .AdaptiveAvgPool2d ((1 , 1 ))
202+ self .max = nn .AdaptiveMaxPool2d ((1 , 1 ))
203+ self .mlp = nn .Sequential (
204+ nn .Conv2d (c , hidden , kernel_size = 1 , bias = True ),
205+ nn .ReLU (inplace = True ),
206+ nn .Conv2d (hidden , c , kernel_size = 1 , bias = True ),
207+ )
208+ self .gate = nn .Sigmoid ()
209+
210+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
211+ a = self .mlp (self .avg (x ))
212+ m = self .mlp (self .max (x ))
213+ return x * self .gate (a + m )
214+
215+
216+ class _CBAMSpatialAttention (nn .Module ):
217+ def __init__ (self , * , kernel_size : int = 7 ) -> None :
218+ super ().__init__ ()
219+ k = int (kernel_size )
220+ if k <= 0 or k % 2 == 0 :
221+ raise ValueError ("kernel_size must be a positive odd integer" )
222+ self .conv = nn .Conv2d (2 , 1 , kernel_size = k , padding = k // 2 , bias = False )
223+ self .gate = nn .Sigmoid ()
224+
225+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
226+ avg = x .mean (dim = 1 , keepdim = True )
227+ mx = x .amax (dim = 1 , keepdim = True )
228+ attn = self .conv (torch .cat ([avg , mx ], dim = 1 ))
229+ return x * self .gate (attn )
230+
231+
232+ class CBAM (nn .Module ):
233+ """Convolutional Block Attention Module (CBAM), simplified."""
234+
235+ def __init__ (self , channels : int , * , reduction : int = 16 , spatial_kernel : int = 7 ) -> None :
236+ super ().__init__ ()
237+ self .ca = _CBAMChannelAttention (int (channels ), reduction = int (reduction ))
238+ self .sa = _CBAMSpatialAttention (kernel_size = int (spatial_kernel ))
239+
240+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
241+ x = self .ca (x )
242+ x = self .sa (x )
243+ return x
244+
245+
166246class BasicBlock (nn .Module ):
167247 expansion = 1
168248
@@ -196,6 +276,80 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
196276 return self .relu (out )
197277
198278
279+ class ECABasicBlock (nn .Module ):
280+ expansion = 1
281+
282+ def __init__ (
283+ self , in_ch : int , out_ch : int , stride : int , * , groups : int = 1 , width_per_group : int = 64 , eca_kernel : int = 3
284+ ) -> None :
285+ super ().__init__ ()
286+ _ = int (groups )
287+ _ = int (width_per_group )
288+ self .conv1 = _conv3x3 (in_ch , out_ch , stride = stride , groups = 1 )
289+ self .bn1 = nn .BatchNorm2d (int (out_ch ))
290+ self .relu = nn .ReLU (inplace = True )
291+ self .conv2 = _conv3x3 (out_ch , out_ch , stride = 1 , groups = 1 )
292+ self .bn2 = nn .BatchNorm2d (int (out_ch ))
293+ self .eca = ECALayer (int (out_ch ), kernel_size = int (eca_kernel ))
294+
295+ self .downsample : nn .Module | None = None
296+ if int (stride ) != 1 or int (in_ch ) != int (out_ch ):
297+ self .downsample = nn .Sequential (
298+ _conv1x1 (int (in_ch ), int (out_ch ), stride = int (stride )),
299+ nn .BatchNorm2d (int (out_ch )),
300+ )
301+
302+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
303+ identity = x
304+
305+ out = self .relu (self .bn1 (self .conv1 (x )))
306+ out = self .bn2 (self .conv2 (out ))
307+ out = self .eca (out )
308+
309+ if self .downsample is not None :
310+ identity = self .downsample (identity )
311+
312+ out = out + identity
313+ return self .relu (out )
314+
315+
316+ class CBAMBasicBlock (nn .Module ):
317+ expansion = 1
318+
319+ def __init__ (
320+ self , in_ch : int , out_ch : int , stride : int , * , groups : int = 1 , width_per_group : int = 64 , reduction : int = 16
321+ ) -> None :
322+ super ().__init__ ()
323+ _ = int (groups )
324+ _ = int (width_per_group )
325+ self .conv1 = _conv3x3 (in_ch , out_ch , stride = stride , groups = 1 )
326+ self .bn1 = nn .BatchNorm2d (int (out_ch ))
327+ self .relu = nn .ReLU (inplace = True )
328+ self .conv2 = _conv3x3 (out_ch , out_ch , stride = 1 , groups = 1 )
329+ self .bn2 = nn .BatchNorm2d (int (out_ch ))
330+ self .cbam = CBAM (int (out_ch ), reduction = int (reduction ))
331+
332+ self .downsample : nn .Module | None = None
333+ if int (stride ) != 1 or int (in_ch ) != int (out_ch ):
334+ self .downsample = nn .Sequential (
335+ _conv1x1 (int (in_ch ), int (out_ch ), stride = int (stride )),
336+ nn .BatchNorm2d (int (out_ch )),
337+ )
338+
339+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
340+ identity = x
341+
342+ out = self .relu (self .bn1 (self .conv1 (x )))
343+ out = self .bn2 (self .conv2 (out ))
344+ out = self .cbam (out )
345+
346+ if self .downsample is not None :
347+ identity = self .downsample (identity )
348+
349+ out = out + identity
350+ return self .relu (out )
351+
352+
199353class SEBasicBlock (nn .Module ):
200354 expansion = 1
201355
@@ -278,6 +432,109 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
278432 return self .relu (out )
279433
280434
435+ class ECABottleneck (nn .Module ):
436+ expansion = 4
437+
438+ def __init__ (
439+ self , in_ch : int , out_ch : int , stride : int , * , groups : int = 1 , width_per_group : int = 64 , eca_kernel : int = 3
440+ ) -> None :
441+ super ().__init__ ()
442+ g = int (groups )
443+ wpg = int (width_per_group )
444+ if g <= 0 :
445+ raise ValueError ("groups must be >= 1" )
446+ if wpg <= 0 :
447+ raise ValueError ("width_per_group must be >= 1" )
448+
449+ width = int (out_ch ) * wpg // 64 * g
450+ width = max (g , width )
451+
452+ self .conv1 = _conv1x1 (in_ch , width , stride = 1 )
453+ self .bn1 = nn .BatchNorm2d (width )
454+ self .conv2 = _conv3x3 (width , width , stride = stride , groups = g )
455+ self .bn2 = nn .BatchNorm2d (width )
456+ self .conv3 = _conv1x1 (width , int (out_ch ) * self .expansion , stride = 1 )
457+ self .bn3 = nn .BatchNorm2d (int (out_ch ) * self .expansion )
458+ self .eca = ECALayer (int (out_ch ) * self .expansion , kernel_size = int (eca_kernel ))
459+ self .relu = nn .ReLU (inplace = True )
460+
461+ self .downsample : nn .Module | None = None
462+ if int (stride ) != 1 or int (in_ch ) != int (out_ch ) * self .expansion :
463+ self .downsample = nn .Sequential (
464+ _conv1x1 (int (in_ch ), int (out_ch ) * self .expansion , stride = int (stride )),
465+ nn .BatchNorm2d (int (out_ch ) * self .expansion ),
466+ )
467+
468+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
469+ identity = x
470+
471+ out = self .relu (self .bn1 (self .conv1 (x )))
472+ out = self .relu (self .bn2 (self .conv2 (out )))
473+ out = self .bn3 (self .conv3 (out ))
474+ out = self .eca (out )
475+
476+ if self .downsample is not None :
477+ identity = self .downsample (identity )
478+
479+ out = out + identity
480+ return self .relu (out )
481+
482+
483+ class CBAMBottleneck (nn .Module ):
484+ expansion = 4
485+
486+ def __init__ (
487+ self ,
488+ in_ch : int ,
489+ out_ch : int ,
490+ stride : int ,
491+ * ,
492+ groups : int = 1 ,
493+ width_per_group : int = 64 ,
494+ reduction : int = 16 ,
495+ ) -> None :
496+ super ().__init__ ()
497+ g = int (groups )
498+ wpg = int (width_per_group )
499+ if g <= 0 :
500+ raise ValueError ("groups must be >= 1" )
501+ if wpg <= 0 :
502+ raise ValueError ("width_per_group must be >= 1" )
503+
504+ width = int (out_ch ) * wpg // 64 * g
505+ width = max (g , width )
506+
507+ self .conv1 = _conv1x1 (in_ch , width , stride = 1 )
508+ self .bn1 = nn .BatchNorm2d (width )
509+ self .conv2 = _conv3x3 (width , width , stride = stride , groups = g )
510+ self .bn2 = nn .BatchNorm2d (width )
511+ self .conv3 = _conv1x1 (width , int (out_ch ) * self .expansion , stride = 1 )
512+ self .bn3 = nn .BatchNorm2d (int (out_ch ) * self .expansion )
513+ self .cbam = CBAM (int (out_ch ) * self .expansion , reduction = int (reduction ))
514+ self .relu = nn .ReLU (inplace = True )
515+
516+ self .downsample : nn .Module | None = None
517+ if int (stride ) != 1 or int (in_ch ) != int (out_ch ) * self .expansion :
518+ self .downsample = nn .Sequential (
519+ _conv1x1 (int (in_ch ), int (out_ch ) * self .expansion , stride = int (stride )),
520+ nn .BatchNorm2d (int (out_ch ) * self .expansion ),
521+ )
522+
523+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
524+ identity = x
525+
526+ out = self .relu (self .bn1 (self .conv1 (x )))
527+ out = self .relu (self .bn2 (self .conv2 (out )))
528+ out = self .bn3 (self .conv3 (out ))
529+ out = self .cbam (out )
530+
531+ if self .downsample is not None :
532+ identity = self .downsample (identity )
533+
534+ out = out + identity
535+ return self .relu (out )
536+
537+
281538class SEBottleneck (nn .Module ):
282539 expansion = 4
283540
@@ -486,6 +743,14 @@ def build_resnet_classifier(
486743 block : type [nn .Module ] = BasicBlock
487744 elif name == "bottleneck" :
488745 block = Bottleneck
746+ elif name == "eca_basic" :
747+ block = ECABasicBlock
748+ elif name == "eca_bottleneck" :
749+ block = ECABottleneck
750+ elif name == "cbam_basic" :
751+ block = CBAMBasicBlock
752+ elif name == "cbam_bottleneck" :
753+ block = CBAMBottleneck
489754 elif name == "se_basic" :
490755 block = SEBasicBlock
491756 elif name == "se_bottleneck" :
@@ -496,7 +761,7 @@ def build_resnet_classifier(
496761 block = PreActBottleneck
497762 else :
498763 raise ValueError (
499- "Unknown ResNet variant. Supported: basic, bottleneck, se_basic, se_bottleneck, preact_basic, preact_bottleneck"
764+ "Unknown ResNet variant. Supported: basic, bottleneck, eca_basic, eca_bottleneck, cbam_basic, cbam_bottleneck, se_basic, se_bottleneck, preact_basic, preact_bottleneck"
500765 )
501766
502767 return ResNetClassifier (
0 commit comments