-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Propose to refactor norm_out
in SkyReelsV2Transformer3DModel
for device handling in multi-GPU case
#11850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Propose to refactor norm_out
in SkyReelsV2Transformer3DModel
for device handling in multi-GPU case
#11850
Conversation
Replace the final `FP32LayerNorm` and manual shift/scale application with a single `AdaLayerNorm` module in both the `WanTransformer3DModel` and `WanVACETransformer3DModel`. This change simplifies the forward pass by encapsulating the adaptive normalization logic within the `AdaLayerNorm` layer, removing the need for a separate `scale_shift_table`. The `_no_split_modules` list is also updated to include `norm_out` for compatibility with model parallelism.
…anVACE transformers
Updates the key mapping for the `head.modulation` layer to `norm_out.linear` in the model conversion script. This correction ensures that weights are loaded correctly for both standard and VACE transformer models.
… in Wan and WanVACE transformers
Replaces the manual implementation of adaptive layer normalization, which used a separate `scale_shift_table` and `nn.LayerNorm`, with the unified `AdaLayerNorm` module. This change simplifies the forward pass logic in several transformer models by encapsulating the normalization and modulation steps into a single component. It also adds `norm_out` to `_no_split_modules` for model parallelism compatibility.
Corrects the target key for `head.modulation` to `norm_out.linear.weight`. This ensures the weights are correctly mapped to the weight parameter of the output normalization layer during model conversion for both transformer types.
Adds a default zero-initialized bias tensor for the transformer's output normalization layer if it is missing from the original state dictionary.
dad0e68
to
65639d5
Compare
…and normalization
…_table` to `SkyReelsV2AdaLayerNorm` format in `_load_from_state_dict`
…ion for device and dtype consistency
…le in SkyReelsV2Transformer3DModel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tolgacangoz Thanks for the PR, but I'm not sure we feel good about _load_from_state_dict
part in each model implementation. It definitely helps with standardizing existing implementations, but we are moving to a more single-file implementation like transformers
, and so somewhat differing implementations is okay to have. We're also moving away from using centralized layers and just directly writing what's necessary within the transformer files itself (such as the normalization/embedding layers)
Pinging @yiyixuxu as well to get her thoughts |
_load_from_state_dict was put as a thinking experiment, so not important now as I implied in the first comment. AFAIU, shift and scale parameters should be packed in a Ada layer norm. Wasn't this what Yiyi requested? I think I attempted to support single file policy in skyreels file, didn't I? |
In the first comment, I wanted to highlight that this PR was open for
|
norm_out
in SkyReelsV2Transformer3DModel
for device handling in multi-GPU case
Obliterated it. Is there a way without re-running parameter conversion scripts? |
This PR attempts to propose enhancing device handling for the multi-GPU use case. Since other adalayernorms have a linear layer, I thought I was supposed to do a similar thing; i.e., converting previous
scale_shift_table
-based calculations into aLinear
layer such that its multiplication is identity and itsbias
isscale_shift_table
. I am unsure if this seems over-engineered. I would appreciate for guidance.This PR is now open for review for
SkyReelsV2Transformer3DModel
. After taking feedback about this, I may continue with others.Adding
_keys_to_ignore_on_load_missing
doesn't seem to remove this warning:After concluding the fix, I will be continuing to consider backward compatibility. Open for thoughts:Context: #11518 (comment)
@yiyixuxu @a-r-r-o-w