Skip to content

Commit e548ebb

Browse files
committed
refactor(archon): extract utility functions and simplify engine code
Move optimizer/scheduler creation, activation checkpoint config, zero-bubble validation, deterministic mode setup, and pad_to_maximum validation into archon_utils.py for reuse and testability. Cache tp/cp parallel groups to avoid repeated lookups, and use context managers for DistributedLock. Key changes: - Extract 6 utility functions into new archon_utils.py module - Cache _tp_group and _cp_group on engine initialization - Add __enter__/__exit__ to DistributedLock for context manager usage - Replace manual lock acquire/release with `with` statements - Add venv activation note to installation docs
1 parent 6093109 commit e548ebb

7 files changed

Lines changed: 486 additions & 340 deletions

File tree

areal/experimental/engine/archon_checkpoint.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,7 @@ def save_model_to_hf(
250250
shutil.rmtree(tmp_path)
251251
dist.barrier(group=engine.cpu_group)
252252
os.makedirs(tmp_path, exist_ok=True)
253-
cpu_offload = not is_async
254-
options = StateDictOptions(full_state_dict=False, cpu_offload=cpu_offload)
253+
options = StateDictOptions(full_state_dict=False, cpu_offload=not is_async)
255254
state_dict = _get_merged_state_dict(engine, options)
256255

257256
hf_state_dict = engine.state_dict_adapter.to_hf(state_dict)
@@ -352,12 +351,13 @@ def load_model_from_hf(engine: ArchonEngine, path: str) -> None:
352351
# Add a placeholder with embed_tokens key so DCP will load it
353352
embed_key = "model.embed_tokens.weight"
354353
if embed_key not in hf_state_dict:
355-
output_tensor = state_dict["output.weight"]
356-
hf_state_dict[embed_key] = torch.empty_like(output_tensor)
354+
hf_state_dict[embed_key] = torch.empty_like(state_dict["output.weight"])
357355

358356
# Load using DCP with HuggingFaceStorageReader
359-
hf_reader = engine.state_dict_adapter.get_hf_storage_reader(path)
360-
dcp.load(hf_state_dict, storage_reader=hf_reader)
357+
dcp.load(
358+
hf_state_dict,
359+
storage_reader=engine.state_dict_adapter.get_hf_storage_reader(path),
360+
)
361361

362362
# Convert back to Archon format
363363
archon_state_dict = engine.state_dict_adapter.from_hf(hf_state_dict)

0 commit comments

Comments
 (0)