Skip to content

Commit 4ce64b6

Browse files
authored
Modular backend - LoRA/LyCORIS (#6667)
## Summary Code for lora patching from #6577. Additionally made it the way, that lora can patch not only `weight`, but also `bias`, because saw some loras which doing it. ## Related Issues / Discussions #6606 https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. ## Merge Plan Replace old lora patcher with new after review done. If you think that there should be some kind of tests - feel free to add. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
2 parents 94d64b8 + 5a9173f commit 4ce64b6

File tree

9 files changed

+331
-129
lines changed

9 files changed

+331
-129
lines changed

invokeai/app/invocations/compel.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
8080

8181
with (
8282
# apply all patches while the model is on the target device
83-
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
83+
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
8484
tokenizer_info as tokenizer,
8585
ModelPatcher.apply_lora_text_encoder(
8686
text_encoder,
8787
loras=_lora_loader(),
88-
model_state_dict=model_state_dict,
88+
cached_weights=cached_weights,
8989
),
9090
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
9191
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
@@ -175,13 +175,13 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
175175

176176
with (
177177
# apply all patches while the model is on the target device
178-
text_encoder_info.model_on_device() as (state_dict, text_encoder),
178+
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
179179
tokenizer_info as tokenizer,
180180
ModelPatcher.apply_lora(
181181
text_encoder,
182182
loras=_lora_loader(),
183183
prefix=lora_prefix,
184-
model_state_dict=state_dict,
184+
cached_weights=cached_weights,
185185
),
186186
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
187187
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),

invokeai/app/invocations/denoise_latents.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
6363
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
6464
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
65+
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
6566
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6667
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
6768
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
@@ -845,6 +846,16 @@ def step_callback(state: PipelineIntermediateState) -> None:
845846
if self.unet.freeu_config:
846847
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
847848

849+
### lora
850+
if self.unet.loras:
851+
for lora_field in self.unet.loras:
852+
ext_manager.add_extension(
853+
LoRAExt(
854+
node_context=context,
855+
model_id=lora_field.lora,
856+
weight=lora_field.weight,
857+
)
858+
)
848859
### seamless
849860
if self.unet.seamless_axes:
850861
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
@@ -964,14 +975,14 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
964975
assert isinstance(unet_info.model, UNet2DConditionModel)
965976
with (
966977
ExitStack() as exit_stack,
967-
unet_info.model_on_device() as (model_state_dict, unet),
978+
unet_info.model_on_device() as (cached_weights, unet),
968979
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
969980
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
970981
# Apply the LoRA after unet has been moved to its target device for faster patching.
971982
ModelPatcher.apply_lora_unet(
972983
unet,
973984
loras=_lora_loader(),
974-
model_state_dict=model_state_dict,
985+
cached_weights=cached_weights,
975986
),
976987
):
977988
assert isinstance(unet, UNet2DConditionModel)

invokeai/backend/lora.py

+85-50
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
import bisect
55
from pathlib import Path
6-
from typing import Dict, List, Optional, Tuple, Union
6+
from typing import Dict, List, Optional, Set, Tuple, Union
77

88
import torch
99
from safetensors.torch import load_file
1010
from typing_extensions import Self
1111

12+
import invokeai.backend.util.logging as logger
1213
from invokeai.backend.model_manager import BaseModelType
1314
from invokeai.backend.raw_model import RawModel
1415

@@ -46,9 +47,19 @@ def __init__(
4647
self.rank = None # set in layer implementation
4748
self.layer_key = layer_key
4849

49-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
50+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
5051
raise NotImplementedError()
5152

53+
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
54+
return self.bias
55+
56+
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
57+
params = {"weight": self.get_weight(orig_module.weight)}
58+
bias = self.get_bias(orig_module.bias)
59+
if bias is not None:
60+
params["bias"] = bias
61+
return params
62+
5263
def calc_size(self) -> int:
5364
model_size = 0
5465
for val in [self.bias]:
@@ -60,6 +71,17 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
6071
if self.bias is not None:
6172
self.bias = self.bias.to(device=device, dtype=dtype)
6273

74+
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
75+
"""Log a warning if values contains unhandled keys."""
76+
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
77+
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
78+
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
79+
unknown_keys = set(values.keys()) - all_known_keys
80+
if unknown_keys:
81+
logger.warning(
82+
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
83+
)
84+
6385

6486
# TODO: find and debug lora/locon with bias
6587
class LoRALayer(LoRALayerBase):
@@ -76,14 +98,19 @@ def __init__(
7698

7799
self.up = values["lora_up.weight"]
78100
self.down = values["lora_down.weight"]
79-
if "lora_mid.weight" in values:
80-
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
81-
else:
82-
self.mid = None
101+
self.mid = values.get("lora_mid.weight", None)
83102

84103
self.rank = self.down.shape[0]
104+
self.check_keys(
105+
values,
106+
{
107+
"lora_up.weight",
108+
"lora_down.weight",
109+
"lora_mid.weight",
110+
},
111+
)
85112

86-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
113+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
87114
if self.mid is not None:
88115
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
89116
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
@@ -125,20 +152,23 @@ def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
125152
self.w1_b = values["hada_w1_b"]
126153
self.w2_a = values["hada_w2_a"]
127154
self.w2_b = values["hada_w2_b"]
128-
129-
if "hada_t1" in values:
130-
self.t1: Optional[torch.Tensor] = values["hada_t1"]
131-
else:
132-
self.t1 = None
133-
134-
if "hada_t2" in values:
135-
self.t2: Optional[torch.Tensor] = values["hada_t2"]
136-
else:
137-
self.t2 = None
155+
self.t1 = values.get("hada_t1", None)
156+
self.t2 = values.get("hada_t2", None)
138157

139158
self.rank = self.w1_b.shape[0]
159+
self.check_keys(
160+
values,
161+
{
162+
"hada_w1_a",
163+
"hada_w1_b",
164+
"hada_w2_a",
165+
"hada_w2_b",
166+
"hada_t1",
167+
"hada_t2",
168+
},
169+
)
140170

141-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
171+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
142172
if self.t1 is None:
143173
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
144174

@@ -186,37 +216,39 @@ def __init__(
186216
):
187217
super().__init__(layer_key, values)
188218

189-
if "lokr_w1" in values:
190-
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
191-
self.w1_a = None
192-
self.w1_b = None
193-
else:
194-
self.w1 = None
219+
self.w1 = values.get("lokr_w1", None)
220+
if self.w1 is None:
195221
self.w1_a = values["lokr_w1_a"]
196222
self.w1_b = values["lokr_w1_b"]
197223

198-
if "lokr_w2" in values:
199-
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
200-
self.w2_a = None
201-
self.w2_b = None
202-
else:
203-
self.w2 = None
224+
self.w2 = values.get("lokr_w2", None)
225+
if self.w2 is None:
204226
self.w2_a = values["lokr_w2_a"]
205227
self.w2_b = values["lokr_w2_b"]
206228

207-
if "lokr_t2" in values:
208-
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
209-
else:
210-
self.t2 = None
229+
self.t2 = values.get("lokr_t2", None)
211230

212-
if "lokr_w1_b" in values:
213-
self.rank = values["lokr_w1_b"].shape[0]
214-
elif "lokr_w2_b" in values:
215-
self.rank = values["lokr_w2_b"].shape[0]
231+
if self.w1_b is not None:
232+
self.rank = self.w1_b.shape[0]
233+
elif self.w2_b is not None:
234+
self.rank = self.w2_b.shape[0]
216235
else:
217236
self.rank = None # unscaled
218237

219-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
238+
self.check_keys(
239+
values,
240+
{
241+
"lokr_w1",
242+
"lokr_w1_a",
243+
"lokr_w1_b",
244+
"lokr_w2",
245+
"lokr_w2_a",
246+
"lokr_w2_b",
247+
"lokr_t2",
248+
},
249+
)
250+
251+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
220252
w1: Optional[torch.Tensor] = self.w1
221253
if w1 is None:
222254
assert self.w1_a is not None
@@ -272,7 +304,9 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
272304

273305

274306
class FullLayer(LoRALayerBase):
307+
# bias handled in LoRALayerBase(calc_size, to)
275308
# weight: torch.Tensor
309+
# bias: Optional[torch.Tensor]
276310

277311
def __init__(
278312
self,
@@ -282,15 +316,12 @@ def __init__(
282316
super().__init__(layer_key, values)
283317

284318
self.weight = values["diff"]
285-
286-
if len(values.keys()) > 1:
287-
_keys = list(values.keys())
288-
_keys.remove("diff")
289-
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
319+
self.bias = values.get("diff_b", None)
290320

291321
self.rank = None # unscaled
322+
self.check_keys(values, {"diff", "diff_b"})
292323

293-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
324+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
294325
return self.weight
295326

296327
def calc_size(self) -> int:
@@ -319,8 +350,9 @@ def __init__(
319350
self.on_input = values["on_input"]
320351

321352
self.rank = None # unscaled
353+
self.check_keys(values, {"weight", "on_input"})
322354

323-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
355+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
324356
weight = self.weight
325357
if not self.on_input:
326358
weight = weight.reshape(-1, 1)
@@ -458,24 +490,27 @@ def from_checkpoint(
458490
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
459491

460492
for layer_key, values in state_dict.items():
493+
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
494+
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
495+
461496
# lora and locon
462-
if "lora_down.weight" in values:
497+
if "lora_up.weight" in values:
463498
layer: AnyLoRALayer = LoRALayer(layer_key, values)
464499

465500
# loha
466-
elif "hada_w1_b" in values:
501+
elif "hada_w1_a" in values:
467502
layer = LoHALayer(layer_key, values)
468503

469504
# lokr
470-
elif "lokr_w1_b" in values or "lokr_w1" in values:
505+
elif "lokr_w1" in values or "lokr_w1_a" in values:
471506
layer = LoKRLayer(layer_key, values)
472507

473508
# diff
474509
elif "diff" in values:
475510
layer = FullLayer(layer_key, values)
476511

477512
# ia3
478-
elif "weight" in values and "on_input" in values:
513+
elif "on_input" in values:
479514
layer = IA3Layer(layer_key, values)
480515

481516
else:

0 commit comments

Comments
 (0)