Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

layernorm_decay_fix #35927

Merged
merged 9 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/source/en/perf_train_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,7 @@ from transformers.trainer_pt_utils import get_parameter_names

training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)

decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, [nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
Expand Down
3 changes: 1 addition & 2 deletions docs/source/ja/perf_train_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ from transformers.trainer_pt_utils import get_parameter_names

training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)

decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, [nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,7 @@ def compute_metrics(pred):
# Instantiate custom data collator
data_collator = DataCollatorCTCWithPadding(processor=processor)

decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,12 +1178,12 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
def get_decay_parameter_names(self, model) -> List[str]:
"""
Get all parameter names that weight decay will be applied to

Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
apply to those modules since this function only filter out instance of nn.LayerNorm

This function filters out parameters in two ways:
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
2. By parameter name patterns (containing 'bias', 'layernorm', or 'rmsnorm')
"""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS, ["bias", "layernorm", "rmsnorm"])
return decay_parameters

def create_optimizer(self):
Expand Down
18 changes: 14 additions & 4 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,19 +1120,29 @@ def numel(p):
return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)


def get_parameter_names(model, forbidden_layer_types):
def get_parameter_names(model, forbidden_layer_types, forbidden_layer_names=None):
"""
Returns the names of the model parameters that are not inside a forbidden layer.
"""
if forbidden_layer_names is None:
forbidden_layer_names = []

result = []
for name, child in model.named_children():
child_params = get_parameter_names(child, forbidden_layer_types, forbidden_layer_names)
result += [
f"{name}.{n}"
for n in get_parameter_names(child, forbidden_layer_types)
for n in child_params
if not isinstance(child, tuple(forbidden_layer_types))
and not any(forbidden in n.lower() for forbidden in forbidden_layer_names)
]
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result += list(model._parameters.keys())

# Add model specific parameters that are not in any child
result += [
k for k in model._parameters.keys()
if not any(forbidden in k.lower() for forbidden in forbidden_layer_names)
]

return result


Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_get_parameter_names(self):
model = nn.Sequential(TstLayer(128), nn.ModuleList([TstLayer(128), TstLayer(128)]))
# fmt: off
self.assertEqual(
get_parameter_names(model, [nn.LayerNorm]),
get_parameter_names(model, [nn.LayerNorm], ["layernorm", "rmsnorm"]),
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
)
# fmt: on
Expand Down