Skip to content

Commit bc2c9fa

Browse files
authored
Merge branch 'master' into patch-1
2 parents 07cc866 + fae714d commit bc2c9fa

File tree

5 files changed

+53
-15
lines changed

5 files changed

+53
-15
lines changed

deepspeed/inference/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
174174
values for :any:`DeepSpeedMoEConfig`.
175175
"""
176176

177+
keep_module_on_host: bool = False
178+
"""
179+
When loading checkpoints to model parameters, they are moved to the device. In very large models
180+
this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on
181+
host and not move them directly to the device (giving an option to quantize checkpoint data before
182+
moving it to the device for example).
183+
Set only for models with injection policies and auto TP.
184+
"""
185+
177186
quant: QuantizationConfig = {}
178187
"""
179188
NOTE: only works for int8 dtype.

deepspeed/inference/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __init__(self, model, config):
169169
is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta'
170170
if is_meta_device:
171171
self.module.to_empty(device=device)
172-
else:
172+
elif not config.keep_module_on_host:
173173
self.module.to(device)
174174

175175
if config.tensor_parallel.tp_size > 1:

deepspeed/module_inject/auto_tp.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
1818

1919

20-
def move(tensor, device):
20+
def move(tensor, device, copy=True):
2121
if tensor.is_meta:
2222
return torch.empty_like(tensor, device=device)
2323
else:
2424
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
2525
# Using copy=True instead of clone() will help in case of cpu --> cpu.
2626
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
27-
return tensor.to(device, copy=True)
27+
return tensor.to(device, copy=copy)
2828

2929

3030
class ReplaceWithTensorSlicing:
@@ -189,7 +189,14 @@ def load(module, state_dict, prefix, mp_group=None):
189189

190190
class AutoTP():
191191

192-
def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl):
192+
def __init__(self,
193+
module,
194+
all_reduce_linears,
195+
prefix,
196+
state_dict,
197+
linear_layer_setting,
198+
orig_layer_impl,
199+
keep_module_on_host=False):
193200
self.module = module
194201
self.all_reduce_linears = all_reduce_linears
195202
self.prefix = prefix
@@ -201,6 +208,7 @@ def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_
201208
self.orig_layer_impl = orig_layer_impl
202209
self.linear_policies = None
203210
self.conv_linear_layer = False
211+
self.keep_module_on_host = keep_module_on_host
204212

205213
def in_module_list(module, module_list):
206214
for item in module_list:
@@ -331,6 +339,10 @@ def set_tensor_parallel_config(self, mp_size, mp_group):
331339
def _replace(self, child, name, conv_linear_layer):
332340
if getattr(child, "replaced", False) == True:
333341
return
342+
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
343+
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
344+
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
345+
return_new_copy = not self.keep_module_on_host
334346
weight_shape = child.weight.shape
335347
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
336348
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
@@ -368,18 +380,17 @@ def _replace(self, child, name, conv_linear_layer):
368380
data = child.weight.data.split(get_shard_size_list(
369381
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
370382
dim=1)
371-
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
383+
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
372384
del data
373385

374386
setattr(child, "replaced", True)
375387
if name == "lm_head" or name == 'embed_out':
376388
return LmHeadLinearAllreduce(
377389
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
378390
child.bias if child.bias is None else torch.nn.parameter.Parameter(
379-
move(child.bias,
380-
get_accelerator().current_device_name())), self.mp_group)
391+
move(child.bias, device_name, return_new_copy)), self.mp_group)
381392
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
382-
torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group)
393+
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
383394
else:
384395

385396
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
@@ -392,22 +403,22 @@ def _replace(self, child, name, conv_linear_layer):
392403
#The copy is a regular copy, The shape of dst and src is the same
393404
data_dc = move(
394405
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
395-
get_accelerator().current_device_name())
406+
device_name, return_new_copy)
396407

397408
bias_data_dc = None if child.bias is None else move(
398409
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
399-
get_accelerator().current_device_name())
410+
device_name, return_new_copy)
400411
else:
401412
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
402413
dim=1 if self.conv_linear_layer else 0)
403-
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
414+
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
404415
del data
405416

406417
if child.bias is not None:
407418
bias_data = child.bias.data.split(get_shard_size_list(
408419
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
409420
dim=0)
410-
bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name())
421+
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
411422
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
412423
del bias_data
413424
else:

deepspeed/module_inject/replace_module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
268268
#mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)
269269

270270
# 1. Create AutoTP object
271-
_autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl)
271+
_autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl,
272+
config.keep_module_on_host)
272273

273274
# 2. Set the tensor parallelism config
274275
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)

tests/unit/inference/test_inference.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dty
554554

555555

556556
@pytest.mark.seq_inference
557+
@pytest.mark.parametrize('keep_module_on_host', [True, False])
557558
@pytest.mark.parametrize(
558559
"model_w_task",
559560
[("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")],
@@ -570,6 +571,7 @@ def test(
570571
inf_kwargs,
571572
assert_fn,
572573
dtype,
574+
keep_module_on_host,
573575
):
574576
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
575577
if invalid_test_msg:
@@ -592,13 +594,20 @@ def test(
592594
framework="pt")
593595
bs_output = pipe(query, **inf_kwargs)
594596

595-
pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
597+
pipe.model = deepspeed.init_inference(pipe.model,
598+
mp_size=world_size,
599+
dtype=dtype,
600+
keep_module_on_host=keep_module_on_host)
596601
ds_output = pipe(query, **inf_kwargs)
597602

598603
print(local_rank, "baseline", bs_output)
599604
print(local_rank, "deepspeed", ds_output)
600605
assert assert_fn(bs_output, ds_output)
601606

607+
if keep_module_on_host:
608+
for name, param in model.named_parameters():
609+
assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"
610+
602611
@pytest.mark.world_size(3)
603612
def test_odd_world_size(
604613
self,
@@ -607,6 +616,7 @@ def test_odd_world_size(
607616
inf_kwargs,
608617
assert_fn,
609618
dtype,
619+
keep_module_on_host,
610620
):
611621
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
612622
if invalid_test_msg:
@@ -624,13 +634,20 @@ def test_odd_world_size(
624634
framework="pt")
625635
bs_output = pipe(query, **inf_kwargs)
626636

627-
pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
637+
pipe.model = deepspeed.init_inference(pipe.model,
638+
mp_size=world_size,
639+
dtype=dtype,
640+
keep_module_on_host=keep_module_on_host)
628641
ds_output = pipe(query, **inf_kwargs)
629642

630643
print(local_rank, "baseline", bs_output)
631644
print(local_rank, "deepspeed", ds_output)
632645
assert assert_fn(bs_output, ds_output)
633646

647+
if keep_module_on_host:
648+
for name, param in model.named_parameters():
649+
assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"
650+
634651

635652
@pytest.mark.nightly
636653
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)