1
1
import os , torch
2
- import segmentation_models_pytorch as smp
3
2
4
3
from .adscnet import ADSCNet
5
4
from .aglnet import AGLNet
38
37
from .sqnet import SQNet
39
38
from .stdc import STDC , LaplacianConv
40
39
from .swiftnet import SwiftNet
41
-
42
-
43
- decoder_hub = {'deeplabv3' :smp .DeepLabV3 , 'deeplabv3p' :smp .DeepLabV3Plus , 'fpn' :smp .FPN ,
44
- 'linknet' :smp .Linknet , 'manet' :smp .MAnet , 'pan' :smp .PAN , 'pspnet' :smp .PSPNet ,
45
- 'unet' :smp .Unet , 'unetpp' :smp .UnetPlusPlus }
40
+ from .model_registry import model_hub , aux_models , detail_head_models
46
41
47
42
48
43
def get_model (config ):
49
- model_hub = {'adscnet' :ADSCNet , 'aglnet' :AGLNet , 'bisenetv1' :BiSeNetv1 ,
50
- 'bisenetv2' :BiSeNetv2 , 'canet' :CANet , 'cfpnet' :CFPNet ,
51
- 'cgnet' :CGNet , 'contextnet' :ContextNet , 'dabnet' :DABNet ,
52
- 'ddrnet' :DDRNet , 'dfanet' :DFANet , 'edanet' :EDANet ,
53
- 'enet' :ENet , 'erfnet' :ERFNet , 'esnet' :ESNet ,
54
- 'espnet' :ESPNet , 'espnetv2' :ESPNetv2 , 'fanet' :FANet , 'farseenet' :FarSeeNet ,
55
- 'fastscnn' :FastSCNN , 'fddwnet' :FDDWNet , 'fpenet' :FPENet ,
56
- 'fssnet' :FSSNet , 'icnet' :ICNet , 'lednet' :LEDNet ,
57
- 'linknet' :LinkNet , 'lite_hrnet' :LiteHRNet , 'liteseg' :LiteSeg , 'mininet' :MiniNet ,
58
- 'mininetv2' :MiniNetv2 , 'ppliteseg' :PPLiteSeg , 'regseg' :RegSeg ,
59
- 'segnet' :SegNet , 'shelfnet' :ShelfNet , 'sqnet' :SQNet ,
60
- 'stdc' :STDC , 'swiftnet' :SwiftNet ,}
61
-
62
- # The following models currently support auxiliary heads
63
- aux_models = ['bisenetv2' , 'ddrnet' , 'icnet' ]
64
-
65
- # The following models currently support detail heads
66
- detail_head_models = ['stdc' ]
67
-
68
44
if config .model == 'smp' : # Use segmentation models pytorch
69
- if config .decoder not in decoder_hub :
70
- raise ValueError (f"Unsupported decoder type: { config .decoder } " )
45
+ from .smp_wrapper import get_smp_model
71
46
72
- model = decoder_hub [config .decoder ](encoder_name = config .encoder ,
73
- encoder_weights = config .encoder_weights ,
74
- in_channels = 3 , classes = config .num_class )
47
+ model = get_smp_model (config .encoder , config .decoder , config .encoder_weights , config .num_class )
75
48
76
49
elif config .model in model_hub .keys ():
77
- if config .model in aux_models :
78
- model = model_hub [config .model ](num_class = config .num_class , use_aux = config .use_aux )
79
- elif config .model in detail_head_models :
80
- model = model_hub [config .model ](num_class = config .num_class , use_detail_head = config .use_detail_head , use_aux = config .use_aux )
50
+ if config .model in aux_models : # models support auxiliary heads
51
+ if config .model in detail_head_models : # models support detail heads
52
+ model = model_hub [config .model ](num_class = config .num_class , use_detail_head = config .use_detail_head , use_aux = config .use_aux )
53
+ else :
54
+ model = model_hub [config .model ](num_class = config .num_class , use_aux = config .use_aux )
55
+
81
56
else :
82
57
if config .use_aux :
83
58
raise ValueError (f'Model { config .model } does not support auxiliary heads.\n ' )
84
- if config .use_detail_head :
85
- raise ValueError (f'Model { config .model } does not support detail heads.\n ' )
86
59
87
60
model = model_hub [config .model ](num_class = config .num_class )
88
61
@@ -92,16 +65,30 @@ def get_model(config):
92
65
return model
93
66
94
67
68
+ def list_available_models ():
69
+ model_list = list (model_hub .keys ())
70
+
71
+ try :
72
+ import segmentation_models_pytorch as smp
73
+ model_list .append ('smp' )
74
+ except :
75
+ pass
76
+
77
+ return model_list
78
+
79
+
95
80
def get_teacher_model (config , device ):
96
81
if config .kd_training :
97
82
if not os .path .isfile (config .teacher_ckpt ):
98
- raise ValueError (f'Could not find teacher checkpoint at path { config .teacher_ckpt } .' )
83
+ raise ValueError (f'Could not find teacher checkpoint at path { config .teacher_ckpt } .' )
99
84
100
- if config .teacher_decoder not in decoder_hub . keys () :
101
- raise ValueError ( f"Unsupported teacher decoder type: { config . teacher_decoder } " )
85
+ if config .teacher_model == 'smp' :
86
+ from . smp_wrapper import get_smp_model
102
87
103
- model = decoder_hub [config .teacher_decoder ](encoder_name = config .teacher_encoder ,
104
- encoder_weights = None , in_channels = 3 , classes = config .num_class )
88
+ model = get_smp_model (config .teacher_encoder , config .teacher_decoder , None , config .num_class )
89
+
90
+ else :
91
+ raise NotImplementedError ()
105
92
106
93
teacher_ckpt = torch .load (config .teacher_ckpt , map_location = torch .device ('cpu' ))
107
94
model .load_state_dict (teacher_ckpt ['state_dict' ])
0 commit comments