diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py index 99f70646db2..c440526b9b9 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -13,12 +13,6 @@ def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor: """An optimized implementation of the residual calculation for a sidecar linear LoRALayer.""" - # up matrix and down matrix have different ranks so we can't simply multiply them - if lora_layer.up.shape[1] != lora_layer.down.shape[0]: - x = torch.nn.functional.linear(input, lora_layer.get_weight(lora_weight), bias=lora_layer.bias) - x *= lora_weight * lora_layer.scale() - return x - x = torch.nn.functional.linear(input, lora_layer.down) if lora_layer.mid is not None: x = torch.nn.functional.linear(x, lora_layer.mid) diff --git a/invokeai/backend/patches/layers/lora_layer.py b/invokeai/backend/patches/layers/lora_layer.py index cf79f520519..c9210dce933 100644 --- a/invokeai/backend/patches/layers/lora_layer.py +++ b/invokeai/backend/patches/layers/lora_layer.py @@ -19,7 +19,6 @@ def __init__( self.up = up self.mid = mid self.down = down - self.are_ranks_equal = up.shape[1] == down.shape[0] @classmethod def from_state_dict_values( @@ -59,42 +58,12 @@ def from_state_dict_values( def _rank(self) -> int: return self.down.shape[0] - def fuse_weights(self, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor: - """ - Fuse the weights of the up and down matrices of a LoRA layer with different ranks. - - Since the Huggingface implementation of KQV projections are fused, when we convert to Kohya format - the LoRA weights have different ranks. This function handles the fusion of these differently sized - matrices. - """ - - fused_lora = torch.zeros((up.shape[0], down.shape[1]), device=down.device, dtype=down.dtype) - rank_diff = down.shape[0] / up.shape[1] - - if rank_diff > 1: - rank_diff = down.shape[0] / up.shape[1] - w_down = down.chunk(int(rank_diff), dim=0) - for w_down_chunk in w_down: - fused_lora = fused_lora + (torch.mm(up, w_down_chunk)) - else: - rank_diff = up.shape[1] / down.shape[0] - w_up = up.chunk(int(rank_diff), dim=0) - for w_up_chunk in w_up: - fused_lora = fused_lora + (torch.mm(w_up_chunk, down)) - - return fused_lora - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1]) weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) else: - # up matrix and down matrix have different ranks so we can't simply multiply them - if not self.are_ranks_equal: - weight = self.fuse_weights(self.up, self.down) - return weight - weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) return weight diff --git a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py index 7b5f3468963..41e41dbb517 100644 --- a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py @@ -20,14 +20,6 @@ FLUX_KOHYA_TRANSFORMER_KEY_REGEX = ( r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)" ) - -# A regex pattern that matches all of the last layer keys in the Kohya FLUX LoRA format. -# Example keys: -# lora_unet_final_layer_linear.alpha -# lora_unet_final_layer_linear.lora_down.weight -# lora_unet_final_layer_linear.lora_up.weight -FLUX_KOHYA_LAST_LAYER_KEY_REGEX = r"lora_unet_final_layer_(linear|linear1|linear2)_?(.*)" - # A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format. # Example keys: # lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha @@ -52,7 +44,6 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo """ return all( re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) - or re.match(FLUX_KOHYA_LAST_LAYER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() @@ -74,9 +65,6 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) - t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} for layer_name, layer_state_dict in grouped_state_dict.items(): if layer_name.startswith("lora_unet"): - # Skip the final layer. This is incompatible with current model definition. - if layer_name.startswith("lora_unet_final_layer"): - continue transformer_grouped_sd[layer_name] = layer_state_dict elif layer_name.startswith("lora_te1"): clip_grouped_sd[layer_name] = layer_state_dict