Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/diffusers/models/transformers/transformer_skyreels_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,25 @@ def __init__(
self.linear = nn.Linear(embedding_dim, output_dim)
self.linear.weight.data[:embedding_dim, :] = torch.eye(embedding_dim)
self.linear.weight.data[embedding_dim:, :] = torch.eye(embedding_dim)
# Initialize bias to match the original scale_shift_table initialization pattern
self.linear.bias.data = torch.randn(output_dim) / embedding_dim**0.5
Copy link
Owner

@tolgacangoz tolgacangoz Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot Wouldn't self.linear = nn.Linear(embedding_dim, output_dim) do a similar thing already?

Copy link
Preview

Copilot AI Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bias initialization uses torch.randn(output_dim) but output_dim is not defined in this scope. Based on the context, it should be 2 * embedding_dim to match the scale and shift parameters.

Copilot uses AI. Check for mistakes.

self.norm = FP32LayerNorm(embedding_dim, norm_eps, norm_elementwise_affine)

def load_from_scale_shift_table(self, scale_shift_table: torch.Tensor) -> None:
"""
Helper method to transfer scale_shift_table values from old model checkpoints.
This can be used to migrate models saved with the old format to the new AdaLayerNorm format.

Args:
scale_shift_table: Tensor of shape (1, 2, embedding_dim) from old model format
"""
if scale_shift_table.shape[0] != 1 or scale_shift_table.shape[1] != 2:
raise ValueError(f"Expected scale_shift_table shape (1, 2, embedding_dim), got {scale_shift_table.shape}")
Comment on lines +79 to +80

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The shape check for scale_shift_table is good, but it could be made more specific to prevent potential silent errors if a tensor with an incorrect embedding_dim is passed. You could also check the third dimension against the linear layer's in_features and include the expected dimension in the error message. This would make the helper method more robust.

Suggested change
if scale_shift_table.shape[0] != 1 or scale_shift_table.shape[1] != 2:
raise ValueError(f"Expected scale_shift_table shape (1, 2, embedding_dim), got {scale_shift_table.shape}")
embedding_dim = self.linear.in_features
if scale_shift_table.shape != (1, 2, embedding_dim):
raise ValueError(f"Expected scale_shift_table shape (1, 2, {embedding_dim}), got {scale_shift_table.shape}")


with torch.no_grad():
# Flatten the scale_shift_table to match the bias shape
Copy link
Preview

Copilot AI Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The view(-1) operation assumes the scale_shift_table dimensions will always produce the correct bias size. Consider adding a size validation to ensure the flattened tensor matches the expected bias dimensions.

Suggested change
# Flatten the scale_shift_table to match the bias shape
# Flatten the scale_shift_table to match the bias shape
if scale_shift_table.numel() != self.linear.bias.data.numel():
raise ValueError(
f"scale_shift_table has {scale_shift_table.numel()} elements, but expected {self.linear.bias.data.numel()} for bias. "
f"scale_shift_table shape: {scale_shift_table.shape}, expected bias shape: {self.linear.bias.data.shape}"
)

Copilot uses AI. Check for mistakes.

self.linear.bias.data = scale_shift_table.view(-1)

def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
if temb.ndim == 2:
# If temb is 2D, we assume it has 1-D time embedding values for each batch.
Expand Down Expand Up @@ -443,7 +460,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["SkyReelsV2TransformerBlock", "SkyReelsV2AdaLayerNorm"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
Copy link
Owner

@tolgacangoz tolgacangoz Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot Shouldn't we keep scale_shift_table for SkyReelsV2TransformerBlock?

_keep_in_fp32_modules = ["time_embedder", "norm1", "norm2", "norm3", "norm_out"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q", "scale_shift_table"]
_keys_to_ignore_on_load_missing = ["norm_out.linear.weight", "norm_out.linear.bias"]

Expand Down