1717from deepspeed .module_inject .tp_shard import get_shard_size , get_shard_size_list
1818
1919
20- def move (tensor , device ):
20+ def move (tensor , device , copy = True ):
2121 if tensor .is_meta :
2222 return torch .empty_like (tensor , device = device )
2323 else :
2424 # Using new tensors help in freeing memory (after split for example) was done before by calling clone().
2525 # Using copy=True instead of clone() will help in case of cpu --> cpu.
2626 # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
27- return tensor .to (device , copy = True )
27+ return tensor .to (device , copy = copy )
2828
2929
3030class ReplaceWithTensorSlicing :
@@ -189,7 +189,14 @@ def load(module, state_dict, prefix, mp_group=None):
189189
190190class AutoTP ():
191191
192- def __init__ (self , module , all_reduce_linears , prefix , state_dict , linear_layer_setting , orig_layer_impl ):
192+ def __init__ (self ,
193+ module ,
194+ all_reduce_linears ,
195+ prefix ,
196+ state_dict ,
197+ linear_layer_setting ,
198+ orig_layer_impl ,
199+ keep_module_on_host = False ):
193200 self .module = module
194201 self .all_reduce_linears = all_reduce_linears
195202 self .prefix = prefix
@@ -201,6 +208,7 @@ def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_
201208 self .orig_layer_impl = orig_layer_impl
202209 self .linear_policies = None
203210 self .conv_linear_layer = False
211+ self .keep_module_on_host = keep_module_on_host
204212
205213 def in_module_list (module , module_list ):
206214 for item in module_list :
@@ -331,6 +339,10 @@ def set_tensor_parallel_config(self, mp_size, mp_group):
331339 def _replace (self , child , name , conv_linear_layer ):
332340 if getattr (child , "replaced" , False ) == True :
333341 return
342+ device_name = 'cpu' if self .keep_module_on_host else get_accelerator ().current_device_name ()
343+ # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
344+ # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
345+ return_new_copy = not self .keep_module_on_host
334346 weight_shape = child .weight .shape
335347 mp_replace = ReplaceWithTensorSlicing (mp_group = self .mp_group )
336348 # For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
@@ -368,18 +380,17 @@ def _replace(self, child, name, conv_linear_layer):
368380 data = child .weight .data .split (get_shard_size_list (
369381 weight_shape [0 ] if self .conv_linear_layer else weight_shape [1 ], self .mp_size , name ),
370382 dim = 1 )
371- data_dc = move (data [mp_replace .gpu_index ], get_accelerator (). current_device_name () ).detach ()
383+ data_dc = move (data [mp_replace .gpu_index ], device_name , return_new_copy ).detach ()
372384 del data
373385
374386 setattr (child , "replaced" , True )
375387 if name == "lm_head" or name == 'embed_out' :
376388 return LmHeadLinearAllreduce (
377389 torch .nn .parameter .Parameter (data_dc , requires_grad = False ), dist .get_rank (), dist .get_world_size (),
378390 child .bias if child .bias is None else torch .nn .parameter .Parameter (
379- move (child .bias ,
380- get_accelerator ().current_device_name ())), self .mp_group )
391+ move (child .bias , device_name , return_new_copy )), self .mp_group )
381392 return LinearAllreduce (torch .nn .parameter .Parameter (data_dc , requires_grad = False ), child .bias if child .bias is None else \
382- torch .nn .parameter .Parameter (move (child .bias , get_accelerator (). current_device_name () )), self .mp_group )
393+ torch .nn .parameter .Parameter (move (child .bias , device_name , return_new_copy )), self .mp_group )
383394 else :
384395
385396 # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
@@ -392,22 +403,22 @@ def _replace(self, child, name, conv_linear_layer):
392403 #The copy is a regular copy, The shape of dst and src is the same
393404 data_dc = move (
394405 prepare_tp_fused_qkvw (self .module , child .weight .data , self .mp_size , mp_replace .gpu_index ),
395- get_accelerator (). current_device_name () )
406+ device_name , return_new_copy )
396407
397408 bias_data_dc = None if child .bias is None else move (
398409 prepare_tp_fused_qkvw (self .module , child .bias .data , self .mp_size , mp_replace .gpu_index ),
399- get_accelerator (). current_device_name () )
410+ device_name , return_new_copy )
400411 else :
401412 data = child .weight .data .split (get_shard_size_list (weight_shape [0 ], self .mp_size , name ),
402413 dim = 1 if self .conv_linear_layer else 0 )
403- data_dc = move (data [mp_replace .gpu_index ], get_accelerator (). current_device_name () ).detach ()
414+ data_dc = move (data [mp_replace .gpu_index ], device_name , return_new_copy ).detach ()
404415 del data
405416
406417 if child .bias is not None :
407418 bias_data = child .bias .data .split (get_shard_size_list (
408419 weight_shape [1 ] if self .conv_linear_layer else weight_shape [0 ], self .mp_size , name ),
409420 dim = 0 )
410- bias_data = move (bias_data [mp_replace .gpu_index ], get_accelerator (). current_device_name () )
421+ bias_data = move (bias_data [mp_replace .gpu_index ], device_name , return_new_copy )
411422 bias_data_dc = torch .nn .parameter .Parameter (bias_data , requires_grad = False )
412423 del bias_data
413424 else :
0 commit comments