Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support adaLN layer while converting LoRAs from diffusers to flux #7708

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion invokeai/backend/patches/layers/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Tuple

import torch

Expand Down Expand Up @@ -33,3 +33,32 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseL
return NormLayer.from_state_dict_values(state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")


def swap_shift_scale_for_linear_weight(weight: torch.Tensor) -> torch.Tensor:
"""Swap shift/scale for given linear layer back and forth"""
# In SD3 and Flux implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. This will flip them around
chunk1, chunk2 = weight.chunk(2, dim=0)
return torch.cat([chunk2, chunk1], dim=0)


def decomposite_weight_matric_with_rank(
delta: torch.Tensor,
rank: int,
epsilon: float = 1e-8,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Decompose given matrix with a specified rank."""
U, S, V = torch.svd(delta)

# Truncate to rank r:
U_r = U[:, :rank]
S_r = S[:rank]
V_r = V[:, :rank]

S_sqrt = torch.sqrt(S_r + epsilon) # regularization

up = torch.matmul(U_r, torch.diag(S_sqrt))
down = torch.matmul(torch.diag(S_sqrt), V_r.T)

return up, down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
import torch

from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.layers.utils import (
any_lora_layer_from_state_dict,
decomposite_weight_matric_with_rank,
swap_shift_scale_for_linear_weight,
)
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw

Expand All @@ -30,6 +35,50 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
return all_keys_in_peft_format and all_expected_keys_present


def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRALayer:
"""Approximate given diffusers AdaLN loRA layer in our Flux model"""

if "lora_up.weight" not in state_dict:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_up")

if "lora_down.weight" not in state_dict:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_down")

up = state_dict.pop("lora_up.weight")
down = state_dict.pop("lora_down.weight")

# layer-patcher upcast things to f32,
# we want to maintain a better precison for this one
dtype = torch.float32

device = up.device
up_shape = up.shape
down_shape = down.shape

# desired low rank
rank = up_shape[1]

# up scaling for more precise
up = up.to(torch.float32)
down = down.to(torch.float32)

weight = up.reshape(up_shape[0], -1) @ down.reshape(down_shape[0], -1)

# swap to our linear format
swapped = swap_shift_scale_for_linear_weight(weight)

_up, _down = decomposite_weight_matric_with_rank(swapped, rank)

assert _up.shape == up_shape
assert _down.shape == down_shape

# down scaling to original dtype, device
state_dict["lora_up.weight"] = _up.to(dtype).to(device=device)
state_dict["lora_down.weight"] = _down.to(dtype).to(device=device)

return LoRALayer.from_state_dict_values(state_dict)


def lora_model_from_flux_diffusers_state_dict(
state_dict: Dict[str, torch.Tensor], alpha: float | None
) -> ModelPatchRaw:
Expand Down Expand Up @@ -82,6 +131,12 @@ def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
values = get_lora_layer_values(src_layer_dict)
layers[dst_key] = any_lora_layer_from_state_dict(values)

def add_adaLN_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
values = get_lora_layer_values(src_layer_dict)
layers[dst_key] = approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(values)

def add_qkv_lora_layer_if_present(
src_keys: list[str],
src_weight_shapes: list[tuple[int, int]],
Expand Down Expand Up @@ -124,8 +179,8 @@ def add_qkv_lora_layer_if_present(
add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")

# time_text_embed.guidance_embedder -> guidance_in.
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in.in_layer")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in.out_layer")

# context_embedder -> txt_in.
add_lora_layer_if_present("context_embedder", "txt_in")
Expand Down Expand Up @@ -223,6 +278,10 @@ def add_qkv_lora_layer_if_present(

# Final layer.
add_lora_layer_if_present("proj_out", "final_layer.linear")
add_adaLN_lora_layer_if_present(
"norm_out.linear",
"final_layer.adaLN_modulation.1",
)

# Assert that all keys were processed.
assert len(grouped_state_dict) == 0
Expand Down
48 changes: 48 additions & 0 deletions tests/backend/patches/layers/test_layer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch

from invokeai.backend.patches.layers.utils import (
decomposite_weight_matric_with_rank,
swap_shift_scale_for_linear_weight,
)


def test_swap_shift_scale_for_linear_weight():
"""Test that swaping should work"""
original = torch.Tensor([1, 2])
expected = torch.Tensor([2, 1])

swapped = swap_shift_scale_for_linear_weight(original)
assert torch.allclose(expected, swapped)

size = (3, 4)
first = torch.randn(size)
second = torch.randn(size)

original = torch.concat([first, second])
expected = torch.concat([second, first])

swapped = swap_shift_scale_for_linear_weight(original)
assert torch.allclose(expected, swapped)

# call this twice will reconstruct the original
reconstructed = swap_shift_scale_for_linear_weight(swapped)
assert torch.allclose(reconstructed, original)


def test_decomposite_weight_matric_with_rank():
"""Test that decompsition of given matrix into 2 low rank matrices work"""
input_dim = 1024
output_dim = 1024
rank = 8 # Low rank

A = torch.randn(input_dim, rank).double()
B = torch.randn(rank, output_dim).double()
W0 = A @ B

C, D = decomposite_weight_matric_with_rank(W0, rank)
R = C @ D

assert C.shape == A.shape
assert D.shape == B.shape

assert torch.allclose(W0, R)
Loading