Skip to content

Missing weights not initialized properly #35437 #35913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

sambhavnoobcoder
Copy link
Contributor

Problem Statement

When using from_pretrained() to load a model with new parameters that weren't in the original saved model, these new parameters were not being properly initialized according to the model's _init_weights() method. Instead, they remained in their default PyTorch initialization state, sometimes resulting in NaN values.

Root Cause Analysis

The issue was identified in the _load_pretrained_model method where missing weights weren't being properly initialized when _fast_init=True (default behavior). This caused inconsistent behavior between direct model initialization and loading via from_pretrained().

Solution

Modified the _load_pretrained_model method to properly initialize missing weights using the model's _init_weights() method, regardless of the _fast_init setting. The solution maintains backward compatibility while ensuring consistent initialization behavior.

Implementation Details

  • Added weight initialization for missing keys after state dict loading
  • Implemented proper module hierarchy traversal for initialization
  • Maintained existing logging behavior

Testing Strategy

Created comprehensive test suite verifying:

  1. Backward compatibility with existing weights
  2. Consistent initialization behavior with/without _fast_init
  3. Proper initialization of new weights
  4. Original issue reproduction case

Test Results

Screenshot 2025-01-27 at 10 59 09 PM

Related Issues

@Rocketknight1
Copy link
Member

Hi @sambhavnoobcoder, I see some failing tests in the CI! I think the cause is that in some cases, models have tied weights, meaning that the input embeddings and output projection are identical. In these cases, only one of those tensors may exist in the safetensors file. I think the problem is that this PR might overwrite / re-initialize the output weights in this case, but I'm not certain.

@sambhavnoobcoder
Copy link
Contributor Author

sambhavnoobcoder commented Feb 2, 2025

@Rocketknight1 Thank you for pointing to the issue with tied weights. I've modified the initialization logic to detect tied parameters using id_tensor_storage and skip initialization if any weight in a tied group exists in loaded_keys. This should prevent reinitializing tied weights when their counterpart exists in the safetensors file. Please let me know if you think we need to handle this differently or if there are other edge cases to consider. I would appreciate your help in figuring out the CI tests a bit as well as if some other changes are needed from my end as well .

@Rocketknight1
Copy link
Member

Rocketknight1 commented Feb 5, 2025

Hi @sambhavnoobcoder debugging issues like this can be tricky. I suggest the following approach:

  1. Install your local copy of transformers with pip install -e . in the transformers directory
  2. pip install pytest parameterized and try running one of the failing tests, like this: pytest tests/models/paligemma/test_modeling_paligemma.py -k 'test_can_use_safetensors'
  3. Once you can reproduce the failure on your local machine, add a breakpoint just before the failure occurs (or wrap the failing line in Try/Except: breakpoint() )
  4. Use the breakpoint to figure out what's happening. You can also try running the same code on main and adding a breakpoint to see what's different
  5. Try to figure out how the code changes caused the test to fail, and adjust the code accordingly! For example, if the error relates to an lm_head key, you can add a breakpoint in your code like if "lm_head" in key_name: breakpoint() or something like that. This will let you figure out why the behaviour changed in your PR, and that should hopefully help you diagnose the problem!
  6. Push changes and see how the CI responds, and try new approaches until you can get the tests to pass.

This is quite advanced debugging, but it's unfortunately necessary sometimes. Good luck!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is important IMO! but does not need a special additionnal file test!
I don't have my head in this part of the code so @Cyrilvallez if you want to have a look! 🤗

@@ -4939,6 +4939,37 @@ def _find_mismatched_keys(
" to use it for predictions and inference."
)

# After loading weights, initialize missing ones properly
missing_keys = set(model.state_dict().keys()) - set(loaded_keys)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we use the missing_keys defined above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay , addressed this in 566028f commit .

@sambhavnoobcoder
Copy link
Contributor Author

sambhavnoobcoder commented Feb 13, 2025

as for tests in seperate file , added the seperate tests now in test_modeling_utils.py file only in commit 54dfcb4 . No changes required to thte tests after refactor .
Also , for anyone wanting to test the changes on a real model , feel free to use the follwing script :

from transformers import AutoModel, AutoConfig
import torch

def test_real_model_initialization():
    """Test initialization with a real model by adding a new classification layer"""
    
    # Create base model and save
    model = AutoModel.from_pretrained("bert-base-uncased")
    model.save_pretrained("./test-model")

    # Modify config to add new classification layer
    config = AutoConfig.from_pretrained("./test-model")
    
    class BertWithClassification(type(model)):
        def __init__(self, config):
            super().__init__(config)
            # Add a new classification layer
            self.classifier = torch.nn.Linear(config.hidden_size, 3)  # 3 classes
            
        def _init_weights(self, module):
            super()._init_weights(module)
            if isinstance(module, torch.nn.Linear):
                module.weight.data.normal_(mean=0.0, std=0.02)
                if module.bias is not None:
                    module.bias.data.zero_()

    # Load with new architecture - should initialize new layer properly
    new_model = BertWithClassification.from_pretrained("./test-model", config=config)

    # Verify no NaN values in new layer
    assert not torch.isnan(new_model.classifier.weight).any(), "NaN found in classifier weights"
    assert not torch.isnan(new_model.classifier.bias).any(), "NaN found in classifier bias"
    
    # Verify base model weights are preserved
    for name, param in new_model.named_parameters():
        if "classifier" not in name:  # Skip the new layer
            orig_param = model.get_parameter(name)
            assert torch.equal(param, orig_param), f"Original weights not preserved for {name}"

    print("Real model test passed successfully!")
    return new_model

if __name__ == "__main__":
    test_real_model_initialization()

Thanks for the reviews . i'll make any other changes required as well @ArthurZucker @Cyrilvallez

@Cyrilvallez
Copy link
Member

Hey @sambhavnoobcoder! The logic of from_pretrained is currently very hard to understand. See here for the first iteration of a big refactor coming up. Currently, this part of the code is responsible for doing what you're talking about - initializing missing keys if using _fast_init. Could you please point me to the exact scenario where this does not work?

@sambhavnoobcoder
Copy link
Contributor Author

Hey @Cyrilvallez ,
Thanks for the feedback. The issue occurs when loading a pretrained model that conditionally adds new layers (for example, when a flag like use_new is enabled). In these cases, the base model's parameters load correctly, but the newly added modules aren’t present in the checkpoint. This means that with _fast_init=True, those new modules aren’t identified for initialization, leaving their weights uninitialized (often resulting in NaNs).
To address this, I removed my previous extra initialization block and instead refined the existing logic in set_initialized_submodules(). The updated function now marks a module as initialized only if all its parameters are found in the loaded state dict; otherwise, it is flagged for proper initialization. This change ensures that new or partially-loaded submodules are correctly handled by our existing fast initialization workflow.
Additionally, I updated our test suite to cover these scenarios—ensuring that:
• Newly added layers are properly initialized,
• Tied weights remain correctly identified and preserved, and
• Partial initialization cases are handled as expected.
I believe these updates fully resolve the issue. Please let me know if you need any further details or adjustments!

@sambhavnoobcoder
Copy link
Contributor Author

btw , i saw the large refactor , and the changes coming are truly awesome and much needed . however i think since you have a better perspective of the changes coming up than me , i would appreciate if you could look at my PR and tell me if i need to make any more changes to it to make it future proof to the upcoming changes .

@Cyrilvallez
Copy link
Member

Hey, sorry for the late reply, we are actively working on improving the loading logic these days!
Do you mind sharing a code snippet to reproduce the issue so that I can have a better look? I'm not sure I understand exactly what corner case you are hitting!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Missing weights are not properly initialized when using model.from_pretrained()
4 participants