From 6d4b725bca6f0061a5ea61d5779e38e8635417aa Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 7 Jun 2024 09:30:05 +0800 Subject: [PATCH 1/2] disable grad scaler when dtype is bf16 --- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 4f3323f2cc..074cd1cdf7 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -91,16 +91,21 @@ def __init__(self, else: scaler_type = GradScaler + enable_loss_scaler = dtype != torch.bfloat16 + if loss_scale == 'dynamic': # If loss_scale is a string, it must be 'dynamic', then dynamic # loss scaling will be used. - self.loss_scaler = scaler_type() + self.loss_scaler = scaler_type(enabled=enable_loss_scaler) elif isinstance(loss_scale, float): # Static loss scaling self._scale_update_param = loss_scale - self.loss_scaler = scaler_type(init_scale=loss_scale) + self.loss_scaler = scaler_type( + init_scale=loss_scale, enabled=enable_loss_scaler) elif isinstance(loss_scale, dict): # More specific configuration. + loss_scale[ + 'enabled'] = loss_scale['enabled'] and enable_loss_scaler self.loss_scaler = scaler_type(**loss_scale) else: raise TypeError('loss_scale must be of type float, dict, or ' From 3db2ee7a35ba3d306977fd430c9a22fb17afb671 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 7 Jun 2024 11:28:06 +0800 Subject: [PATCH 2/2] fix bugs --- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 074cd1cdf7..d2c8cf1f19 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -104,8 +104,8 @@ def __init__(self, init_scale=loss_scale, enabled=enable_loss_scaler) elif isinstance(loss_scale, dict): # More specific configuration. - loss_scale[ - 'enabled'] = loss_scale['enabled'] and enable_loss_scaler + loss_scale['enabled'] = loss_scale.pop('enabled', + True) and enable_loss_scaler self.loss_scaler = scaler_type(**loss_scale) else: raise TypeError('loss_scale must be of type float, dict, or '