99import torch
1010
1111from fast_llm import __version__
12+ from fast_llm .config import MISSING
1213from fast_llm .engine .base_model .config import BaseModelArchitectureConfig
1314from fast_llm .engine .checkpoint .config import (
1415 CheckpointLoadConfig ,
2425logger = logging .getLogger (__name__ )
2526
2627
27- @dataclasses .dataclass
28- class ParamConverter :
29- fast_llm_name : tuple [str , ...] | None
30- export_name : tuple [str , ...] | str | None
28+ @dataclasses .dataclass ( kw_only = True )
29+ class ParamConverter ( abc . ABC ) :
30+ fast_llm_names : tuple [tuple [ str , ...], ...] = () # Array of fast-llm names, in nested (tuple) format.
31+ export_names : tuple [tuple [ str , ...], ...] = () # Array of export names, in nested (tuple) format.
3132
32- def export_param (self , fast_llm_value ):
33- return fast_llm_value
33+ @abc .abstractmethod
34+ def export_params (self , fast_llm_values : tuple [typing .Any , ...]) -> tuple [typing .Any , ...]:
35+ pass
36+
37+ @abc .abstractmethod
38+ def import_params (self , export_values : tuple [typing .Any , ...]) -> tuple [typing .Any , ...]:
39+ pass
40+
41+
42+ @dataclasses .dataclass (kw_only = True )
43+ class RenameParamConverter (ParamConverter ):
3444
35- def import_param (self , export_value ):
36- return export_value
45+ def __post_init__ (self ):
46+ Assert .eq (len (self .fast_llm_names ), 1 )
47+ Assert .eq (len (self .export_names ), 1 )
3748
49+ def export_params (self , fast_llm_values ):
50+ return fast_llm_values
3851
39- @dataclasses .dataclass
52+ def import_params (self , export_values ):
53+ return export_values
54+
55+
56+ # def __repr__(self):
57+ # return f"RenameParamConverter({'.'.join(self.fast_llm_names[0])} <--> {'.'.join(self.export_names[0])})"
58+
59+
60+ @dataclasses .dataclass (kw_only = True )
4061class ConstantImportParamConverter (ParamConverter ):
41- fast_llm_value : typing .Any
62+ fast_llm_value : typing .Any = MISSING
63+
64+ def __post_init__ (self ):
65+ Assert .eq (len (self .fast_llm_names ), 1 )
66+ Assert .eq (len (self .export_names ), 0 )
4267
43- def export_param (self , fast_llm_value ):
44- Assert .eq (fast_llm_value , self .fast_llm_value )
68+ def export_params (self , fast_llm_values ):
69+ Assert .eq (fast_llm_values [0 ], self .fast_llm_value )
70+ return ()
4571
46- def import_param (self , export_value ):
47- return self .fast_llm_value
72+ def import_params (self , export_values ):
73+ return ( self .fast_llm_value ,)
4874
4975
50- @dataclasses .dataclass
76+ @dataclasses .dataclass ( kw_only = True )
5177class ConstantExportParamConverter (ParamConverter ):
52- export_value : typing .Any
78+ export_value : typing .Any = MISSING
5379
54- def export_param (self , fast_llm_value ):
55- return self .export_value
80+ def __post_init__ (self ):
81+ Assert .eq (len (self .fast_llm_names ), 0 )
82+ Assert .eq (len (self .export_names ), 1 )
5683
57- def import_param (self , export_value ):
58- Assert .eq (export_value , self .export_value )
84+ def export_params (self , fast_llm_values ):
85+ return (self .export_value ,)
86+
87+ def import_params (self , export_values ):
88+ Assert .eq (export_values [0 ], self .export_value )
89+ return ()
5990
6091
61- @dataclasses .dataclass
92+ @dataclasses .dataclass ( kw_only = True )
6293class IgnoreImportParamConverter (ParamConverter ):
63- ignore_export_value : typing .Any
94+ ignore_export_value : typing .Any = MISSING
6495
65- def export_param (self , fast_llm_value ):
66- pass
96+ def __post_init__ (self ):
97+ Assert .eq (len (self .fast_llm_names ), 0 )
98+ Assert .eq (len (self .export_names ), 1 )
6799
68- def import_param (self , export_value ):
69- if export_value is not self .ignore_export_value :
100+ def export_params (self , fast_llm_values ):
101+ return (MISSING ,)
102+
103+ def import_params (self , export_values ):
104+ if export_values [0 ] not in (self .ignore_export_value , MISSING ):
70105 logger .warning (
71- f"The configuration parameter `{ self .export_name } ={ export_value } ` is ignored during conversion."
106+ f"The configuration parameter `{ self .export_names [ 0 ] } ={ export_values [ 0 ] } ` is ignored during conversion."
72107 f" If you intend to use it in Fast-LLM, make sure to set it explicitly in the model configuration."
73108 )
109+ return ()
74110
75111
76- @dataclasses .dataclass
112+ @dataclasses .dataclass ( kw_only = True )
77113class MappedConfigParamConverter (ParamConverter ):
78- fast_llm_value : typing .Callable
79- export_value : typing .Callable
114+ fast_llm_value : typing .Callable = lambda x : x
115+ export_value : typing .Callable = lambda x : x
116+
117+ def __post_init__ (self ):
118+ Assert .eq (len (self .fast_llm_names ), 1 )
119+ Assert .eq (len (self .export_names ), 1 )
80120
81- def export_param (self , fast_llm_value ):
82- return self .export_value (fast_llm_value )
121+ def export_params (self , fast_llm_values ):
122+ return ( self .export_value (fast_llm_values [ 0 ]), )
83123
84- def import_param (self , export_value ):
85- return self .fast_llm_value (export_value )
124+ def import_params (self , export_values ):
125+ return ( self .fast_llm_value (export_values [ 0 ]), )
86126
87127
88128class WeightConverter :
@@ -197,13 +237,18 @@ def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing
197237 # TODO v0.3: not used in this class
198238 exported_config = {}
199239 for converter in cls ._get_config_converters ():
200- value = converter .export_param (
201- None
202- if converter .fast_llm_name is None
203- else cls ._get_fast_llm_attribute (config , converter .fast_llm_name ) # Noqa
204- )
205- if converter .export_name is not None :
206- set_nested_dict_value (exported_config , converter .export_name , value )
240+ try :
241+ values = converter .export_params (
242+ tuple (
243+ cls ._get_fast_llm_attribute (config , fast_llm_name )
244+ for fast_llm_name in converter .fast_llm_names
245+ )
246+ )
247+ for export_name , value in zip (converter .export_names , values , strict = True ):
248+ if value is not MISSING :
249+ set_nested_dict_value (exported_config , export_name , value )
250+ except Exception as e :
251+ raise RuntimeError (f"Config conversion failed for converter { converter } " , * e .args )
207252
208253 return exported_config # Noqa
209254
@@ -214,12 +259,25 @@ def _import_config(
214259 kwargs = {}
215260 for converter in cls ._get_config_converters ():
216261 try :
217- value = None if converter .export_name is None else get_nested_dict_value (config , converter .export_name )
218- except KeyError :
219- value = None
220- value = converter .import_param (value )
221- if converter .fast_llm_name is not None :
222- kwargs [converter .fast_llm_name ] = value
262+ values = ()
263+ for export_name in converter .export_names :
264+ try :
265+ value = get_nested_dict_value (config , export_name )
266+ except KeyError :
267+ value = MISSING
268+ values = values + (value ,)
269+ values = converter .import_params (values )
270+ for fast_llm_name , value in zip (converter .fast_llm_names , values , strict = True ):
271+ if value is MISSING :
272+ # Missing values need to be handled in dedicated converters,
273+ # because implicit / default values may not match.
274+ # TODO: Different behavior from other uses of MISSING. Use different tag?
275+ raise ValueError (f"Missing converted value for fast-llm parameter { fast_llm_name } " )
276+ if fast_llm_name in kwargs :
277+ raise ValueError (f"Duplicate converted value for fast-llm parameter { fast_llm_name } " )
278+ kwargs [fast_llm_name ] = value
279+ except Exception as e :
280+ raise RuntimeError (f"Config conversion failed for converter { converter } " , * e .args )
223281
224282 config_class = cls ._model_class .get_base_model_config_class ()
225283 if architecture_only :
@@ -335,7 +393,11 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str:
335393 @classmethod
336394 @abc .abstractmethod
337395 def _create_config_converters (cls ) -> list [ParamConverter ]:
338- return [ConstantExportParamConverter (None , "model_type" , cls .get_huggingface_model_type ())]
396+ return [
397+ ConstantExportParamConverter (
398+ export_names = (("model_type" ,),), export_value = cls .get_huggingface_model_type ()
399+ )
400+ ]
339401
340402 @classmethod
341403 def _load_config (cls , directory : pathlib .Path | str ) -> dict :
0 commit comments