11
11
12
12
from colossalai .interface import ModelWrapper
13
13
from colossalai .utils import get_non_persistent_buffers_set
14
+ from colossalai .shardformer .layer .parallel_module import ParallelModule
15
+ from contextlib import contextmanager
14
16
15
17
from .index_file import CheckpointIndexFile
16
18
from .utils import (
32
34
MODEL_META_PREFIX = "pytorch_model-meta-dist-"
33
35
MODEL_WEIGHT_PREFIX = "pytorch_model-dist-"
34
36
SHARD_META_SUFFIX = ".index.json"
37
+ UNSHARD_META_SUFFIX = ".json"
35
38
36
39
37
- def dist_model_state_dict (model : nn .Module , prefix : str = "" , keep_vars : bool = False ):
38
- destination = dict ()
39
- # Save parameters.
40
- for name , param in model .named_parameters ():
41
- if param is None :
42
- continue
43
- destination [prefix + name ] = param
44
- # Save buffers.
45
- non_persist_buffers_set = get_non_persistent_buffers_set (model )
46
- for name , buf in model .named_buffers ():
47
- if buf is not None and name not in non_persist_buffers_set :
48
- buffer = buf if keep_vars else buf .detach ()
49
- destination [prefix + name ] = buffer
50
-
51
- # Save extra states.
52
- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
53
- if (
54
- getattr (model .__class__ , "get_extra_state" , torch .nn .Module .get_extra_state )
55
- is not torch .nn .Module .get_extra_state
56
- ):
57
- extra_state = model .get_extra_state ()
58
- destination [extra_state_key ] = extra_state
59
- return destination
60
-
61
-
62
- def load_state_dict_into_dist_model (
63
- model : nn .Module , state_dict : Dict , prefix : str = "" , keep_vars : bool = False , strict : bool = False
64
- ):
65
- destination = dict ()
66
- # Save parameters.
67
- for name , param in model .named_parameters ():
68
- if param is None :
69
- continue
70
- with torch .no_grad ():
71
- param .copy_ (state_dict [prefix + name ])
72
- # Save buffers.
73
- non_persist_buffers_set = get_non_persistent_buffers_set (model )
74
- for name , buf in model .named_buffers ():
75
- if buf is not None and name not in non_persist_buffers_set :
76
- with torch .no_grad ():
77
- buf .copy_ (state_dict [prefix + name ])
78
-
79
- # Save extra states.
80
- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
81
- if (
82
- getattr (model .__class__ , "get_extra_state" , torch .nn .Module .get_extra_state )
83
- is not torch .nn .Module .get_extra_state
84
- ):
85
- extra_state = model .get_extra_state ()
86
- with torch .no_grad ():
87
- extra_state .copy_ (state_dict [extra_state_key ])
88
- return destination
40
+ @contextmanager
41
+ def RestoreDefaultStateDictBehavior (model ):
42
+ original_methods = {}
43
+ for name , module in model .named_modules ():
44
+ if isinstance (module , ParallelModule ):
45
+ original_methods [module ] = (module ._save_to_state_dict , module ._load_from_state_dict )
46
+ module ._save_to_state_dict = nn .Module ._save_to_state_dict .__get__ (module , nn .Module )
47
+ module ._load_from_state_dict = nn .Module ._load_from_state_dict .__get__ (module , nn .Module )
48
+ try :
49
+ yield model
50
+ finally :
51
+ for module , original_method in original_methods .items ():
52
+ module ._save_to_state_dict , module ._load_from_state_dict = original_method
53
+
89
54
90
55
91
56
def create_model_metadata (
92
- model : nn . Module ,
57
+ model : ModelWrapper ,
93
58
prefix : str = "" ,
94
- tp_size = None ,
95
- tp_rank = None ,
59
+ tp_size : int = None ,
60
+ tp_rank : int = None ,
61
+ zero_size : int = None ,
62
+ zero_rank : int = None ,
96
63
):
97
64
param_origin_shape = model .param_origin_shape
98
65
model = model .unwrap ()
@@ -105,7 +72,7 @@ def create_model_metadata(
105
72
tp_partition_dim = search_tp_partition_dim (
106
73
current_shape = param .shape , original_shape = original_shape , tp_size = tp_size
107
74
)
108
- model_metadata [prefix + name ]["offsets" ] = torch . zeros ( len (original_shape ), dtype = torch . int )
75
+ model_metadata [prefix + name ]["offsets" ] = [ 0 ] * len (original_shape )
109
76
model_metadata [prefix + name ]["lengths" ] = list (param .shape )
110
77
model_metadata [prefix + name ]["global_shape" ] = list (original_shape )
111
78
if tp_partition_dim is not None :
@@ -257,119 +224,9 @@ def is_pytorch_model_meta_dist_file(checkpoint_index_file):
257
224
return False
258
225
259
226
260
- def dist_model_sharder (
261
- model : nn .Module ,
262
- prefix : str = "" ,
263
- keep_vars : bool = False ,
264
- size_per_shard : int = 1024 ,
265
- pinned_state_dicts : Optional [Dict [str , torch .Tensor ]] = None ,
266
- ) -> Iterator [Tuple [OrderedDict , int ]]:
267
- # An internel method that breaks state_dict of model into shards within limited size.
268
-
269
- state_dict_sharder = StateDictSharder (size_per_shard )
270
-
271
- # Save parameters.
272
- for name , param in model .named_parameters ():
273
- if param is None :
274
- continue
275
- if pinned_state_dicts is not None :
276
- if (prefix + name ) not in pinned_state_dicts :
277
- pinned_state_dicts [prefix + name ] = torch .empty_like (param , pin_memory = True , device = "cpu" )
278
- pinned_state_dicts [prefix + name ].copy_ (param )
279
- param = pinned_state_dicts [prefix + name ]
280
- block , block_size = state_dict_sharder .append_param (prefix + name , param )
281
- if block is not None :
282
- yield block , block_size
283
-
284
- # Save buffers.
285
- non_persist_buffers_set = get_non_persistent_buffers_set (model )
286
- for name , buf in model .named_buffers ():
287
- if buf is not None and name not in non_persist_buffers_set :
288
- buffer = buf if keep_vars else buf .detach ()
289
- if pinned_state_dicts is not None :
290
- if (prefix + name ) not in pinned_state_dicts :
291
- pinned_state_dicts [prefix + name ] = torch .empty_like (buffer , pin_memory = True , device = "cpu" )
292
- pinned_state_dicts [prefix + name ].copy_ (buffer )
293
- buffer = pinned_state_dicts [prefix + name ]
294
- block , block_size = state_dict_sharder .append_param (prefix + name , buffer )
295
- if block is not None :
296
- yield block , block_size
297
-
298
- # Save extra states.
299
- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
300
- if (
301
- getattr (model .__class__ , "get_extra_state" , torch .nn .Module .get_extra_state )
302
- is not torch .nn .Module .get_extra_state
303
- ):
304
- extra_state = model .get_extra_state ()
305
- if pinned_state_dicts is not None :
306
- if extra_state_key not in pinned_state_dicts :
307
- pinned_state_dicts [extra_state_key ] = torch .empty_like (extra_state , pin_memory = True , device = "cpu" )
308
- pinned_state_dicts [extra_state_key ].copy_ (extra_state )
309
- extra_state = pinned_state_dicts [extra_state_key ]
310
- block , block_size = state_dict_sharder .append_param (extra_state_key , extra_state )
311
- if block is not None :
312
- yield block , block_size
313
-
314
- # Return the last block in sharder.
315
- yield state_dict_sharder .current_block , state_dict_sharder .current_block_size
316
-
317
-
318
- def save_dist_unshard_model (
319
- model : ModelWrapper ,
320
- model_metadata : Dict ,
321
- checkpoint : str ,
322
- use_safetensors : bool ,
323
- use_async : bool = False ,
324
- dist_id = 0 ,
325
- pinned_state_dicts = None ,
326
- ):
327
- """
328
- Save model state dict to a single file with given checkpointing path.
329
-
330
- Args:
331
- model (nn.Module): Model on local device to be saved.
332
- checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
333
- gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
334
- use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
335
- use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
336
- """
337
-
338
- model = model .unwrap ()
339
-
340
- # The logic of collecting parameter shards along tp degree
341
- # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
342
- state_dict = dist_model_state_dict (model )
343
-
344
- Path (checkpoint ).mkdir (parents = True , exist_ok = True )
345
- file_name = f"{ MODEL_WEIGHT_PREFIX } { dist_id :05d} .bin"
346
- if use_async :
347
- file_name = file_name .replace (".bin" , ".safetensors" )
348
- checkpoint_file = os .path .join (checkpoint , file_name )
349
- metadata_file = os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} .json" )
350
- save_metadata (model_metadata , metadata_file , file_name )
351
-
352
- if use_async :
353
- from colossalai .utils .safetensors import save
354
-
355
- if id (model ) not in pinned_state_dicts :
356
- pinned_state_dicts [id (model )] = create_pinned_state_dict (state_dict )
357
- for name , param in state_dict .items ():
358
- pinned_state_dicts [id (model )][name ].copy_ (param )
359
- state_dict [name ] = pinned_state_dicts [id (model )][name ]
360
- writer = save (path = checkpoint_file , state_dict = state_dict )
361
- return writer
362
- else :
363
- save_state_dict (state_dict , checkpoint_file , use_safetensors )
364
- return None
365
-
366
-
367
227
def load_dist_model (
368
- model : ModelWrapper ,
369
228
model_metadata : Dict ,
370
229
checkpoint : str ,
371
- low_cpu_mem_mode : bool = True ,
372
- num_threads : int = 1 ,
373
230
):
374
231
"""
375
232
Load model from a single file with the given path of checkpoint.
@@ -380,10 +237,6 @@ def load_dist_model(
380
237
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
381
238
This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled.
382
239
"""
383
-
384
- model_before_wrapping = model
385
- model = model .unwrap ()
386
-
387
240
metadata_loaded = load_metadata (checkpoint )
388
241
389
242
load_files = {}
@@ -420,92 +273,14 @@ def load_dist_model(
420
273
)
421
274
state_dict [key ] = state
422
275
423
- if not low_cpu_mem_mode :
424
- state_dict = create_pinned_state_dict (state_dict , empty = False , num_threads = num_threads )
425
-
426
- load_state_dict_into_dist_model (model = model , state_dict = state_dict )
427
-
428
- # Update master params if mixed-precision training is enabled.
429
- model_before_wrapping .update_master_params ()
430
-
276
+ return state_dict
431
277
432
- def save_dist_sharded_model (
433
- model : ModelWrapper ,
434
- model_metadata : Dict ,
435
- checkpoint : str ,
436
- prefix : Optional [str ] = None ,
437
- size_per_shard : int = 1024 ,
438
- use_safetensors : bool = False ,
439
- use_async : bool = False ,
440
- dist_id : int = 0 ,
441
- pinned_state_dicts = None ,
442
- ) -> None :
443
- """
444
- Save sharded model checkpoint under the given checkpointing path.
445
- The following files will be created under the path:
446
- - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
447
- - Multiple files that store state tensors of models.
448
- If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
449
- If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
450
-
451
-
452
- Args:
453
- model (nn.Module): Model on local device to be saved.
454
- checkpoint (str): Checkpointing path which should be a directory path.
455
- gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
456
- prefix (str, optional): Perfix of file to save. Defaults to None.
457
- size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
458
- use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
459
- use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
460
- """
461
-
462
- model = model .unwrap ()
463
-
464
- if os .path .isfile (checkpoint ):
465
- logging .error (f"Provided path ({ checkpoint } ) should be a directory, not a file" )
466
- return
467
-
468
- Path (checkpoint ).mkdir (parents = True , exist_ok = True )
469
- # Devices along the same dp_group share the same copies of model.
470
- # So only let the device with dp_rank == 0 and sp_rank == 0 save the model.
471
-
472
- if use_async :
473
- if id (model ) not in pinned_state_dicts :
474
- pinned_state_dicts [id (model )] = {}
475
- pinned_state_dicts = pinned_state_dicts [id (model )]
476
- else :
477
- pinned_state_dicts = None
478
- state_dict_shard = dist_model_sharder (model , size_per_shard = size_per_shard , pinned_state_dicts = pinned_state_dicts )
479
- weights_name , _ = get_model_base_filenames (prefix , use_safetensors )
480
- index_file = CheckpointIndexFile (checkpoint )
481
-
482
- # Manage filenames of sharded weights and index file for each pipeline stage.
278
+ def get_dist_files_name (weights_name , dist_id ):
483
279
weights_name = weights_name .replace (".bin" , f"-dist-{ dist_id :05d} -shard.bin" )
484
280
weights_name = weights_name .replace (".safetensors" , f"-dist-{ dist_id :05d} -shard.safetensors" )
485
- metadata_file = os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} { SHARD_META_SUFFIX } " )
486
- async_writers = []
487
- if use_async :
488
- total_size , writers = async_save_state_dict_shards (
489
- sharded_state_dict = state_dict_shard ,
490
- checkpoint = checkpoint ,
491
- index_file = index_file ,
492
- base_filename = weights_name ,
493
- is_master = True ,
494
- state_preprocess = False ,
495
- )
496
- async_writers .extend (writers )
497
- else :
498
- total_size = save_state_dict_shards (
499
- sharded_state_dict = state_dict_shard ,
500
- checkpoint = checkpoint ,
501
- index_file = index_file ,
502
- base_filename = weights_name ,
503
- is_master = True ,
504
- use_safetensors = use_safetensors ,
505
- use_pp_format = True ,
506
- )
507
- for k , _ in model_metadata .items ():
508
- model_metadata [k ]["file" ] = index_file .get_checkpoint_file (k )
281
+ return weights_name
509
282
510
- save_metadata (model_metadata , metadata_file , total_size = total_size )
511
- return async_writers
283
+ def get_dist_meta_file_name (checkpoint , dist_id , use_safetensors ):
284
+ if use_safetensors :
285
+ return os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} { SHARD_META_SUFFIX } " )
286
+ return os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} { UNSHARD_META_SUFFIX } " )
0 commit comments