-
Notifications
You must be signed in to change notification settings - Fork 28.7k
Missing weights not initialized properly #35437 #35913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Missing weights not initialized properly #35437 #35913
Conversation
Hi @sambhavnoobcoder, I see some failing tests in the CI! I think the cause is that in some cases, models have tied weights, meaning that the input embeddings and output projection are identical. In these cases, only one of those tensors may exist in the |
@Rocketknight1 Thank you for pointing to the issue with tied weights. I've modified the initialization logic to detect tied parameters using |
Hi @sambhavnoobcoder debugging issues like this can be tricky. I suggest the following approach:
This is quite advanced debugging, but it's unfortunately necessary sometimes. Good luck! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is important IMO! but does not need a special additionnal file test!
I don't have my head in this part of the code so @Cyrilvallez if you want to have a look! 🤗
src/transformers/modeling_utils.py
Outdated
@@ -4939,6 +4939,37 @@ def _find_mismatched_keys( | |||
" to use it for predictions and inference." | |||
) | |||
|
|||
# After loading weights, initialize missing ones properly | |||
missing_keys = set(model.state_dict().keys()) - set(loaded_keys) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't we use the missing_keys
defined above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay , addressed this in 566028f commit .
as for tests in seperate file , added the seperate tests now in test_modeling_utils.py file only in commit 54dfcb4 . No changes required to thte tests after refactor . from transformers import AutoModel, AutoConfig
import torch
def test_real_model_initialization():
"""Test initialization with a real model by adding a new classification layer"""
# Create base model and save
model = AutoModel.from_pretrained("bert-base-uncased")
model.save_pretrained("./test-model")
# Modify config to add new classification layer
config = AutoConfig.from_pretrained("./test-model")
class BertWithClassification(type(model)):
def __init__(self, config):
super().__init__(config)
# Add a new classification layer
self.classifier = torch.nn.Linear(config.hidden_size, 3) # 3 classes
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
# Load with new architecture - should initialize new layer properly
new_model = BertWithClassification.from_pretrained("./test-model", config=config)
# Verify no NaN values in new layer
assert not torch.isnan(new_model.classifier.weight).any(), "NaN found in classifier weights"
assert not torch.isnan(new_model.classifier.bias).any(), "NaN found in classifier bias"
# Verify base model weights are preserved
for name, param in new_model.named_parameters():
if "classifier" not in name: # Skip the new layer
orig_param = model.get_parameter(name)
assert torch.equal(param, orig_param), f"Original weights not preserved for {name}"
print("Real model test passed successfully!")
return new_model
if __name__ == "__main__":
test_real_model_initialization() Thanks for the reviews . i'll make any other changes required as well @ArthurZucker @Cyrilvallez |
Hey @sambhavnoobcoder! The logic of |
Hey @Cyrilvallez , |
btw , i saw the large refactor , and the changes coming are truly awesome and much needed . however i think since you have a better perspective of the changes coming up than me , i would appreciate if you could look at my PR and tell me if i need to make any more changes to it to make it future proof to the upcoming changes . |
Hey, sorry for the late reply, we are actively working on improving the loading logic these days! |
Problem Statement
When using
from_pretrained()
to load a model with new parameters that weren't in the original saved model, these new parameters were not being properly initialized according to the model's_init_weights()
method. Instead, they remained in their default PyTorch initialization state, sometimes resulting in NaN values.Root Cause Analysis
The issue was identified in the
_load_pretrained_model
method where missing weights weren't being properly initialized when_fast_init=True
(default behavior). This caused inconsistent behavior between direct model initialization and loading viafrom_pretrained()
.Solution
Modified the
_load_pretrained_model
method to properly initialize missing weights using the model's_init_weights()
method, regardless of the_fast_init
setting. The solution maintains backward compatibility while ensuring consistent initialization behavior.Implementation Details
Testing Strategy
Created comprehensive test suite verifying:
_fast_init
Test Results
Related Issues