@@ -90,6 +90,8 @@ def __init__(
9090 ) -> None :
9191 self .source_state_shard_info = source_state_shard_info
9292 self .destination_state_shard_info = destination_state_shard_info
93+ self .left_var_to_right_var_mapping = {}
94+ self .right_var_from_left_var_mapping = {}
9395
9496 def get_all_dst_state_keys (self ):
9597 dst_state_keys = set ()
@@ -108,7 +110,9 @@ def get_all_src_state_keys(self):
108110 return src_state_keys
109111
110112 def get_num_hidden_layers (
111- self , name_with_layer_id : str , layer_id_macro_tag : str
113+ self ,
114+ name_with_layer_id : str ,
115+ layer_id_macro_tag : str ,
112116 ) -> int :
113117 if layer_id_macro_tag not in name_with_layer_id :
114118 raise ValueError (
@@ -133,11 +137,16 @@ def get_src_state_shard_num(self, src_state_key: str) -> int:
133137 "AOA notions apply only to the model state, but are automatically propagated to the optimizer state."
134138 )
135139
140+ # Only need to parse the model state key for optimizer state shard num, because the optimizer state slice info is completely consistent with the model state slice info.
141+ resolved_model_state_key = self .resolve_mapping_chain (
142+ model_state_key , reverse = True
143+ )
144+
136145 state_keys = [
137- model_state_key ,
138- f"{ model_state_key } .w_0" ,
139- f"{ model_state_key } .moment1_0" ,
140- f"{ model_state_key } .moment2_0" ,
146+ resolved_model_state_key ,
147+ f"{ resolved_model_state_key } .w_0" ,
148+ f"{ resolved_model_state_key } .moment1_0" ,
149+ f"{ resolved_model_state_key } .moment2_0" ,
141150 ]
142151
143152 shard_nums = {
@@ -152,10 +161,10 @@ def get_src_state_shard_num(self, src_state_key: str) -> int:
152161 }
153162
154163 if not shard_nums :
155- raise ValueError (
156- f"No shard information found for any of the keys: { state_keys } "
164+ logger . warning (
165+ f"No shard information found for any of the keys: { state_keys } , return 1. "
157166 )
158-
167+ return 1
159168 if len (shard_nums ) > 1 :
160169 raise AssertionError (
161170 f"Inconsistent shard numbers among keys in source_sharded_state_dict: { shard_nums } ."
@@ -166,7 +175,6 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
166175 if self .destination_state_shard_info is None :
167176 # Default `dst_state_shard_num=1` if `destination_state_shard_info` is missing.
168177 return 1
169-
170178 model_state_key , opt_state_name = split_optimizer_state_key (
171179 dst_state_key
172180 )
@@ -175,11 +183,16 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
175183 "AOA notions apply only to the model state, but are automatically propagated to the optimizer state."
176184 )
177185
186+ # Only need to parse the model state key for optimizer state shard num, because the optimizer state slice info is completely consistent with the model state slice info.
187+ resolved_model_state_key = self .resolve_mapping_chain (
188+ model_state_key , reverse = False
189+ )
190+
178191 state_keys = [
179- model_state_key ,
180- f"{ model_state_key } .w_0" ,
181- f"{ model_state_key } .moment1_0" ,
182- f"{ model_state_key } .moment2_0" ,
192+ resolved_model_state_key ,
193+ f"{ resolved_model_state_key } .w_0" ,
194+ f"{ resolved_model_state_key } .moment1_0" ,
195+ f"{ resolved_model_state_key } .moment2_0" ,
183196 ]
184197
185198 shard_nums = {
@@ -194,16 +207,54 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
194207 }
195208
196209 if not shard_nums :
197- raise ValueError (
198- f"No shard information found for any of the keys: { state_keys } "
210+ logger . warning (
211+ f"No shard information found for any of the keys: { state_keys } , return 1. "
199212 )
200-
213+ return 1
201214 if len (shard_nums ) > 1 :
202215 raise AssertionError (
203216 f"Inconsistent shard numbers among keys in destination_state_shard_info: { shard_nums } ."
204217 )
205218 return shard_nums .pop ()
206219
220+ def resolve_mapping_chain (self , key : str , reverse : bool = False ) -> str :
221+ """
222+ Recursively resolve the mapping chain, find the final leaf node
223+
224+ Args:
225+ key: The key to be resolved
226+ reverse: False use left_var_to_right_var_mapping,True use right_var_from_left_var_mapping
227+
228+ For example:
229+ - reverse=False: temp_var -> dst_key
230+ - reverse=True: temp_var -> src_key
231+ """
232+ visited = set () # avoid infinite loop
233+ current_key = key
234+
235+ if reverse :
236+ mapping_dict = self .right_var_from_left_var_mapping
237+ else :
238+ mapping_dict = self .left_var_to_right_var_mapping
239+
240+ while current_key in mapping_dict :
241+ assert current_key not in visited , (
242+ "Infinite loop detected in resolve_mapping_chain,which means the start key is not src_key or the end key is not dst_key, the aoa_config is error"
243+ )
244+ visited .add (current_key )
245+ if reverse and current_key in self .get_all_src_state_keys ():
246+ break
247+ elif not reverse and current_key in self .get_all_dst_state_keys ():
248+ break
249+
250+ mapped_vars = mapping_dict [current_key ]
251+ if mapped_vars and len (mapped_vars ) > 0 :
252+ current_key = mapped_vars [0 ]
253+ else :
254+ break
255+
256+ return current_key
257+
207258
208259class AOAEngine :
209260 def __init__ (
@@ -246,14 +297,20 @@ def make_input_tensor(
246297
247298 def build_input_vars (self ):
248299 input_vars = {}
249- for key , shards in self .source_state_shard_info .items ():
300+ dtype = None
301+ for key , shards in sorted (self .source_state_shard_info .items ()):
250302 global_shape = shards [0 ].global_shape
251- dtype = shards [0 ].dtype
252303 model_state_key , opt_state_name = split_optimizer_state_key (key )
253- if opt_state_name in [".w_0" , ".moment1_0" , ".moment2_0" , None ]:
254- input_vars [model_state_key ] = self .make_input_tensor (
255- model_state_key , global_shape , dtype
256- )
304+ if opt_state_name is None :
305+ dtype = shards [0 ].dtype
306+ if model_state_key in input_vars .keys () or opt_state_name in [
307+ ".beta1_pow_acc_0" ,
308+ ".beta2_pow_acc_0" ,
309+ ]:
310+ continue
311+ input_vars [model_state_key ] = self .make_input_tensor (
312+ model_state_key , global_shape , dtype
313+ )
257314 return input_vars
258315
259316 def split (
@@ -487,7 +544,7 @@ def _get_var_ref(var):
487544 elif attr .key == "dtype" :
488545 result = self .cast (in_ref , attr .value )
489546 elif attr .key == "axis" :
490- pass
547+ result = in_ref
491548 else :
492549 raise ValueError (f"Unsupported attribute: { attr } " )
493550
@@ -530,6 +587,8 @@ def find_source_slices(
530587 ) -> list [SliceRef ]:
531588 assert key in self .output_vars
532589 tensor = self .output_vars [key ]
590+ if tensor is None :
591+ return []
533592 results = []
534593 assert len (local_slice ) == len (tensor .shape )
535594 ndim = len (tensor .shape )
@@ -648,10 +707,19 @@ def find_shard_sources(
648707
649708 for src_key , src_slices , local_slices , pp_list in results :
650709 src_var = self .input_vars [src_key ]
651- assert src_var .dtype == target .dtype , (
652- "Direct assignment of Tensors with different types is prohibited in AOA. "
653- "If you want to achieve this functionality, please use the cast semantics provided by AOA."
710+ target_model_state_key , target_opt_state_name = (
711+ split_optimizer_state_key (target .key )
654712 )
713+ if target_opt_state_name is None :
714+ if src_var .dtype != target .dtype :
715+ assert pp_list is not None and target .dtype in str (
716+ pp_list
717+ ), (
718+ "Direct assignment of Tensors with different types is prohibited in AOA. "
719+ "If you want to achieve this functionality, please use the cast semantics provided by AOA."
720+ )
721+ else :
722+ src_var .dtype = target .dtype
655723
656724 src_global_shape = src_var .shape
657725
@@ -674,7 +742,7 @@ def find_shard_sources(
674742 src_local_shape ,
675743 tuple (src_global_shape ),
676744 src_global_offset ,
677- target .dtype ,
745+ src_var .dtype ,
678746 )
679747 target_sharded_weight = ShardedWeightDesc (
680748 target_key ,
0 commit comments