-
Notifications
You must be signed in to change notification settings - Fork 1
Fix SkyReelsV2AdaLayerNorm bias initialization for mathematical equivalence #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -64,8 +64,25 @@ def __init__( | |||||||||||||||
self.linear = nn.Linear(embedding_dim, output_dim) | ||||||||||||||||
self.linear.weight.data[:embedding_dim, :] = torch.eye(embedding_dim) | ||||||||||||||||
self.linear.weight.data[embedding_dim:, :] = torch.eye(embedding_dim) | ||||||||||||||||
# Initialize bias to match the original scale_shift_table initialization pattern | ||||||||||||||||
self.linear.bias.data = torch.randn(output_dim) / embedding_dim**0.5 | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The bias initialization uses Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||
self.norm = FP32LayerNorm(embedding_dim, norm_eps, norm_elementwise_affine) | ||||||||||||||||
|
||||||||||||||||
def load_from_scale_shift_table(self, scale_shift_table: torch.Tensor) -> None: | ||||||||||||||||
""" | ||||||||||||||||
Helper method to transfer scale_shift_table values from old model checkpoints. | ||||||||||||||||
This can be used to migrate models saved with the old format to the new AdaLayerNorm format. | ||||||||||||||||
|
||||||||||||||||
Args: | ||||||||||||||||
scale_shift_table: Tensor of shape (1, 2, embedding_dim) from old model format | ||||||||||||||||
""" | ||||||||||||||||
if scale_shift_table.shape[0] != 1 or scale_shift_table.shape[1] != 2: | ||||||||||||||||
raise ValueError(f"Expected scale_shift_table shape (1, 2, embedding_dim), got {scale_shift_table.shape}") | ||||||||||||||||
Comment on lines
+79
to
+80
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The shape check for
Suggested change
|
||||||||||||||||
|
||||||||||||||||
with torch.no_grad(): | ||||||||||||||||
# Flatten the scale_shift_table to match the bias shape | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The view(-1) operation assumes the scale_shift_table dimensions will always produce the correct bias size. Consider adding a size validation to ensure the flattened tensor matches the expected bias dimensions.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||
self.linear.bias.data = scale_shift_table.view(-1) | ||||||||||||||||
|
||||||||||||||||
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: | ||||||||||||||||
if temb.ndim == 2: | ||||||||||||||||
# If temb is 2D, we assume it has 1-D time embedding values for each batch. | ||||||||||||||||
|
@@ -443,7 +460,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr | |||||||||||||||
_supports_gradient_checkpointing = True | ||||||||||||||||
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] | ||||||||||||||||
_no_split_modules = ["SkyReelsV2TransformerBlock", "SkyReelsV2AdaLayerNorm"] | ||||||||||||||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @copilot Shouldn't we keep scale_shift_table for SkyReelsV2TransformerBlock? |
||||||||||||||||
_keep_in_fp32_modules = ["time_embedder", "norm1", "norm2", "norm3", "norm_out"] | ||||||||||||||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q", "scale_shift_table"] | ||||||||||||||||
_keys_to_ignore_on_load_missing = ["norm_out.linear.weight", "norm_out.linear.bias"] | ||||||||||||||||
|
||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot Wouldn't self.linear = nn.Linear(embedding_dim, output_dim) do a similar thing already?