Skip to content

Propose to refactor norm_out in SkyReelsV2Transformer3DModel for device handling in multi-GPU case #11850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from

Conversation

tolgacangoz
Copy link
Contributor

@tolgacangoz tolgacangoz commented Jul 2, 2025

This PR attempts to propose enhancing device handling for the multi-GPU use case. Since other adalayernorms have a linear layer, I thought I was supposed to do a similar thing; i.e., converting previous scale_shift_table-based calculations into a Linear layer such that its multiplication is identity and its bias is scale_shift_table. I am unsure if this seems over-engineered. I would appreciate for guidance.

This PR is now open for review for SkyReelsV2Transformer3DModel. After taking feedback about this, I may continue with others.

Adding _keys_to_ignore_on_load_missing doesn't seem to remove this warning:

Some weights of SkyReelsV2Transformer3DModel were not initialized from the model checkpoint at [/teamspace/studios/this_studio/.cache/huggingface/hub/models--Skywork--SkyReels-V2-DF-1.3B-540P-Diffusers/snapshots/958acd63685c7e632e4b194549f2a703e34bd98b/transformer](https://vscode-remote+vscode-002d01k2wfnth557hyndfkpvdh9f5n-002estudio-002elightning-002eai.vscode-resource.vscode-cdn.net/teamspace/studios/this_studio/.cache/huggingface/hub/models--Skywork--SkyReels-V2-DF-1.3B-540P-Diffusers/snapshots/958acd63685c7e632e4b194549f2a703e34bd98b/transformer) and are newly initialized: ['norm_out.linear.bias', 'norm_out.linear.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
After concluding the fix, I will be continuing to consider backward compatibility. Open for thoughts:
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[1], line 14
     12 flow_shift = 8.0  # 8.0 for T2V, 5.0 for I2V
     13 pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
---> 14 pipeline = pipeline.to("cuda")
     16 prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
     18 output = pipeline(
     19     prompt=prompt,
     20     num_inference_steps=30,
   (...)
     28     addnoise_condition=20,  # Improves consistency in long video generation
     29 ).frames[0]

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py:541, in DiffusionPipeline.to(self, *args, **kwargs)
    539     module.to(device=device)
    540 elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
--> 541     module.to(device, dtype)
    543 if (
    544     module.dtype == torch.float16
    545     and str(device) in ["cpu"]
    546     and not silence_dtype_warnings
    547     and not is_offloaded
    548 ):
    549     logger.warning(
    550         "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
    551         " is not recommended to move them to `cpu` as running them will fail. Please make"
   (...)
    554         " `torch_dtype=torch.float16` argument, or use another device for inference."
    555     )

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/diffusers/models/modeling_utils.py:1424, in ModelMixin.to(self, *args, **kwargs)
   1419     logger.warning(
   1420         f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
   1421     )
   1422     return self
-> 1424 return super().to(*args, **kwargs)

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1355, in Module.to(self, *args, **kwargs)
   1352         else:
   1353             raise
-> 1355 return self._apply(convert)

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:915, in Module._apply(self, fn, recurse)
    913 if recurse:
    914     for module in self.children():
--> 915         module._apply(fn)
    917 def compute_should_use_set_data(tensor, tensor_applied):
    918     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    919         # If the new tensor has compatible tensor type as the existing tensor,
    920         # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    925         # global flag to let the user control whether they want the future
    926         # behavior of overwriting the existing tensor or not.

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:915, in Module._apply(self, fn, recurse)
    913 if recurse:
    914     for module in self.children():
--> 915         module._apply(fn)
    917 def compute_should_use_set_data(tensor, tensor_applied):
    918     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    919         # If the new tensor has compatible tensor type as the existing tensor,
    920         # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    925         # global flag to let the user control whether they want the future
    926         # behavior of overwriting the existing tensor or not.

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:942, in Module._apply(self, fn, recurse)
    938 # Tensors stored in modules are graph leaves, and we don't want to
    939 # track autograd history of `param_applied`, so we have to use
    940 # `with torch.no_grad():`
    941 with torch.no_grad():
--> 942     param_applied = fn(param)
    943 p_should_use_set_data = compute_should_use_set_data(param, param_applied)
    945 # subclasses may have multiple child tensors so we need to use swap_tensors

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1348, in Module.to.<locals>.convert(t)
   1346 except NotImplementedError as e:
   1347     if str(e) == "Cannot copy out of meta tensor; no data!":
-> 1348         raise NotImplementedError(
   1349             f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
   1350             f"when moving module from meta to a different device."
   1351         ) from None
   1352     else:
   1353         raise

NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

Context: #11518 (comment)

@yiyixuxu @a-r-r-o-w

Replace the final `FP32LayerNorm` and manual shift/scale application with a single `AdaLayerNorm` module in both the `WanTransformer3DModel` and `WanVACETransformer3DModel`.

This change simplifies the forward pass by encapsulating the adaptive normalization logic within the `AdaLayerNorm` layer, removing the need for a separate `scale_shift_table`. The `_no_split_modules` list is also updated to include `norm_out` for compatibility with model parallelism.
Updates the key mapping for the `head.modulation` layer to `norm_out.linear` in the model conversion script.

This correction ensures that weights are loaded correctly for both standard and VACE transformer models.
Replaces the manual implementation of adaptive layer normalization, which used a separate `scale_shift_table` and `nn.LayerNorm`, with the unified `AdaLayerNorm` module.

This change simplifies the forward pass logic in several transformer models by encapsulating the normalization and modulation steps into a single component. It also adds `norm_out` to `_no_split_modules` for model parallelism compatibility.
Corrects the target key for `head.modulation` to `norm_out.linear.weight`.

This ensures the weights are correctly mapped to the weight parameter of the output normalization layer during model conversion for both transformer types.
Adds a default zero-initialized bias tensor for the transformer's output normalization layer if it is missing from the original state dictionary.
@tolgacangoz tolgacangoz changed the title Refactor output normalization in several transformers Propose to refactor output normalization in several transformers Jul 3, 2025
@tolgacangoz tolgacangoz force-pushed the transfer-shift_scale_norm-to-AdaLayerNorm branch from dad0e68 to 65639d5 Compare July 18, 2025 07:09
@tolgacangoz tolgacangoz marked this pull request as ready for review August 17, 2025 17:25
Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tolgacangoz Thanks for the PR, but I'm not sure we feel good about _load_from_state_dict part in each model implementation. It definitely helps with standardizing existing implementations, but we are moving to a more single-file implementation like transformers, and so somewhat differing implementations is okay to have. We're also moving away from using centralized layers and just directly writing what's necessary within the transformer files itself (such as the normalization/embedding layers)

@a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu August 17, 2025 22:31
@a-r-r-o-w
Copy link
Member

Pinging @yiyixuxu as well to get her thoughts

@tolgacangoz
Copy link
Contributor Author

tolgacangoz commented Aug 17, 2025

_load_from_state_dict was put as a thinking experiment, so not important now as I implied in the first comment. AFAIU, shift and scale parameters should be packed in a Ada layer norm. Wasn't this what Yiyi requested? I think I attempted to support single file policy in skyreels file, didn't I?

@tolgacangoz
Copy link
Contributor Author

tolgacangoz commented Aug 18, 2025

In the first comment, I wanted to highlight that this PR was open for SkyReelsV2Transformer3DModel. Now, I have removed the others that have been unrelated. Sorry for the likely confusion.

_load_from_state_dict: This was an ugly and premature attempt for backward compatibility. Because some of the parameter names have been changed, will parameter conversion scripts be run again? Or am I understanding here wrong?

@tolgacangoz tolgacangoz requested a review from a-r-r-o-w August 18, 2025 08:11
@tolgacangoz tolgacangoz changed the title Propose to refactor output normalization in several transformers Propose to refactor norm_out in SkyReelsV2Transformer3DModel for device handling in multi-GPU case Aug 18, 2025
@tolgacangoz
Copy link
Contributor Author

Obliterated it. Is there a way without re-running parameter conversion scripts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants