3
3
4
4
import bisect
5
5
from pathlib import Path
6
- from typing import Dict , List , Optional , Tuple , Union
6
+ from typing import Dict , List , Optional , Set , Tuple , Union
7
7
8
8
import torch
9
9
from safetensors .torch import load_file
10
10
from typing_extensions import Self
11
11
12
+ import invokeai .backend .util .logging as logger
12
13
from invokeai .backend .model_manager import BaseModelType
13
14
from invokeai .backend .raw_model import RawModel
14
15
@@ -46,9 +47,19 @@ def __init__(
46
47
self .rank = None # set in layer implementation
47
48
self .layer_key = layer_key
48
49
49
- def get_weight (self , orig_weight : Optional [ torch .Tensor ] ) -> torch .Tensor :
50
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
50
51
raise NotImplementedError ()
51
52
53
+ def get_bias (self , orig_bias : torch .Tensor ) -> Optional [torch .Tensor ]:
54
+ return self .bias
55
+
56
+ def get_parameters (self , orig_module : torch .nn .Module ) -> Dict [str , torch .Tensor ]:
57
+ params = {"weight" : self .get_weight (orig_module .weight )}
58
+ bias = self .get_bias (orig_module .bias )
59
+ if bias is not None :
60
+ params ["bias" ] = bias
61
+ return params
62
+
52
63
def calc_size (self ) -> int :
53
64
model_size = 0
54
65
for val in [self .bias ]:
@@ -60,6 +71,17 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
60
71
if self .bias is not None :
61
72
self .bias = self .bias .to (device = device , dtype = dtype )
62
73
74
+ def check_keys (self , values : Dict [str , torch .Tensor ], known_keys : Set [str ]):
75
+ """Log a warning if values contains unhandled keys."""
76
+ # {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
77
+ # `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
78
+ all_known_keys = known_keys | {"alpha" , "bias_indices" , "bias_values" , "bias_size" }
79
+ unknown_keys = set (values .keys ()) - all_known_keys
80
+ if unknown_keys :
81
+ logger .warning (
82
+ f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: { unknown_keys } "
83
+ )
84
+
63
85
64
86
# TODO: find and debug lora/locon with bias
65
87
class LoRALayer (LoRALayerBase ):
@@ -76,14 +98,19 @@ def __init__(
76
98
77
99
self .up = values ["lora_up.weight" ]
78
100
self .down = values ["lora_down.weight" ]
79
- if "lora_mid.weight" in values :
80
- self .mid : Optional [torch .Tensor ] = values ["lora_mid.weight" ]
81
- else :
82
- self .mid = None
101
+ self .mid = values .get ("lora_mid.weight" , None )
83
102
84
103
self .rank = self .down .shape [0 ]
104
+ self .check_keys (
105
+ values ,
106
+ {
107
+ "lora_up.weight" ,
108
+ "lora_down.weight" ,
109
+ "lora_mid.weight" ,
110
+ },
111
+ )
85
112
86
- def get_weight (self , orig_weight : Optional [ torch .Tensor ] ) -> torch .Tensor :
113
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
87
114
if self .mid is not None :
88
115
up = self .up .reshape (self .up .shape [0 ], self .up .shape [1 ])
89
116
down = self .down .reshape (self .down .shape [0 ], self .down .shape [1 ])
@@ -125,20 +152,23 @@ def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
125
152
self .w1_b = values ["hada_w1_b" ]
126
153
self .w2_a = values ["hada_w2_a" ]
127
154
self .w2_b = values ["hada_w2_b" ]
128
-
129
- if "hada_t1" in values :
130
- self .t1 : Optional [torch .Tensor ] = values ["hada_t1" ]
131
- else :
132
- self .t1 = None
133
-
134
- if "hada_t2" in values :
135
- self .t2 : Optional [torch .Tensor ] = values ["hada_t2" ]
136
- else :
137
- self .t2 = None
155
+ self .t1 = values .get ("hada_t1" , None )
156
+ self .t2 = values .get ("hada_t2" , None )
138
157
139
158
self .rank = self .w1_b .shape [0 ]
159
+ self .check_keys (
160
+ values ,
161
+ {
162
+ "hada_w1_a" ,
163
+ "hada_w1_b" ,
164
+ "hada_w2_a" ,
165
+ "hada_w2_b" ,
166
+ "hada_t1" ,
167
+ "hada_t2" ,
168
+ },
169
+ )
140
170
141
- def get_weight (self , orig_weight : Optional [ torch .Tensor ] ) -> torch .Tensor :
171
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
142
172
if self .t1 is None :
143
173
weight : torch .Tensor = (self .w1_a @ self .w1_b ) * (self .w2_a @ self .w2_b )
144
174
@@ -186,37 +216,39 @@ def __init__(
186
216
):
187
217
super ().__init__ (layer_key , values )
188
218
189
- if "lokr_w1" in values :
190
- self .w1 : Optional [torch .Tensor ] = values ["lokr_w1" ]
191
- self .w1_a = None
192
- self .w1_b = None
193
- else :
194
- self .w1 = None
219
+ self .w1 = values .get ("lokr_w1" , None )
220
+ if self .w1 is None :
195
221
self .w1_a = values ["lokr_w1_a" ]
196
222
self .w1_b = values ["lokr_w1_b" ]
197
223
198
- if "lokr_w2" in values :
199
- self .w2 : Optional [torch .Tensor ] = values ["lokr_w2" ]
200
- self .w2_a = None
201
- self .w2_b = None
202
- else :
203
- self .w2 = None
224
+ self .w2 = values .get ("lokr_w2" , None )
225
+ if self .w2 is None :
204
226
self .w2_a = values ["lokr_w2_a" ]
205
227
self .w2_b = values ["lokr_w2_b" ]
206
228
207
- if "lokr_t2" in values :
208
- self .t2 : Optional [torch .Tensor ] = values ["lokr_t2" ]
209
- else :
210
- self .t2 = None
229
+ self .t2 = values .get ("lokr_t2" , None )
211
230
212
- if "lokr_w1_b" in values :
213
- self .rank = values [ "lokr_w1_b" ] .shape [0 ]
214
- elif "lokr_w2_b" in values :
215
- self .rank = values [ "lokr_w2_b" ] .shape [0 ]
231
+ if self . w1_b is not None :
232
+ self .rank = self . w1_b .shape [0 ]
233
+ elif self . w2_b is not None :
234
+ self .rank = self . w2_b .shape [0 ]
216
235
else :
217
236
self .rank = None # unscaled
218
237
219
- def get_weight (self , orig_weight : Optional [torch .Tensor ]) -> torch .Tensor :
238
+ self .check_keys (
239
+ values ,
240
+ {
241
+ "lokr_w1" ,
242
+ "lokr_w1_a" ,
243
+ "lokr_w1_b" ,
244
+ "lokr_w2" ,
245
+ "lokr_w2_a" ,
246
+ "lokr_w2_b" ,
247
+ "lokr_t2" ,
248
+ },
249
+ )
250
+
251
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
220
252
w1 : Optional [torch .Tensor ] = self .w1
221
253
if w1 is None :
222
254
assert self .w1_a is not None
@@ -272,7 +304,9 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
272
304
273
305
274
306
class FullLayer (LoRALayerBase ):
307
+ # bias handled in LoRALayerBase(calc_size, to)
275
308
# weight: torch.Tensor
309
+ # bias: Optional[torch.Tensor]
276
310
277
311
def __init__ (
278
312
self ,
@@ -282,15 +316,12 @@ def __init__(
282
316
super ().__init__ (layer_key , values )
283
317
284
318
self .weight = values ["diff" ]
285
-
286
- if len (values .keys ()) > 1 :
287
- _keys = list (values .keys ())
288
- _keys .remove ("diff" )
289
- raise NotImplementedError (f"Unexpected keys in lora diff layer: { _keys } " )
319
+ self .bias = values .get ("diff_b" , None )
290
320
291
321
self .rank = None # unscaled
322
+ self .check_keys (values , {"diff" , "diff_b" })
292
323
293
- def get_weight (self , orig_weight : Optional [ torch .Tensor ] ) -> torch .Tensor :
324
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
294
325
return self .weight
295
326
296
327
def calc_size (self ) -> int :
@@ -319,8 +350,9 @@ def __init__(
319
350
self .on_input = values ["on_input" ]
320
351
321
352
self .rank = None # unscaled
353
+ self .check_keys (values , {"weight" , "on_input" })
322
354
323
- def get_weight (self , orig_weight : Optional [ torch .Tensor ] ) -> torch .Tensor :
355
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
324
356
weight = self .weight
325
357
if not self .on_input :
326
358
weight = weight .reshape (- 1 , 1 )
@@ -458,24 +490,27 @@ def from_checkpoint(
458
490
state_dict = cls ._convert_sdxl_keys_to_diffusers_format (state_dict )
459
491
460
492
for layer_key , values in state_dict .items ():
493
+ # Detect layers according to LyCORIS detection logic(`weight_list_det`)
494
+ # https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
495
+
461
496
# lora and locon
462
- if "lora_down .weight" in values :
497
+ if "lora_up .weight" in values :
463
498
layer : AnyLoRALayer = LoRALayer (layer_key , values )
464
499
465
500
# loha
466
- elif "hada_w1_b " in values :
501
+ elif "hada_w1_a " in values :
467
502
layer = LoHALayer (layer_key , values )
468
503
469
504
# lokr
470
- elif "lokr_w1_b " in values or "lokr_w1 " in values :
505
+ elif "lokr_w1 " in values or "lokr_w1_a " in values :
471
506
layer = LoKRLayer (layer_key , values )
472
507
473
508
# diff
474
509
elif "diff" in values :
475
510
layer = FullLayer (layer_key , values )
476
511
477
512
# ia3
478
- elif "weight" in values and " on_input" in values :
513
+ elif "on_input" in values :
479
514
layer = IA3Layer (layer_key , values )
480
515
481
516
else :
0 commit comments