@@ -1653,6 +1653,44 @@ def override_loss_scale(self, loss_scale):
1653
1653
self .custom_loss_scaler = True
1654
1654
self .external_loss_scale = loss_scale
1655
1655
1656
+ def scaled_global_norm (self , norm_type = 2 ):
1657
+ assert norm_type == 2 , "only L2 norm supported"
1658
+ norm_groups = []
1659
+ for i , group in enumerate (self .bit16_groups ):
1660
+ partition_id = dist .get_rank (group = self .real_dp_process_group [i ])
1661
+ if self .cpu_offload :
1662
+ norm_groups .append (
1663
+ self .complete_grad_norm_calculation_for_cpu_offload (
1664
+ self .params_in_partition [i ]))
1665
+ single_grad_partition = self .single_partition_of_fp32_groups [i ].grad
1666
+ else :
1667
+ norm_groups .append (
1668
+ self .get_grad_norm_direct (self .averaged_gradients [i ],
1669
+ self .params_in_partition [i ]))
1670
+
1671
+ if self .has_moe_layers :
1672
+ self ._average_expert_grad_norms (norm_groups )
1673
+
1674
+ # note that the get_global_norm function only supports l2 norm
1675
+ return get_global_norm (norm_list = norm_groups )
1676
+
1677
+ def get_bit16_param_group (self , group_no ):
1678
+ bit16_partitions = self .parallel_partitioned_bit16_groups [group_no ]
1679
+ partition_id = dist .get_rank (group = self .real_dp_process_group [group_no ])
1680
+ return [
1681
+ bit16_partitions [dist .get_rank (group = self .real_dp_process_group [group_no ])]
1682
+ ]
1683
+
1684
+ def _optimizer_step (self , group_no ):
1685
+ original_param_groups = self .optimizer .param_groups
1686
+ self .optimizer .param_groups = [original_param_groups [group_no ]]
1687
+ from deepspeed .ops .adam import DeepSpeedCPUAdam
1688
+ if type (self .optimizer ) == DeepSpeedCPUAdam and self .dtype == torch .half :
1689
+ self .optimizer .step (fp16_param_groups = [self .get_bit16_param_group (group_no )])
1690
+ else :
1691
+ self .optimizer .step ()
1692
+ self .optimizer .param_groups = original_param_groups
1693
+
1656
1694
def step (self , closure = None ):
1657
1695
"""
1658
1696
Not supporting closure.
@@ -1671,7 +1709,6 @@ def step(self, closure=None):
1671
1709
prev_scale = self .loss_scale
1672
1710
self ._update_scale (self .overflow )
1673
1711
if self .overflow :
1674
-
1675
1712
if dist .get_rank () == 0 :
1676
1713
logger .info (
1677
1714
"[deepspeed] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
@@ -1692,22 +1729,33 @@ def step(self, closure=None):
1692
1729
self .stop_timers (timer_names )
1693
1730
return
1694
1731
1695
- self .start_timers ([OPTIMIZER_GRADIENTS ])
1696
- norm_groups = []
1697
- single_partition_grad_groups = []
1698
- # skip = False
1732
+ # Step 1:- Calculate gradient norm using fp-16 grads
1733
+ see_memory_usage ('Before norm calculation' )
1734
+ scaled_global_grad_norm = self .scaled_global_norm ()
1735
+ self ._global_grad_norm = scaled_global_grad_norm / self .loss_scale
1736
+
1737
+ see_memory_usage ('After norm before optimizer' )
1738
+ # Step 2:- run optimizer and upscaling simultaneously
1699
1739
for i , group in enumerate (self .bit16_groups ):
1740
+ self .start_timers ([OPTIMIZER_GRADIENTS ])
1700
1741
partition_id = dist .get_rank (group = self .real_dp_process_group [i ])
1701
1742
if self .cpu_offload :
1702
- norm_groups .append (
1703
- self .complete_grad_norm_calculation_for_cpu_offload (
1704
- self .params_in_partition [i ]))
1705
1743
single_grad_partition = self .single_partition_of_fp32_groups [i ].grad
1706
- else :
1707
- norm_groups .append (
1708
- self .get_grad_norm_direct (self .averaged_gradients [i ],
1709
- self .params_in_partition [i ]))
1744
+ self .unscale_and_clip_grads ([single_grad_partition ],
1745
+ scaled_global_grad_norm )
1746
+ self .stop_timers ([OPTIMIZER_GRADIENTS ])
1747
+ self .start_timers ([OPTIMIZER_STEP ])
1748
+ self ._optimizer_step (i )
1749
+
1750
+ from deepspeed .ops .adam import DeepSpeedCPUAdam
1751
+ if not (type (self .optimizer ) == DeepSpeedCPUAdam
1752
+ and self .dtype == torch .half ):
1753
+ bit16_partitions = self .parallel_partitioned_bit16_groups [i ]
1754
+ fp32_partition = self .single_partition_of_fp32_groups [i ]
1755
+ bit16_partitions [partition_id ].data .copy_ (fp32_partition .data )
1710
1756
1757
+ self .stop_timers ([OPTIMIZER_STEP ])
1758
+ else :
1711
1759
# free gradients for all the parameters that are not updated by this process(ZeRO stage2)
1712
1760
self .free_grad_in_param_list (self .params_not_in_partition [i ])
1713
1761
@@ -1732,53 +1780,22 @@ def step(self, closure=None):
1732
1780
1733
1781
self .averaged_gradients [i ] = None
1734
1782
1735
- single_partition_grad_groups .append (single_grad_partition )
1736
-
1737
- if self .has_moe_layers :
1738
- self ._average_expert_grad_norms (norm_groups )
1739
-
1740
- scaled_global_grad_norm = get_global_norm (norm_list = norm_groups )
1741
- self .unscale_and_clip_grads (single_partition_grad_groups ,
1742
- scaled_global_grad_norm )
1743
-
1744
- # Stash unscaled gradient norm
1745
- self ._global_grad_norm = scaled_global_grad_norm / self .loss_scale
1746
-
1747
- self .stop_timers ([OPTIMIZER_GRADIENTS ])
1748
-
1749
- self .start_timers ([OPTIMIZER_STEP ])
1750
- if self .deepspeed_adam_offload :
1751
- from deepspeed .ops .adam import DeepSpeedCPUAdam
1752
- if type (self .optimizer ) == DeepSpeedCPUAdam and self .dtype == torch .half :
1753
- bit16_param_groups = [
1754
- [
1755
- bit16_partitions [dist .get_rank (
1756
- group = self .real_dp_process_group [group_id ])]
1757
- ] for group_id ,
1758
- bit16_partitions in enumerate (self .parallel_partitioned_bit16_groups )
1759
- ]
1760
- self .optimizer .step (fp16_param_groups = bit16_param_groups )
1761
- else :
1762
- self .optimizer .step ()
1763
- for group_id , (bit16_partitions , fp32_partition ) in enumerate (zip (self .parallel_partitioned_bit16_groups , self .single_partition_of_fp32_groups )):
1764
- partition_id = dist .get_rank (
1765
- group = self .real_dp_process_group [group_id ])
1766
-
1767
- bit16_partitions [partition_id ].data .copy_ (fp32_partition .data )
1768
- else :
1769
- self .optimizer .step ()
1770
-
1771
- # get rid of the fp32 gradients. Not needed anymore
1772
- if not self .cpu_offload :
1773
- for group in self .single_partition_of_fp32_groups :
1774
- group .grad = None # in step
1775
-
1776
- for group_id , (bit16_partitions , fp32_partition ) in enumerate (zip (self .parallel_partitioned_bit16_groups , self .single_partition_of_fp32_groups )):
1777
- partition_id = dist .get_rank (group = self .real_dp_process_group [group_id ])
1783
+ self .unscale_and_clip_grads ([single_grad_partition ],
1784
+ scaled_global_grad_norm )
1785
+ self .stop_timers ([OPTIMIZER_GRADIENTS ])
1786
+
1787
+ # Step 3:- run the optimizer if no offloading
1788
+ self .start_timers ([OPTIMIZER_STEP ])
1789
+ self ._optimizer_step (i )
1790
+ # Step 4:- get rid of the fp32 gradients. Not needed anymore
1791
+ self .single_partition_of_fp32_groups [i ].grad = None
1792
+ del single_grad_partition
1793
+ bit16_partitions = self .parallel_partitioned_bit16_groups [i ]
1794
+ fp32_partition = self .single_partition_of_fp32_groups [i ]
1778
1795
bit16_partitions [partition_id ].data .copy_ (fp32_partition .data )
1796
+ self .stop_timers ([OPTIMIZER_STEP ])
1779
1797
1780
- self .stop_timers ([OPTIMIZER_STEP ])
1781
-
1798
+ see_memory_usage ('After optimizer before all-gather' )
1782
1799
if self .cpu_offload :
1783
1800
self .reset_cpu_buffers ()
1784
1801
@@ -1794,7 +1811,7 @@ def step(self, closure=None):
1794
1811
self .stop_timers ([OPTIMIZER_ALLGATHER ])
1795
1812
1796
1813
# TODO: we probably don't need this? just to be safe
1797
- for i in range (len (norm_groups )):
1814
+ for i in range (len (self . bit16_groups )):
1798
1815
self ._update_model_bit16_weights (i )
1799
1816
1800
1817
self .log_timers (timer_names )
0 commit comments