Skip to content

Commit

Permalink
layernorm_decay_fix (huggingface#35927)
Browse files Browse the repository at this point in the history
* layernorm_decay_fix

* W293 fix

* ruff format fix

* black format

* ruff format

* erase last layer

* add test_get_parameter_names_rmsnorm

* rmsnorm fix
  • Loading branch information
Ryoo72 authored and elvircrn committed Feb 13, 2025
1 parent fec93af commit b1218bb
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 15 deletions.
3 changes: 1 addition & 2 deletions docs/source/en/perf_train_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,7 @@ from transformers.trainer_pt_utils import get_parameter_names

training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)

decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, [nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
Expand Down
3 changes: 1 addition & 2 deletions docs/source/ja/perf_train_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ from transformers.trainer_pt_utils import get_parameter_names

training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)

decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, [nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,7 @@ def compute_metrics(pred):
# Instantiate custom data collator
data_collator = DataCollatorCTCWithPadding(processor=processor)

decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,13 +1177,13 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):

def get_decay_parameter_names(self, model) -> List[str]:
"""
Get all parameter names that weight decay will be applied to
Get all parameter names that weight decay will be applied to.
Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
apply to those modules since this function only filter out instance of nn.LayerNorm
This function filters out parameters in two ways:
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
2. By parameter name patterns (containing 'bias', 'layernorm', or 'rmsnorm')
"""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS, ["bias", "layernorm", "rmsnorm"])
return decay_parameters

def create_optimizer(self):
Expand Down
14 changes: 10 additions & 4 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,19 +1120,25 @@ def numel(p):
return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)


def get_parameter_names(model, forbidden_layer_types):
def get_parameter_names(model, forbidden_layer_types, forbidden_layer_names=None):
"""
Returns the names of the model parameters that are not inside a forbidden layer.
"""
if forbidden_layer_names is None:
forbidden_layer_names = []
result = []
for name, child in model.named_children():
child_params = get_parameter_names(child, forbidden_layer_types, forbidden_layer_names)
result += [
f"{name}.{n}"
for n in get_parameter_names(child, forbidden_layer_types)
for n in child_params
if not isinstance(child, tuple(forbidden_layer_types))
and not any(forbidden in f"{name}.{n}".lower() for forbidden in forbidden_layer_names)
]
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result += list(model._parameters.keys())
# Add model specific parameters that are not in any child
result += [
k for k in model._parameters.keys() if not any(forbidden in k.lower() for forbidden in forbidden_layer_names)
]
return result


Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,33 @@ def test_get_parameter_names(self):
)
# fmt: on

def test_get_parameter_names_rmsnorm(self):
class RMSNorm(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))

class ModelWithRMSNorm(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(128, 128)
self.rmsnorm = RMSNorm(128)
self.bias = nn.Parameter(torch.zeros(128))

model = ModelWithRMSNorm()
# Test both type-based and name-based filtering
decay_parameters = get_parameter_names(model, [], ["bias", "rmsnorm"])

# Parameters that should be in weight decay
self.assertIn("linear.weight", decay_parameters)

# Parameters that should NOT be in weight decay
self.assertNotIn("linear.bias", decay_parameters)
self.assertNotIn("rmsnorm.weight", decay_parameters)
self.assertNotIn("rmsnorm.bias", decay_parameters)
self.assertNotIn("bias", decay_parameters)

def test_distributed_sampler_with_loop(self):
batch_size = 16
for length in [23, 64, 123]:
Expand Down

0 comments on commit b1218bb

Please sign in to comment.