Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@
except ImportError:
logger.exception("Nodes `NunchakuFluxLoraLoader` and `NunchakuFluxLoraStack` import failed:")

try:
from .nodes.lora.qwenimage import NunchakuQwenImageLoraLoader, NunchakuQwenImageLoraStack

NODE_CLASS_MAPPINGS["NunchakuQwenImageLoraLoader"] = NunchakuQwenImageLoraLoader
NODE_CLASS_MAPPINGS["NunchakuQwenImageLoraStack"] = NunchakuQwenImageLoraStack
except ImportError:
logger.exception("Nodes `NunchakuQwenImageLoraLoader` and `NunchakuQwenImageLoraStack` import failed:")


try:
from .nodes.models.text_encoder import NunchakuTextEncoderLoader, NunchakuTextEncoderLoaderV2
Expand Down
44 changes: 44 additions & 0 deletions model_base/qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,48 @@ def load_model_weights(self, sd: dict[str, torch.Tensor], unet_prefix: str = "")
if isinstance(m, SVDQW4A4Linear):
if m.wtscale is not None:
m.wtscale = sd.pop(f"{n}.wtscale", 1.0)

# CRITICAL FIX: Fill _quantized_part_sd for LoRA support (following Flux approach)
# Store proj_down, proj_up, and qweight for LoRA merging
new_quantized_part_sd = {}
for k, v in sd.items():
if v.ndim == 1:
# Store all 1D tensors (biases, scales)
new_quantized_part_sd[k] = v
elif "qweight" in k:
# Store qweight shape info
new_quantized_part_sd[k] = v.to("meta")
elif "proj_down" in k or "proj_up" in k:
# Store REAL low-rank branches for LoRA merging and restoration
# Unlike Flux (which uses empty tensors), Qwen Image needs real weights
# for proper cleanup when removing LoRAs
new_quantized_part_sd[k] = v
elif "lora" in k:
# Store all lora-related keys (same as Flux implementation)
# This ensures reset_lora() can properly restore original weights
new_quantized_part_sd[k] = v

diffusion_model._quantized_part_sd = new_quantized_part_sd

# CRITICAL FIX: Initialize clean LoRA state to prevent pollution
# Clear any existing LoRA state and reset all LoRA strengths to 0
diffusion_model.comfy_lora_meta_list = []
diffusion_model.comfy_lora_sd_list = []
diffusion_model._lora_state_cache = {}

# Reset all LoRA strengths to 0 for a clean start
self._reset_all_lora_strength_clean(diffusion_model)

diffusion_model.load_state_dict(sd, strict=True)

def _reset_all_lora_strength_clean(self, diffusion_model):
"""
Reset LoRA strength to 0 for all SVDQW4A4Linear layers in the diffusion model.
This ensures a clean start without any residual LoRA effects.
"""
from nunchaku.models.linear import SVDQW4A4Linear

for name, module in diffusion_model.named_modules():
if isinstance(module, SVDQW4A4Linear):
# Reset to 0 to ensure no residual LoRA effects
module.lora_strength = 0.0
Loading