Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,13 +666,18 @@ def _get_var_ref(var):
for name in self.destination_state_shard_info:
model_state_key, _ = split_optimizer_state_key(name)
if model_state_key not in self.output_vars:
self.output_vars[model_state_key] = (
None
if model_state_key in self.need_add_output_vars
else self.input_vars[
if model_state_key in self.need_add_output_vars:
self.output_vars[model_state_key] = None
else:
assert model_state_key in self.input_vars, (
f"{model_state_key} is in dst_keys (needs to be loaded), "
f"but not found in src_keys. "
f"If it is a new key and you want to load it, please use the add primitive in aoa_statements: "
f"_ -> {model_state_key}, and {model_state_key} will be randomly initialized."
)
self.output_vars[model_state_key] = self.input_vars[
model_state_key
] # Assertion implied by direct access
)
]
else:
# When destination_state_shard_info is not provided, the AOAEngine automatically derives it
# from source_state_shard_info and aha_statements. In this case, all destination_states
Expand Down
61 changes: 60 additions & 1 deletion python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,9 @@ def _handle_aoa(
]
for param_name, local_tensor_metas in state_dict_metadata.items()
}

logger_missing_key_and_unexpected_keys_before_aoa(
metadata, load_dict, process_group, safetensors, use_dist
)
aoa_engine = AOAEngine(
source_state_shard_info=source_state_shard_info,
destination_state_shard_info=destination_state_shard_info,
Expand Down Expand Up @@ -1668,6 +1670,63 @@ def slice_dict(d, start, end):
)


def logger_missing_key_and_unexpected_keys_before_aoa(
metadata: Metadata,
state_dict: dict[str, Tensor] | dict[str, ShardedWeight],
process_group: Group | None = None,
safetensors: bool = False,
use_dist: bool = False,
):
first_key = next(iter(state_dict), None)
if isinstance(first_key, tuple):
flat_state_dict = state_dict
else:
flat_state_dict, mapping = flatten_state_dict(state_dict)

global_src_key_list = []
dst_key_list = [
key if isinstance(key, str) else key[0]
for key in flat_state_dict.keys()
]
global_dst_key_list = []
if use_dist:
paddle.distributed.all_gather_object(
global_dst_key_list, dst_key_list, process_group
)
flatten_global_src_key_list = [
item for sublist in global_dst_key_list for item in sublist
]
else:
global_dst_key_list.extend(dst_key_list)
flatten_global_src_key_list = global_dst_key_list

for local_tensor_index, file_name in metadata.storage_metadata.items():
if (
local_tensor_index.replica_id is not None
and local_tensor_index.replica_id != 0
):
continue
global_src_key_list.append(local_tensor_index.tensor_key)
missing_keys = set(flatten_global_src_key_list) - set(global_src_key_list)
unexpected_keys = set(global_src_key_list) - set(
flatten_global_src_key_list
)
if len(missing_keys) > 0:
print(
f"Missing keys:{missing_keys}, check whether the checkpoint is complete."
)
logger.warning(
f"Missing keys:{missing_keys}, check whether the checkpoint is complete."
)
if len(unexpected_keys) > 0:
print(
f"Unexpected keys:{unexpected_keys}, check whether the checkpoint is complete."
)
logger.warning(
f"Unexpected keys:{unexpected_keys}, check whether the checkpoint is complete."
)


class SavePartialSafetensors:
def __init__(self, output_path, process_group, prefix="model"):
self.output_path = output_path
Expand Down
19 changes: 19 additions & 0 deletions test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,25 @@ def test_save_load_state_dict_with_aoa_config_reverse(self):
)
ckpt_path.cleanup()

def test_save_load_with_missing_key_and_unexpected_keys(self):
"""Test logger missing key and unexpected keys."""
ckpt_path = tempfile.TemporaryDirectory()
envs_list = test_base.gen_product_envs_list(
self._default_envs, {"device_num": ["1", "2"]}
)
for envs in envs_list:
envs["ckpt_path"] = ckpt_path.name
super().setUp(
num_of_devices=int(envs["device_num"]),
timeout=60,
nnode=1,
)
self.run_test_case(
"test_save_load_with_missing_key_and_unexpected_keys.py",
user_defined_envs=envs,
)
ckpt_path.cleanup()


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
import paddle.distributed as dist
from paddle import nn
from paddle.distributed.flex_checkpoint.dcp.load_state_dict import (
load_state_dict,
)


class HuggingFaceModel(nn.Layer):
def __init__(self):
super().__init__()
self.huggingface = nn.Linear(2, 2, bias_attr=False)


class FCModel(nn.Layer):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(1, 2, bias_attr=False)
self.fc2 = nn.Linear(1, 2, bias_attr=False)


def init_hf_model_weights(model):
with paddle.no_grad():
w = paddle.to_tensor([[0, 1], [2, 3]], dtype="float16")
model.huggingface.weight.set_value(w)


def save_safetensors_model(model, ckpt_path):
import safetensors.numpy

os.makedirs(ckpt_path, exist_ok=True)
weight_np = model.huggingface.weight.numpy()
file_path = os.path.join(ckpt_path, "tensor1.safetensors")
safetensors.numpy.save_file({"huggingface.weight": weight_np}, file_path)


def test_save_load_with_missing_key_and_unexpected_keys():
ckpt_path = os.getenv("ckpt_path")
dist.init_parallel_env()

hf_model = HuggingFaceModel()
fc_model = FCModel()
hf_model = paddle.amp.decorate(
models=hf_model, optimizers=None, level="O2", dtype="float16"
)
init_hf_model_weights(hf_model)

save_safetensors_model(hf_model, ckpt_path)

aoa_statements = []
aoa_config = {"aoa_statements": aoa_statements}

try:
load_state_dict(
fc_model.sharded_state_dict(),
ckpt_path,
safetensors=True,
aoa_config=aoa_config,
)
raise AssertionError
except Exception as e:
print(e)
pass


if __name__ == "__main__":
test_save_load_with_missing_key_and_unexpected_keys()
Loading