Skip to content

Commit c1af73f

Browse files
Improving memory utilization of Z2+MoE (#2079)
* Shards expert parameter groups * Do upscaling, optimizer and deletion of fp32 grads one-by-one on each parameter group in zero-2 Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent b052378 commit c1af73f

File tree

2 files changed

+105
-63
lines changed

2 files changed

+105
-63
lines changed

deepspeed/moe/utils.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ def split_params_grads_into_shared_and_expert_params(
5959
return shared_grads, expert_grads
6060

6161

62-
def split_params_into_different_moe_groups_for_optimizer(
63-
param_groups: Tuple[Dict]) -> Tuple[Dict]:
62+
def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict],
63+
max_group_size=178956971
64+
) -> Tuple[Dict]:
6465
"""Split parameters into different MoE groups for optimizer
6566
6667
Args:
@@ -112,8 +113,32 @@ def split_params_into_different_moe_groups_for_optimizer(
112113
param_group['params'] = new_params
113114

114115
# Flatten the moe groups
115-
for k, v in group_moe.items():
116-
for k1, v1 in v.items():
117-
param_groups.append(v1)
116+
if max_group_size is not None:
117+
for k, v in group_moe.items():
118+
for k1, v1 in v.items():
119+
cur_group = []
120+
all_groups = []
121+
size_of_cur_group = 0
122+
for param in v1['params']:
123+
if size_of_cur_group + param.numel() <= max_group_size:
124+
cur_group.append(param)
125+
size_of_cur_group += param.numel()
126+
else:
127+
all_groups.append(cur_group)
128+
cur_group = [param]
129+
size_of_cur_group = param.numel()
130+
if cur_group:
131+
all_groups.append(cur_group)
132+
for group in all_groups:
133+
new_dict = {}
134+
for key, val in v1.items():
135+
if key != 'params':
136+
new_dict[key] = val
137+
new_dict['params'] = group
138+
param_groups.append(new_dict)
139+
else:
140+
for k, v in group_moe.items():
141+
for k1, v1 in v.items():
142+
param_groups.append(v1)
118143

119144
return tuple(param_groups)

deepspeed/runtime/zero/stage_1_and_2.py

+75-58
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,44 @@ def override_loss_scale(self, loss_scale):
16531653
self.custom_loss_scaler = True
16541654
self.external_loss_scale = loss_scale
16551655

1656+
def scaled_global_norm(self, norm_type=2):
1657+
assert norm_type == 2, "only L2 norm supported"
1658+
norm_groups = []
1659+
for i, group in enumerate(self.bit16_groups):
1660+
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
1661+
if self.cpu_offload:
1662+
norm_groups.append(
1663+
self.complete_grad_norm_calculation_for_cpu_offload(
1664+
self.params_in_partition[i]))
1665+
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
1666+
else:
1667+
norm_groups.append(
1668+
self.get_grad_norm_direct(self.averaged_gradients[i],
1669+
self.params_in_partition[i]))
1670+
1671+
if self.has_moe_layers:
1672+
self._average_expert_grad_norms(norm_groups)
1673+
1674+
# note that the get_global_norm function only supports l2 norm
1675+
return get_global_norm(norm_list=norm_groups)
1676+
1677+
def get_bit16_param_group(self, group_no):
1678+
bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]
1679+
partition_id = dist.get_rank(group=self.real_dp_process_group[group_no])
1680+
return [
1681+
bit16_partitions[dist.get_rank(group=self.real_dp_process_group[group_no])]
1682+
]
1683+
1684+
def _optimizer_step(self, group_no):
1685+
original_param_groups = self.optimizer.param_groups
1686+
self.optimizer.param_groups = [original_param_groups[group_no]]
1687+
from deepspeed.ops.adam import DeepSpeedCPUAdam
1688+
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
1689+
self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)])
1690+
else:
1691+
self.optimizer.step()
1692+
self.optimizer.param_groups = original_param_groups
1693+
16561694
def step(self, closure=None):
16571695
"""
16581696
Not supporting closure.
@@ -1671,7 +1709,6 @@ def step(self, closure=None):
16711709
prev_scale = self.loss_scale
16721710
self._update_scale(self.overflow)
16731711
if self.overflow:
1674-
16751712
if dist.get_rank() == 0:
16761713
logger.info(
16771714
"[deepspeed] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
@@ -1692,22 +1729,33 @@ def step(self, closure=None):
16921729
self.stop_timers(timer_names)
16931730
return
16941731

1695-
self.start_timers([OPTIMIZER_GRADIENTS])
1696-
norm_groups = []
1697-
single_partition_grad_groups = []
1698-
# skip = False
1732+
# Step 1:- Calculate gradient norm using fp-16 grads
1733+
see_memory_usage('Before norm calculation')
1734+
scaled_global_grad_norm = self.scaled_global_norm()
1735+
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
1736+
1737+
see_memory_usage('After norm before optimizer')
1738+
# Step 2:- run optimizer and upscaling simultaneously
16991739
for i, group in enumerate(self.bit16_groups):
1740+
self.start_timers([OPTIMIZER_GRADIENTS])
17001741
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
17011742
if self.cpu_offload:
1702-
norm_groups.append(
1703-
self.complete_grad_norm_calculation_for_cpu_offload(
1704-
self.params_in_partition[i]))
17051743
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
1706-
else:
1707-
norm_groups.append(
1708-
self.get_grad_norm_direct(self.averaged_gradients[i],
1709-
self.params_in_partition[i]))
1744+
self.unscale_and_clip_grads([single_grad_partition],
1745+
scaled_global_grad_norm)
1746+
self.stop_timers([OPTIMIZER_GRADIENTS])
1747+
self.start_timers([OPTIMIZER_STEP])
1748+
self._optimizer_step(i)
1749+
1750+
from deepspeed.ops.adam import DeepSpeedCPUAdam
1751+
if not (type(self.optimizer) == DeepSpeedCPUAdam
1752+
and self.dtype == torch.half):
1753+
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
1754+
fp32_partition = self.single_partition_of_fp32_groups[i]
1755+
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
17101756

1757+
self.stop_timers([OPTIMIZER_STEP])
1758+
else:
17111759
# free gradients for all the parameters that are not updated by this process(ZeRO stage2)
17121760
self.free_grad_in_param_list(self.params_not_in_partition[i])
17131761

@@ -1732,53 +1780,22 @@ def step(self, closure=None):
17321780

17331781
self.averaged_gradients[i] = None
17341782

1735-
single_partition_grad_groups.append(single_grad_partition)
1736-
1737-
if self.has_moe_layers:
1738-
self._average_expert_grad_norms(norm_groups)
1739-
1740-
scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)
1741-
self.unscale_and_clip_grads(single_partition_grad_groups,
1742-
scaled_global_grad_norm)
1743-
1744-
# Stash unscaled gradient norm
1745-
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
1746-
1747-
self.stop_timers([OPTIMIZER_GRADIENTS])
1748-
1749-
self.start_timers([OPTIMIZER_STEP])
1750-
if self.deepspeed_adam_offload:
1751-
from deepspeed.ops.adam import DeepSpeedCPUAdam
1752-
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
1753-
bit16_param_groups = [
1754-
[
1755-
bit16_partitions[dist.get_rank(
1756-
group=self.real_dp_process_group[group_id])]
1757-
] for group_id,
1758-
bit16_partitions in enumerate(self.parallel_partitioned_bit16_groups)
1759-
]
1760-
self.optimizer.step(fp16_param_groups=bit16_param_groups)
1761-
else:
1762-
self.optimizer.step()
1763-
for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
1764-
partition_id = dist.get_rank(
1765-
group=self.real_dp_process_group[group_id])
1766-
1767-
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
1768-
else:
1769-
self.optimizer.step()
1770-
1771-
# get rid of the fp32 gradients. Not needed anymore
1772-
if not self.cpu_offload:
1773-
for group in self.single_partition_of_fp32_groups:
1774-
group.grad = None # in step
1775-
1776-
for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
1777-
partition_id = dist.get_rank(group=self.real_dp_process_group[group_id])
1783+
self.unscale_and_clip_grads([single_grad_partition],
1784+
scaled_global_grad_norm)
1785+
self.stop_timers([OPTIMIZER_GRADIENTS])
1786+
1787+
# Step 3:- run the optimizer if no offloading
1788+
self.start_timers([OPTIMIZER_STEP])
1789+
self._optimizer_step(i)
1790+
# Step 4:- get rid of the fp32 gradients. Not needed anymore
1791+
self.single_partition_of_fp32_groups[i].grad = None
1792+
del single_grad_partition
1793+
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
1794+
fp32_partition = self.single_partition_of_fp32_groups[i]
17781795
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
1796+
self.stop_timers([OPTIMIZER_STEP])
17791797

1780-
self.stop_timers([OPTIMIZER_STEP])
1781-
1798+
see_memory_usage('After optimizer before all-gather')
17821799
if self.cpu_offload:
17831800
self.reset_cpu_buffers()
17841801

@@ -1794,7 +1811,7 @@ def step(self, closure=None):
17941811
self.stop_timers([OPTIMIZER_ALLGATHER])
17951812

17961813
# TODO: we probably don't need this? just to be safe
1797-
for i in range(len(norm_groups)):
1814+
for i in range(len(self.bit16_groups)):
17981815
self._update_model_bit16_weights(i)
17991816

18001817
self.log_timers(timer_names)

0 commit comments

Comments
 (0)