Skip to content

Commit 1fa7b0a

Browse files
【Cherry-pick】Cherry-pick FlexCheckpoint PRs from develop to fleety_12 (#76252)
* 【FlexCheckpoint】fix_the_layer_id_macro (#75556) * fix_the_layer_id_macro * fix the ctest * add expert_id_macro * fix the assert bug * fix the code style * Pr support load hf checkpoint (#75928) * support hf checkpoint fix support cast add id macro fix * add test and fix some bug * fix full param bug * add full param cast test --------- Co-authored-by: xingmingyyj <[email protected]> * 【Flexcheckpoint】add_get_var_mapping_chain_macro (#76013) * add_get_var_mapping_chain_macro * add note * fix the bug input_vars and resolve_mapping_chain * fix the code style * fit the dtype assert bug * fix the bug * fix the merge_sharded_state_dict bug * fix aoa transpose corner case (#76234) --------- Co-authored-by: Tianyu Zheng <[email protected]>
1 parent fe8aaaa commit 1fa7b0a

File tree

13 files changed

+1253
-201
lines changed

13 files changed

+1253
-201
lines changed

python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py

Lines changed: 95 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def __init__(
9090
) -> None:
9191
self.source_state_shard_info = source_state_shard_info
9292
self.destination_state_shard_info = destination_state_shard_info
93+
self.left_var_to_right_var_mapping = {}
94+
self.right_var_from_left_var_mapping = {}
9395

9496
def get_all_dst_state_keys(self):
9597
dst_state_keys = set()
@@ -108,7 +110,9 @@ def get_all_src_state_keys(self):
108110
return src_state_keys
109111

110112
def get_num_hidden_layers(
111-
self, name_with_layer_id: str, layer_id_macro_tag: str
113+
self,
114+
name_with_layer_id: str,
115+
layer_id_macro_tag: str,
112116
) -> int:
113117
if layer_id_macro_tag not in name_with_layer_id:
114118
raise ValueError(
@@ -133,11 +137,16 @@ def get_src_state_shard_num(self, src_state_key: str) -> int:
133137
"AOA notions apply only to the model state, but are automatically propagated to the optimizer state."
134138
)
135139

140+
# Only need to parse the model state key for optimizer state shard num, because the optimizer state slice info is completely consistent with the model state slice info.
141+
resolved_model_state_key = self.resolve_mapping_chain(
142+
model_state_key, reverse=True
143+
)
144+
136145
state_keys = [
137-
model_state_key,
138-
f"{model_state_key}.w_0",
139-
f"{model_state_key}.moment1_0",
140-
f"{model_state_key}.moment2_0",
146+
resolved_model_state_key,
147+
f"{resolved_model_state_key}.w_0",
148+
f"{resolved_model_state_key}.moment1_0",
149+
f"{resolved_model_state_key}.moment2_0",
141150
]
142151

143152
shard_nums = {
@@ -152,10 +161,10 @@ def get_src_state_shard_num(self, src_state_key: str) -> int:
152161
}
153162

154163
if not shard_nums:
155-
raise ValueError(
156-
f"No shard information found for any of the keys: {state_keys}"
164+
logger.warning(
165+
f"No shard information found for any of the keys: {state_keys}, return 1."
157166
)
158-
167+
return 1
159168
if len(shard_nums) > 1:
160169
raise AssertionError(
161170
f"Inconsistent shard numbers among keys in source_sharded_state_dict: {shard_nums}."
@@ -166,7 +175,6 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
166175
if self.destination_state_shard_info is None:
167176
# Default `dst_state_shard_num=1` if `destination_state_shard_info` is missing.
168177
return 1
169-
170178
model_state_key, opt_state_name = split_optimizer_state_key(
171179
dst_state_key
172180
)
@@ -175,11 +183,16 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
175183
"AOA notions apply only to the model state, but are automatically propagated to the optimizer state."
176184
)
177185

186+
# Only need to parse the model state key for optimizer state shard num, because the optimizer state slice info is completely consistent with the model state slice info.
187+
resolved_model_state_key = self.resolve_mapping_chain(
188+
model_state_key, reverse=False
189+
)
190+
178191
state_keys = [
179-
model_state_key,
180-
f"{model_state_key}.w_0",
181-
f"{model_state_key}.moment1_0",
182-
f"{model_state_key}.moment2_0",
192+
resolved_model_state_key,
193+
f"{resolved_model_state_key}.w_0",
194+
f"{resolved_model_state_key}.moment1_0",
195+
f"{resolved_model_state_key}.moment2_0",
183196
]
184197

185198
shard_nums = {
@@ -194,16 +207,54 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
194207
}
195208

196209
if not shard_nums:
197-
raise ValueError(
198-
f"No shard information found for any of the keys: {state_keys}"
210+
logger.warning(
211+
f"No shard information found for any of the keys: {state_keys}, return 1."
199212
)
200-
213+
return 1
201214
if len(shard_nums) > 1:
202215
raise AssertionError(
203216
f"Inconsistent shard numbers among keys in destination_state_shard_info: {shard_nums}."
204217
)
205218
return shard_nums.pop()
206219

220+
def resolve_mapping_chain(self, key: str, reverse: bool = False) -> str:
221+
"""
222+
Recursively resolve the mapping chain, find the final leaf node
223+
224+
Args:
225+
key: The key to be resolved
226+
reverse: False use left_var_to_right_var_mapping,True use right_var_from_left_var_mapping
227+
228+
For example:
229+
- reverse=False: temp_var -> dst_key
230+
- reverse=True: temp_var -> src_key
231+
"""
232+
visited = set() # avoid infinite loop
233+
current_key = key
234+
235+
if reverse:
236+
mapping_dict = self.right_var_from_left_var_mapping
237+
else:
238+
mapping_dict = self.left_var_to_right_var_mapping
239+
240+
while current_key in mapping_dict:
241+
assert current_key not in visited, (
242+
"Infinite loop detected in resolve_mapping_chain,which means the start key is not src_key or the end key is not dst_key, the aoa_config is error"
243+
)
244+
visited.add(current_key)
245+
if reverse and current_key in self.get_all_src_state_keys():
246+
break
247+
elif not reverse and current_key in self.get_all_dst_state_keys():
248+
break
249+
250+
mapped_vars = mapping_dict[current_key]
251+
if mapped_vars and len(mapped_vars) > 0:
252+
current_key = mapped_vars[0]
253+
else:
254+
break
255+
256+
return current_key
257+
207258

208259
class AOAEngine:
209260
def __init__(
@@ -246,14 +297,20 @@ def make_input_tensor(
246297

247298
def build_input_vars(self):
248299
input_vars = {}
249-
for key, shards in self.source_state_shard_info.items():
300+
dtype = None
301+
for key, shards in sorted(self.source_state_shard_info.items()):
250302
global_shape = shards[0].global_shape
251-
dtype = shards[0].dtype
252303
model_state_key, opt_state_name = split_optimizer_state_key(key)
253-
if opt_state_name in [".w_0", ".moment1_0", ".moment2_0", None]:
254-
input_vars[model_state_key] = self.make_input_tensor(
255-
model_state_key, global_shape, dtype
256-
)
304+
if opt_state_name is None:
305+
dtype = shards[0].dtype
306+
if model_state_key in input_vars.keys() or opt_state_name in [
307+
".beta1_pow_acc_0",
308+
".beta2_pow_acc_0",
309+
]:
310+
continue
311+
input_vars[model_state_key] = self.make_input_tensor(
312+
model_state_key, global_shape, dtype
313+
)
257314
return input_vars
258315

259316
def split(
@@ -487,7 +544,7 @@ def _get_var_ref(var):
487544
elif attr.key == "dtype":
488545
result = self.cast(in_ref, attr.value)
489546
elif attr.key == "axis":
490-
pass
547+
result = in_ref
491548
else:
492549
raise ValueError(f"Unsupported attribute: {attr}")
493550

@@ -530,6 +587,8 @@ def find_source_slices(
530587
) -> list[SliceRef]:
531588
assert key in self.output_vars
532589
tensor = self.output_vars[key]
590+
if tensor is None:
591+
return []
533592
results = []
534593
assert len(local_slice) == len(tensor.shape)
535594
ndim = len(tensor.shape)
@@ -648,10 +707,19 @@ def find_shard_sources(
648707

649708
for src_key, src_slices, local_slices, pp_list in results:
650709
src_var = self.input_vars[src_key]
651-
assert src_var.dtype == target.dtype, (
652-
"Direct assignment of Tensors with different types is prohibited in AOA. "
653-
"If you want to achieve this functionality, please use the cast semantics provided by AOA."
710+
target_model_state_key, target_opt_state_name = (
711+
split_optimizer_state_key(target.key)
654712
)
713+
if target_opt_state_name is None:
714+
if src_var.dtype != target.dtype:
715+
assert pp_list is not None and target.dtype in str(
716+
pp_list
717+
), (
718+
"Direct assignment of Tensors with different types is prohibited in AOA. "
719+
"If you want to achieve this functionality, please use the cast semantics provided by AOA."
720+
)
721+
else:
722+
src_var.dtype = target.dtype
655723

656724
src_global_shape = src_var.shape
657725

@@ -674,7 +742,7 @@ def find_shard_sources(
674742
src_local_shape,
675743
tuple(src_global_shape),
676744
src_global_offset,
677-
target.dtype,
745+
src_var.dtype,
678746
)
679747
target_sharded_weight = ShardedWeightDesc(
680748
target_key,

python/paddle/distributed/flex_checkpoint/aoa/lexer.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,6 @@ def tokenize(self, text):
8888
mo = self.get_token(text, pos)
8989
return tokens
9090

91-
def apply_macros(self, expression):
92-
expressions = [expression]
93-
for macro in self.macros:
94-
expressions = self.apply_macro(expressions, macro)
95-
return expressions
96-
9791
def apply_macro(self, expression, macro):
9892
if isinstance(expression, str):
9993
expression = [expression]
@@ -106,10 +100,24 @@ def apply_macro(self, expression, macro):
106100
new_expression.extend(results)
107101
return new_expression
108102

103+
def apply_single_macro_to_all(self, expressions, macro):
104+
new_expressions = []
105+
for expr in expressions:
106+
results = macro(self.tokenize(expr), expr, self.context)
107+
if isinstance(results, str):
108+
new_expressions.append(results)
109+
else:
110+
new_expressions.extend(results)
111+
return new_expressions
112+
109113
def all_tokens(self, expressions):
114+
current_expressions = expressions
115+
for macro in self.macros:
116+
current_expressions = self.apply_single_macro_to_all(
117+
current_expressions, macro
118+
)
119+
110120
tokens = []
111-
for expr in expressions:
112-
expanded_expressions = self.apply_macros(expr)
113-
for e in expanded_expressions:
114-
tokens.extend(self.tokenize(e))
121+
for expr in current_expressions:
122+
tokens.extend(self.tokenize(expr))
115123
return tokens

0 commit comments

Comments
 (0)