Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ repos:
hooks:
- id: compileall
name: Compile all python files
entry: sh -c 'PYTHONWARNINGS=error python3 -m compileall -q .'
entry: sh -c 'PYTHONWARNINGS=error python3 -m compileall -q . -x "\.venv|venv|\.git"'
language: python
pass_filenames: false
9 changes: 8 additions & 1 deletion verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,14 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
raise NotImplementedError(f"Unknown architecture {model_config['architectures']}")

with init_empty_weights():
save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16)
# infer trust_remote_code from model structure
has_remote_code = hasattr(model_config, "auto_map") and any(
model_config.architectures[0] in val for val in model_config.auto_map.values()
)
save_model = auto_model_cls.from_config(
model_config, torch_dtype=torch.bfloat16, trust_remote_code=has_remote_code
)

save_model.to_empty(device="cpu")

if save_model.can_generate():
Expand Down