@@ -85,29 +85,31 @@ def quant_module(module: nn.Module, c_activation_val: float = 8.0, c_weight_val:
8585 m .output_quantizer .data_bits = out_bits
8686 m .output_quantizer .clamp_activation_value = c_activation_val
8787
88- def constrain (model : nn .Module , config_file : str = None , disable_module = None , disable_submodel = [] ):
88+ def constrain (model : nn .Module , config_file : str = None , disable_module = None , disable_submodel = None ):
8989 c_configs = QUANT_CONFIGS
9090 if config_file is not None :
9191 c_configs ._load_from_yaml (config_file )
92+ disabled_submodels = [] if disable_submodel is None else list (disable_submodel )
9293
9394 if disable_module is not None :
9495 for name in disable_module :
9596 if _CMODULE_TABLE .get (name , None ) is not None :
9697 _CMODULE_TABLE .pop (name )
9798
9899 for name , m in model .named_modules ():
99- if any (name .startswith (p ) for p in disable_submodel ): # disable_submodel 直接按照module的名称进行屏蔽
100+ if any (name .startswith (p ) for p in disabled_submodels ): # disable_submodel 直接按照module的名称进行屏蔽
100101 continue
101102 _constrain_submodule (model , name , m , c_configs .clamp_info .to_dict ())
102103
103104 model .to (c_configs .device )
104105 return model
105106
106- def init (model : nn .Module , config_file : str = None , disable_module = None , disable_submodel = [] ):
107+ def init (model : nn .Module , config_file : str = None , disable_module = None , disable_submodel = None ):
107108
108109 q_configs = QUANT_CONFIGS
109110 if config_file is not None :
110111 q_configs ._load_from_yaml (config_file )
112+ disabled_submodels = [] if disable_submodel is None else list (disable_submodel )
111113
112114 if disable_module is not None :
113115 for name in disable_module :
@@ -118,15 +120,15 @@ def init(model: nn.Module, config_file: str = None, disable_module=None, disable
118120 # model = _replace_ops(traced_model, q_configs)
119121
120122 for name , m in model .named_modules ():
121- if any (name .startswith (p ) for p in disable_submodel ): # disable_submodel 直接按照module的名称进行量化屏蔽
123+ if any (name .startswith (p ) for p in disabled_submodels ): # disable_submodel 直接按照module的名称进行量化屏蔽
122124 continue
123125
124126 m .register_forward_pre_hook (hook_pre_forward )
125127 m .register_forward_hook (hook_forward )
126128
127129 is_replaced = _quantize_submodule (model , name , m , weights_cfg = q_configs .quant_info .to_dict (), activations_cfg = q_configs .quant_info .to_dict (), bias_cfg = q_configs .quant_info .to_dict (), constrain = q_configs .clamp_info .to_dict ())
128130 if is_replaced :
129- disable_submodel .append (name )
131+ disabled_submodels .append (name )
130132
131133 def quant_tensor_pre_hook (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs ):
132134 quantizer_tags = (
@@ -192,6 +194,14 @@ def quant_tensor_layer(module, prefix=''):
192194 iq_layer .training = model .training
193195 # iq_layer = iq_layer.to(device)
194196 setattr (module , input_name , iq_layer )
197+ elif '_qgelu_' in input_name :
198+ iq_layer = QGelu (activate_config = activate_cfg , num_input = 1 )
199+ iq_layer .training = model .training
200+ setattr (module , input_name , iq_layer )
201+ elif '_qswish_' in input_name :
202+ iq_layer = QSwish (activate_config = activate_cfg , num_input = 1 )
203+ iq_layer .training = model .training
204+ setattr (module , input_name , iq_layer )
195205 elif '_qsoftmax_' in input_name :
196206 iq_layer = QSoftmax (activate_config = activate_cfg , num_input = 1 )
197207 iq_layer .training = model .training
0 commit comments