-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAdd_FPN.py
More file actions
72 lines (57 loc) · 2.11 KB
/
Add_FPN.py
File metadata and controls
72 lines (57 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Dict, Optional, List
import torch.nn as nn
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock
from torch import nn, Tensor
from torchvision.models._utils import IntermediateLayerGetter
class Effnet_BackboneWithFPN(nn.Module):
def __init__(
self,
backbone: nn.Module,
return_layers: Dict[str, str],
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> None:
super().__init__()
if extra_blocks is None:
extra_blocks = LastLevelMaxPool()
self.backbone = backbone
self.back = self.backbone.blocks
self.body = IntermediateLayerGetter(self.back, return_layers=return_layers)
self.fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=extra_blocks,
)
self.out_channels = out_channels
def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.backbone.conv_stem(x)
x = self.backbone.bn1(x)
x = self.backbone.act1(x)
x = self.body(x)
x = self.fpn(x)
return x
class Densenet_BackboneWithFPN(nn.Module):
def __init__(
self,
backbone: nn.Module,
return_layers: Dict[str, str],
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> None:
super().__init__()
if extra_blocks is None:
extra_blocks = LastLevelMaxPool()
self.back = backbone
self.body = IntermediateLayerGetter(self.back.features, return_layers=return_layers)
self.fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=extra_blocks,
)
self.out_channels = out_channels
def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.body(x)
x = self.fpn(x)
return x