-
Notifications
You must be signed in to change notification settings - Fork 501
Expand file tree
/
Copy patharchon_checkpoint.py
More file actions
463 lines (381 loc) · 16.8 KB
/
archon_checkpoint.py
File metadata and controls
463 lines (381 loc) · 16.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
from __future__ import annotations
import json
import os
import shutil
import struct
from typing import TYPE_CHECKING, Any
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch import nn
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from areal.utils.logging import getLogger
if TYPE_CHECKING:
from transformers import AutoProcessor, PreTrainedTokenizerFast
from areal.experimental.engine.archon_engine import ArchonEngine
from areal.utils.async_checkpoint import AsyncCheckpointManager
logger = getLogger("ArchonCheckpoint")
# NOTE: Upgrading PyTorch may resolve this in the future.
def _consolidate_shards_distributed(
input_dir: str,
output_dir: str,
fqn_to_index_mapping: dict[str, int],
num_threads: int = 8,
process_group: dist.ProcessGroup | None = None,
) -> None:
"""Distribute safetensors consolidation across ranks, with correct PG barrier.
This replaces ``consolidate_safetensors_files_on_every_rank`` which has a bug:
its internal ``dist.barrier()`` ignores the *process_group* parameter and uses
the default (NCCL) PG instead. When the bg consolidation thread calls that
NCCL barrier concurrently with the main thread's NCCL collectives, different
ranks may enqueue the collectives in different order, causing a deadlock.
"""
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
_consolidate_safetensors_files,
)
from torch.distributed.checkpoint._hf_utils import _gen_file_name
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
unique_indices = set(fqn_to_index_mapping.values())
# Simple round-robin: index % world_size == rank
indices_for_this_rank = [idx for idx in unique_indices if idx % world_size == rank]
filtered_mapping = {
fqn: idx
for fqn, idx in fqn_to_index_mapping.items()
if idx in indices_for_this_rank
}
if filtered_mapping:
max_index = max(unique_indices)
filtered_filename_mapping = {
fqn: _gen_file_name(idx, max_index) for fqn, idx in filtered_mapping.items()
}
_consolidate_safetensors_files(
input_dir=input_dir,
output_dir=output_dir,
fqn_to_file_mapping=filtered_filename_mapping,
num_threads=num_threads,
)
dist.barrier(group=process_group)
class DCPState(Stateful):
"""DCP wrapper for archon models.
Key design decisions:
- Uses flatten_optimizer_state_dict=True to avoid param_group index collisions
(without flatten, each optimizer uses indices 0, 1, 2... which collide across
PP stages; with flatten, keys become parameter FQNs which are unique)
- For PP (len(model_parts) > 1): uses strict=False when loading because each
PP stage only has subset of keys
- For non-PP (len(model_parts) == 1): uses strict=True to catch real issues
"""
def __init__(
self,
model_parts: list[nn.Module] | nn.Module,
optimizer: torch.optim.Optimizer | None = None,
):
"""Initialize DCPState.
Args:
model_parts: Single model or list of model parts from pipeline_llm
optimizer: Optimizer for the model(s)
"""
if isinstance(model_parts, nn.Module):
self.model_parts = [model_parts]
else:
self.model_parts = model_parts
self.optimizer = optimizer
# PP mode uses non-strict loading since each stage only has subset of keys
self._is_pp = len(self.model_parts) > 1
def state_dict(self) -> dict[str, Any]:
"""Get state dict for model parts and optimizer using DCP utilities."""
# Merge model state dicts from all parts
# cpu_offload=True ensures tensors are on CPU for DCP filesystem writer
model_state: dict[str, Any] = {}
model_options = StateDictOptions(cpu_offload=True)
for model_part in self.model_parts:
part_state = get_model_state_dict(model_part, options=model_options)
model_state.update(part_state)
state: dict[str, Any] = {"model": model_state}
if self.optimizer is not None:
optim_options = StateDictOptions(
flatten_optimizer_state_dict=True,
cpu_offload=True,
)
# Get optimizer state for each model part and merge
optim_state: dict[str, Any] = {}
for model_part in self.model_parts:
part_optim = get_optimizer_state_dict(
model_part, self.optimizer, options=optim_options
)
optim_state.update(part_optim)
state["optim"] = optim_state
return state
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load state dicts onto model parts and optimizer."""
model_state = state_dict["model"]
model_options = StateDictOptions(strict=not self._is_pp)
for model_part in self.model_parts:
set_model_state_dict(model_part, model_state, options=model_options)
if self.optimizer is not None and "optim" in state_dict:
optim_state = state_dict["optim"]
optim_options = StateDictOptions(
strict=not self._is_pp,
flatten_optimizer_state_dict=True,
)
for model_part in self.model_parts:
set_optimizer_state_dict(
model_part, self.optimizer, optim_state, options=optim_options
)
def _validate_model_initialized(engine: ArchonEngine) -> None:
"""Validate that model is properly initialized for checkpoint operations."""
if not engine.model_parts:
raise RuntimeError("Model parts not initialized")
def _get_merged_state_dict(
engine: ArchonEngine,
options: StateDictOptions,
) -> dict[str, Any]:
"""Get merged model state dict, handling PP mode."""
if engine.parallel_dims.pp_enabled:
state_dict: dict = {}
for model_part in engine.model_parts:
part_state = get_model_state_dict(model_part, options=options)
state_dict.update(part_state)
return state_dict
return get_model_state_dict(engine.model, options=options)
def _write_safetensors_index(
output_dir: str, fqn_to_index_mapping: dict[str, int]
) -> None:
"""Write model.safetensors.index.json for multi-file HF checkpoints."""
max_index = max(fqn_to_index_mapping.values())
weight_map = {
fqn: f"model-{idx:05d}-of-{max_index:05d}.safetensors"
for fqn, idx in fqn_to_index_mapping.items()
}
# Compute total_size from safetensors file headers (no tensor loading)
total_size = 0
for filename in set(weight_map.values()):
filepath = os.path.join(output_dir, filename)
with open(filepath, "rb") as f:
# safetensors format: 8-byte LE header size, then JSON header
header_size = struct.unpack("<Q", f.read(8))[0]
header = json.loads(f.read(header_size))
for key, meta in header.items():
if key == "__metadata__":
continue
start, end = meta["data_offsets"]
total_size += end - start
index = {
"metadata": {"total_size": total_size},
"weight_map": weight_map,
}
index_path = os.path.join(output_dir, "model.safetensors.index.json")
with open(index_path, "w") as f:
json.dump(index, f, indent=2)
def save_model_to_hf(
engine: ArchonEngine,
path: str,
tokenizer: PreTrainedTokenizerFast | None,
processor: AutoProcessor | None = None,
async_mgr: AsyncCheckpointManager | None = None,
) -> None:
"""Save model in HuggingFace format using DCP infrastructure.
Args:
engine: The ArchonEngine instance.
path: Output directory for the HF checkpoint.
tokenizer: Optional tokenizer to save alongside the model.
processor: Optional processor to save alongside the model.
async_mgr: Optional async checkpoint manager. When provided and async
is enabled, dcp.async_save() is used instead of dcp.save().
The manager's post_upload_fn is set to handle consolidation.
"""
from torch.distributed.checkpoint import HuggingFaceStorageWriter
_validate_model_initialized(engine)
if engine.state_dict_adapter is None:
raise RuntimeError("state_dict_adapter is required for HF format")
engine.logger.info(f"Saving HF checkpoint to {path}")
# In async mode, let the stager handle GPU->CPU transfer
is_async = async_mgr is not None and async_mgr.is_async
# Write to temp dir first, then atomically rename to final path.
tmp_path = path + ".tmp"
if dist.get_rank() == 0 and os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
dist.barrier(group=engine.cpu_group)
os.makedirs(tmp_path, exist_ok=True)
options = StateDictOptions(full_state_dict=False, cpu_offload=not is_async)
state_dict = _get_merged_state_dict(engine, options)
hf_state_dict = engine.state_dict_adapter.to_hf(state_dict)
fqn_to_index_mapping = engine.state_dict_adapter.fqn_to_index_mapping
# HuggingFaceStorageWriter creates a sharded/ subdir we must clean up after consolidation.
sharded_dir = os.path.join(tmp_path, "sharded")
consolidation_mapping = fqn_to_index_mapping or dict.fromkeys(
hf_state_dict.keys(), 1
)
hf_writer = HuggingFaceStorageWriter(
path=sharded_dir,
save_distributed=True,
fqn_to_index_mapping=fqn_to_index_mapping,
enable_consolidation=False,
)
consolidation_pg = (
async_mgr.consolidation_process_group if is_async else engine.cpu_group
)
def _consolidate_and_cleanup(process_group=consolidation_pg):
try:
_consolidate_shards_distributed(
input_dir=sharded_dir,
output_dir=tmp_path,
fqn_to_index_mapping=consolidation_mapping,
num_threads=8,
process_group=process_group,
)
except Exception:
# Must re-raise: this function contains a collective barrier, so
# swallowing the exception on a subset of ranks causes deadlock.
logger.error("Consolidation failed, keeping sharded dir", exc_info=True)
raise
if dist.get_rank(group=process_group) == 0:
# _consolidate_shards_distributed does not write the
# index JSON that HuggingFace from_pretrained() needs.
# Always write it - consolidation_mapping is defined for
# both multi-file and single-file (fallback) cases.
_write_safetensors_index(tmp_path, consolidation_mapping)
if os.path.exists(sharded_dir):
shutil.rmtree(sharded_dir)
# Write config / tokenizer / processor into temp dir
engine.model_config.save_pretrained(tmp_path)
if tokenizer is not None:
tokenizer.save_pretrained(tmp_path)
if processor is not None:
processor.save_pretrained(tmp_path)
# Atomically swap temp dir to final path
if os.path.exists(path):
shutil.rmtree(path)
os.rename(tmp_path, path)
dist.barrier(group=process_group)
if is_async:
async_mgr.save(
state_dict=hf_state_dict,
storage_writer=hf_writer,
post_fn=_consolidate_and_cleanup,
)
else:
dcp.save(hf_state_dict, storage_writer=hf_writer)
_consolidate_and_cleanup()
if not is_async:
dist.barrier(group=engine.cpu_group)
def load_model_from_hf(engine: ArchonEngine, path: str) -> None:
"""Load model from HuggingFace format using DCP infrastructure."""
_validate_model_initialized(engine)
if engine.state_dict_adapter is None:
raise RuntimeError("state_dict_adapter is required for HF format")
engine.logger.info(f"Loading HF checkpoint from {path}")
# Get model state dict structure
options = StateDictOptions(full_state_dict=False, cpu_offload=True)
state_dict = _get_merged_state_dict(engine, options)
# Convert to HF format to match checkpoint keys
hf_state_dict = engine.state_dict_adapter.to_hf(state_dict)
# PP mode + weight tying fix: last stage needs embed_tokens weight for output layer
# When tie_word_embeddings=True, HF checkpoint only stores embed_tokens.weight,
# not lm_head.weight. In PP mode, last stage has output.weight but no tok_embeddings,
# so we need to explicitly load embed_tokens.weight even though it's not in state_dict.
pp_weight_tying_fix = (
engine.parallel_dims.pp_enabled
and engine.pp_has_last_stage
and getattr(engine.state_dict_adapter, "enable_weight_tying", False)
and "output.weight" in state_dict
)
if pp_weight_tying_fix:
# Add a placeholder with embed_tokens key so DCP will load it
embed_key = "model.embed_tokens.weight"
if embed_key not in hf_state_dict:
hf_state_dict[embed_key] = torch.empty_like(state_dict["output.weight"])
# Load using DCP with HuggingFaceStorageReader
dcp.load(
hf_state_dict,
storage_reader=engine.state_dict_adapter.get_hf_storage_reader(path),
)
# Convert back to Archon format
archon_state_dict = engine.state_dict_adapter.from_hf(hf_state_dict)
# In PP mode, filter to only keep keys needed by this rank's model_parts
model_keys = set(state_dict.keys())
if engine.parallel_dims.pp_enabled:
archon_state_dict = {
k: v for k, v in archon_state_dict.items() if k in model_keys
}
loaded_keys = set(archon_state_dict.keys())
# Compute key differences for diagnostics
missing_keys = model_keys - loaded_keys
unexpected_keys = loaded_keys - model_keys
# Filter known expected missing keys
expected_missing = set()
for key in list(missing_keys):
# rotary_emb is computed at runtime, not stored in checkpoint
if "rotary_emb" in key:
expected_missing.add(key)
missing_keys -= expected_missing
if dist.get_rank() == 0:
if missing_keys:
engine.logger.warning(
f"Unexpected missing keys in checkpoint: {missing_keys}"
)
if unexpected_keys:
engine.logger.warning(
f"Unexpected extra keys in checkpoint: {unexpected_keys}"
)
# Load into model(s)
load_options = StateDictOptions(strict=False)
if engine.parallel_dims.pp_enabled:
for model_part in engine.model_parts:
set_model_state_dict(
model_part,
model_state_dict=archon_state_dict,
options=load_options,
)
else:
set_model_state_dict(
engine.model,
model_state_dict=archon_state_dict,
options=load_options,
)
dist.barrier(group=engine.cpu_group)
def save_to_dcp(engine: ArchonEngine, path: str, with_optim: bool) -> None:
"""Save model (and optionally optimizer) using DCP format."""
_validate_model_initialized(engine)
os.makedirs(path, exist_ok=True)
dcp_state = DCPState(engine.model_parts, engine.optimizer if with_optim else None)
state_dict = {"dcp": dcp_state}
dcp.save(state_dict, checkpoint_id=path)
dist.barrier(group=engine.cpu_group)
def load_from_dcp(engine: ArchonEngine, path: str, with_optim: bool) -> None:
"""Load model (and optionally optimizer) from DCP format."""
_validate_model_initialized(engine)
dcp_state = DCPState(engine.model_parts, engine.optimizer if with_optim else None)
state_dict = {"dcp": dcp_state}
dcp.load(state_dict=state_dict, checkpoint_id=path)
dist.barrier(group=engine.cpu_group)
def save_optimizer_state(engine: ArchonEngine, path: str) -> None:
"""Save optimizer state to disk (sharded by rank)."""
assert engine.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{engine.world_size}_rank_{rank}.pt"
)
state_dict = engine.optimizer.state_dict()
torch.save(state_dict, shard_path)
dist.barrier(group=engine.cpu_group)
def load_optimizer_state(engine: ArchonEngine, path: str) -> None:
"""Load optimizer state from disk (sharded by rank)."""
assert engine.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{engine.world_size}_rank_{rank}.pt"
)
optimizer_state_dict = torch.load(shard_path, weights_only=False)
engine.optimizer.load_state_dict(optimizer_state_dict)
dist.barrier(group=engine.cpu_group)