diff --git a/__init__.py b/__init__.py index 11674484..b93a43bd 100644 --- a/__init__.py +++ b/__init__.py @@ -49,6 +49,14 @@ except ImportError: logger.exception("Nodes `NunchakuFluxLoraLoader` and `NunchakuFluxLoraStack` import failed:") +try: + from .nodes.lora.qwenimage import NunchakuQwenImageLoraLoader, NunchakuQwenImageLoraStack + + NODE_CLASS_MAPPINGS["NunchakuQwenImageLoraLoader"] = NunchakuQwenImageLoraLoader + NODE_CLASS_MAPPINGS["NunchakuQwenImageLoraStack"] = NunchakuQwenImageLoraStack +except ImportError: + logger.exception("Nodes `NunchakuQwenImageLoraLoader` and `NunchakuQwenImageLoraStack` import failed:") + try: from .nodes.models.text_encoder import NunchakuTextEncoderLoader, NunchakuTextEncoderLoaderV2 diff --git a/model_base/qwenimage.py b/model_base/qwenimage.py index 2c9fe372..e431a3be 100644 --- a/model_base/qwenimage.py +++ b/model_base/qwenimage.py @@ -71,4 +71,48 @@ def load_model_weights(self, sd: dict[str, torch.Tensor], unet_prefix: str = "") if isinstance(m, SVDQW4A4Linear): if m.wtscale is not None: m.wtscale = sd.pop(f"{n}.wtscale", 1.0) + + # CRITICAL FIX: Fill _quantized_part_sd for LoRA support (following Flux approach) + # Store proj_down, proj_up, and qweight for LoRA merging + new_quantized_part_sd = {} + for k, v in sd.items(): + if v.ndim == 1: + # Store all 1D tensors (biases, scales) + new_quantized_part_sd[k] = v + elif "qweight" in k: + # Store qweight shape info + new_quantized_part_sd[k] = v.to("meta") + elif "proj_down" in k or "proj_up" in k: + # Store REAL low-rank branches for LoRA merging and restoration + # Unlike Flux (which uses empty tensors), Qwen Image needs real weights + # for proper cleanup when removing LoRAs + new_quantized_part_sd[k] = v + elif "lora" in k: + # Store all lora-related keys (same as Flux implementation) + # This ensures reset_lora() can properly restore original weights + new_quantized_part_sd[k] = v + + diffusion_model._quantized_part_sd = new_quantized_part_sd + + # CRITICAL FIX: Initialize clean LoRA state to prevent pollution + # Clear any existing LoRA state and reset all LoRA strengths to 0 + diffusion_model.comfy_lora_meta_list = [] + diffusion_model.comfy_lora_sd_list = [] + diffusion_model._lora_state_cache = {} + + # Reset all LoRA strengths to 0 for a clean start + self._reset_all_lora_strength_clean(diffusion_model) + diffusion_model.load_state_dict(sd, strict=True) + + def _reset_all_lora_strength_clean(self, diffusion_model): + """ + Reset LoRA strength to 0 for all SVDQW4A4Linear layers in the diffusion model. + This ensures a clean start without any residual LoRA effects. + """ + from nunchaku.models.linear import SVDQW4A4Linear + + for name, module in diffusion_model.named_modules(): + if isinstance(module, SVDQW4A4Linear): + # Reset to 0 to ensure no residual LoRA effects + module.lora_strength = 0.0 diff --git a/models/qwenimage.py b/models/qwenimage.py index 600339c0..a7e96f93 100644 --- a/models/qwenimage.py +++ b/models/qwenimage.py @@ -29,6 +29,109 @@ from ..mixins.model import NunchakuModelMixin +class LoRAConfigContainer(nn.Module): + """ + Lightweight container for LoRA configuration. + + This class acts as a transparent proxy to the transformer, + storing only the LoRA configuration separately for each model copy. + All method calls and attribute access are forwarded to the transformer. + + This design avoids the problems encountered with full wrapper implementations: + - No need to customize forward() for parameter name conversion + - No need to handle 5D/4D dimension mismatches + - No need to implement to_safely() and other ComfyUI methods + - No type checking failures + + Inherits from nn.Module to satisfy PyTorch's module hierarchy requirements. + + Attributes + ---------- + _transformer : NunchakuQwenImageTransformer2DModel + The shared transformer instance (contains LoRA cache). + _lora_config_list : list + Independent LoRA configuration for this container. + + Examples + -------- + >>> transformer = NunchakuQwenImageTransformer2DModel(...) + >>> container = LoRAConfigContainer(transformer) + >>> container._lora_config_list.append(("path/to/lora.safetensors", 1.0)) + >>> # All other attributes/methods transparently forwarded to transformer + >>> output = container(x, timestep, context, ...) # Calls transformer's forward + """ + + def __init__(self, transformer): + """ + Initialize the container with a transformer instance. + + Parameters + ---------- + transformer : NunchakuQwenImageTransformer2DModel + The transformer to wrap. + """ + super().__init__() + # Use object.__setattr__ to bypass nn.Module's __setattr__ for private attributes + object.__setattr__(self, "_transformer", transformer) + object.__setattr__(self, "_lora_config_list", []) + + def __getattr__(self, name): + """ + Forward all attribute access to the transformer. + + This makes the container transparent for all operations + except accessing _transformer and _lora_config_list. + + Note: This is called AFTER checking self.__dict__ and self.__class__.__dict__, + so it won't interfere with nn.Module's internal attributes. + """ + # Avoid recursion for _transformer + if name == "_transformer": + return object.__getattribute__(self, "_transformer") + return getattr(object.__getattribute__(self, "_transformer"), name) + + def __setattr__(self, name, value): + """ + Store private attributes in container, everything else in transformer. + + Private attributes (starting with '_') are stored in the container itself. + All other attributes are forwarded to the transformer. + """ + if name.startswith("_"): + object.__setattr__(self, name, value) + else: + setattr(object.__getattribute__(self, "_transformer"), name, value) + + def forward(self, *args, **kwargs): + """ + Forward pass - handles LoRA composition then delegates to transformer. + + This is the entry point when ComfyUI calls the model. + We inject the LoRA config into the transformer before calling it. + """ + # Temporarily inject LoRA config into transformer for this forward pass + transformer = object.__getattribute__(self, "_transformer") + lora_config_list = object.__getattribute__(self, "_lora_config_list") + + # Save original config (if any) + original_config = getattr(transformer, "_lora_config_list", None) + + # Inject our config + transformer._lora_config_list = lora_config_list + + try: + # Call transformer's forward + result = transformer(*args, **kwargs) + finally: + # Restore original config + if original_config is not None: + transformer._lora_config_list = original_config + elif hasattr(transformer, "_lora_config_list"): + delattr(transformer, "_lora_config_list") + + return result + + class NunchakuGELU(GELU): """ GELU activation with a quantized linear projection. @@ -149,6 +252,72 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = module(hidden_states) return hidden_states + def update_lora_params(self, lora_dict: dict[str, torch.Tensor]): + """ + Update LoRA parameters for the feed-forward network. + """ + + # Helper function to apply LoRA to a SVDQW4A4Linear layer + def apply_lora_to_linear(linear_layer, lora_dict, layer_prefix): + lora_down_key = None + lora_up_key = None + + # Find lora_down and lora_up for this layer + for k in lora_dict.keys(): + if layer_prefix in k: + if "lora_down" in k: + lora_down_key = k + elif "lora_up" in k: + lora_up_key = k + + if lora_down_key is None or lora_up_key is None: + return False + + lora_down_packed = lora_dict[lora_down_key] + lora_up_packed = lora_dict[lora_up_key] + + # The LoRA weights are already merged with original low-rank branches in the converter + # Just directly apply them + device = linear_layer.proj_down.device + dtype = linear_layer.proj_down.dtype + + # Directly replace parameters with merged weights + linear_layer.proj_down.data = lora_down_packed.to(device=device, dtype=dtype) + linear_layer.proj_up.data = lora_up_packed.to(device=device, dtype=dtype) + linear_layer.rank = lora_down_packed.shape[1] + + return True + + # Apply LoRA to each SVDQW4A4Linear layer in the network + for i, module in enumerate(self.net): + if isinstance(module, SVDQW4A4Linear): + apply_lora_to_linear(module, lora_dict, f"net.{i}") + elif ( + isinstance(module, NunchakuGELU) and hasattr(module, "proj") and isinstance(module.proj, SVDQW4A4Linear) + ): + # For GELU with proj attribute + apply_lora_to_linear(module.proj, lora_dict, f"net.{i}.proj") + + def restore_original_params(self): + """ + Restore original parameters for all quantized linear layers in the feed-forward network. + """ + + def restore_linear_layer(linear_layer, layer_prefix): + if hasattr(linear_layer, "_original_proj_down"): + linear_layer.proj_down = linear_layer._original_proj_down + linear_layer.proj_up = linear_layer._original_proj_up + linear_layer.rank = linear_layer._original_rank + + # Restore parameters for each SVDQW4A4Linear layer + for i, module in enumerate(self.net): + if isinstance(module, SVDQW4A4Linear): + restore_linear_layer(module, f"net.{i}") + elif ( + isinstance(module, NunchakuGELU) and hasattr(module, "proj") and isinstance(module.proj, SVDQW4A4Linear) + ): + restore_linear_layer(module.proj, f"net.{i}.proj") + class Attention(nn.Module): """ @@ -320,6 +489,88 @@ def forward( return img_attn_output, txt_attn_output + def update_lora_params(self, lora_dict: dict[str, torch.Tensor]): + """ + Update LoRA parameters for the attention module. + + This applies LoRA by concatenating LoRA projections with existing low-rank projections + in SVDQW4A4Linear layers. + """ + + # Helper function to apply LoRA to a SVDQW4A4Linear layer + def apply_lora_to_linear(linear_layer, lora_dict, layer_prefix): + lora_down_key = None + lora_up_key = None + + # Find lora_down/lora_up (Nunchaku format) or lora_A/lora_B (Diffusers format) + for k in lora_dict.keys(): + if layer_prefix in k: + if "lora_down" in k or "lora_A" in k: + lora_down_key = k + elif "lora_up" in k or "lora_B" in k: + lora_up_key = k + + if lora_down_key is None or lora_up_key is None: + return False # No LoRA for this layer + + lora_down_packed = lora_dict[lora_down_key] + lora_up_packed = lora_dict[lora_up_key] + + # The LoRA weights are already packed and merged in the converter + # Directly replace proj_down and proj_up (following official implementation) + # The LoRA weights are already merged with original low-rank branches in the converter + # Just directly apply them + device = linear_layer.proj_down.device + dtype = linear_layer.proj_down.dtype + + # Directly replace parameters with merged weights + linear_layer.proj_down.data = lora_down_packed.to(device=device, dtype=dtype) + linear_layer.proj_up.data = lora_up_packed.to(device=device, dtype=dtype) + linear_layer.rank = lora_down_packed.shape[1] + + return True + + # Apply LoRA to each quantized linear layer + applied = False + if isinstance(self.to_qkv, SVDQW4A4Linear): + applied |= apply_lora_to_linear(self.to_qkv, lora_dict, "to_qkv") + + if isinstance(self.add_qkv_proj, SVDQW4A4Linear): + applied |= apply_lora_to_linear(self.add_qkv_proj, lora_dict, "add_qkv_proj") + + if isinstance(self.to_out[0], SVDQW4A4Linear): + applied |= apply_lora_to_linear(self.to_out[0], lora_dict, "to_out.0") + + if isinstance(self.to_add_out, SVDQW4A4Linear): + applied |= apply_lora_to_linear(self.to_add_out, lora_dict, "to_add_out") + + # Summary log disabled - will show overall count instead + return applied + + def restore_original_params(self): + """ + Restore original parameters for all quantized linear layers in the attention module. + """ + + def restore_linear_layer(linear_layer, layer_prefix): + if hasattr(linear_layer, "_original_proj_down"): + linear_layer.proj_down = linear_layer._original_proj_down + linear_layer.proj_up = linear_layer._original_proj_up + linear_layer.rank = linear_layer._original_rank + + # Restore parameters for each quantized linear layer + if isinstance(self.to_qkv, SVDQW4A4Linear): + restore_linear_layer(self.to_qkv, "to_qkv") + + if isinstance(self.add_qkv_proj, SVDQW4A4Linear): + restore_linear_layer(self.add_qkv_proj, "add_qkv_proj") + + if isinstance(self.to_out[0], SVDQW4A4Linear): + restore_linear_layer(self.to_out[0], "to_out.0") + + if isinstance(self.to_add_out, SVDQW4A4Linear): + restore_linear_layer(self.to_add_out, "to_add_out") + class NunchakuQwenImageTransformerBlock(nn.Module): """ @@ -502,6 +753,51 @@ def forward( return encoder_hidden_states, hidden_states + def update_lora_params(self, lora_dict: dict): + """ + Update LoRA parameters for the transformer block. + + Directly applies LoRA to attention and MLP layers by calling their update methods. + + Parameters + ---------- + lora_dict : dict + Dictionary containing LoRA weights for this block in Nunchaku format (lora_down/lora_up). + """ + # Apply LoRA to attention + if hasattr(self.attn, "update_lora_params"): + attn_lora = {k: v for k, v in lora_dict.items() if "attn" in k} + if attn_lora: + self.attn.update_lora_params(attn_lora) + + # Apply LoRA to image stream MLP + if hasattr(self.img_mlp, "update_lora_params"): + img_mlp_lora = {k: v for k, v in lora_dict.items() if "img_mlp" in k} + if img_mlp_lora: + self.img_mlp.update_lora_params(img_mlp_lora) + + # Apply LoRA to text stream MLP + if hasattr(self.txt_mlp, "update_lora_params"): + txt_mlp_lora = {k: v for k, v in lora_dict.items() if "txt_mlp" in k} + if txt_mlp_lora: + self.txt_mlp.update_lora_params(txt_mlp_lora) + + def restore_original_params(self): + """ + Restore original parameters for all components in this transformer block. + """ + # Restore attention parameters + if hasattr(self.attn, "restore_original_params"): + self.attn.restore_original_params() + + # Restore image MLP parameters + if hasattr(self.img_mlp, "restore_original_params"): + self.img_mlp.restore_original_params() + + # Restore text MLP parameters + if hasattr(self.txt_mlp, "restore_original_params"): + self.txt_mlp.restore_original_params() + class NunchakuQwenImageTransformer2DModel(NunchakuModelMixin, QwenImageTransformer2DModel): """ @@ -563,6 +859,20 @@ def __init__( **kwargs, ): super(QwenImageTransformer2DModel, self).__init__() + + # LoRA support attributes (similar to nunchaku library implementation) + self._unquantized_part_sd: dict[str, torch.Tensor] = {} + self._unquantized_part_loras: dict[str, torch.Tensor] = {} + self._quantized_part_sd: dict[str, torch.Tensor] = {} + self._quantized_part_vectors: dict[str, torch.Tensor] = {} + + # ComfyUI LoRA related attributes + # Note: comfy_lora_meta_list and comfy_lora_sd_list are now initialized dynamically in _forward + # to support Flux-style caching. _lora_config_list is set by LoRA Loader nodes. + + # VAE scale factor for img_shapes calculation (same as diffusers pipeline) + self.vae_scale_factor = 8 # Default for Qwen Image + self.dtype = dtype self.patch_size = patch_size self.out_channels = out_channels or in_channels @@ -614,6 +924,170 @@ def __init__( ) self.gradient_checkpointing = False + def process_img(self, x, index=0, h_offset=0, w_offset=0): + """ + Preprocess an input image tensor for the model. + + Overrides the base class method to handle 4D tensors (batch, channels, height, width) + instead of 5D tensors required by ComfyUI's base implementation. + + Supports both Qwen Image (T2I) and Qwen Image Edit (I2I) models. + + Parameters + ---------- + x : torch.Tensor + Input image tensor of shape (batch, channels, height, width) or + (batch, channels, 1, height, width) for Image Edit models. + index : int, optional + Index for image ID encoding. + h_offset : int, optional + Height offset for patch IDs. + w_offset : int, optional + Width offset for patch IDs. + + Returns + ------- + img : torch.Tensor + Rearranged image tensor of shape (batch, num_patches, patch_dim). + img_ids : torch.Tensor + Image ID tensor of shape (batch, num_patches, 3). + orig_shape : tuple + Original shape (batch, channels, height, width) for unpatchify. + """ + from comfy.ldm.common_dit import pad_to_patch_size + from einops import rearrange, repeat + + # Handle 5D input for Image Edit models (batch, channels, 1, height, width) + # This happens when processing ref_latents in Qwen Image Edit + if x.ndim == 5: + x = x.squeeze(2) # Remove middle dimension -> (batch, channels, height, width) + + bs, c, h_orig, w_orig = x.shape + x = pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + # CRITICAL: The key insight is that rearrange() creates patches for the ENTIRE padded tensor + # So img_ids must match the actual patch grid created by rearrange() + _, _, h_padded, w_padded = x.shape + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) + + # img.shape[1] is the actual number of patches created by rearrange() + actual_patches = img.shape[1] + + # Calculate patch grid dimensions using original dimensions (consistent with diffusers) + # This matches the original QwenImageTransformer2DModel implementation + h_len = (h_orig + (self.patch_size // 2)) // self.patch_size + w_len = (w_orig + (self.patch_size // 2)) // self.patch_size + + # Verify that our calculation matches the actual patches + assert ( + h_len * w_len == actual_patches + ), f"Patch count mismatch: calculated={h_len * w_len}, actual={actual_patches}" + + h_offset = (h_offset + (self.patch_size // 2)) // self.patch_size + w_offset = (w_offset + (self.patch_size // 2)) // self.patch_size + + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + + # EXPERIMENTAL: Center-aligned position IDs (like Diffusers pipeline) + # Instead of 0 to h_len-1, use -(h_len//2) to +(h_len//2) + # This should make objects appear centered in non-square aspect ratios + h_center = h_len // 2 + w_center = w_len // 2 + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace( + -h_center + h_offset, h_len - 1 - h_center + h_offset, steps=h_len, device=x.device, dtype=x.dtype + ).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace( + -w_center + w_offset, w_len - 1 - w_center + w_offset, steps=w_len, device=x.device, dtype=x.dtype + ).unsqueeze(0) + + # Return orig_shape as tuple: (bs, c, h_padded, w_padded, h_orig, w_orig) + # h_padded/w_padded for unpatchify reshape, h_orig/w_orig for final cropping + return img, repeat(img_ids, "h w c -> b (h w) c", b=bs), (bs, c, h_padded, w_padded, h_orig, w_orig) + + def _rope_position_embedding(self, ids: torch.Tensor) -> torch.Tensor: + """ + RoPE-based position embedding using Nunchaku's implementation. + This should match the official Diffusers pipeline behavior. + + Uses axes_dims_rope to allocate dimensions for each axis (index, h, w). + For example, axes_dims_rope=(16, 56, 56) means: + - index: 16 dims + - h_pos: 56 dims + - w_pos: 56 dims + Total: 128 dims = attention_head_dim + """ + from nunchaku.models.embeddings import rope + + # Extract position indices for each axis + # ids shape: (batch, seq_len, 3) where 3 = [index, h_pos, w_pos] + batch_size, seq_len, n_axes = ids.shape + + # Apply RoPE for each axis with the correct dimension from axes_dims_rope + rope_embs = [] + for i in range(n_axes): + pos = ids[:, :, i] # Extract position for axis i + axis_dim = self.axes_dims_rope[i] # Get dimension for this axis + rope_emb = rope(pos, axis_dim, self.rope_theta) + rope_embs.append(rope_emb) + + # Concatenate along the dimension axis + image_rotary_emb = torch.cat(rope_embs, dim=-3) + # Apply the same transform as ComfyUI's pe_embedder: .squeeze(1).unsqueeze(2) + # Nunchaku rope outputs (batch, seq_len, dim, 1, 2) + # We need to add batch dimension first, then apply squeeze/unsqueeze + image_rotary_emb = image_rotary_emb.unsqueeze(0) # (1, batch, seq_len, dim, 1, 2) + image_rotary_emb = image_rotary_emb.squeeze(1) # (1, seq_len, dim, 1, 2) + image_rotary_emb = image_rotary_emb.unsqueeze(2) # (1, seq_len, 1, dim, 1, 2) + return image_rotary_emb + + def process_img_packed(self, x, index=0, h_offset=0, w_offset=0): + """ + Process image input to get packed latents (same as diffusers pipeline). + This method is compatible with the official diffusers pipeline approach. + """ + # Store orig_shape for later use in _forward + img, _, orig_shape = self.process_img(x, index, h_offset, w_offset) + self.last_orig_shape = orig_shape + return img + + def forward( + self, + hidden_states=None, + encoder_hidden_states=None, + encoder_hidden_states_mask=None, + timestep=None, + x=None, + context=None, + attention_mask=None, + **kwargs, + ): + """ + Forward pass adapter for ComfyUI compatibility. + + This method handles parameter name conversion between ComfyUI's convention + (hidden_states, encoder_hidden_states) and the internal implementation + (x, context). + + Parameters can be provided in either naming convention: + - ComfyUI style: hidden_states, encoder_hidden_states, encoder_hidden_states_mask, timestep + - Internal style: x, context, attention_mask, timesteps + + This method delegates to _forward() with the correct parameter names. + """ + # Convert parameter names from ComfyUI to internal format + if x is None and hidden_states is not None: + x = hidden_states + if context is None and encoder_hidden_states is not None: + context = encoder_hidden_states + if attention_mask is None and encoder_hidden_states_mask is not None: + attention_mask = encoder_hidden_states_mask + if "timesteps" not in kwargs and timestep is not None: + kwargs["timesteps"] = timestep + + # Call internal _forward with correct parameter names + return self._forward(x=x, context=context, attention_mask=attention_mask, **kwargs) + def _forward( self, x, @@ -624,6 +1098,7 @@ def _forward( ref_latents=None, transformer_options={}, control=None, + controlnet_block_samples=None, **kwargs, ): """ @@ -658,14 +1133,89 @@ def _forward( if self.offload: self.offload_manager.set_device(device) - timestep = timesteps - encoder_hidden_states = context - encoder_hidden_states_mask = attention_mask + # CRITICAL: Handle both control dict and controlnet_block_samples list + # Wrapper now passes complete control dict (with weight/scale) + # But for backward compatibility, also support old controlnet_block_samples list + if control is None and controlnet_block_samples is not None: + # Old format: list of tensors → convert to dict + control = {"input": controlnet_block_samples} + + # LoRA composition logic with caching (Flux-style) + # Note: self is always the transformer (not the container) + # The container injects _lora_config_list into the transformer before calling forward + + if hasattr(self, "_lora_config_list"): + # If config is empty, clear all LoRA parameters + if len(self._lora_config_list) == 0: + if hasattr(self, "comfy_lora_meta_list") and len(self.comfy_lora_meta_list) > 0: + self.reset_lora() + self.comfy_lora_meta_list = [] + self.comfy_lora_sd_list = [] + # If config is not empty, execute sync logic + elif len(self._lora_config_list) > 0: + from nunchaku.lora.qwenimage import compose_lora + from nunchaku.utils import load_state_dict_in_safetensors + + # Initialize cache lists if not present (on transformer, shared) + if not hasattr(self, "comfy_lora_meta_list"): + self.comfy_lora_meta_list = [] + if not hasattr(self, "comfy_lora_sd_list"): + self.comfy_lora_sd_list = [] + + # Smart sync: compare config with applied state + if self._lora_config_list != self.comfy_lora_meta_list: + # Remove excess cache entries if config list shortened + for _ in range(max(0, len(self.comfy_lora_meta_list) - len(self._lora_config_list))): + self.comfy_lora_meta_list.pop() + self.comfy_lora_sd_list.pop() + + # Sync each LoRA + lora_to_be_composed = [] + for i in range(len(self._lora_config_list)): + meta = self._lora_config_list[i] # (path, strength) + + # New LoRA: load and cache + if i >= len(self.comfy_lora_meta_list): + sd = load_state_dict_in_safetensors(meta[0]) + self.comfy_lora_meta_list.append(meta) + self.comfy_lora_sd_list.append(sd) + # LoRA config changed + elif self.comfy_lora_meta_list[i] != meta: + # Path changed: reload file + if meta[0] != self.comfy_lora_meta_list[i][0]: + sd = load_state_dict_in_safetensors(meta[0]) + self.comfy_lora_sd_list[i] = sd + # Only strength changed: reuse cache + self.comfy_lora_meta_list[i] = meta + + # Add to composition list (always recompose with current strength) + lora_to_be_composed.append(({k: v for k, v in self.comfy_lora_sd_list[i].items()}, meta[1])) + + # Compose all LoRAs + composed_lora = compose_lora(lora_to_be_composed) + + # Apply to model + if len(composed_lora) == 0: + self.reset_lora() + else: + self.update_lora_params(composed_lora) + # Activate LoRA + from nunchaku.models.linear import SVDQW4A4Linear + + for block in self.transformer_blocks: + for module in block.modules(): + if isinstance(module, SVDQW4A4Linear): + module.lora_strength = 1.0 + + # CRITICAL: Use the EXACT original ComfyUI approach + # Process image input to get both hidden_states and img_ids hidden_states, img_ids, orig_shape = self.process_img(x) + self.last_orig_shape = orig_shape # Set for later use in unpatchify num_embeds = hidden_states.shape[1] if ref_latents is not None: + # Handle reference latents (for Kontext, etc.) - use original method h = 0 w = 0 index = 0 @@ -690,79 +1240,101 @@ def _forward( hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + # Extract dimensions from orig_shape for unpatchify + # orig_shape = (bs, c, h_padded, w_padded, h_orig, w_orig) + bs, c, h_padded, w_padded, h_orig, w_orig = self.last_orig_shape + + # Prepare ControlNet parameters + if control is not None and controlnet_block_samples is not None: + # Merge control dict with controlnet_block_samples list + if isinstance(control, dict): + # Convert list format to dict format for internal processing + control_dict = {} + for i, block_sample in enumerate(controlnet_block_samples): + control_dict[f"block_{i}"] = block_sample + control.update(control_dict) + controlnet_block_samples = control + elif control is not None: + controlnet_block_samples = control + elif controlnet_block_samples is not None: + pass # Use as-is + else: + controlnet_block_samples = None + + # Implement the official Nunchaku forward logic directly + # This matches the nunchaku/nunchaku/models/transformers/transformer_qwenimage.py implementation + device = hidden_states.device + if self.offload: + self.offload_manager.set_device(device) + + hidden_states = self.img_in(hidden_states) + + timesteps = timesteps.to(hidden_states.dtype) + encoder_hidden_states = self.txt_norm(context) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timesteps, hidden_states) + if guidance is None + else self.time_text_embed(timesteps, guidance, hidden_states) + ) + + # Calculate txt_start using the original ComfyUI method txt_start = round( max( ((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2, ) ) + + # Generate txt_ids exactly like original ComfyUI txt_ids = ( torch.arange(txt_start, txt_start + context.shape[1], device=x.device) .reshape(1, -1, 1) .repeat(x.shape[0], 1, 3) ) + + # Combine txt_ids and img_ids exactly like original ComfyUI ids = torch.cat((txt_ids, img_ids), dim=1) image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) del ids, txt_ids, img_ids - hidden_states = self.img_in(hidden_states) - encoder_hidden_states = self.txt_norm(encoder_hidden_states) - encoder_hidden_states = self.txt_in(encoder_hidden_states) - - if guidance is not None: - guidance = guidance * 1000 - - temb = ( - self.time_text_embed(timestep, hidden_states) - if guidance is None - else self.time_text_embed(timestep, guidance, hidden_states) - ) - - patches_replace = transformer_options.get("patches_replace", {}) - blocks_replace = patches_replace.get("dit", {}) - - # Setup compute stream for offloading compute_stream = torch.cuda.current_stream() if self.offload: self.offload_manager.initialize(compute_stream) - - for i, block in enumerate(self.transformer_blocks): + for block_idx, block in enumerate(self.transformer_blocks): with torch.cuda.stream(compute_stream): if self.offload: - block = self.offload_manager.get_block(i) - if ("double_block", i) in blocks_replace: - - def block_wrap(args): - out = {} - out["txt"], out["img"] = block( - hidden_states=args["img"], - encoder_hidden_states=args["txt"], - encoder_hidden_states_mask=encoder_hidden_states_mask, - temb=args["vec"], - image_rotary_emb=args["pe"], - ) - return out - - out = blocks_replace[("double_block", i)]( - {"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, - {"original_block": block_wrap}, + block = self.offload_manager.get_block(block_idx) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + attention_mask, + temb, + image_rotary_emb, ) - hidden_states = out["img"] - encoder_hidden_states = out["txt"] else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, + encoder_hidden_states_mask=attention_mask, temb=temb, image_rotary_emb=image_rotary_emb, ) - # ControlNet helpers(device/dtype-safe residual adds) + + # ControlNet helpers (device/dtype-safe residual adds) _control = ( control if control is not None else (transformer_options.get("control", None) if isinstance(transformer_options, dict) else None) ) + if isinstance(_control, dict): control_i = _control.get("input") try: @@ -772,17 +1344,24 @@ def block_wrap(args): else: control_i = None _scale = 1.0 - if control_i is not None and i < len(control_i): - add = control_i[i] + + if control_i is not None and block_idx < len(control_i): + add = control_i[block_idx] if add is not None: if ( getattr(add, "device", None) != hidden_states.device or getattr(add, "dtype", None) != hidden_states.dtype ): add = add.to(device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True) - t = min(hidden_states.shape[1], add.shape[1]) - if t > 0: - hidden_states[:, :t].add_(add[:, :t], alpha=_scale) + # Check if shapes match exactly (following official nunchaku implementation) + if hidden_states.shape == add.shape: + # Shapes match - simple addition (like official implementation) + hidden_states = hidden_states + add * _scale + else: + # Shapes don't match - use safe slicing + t = min(hidden_states.shape[1], add.shape[1]) + if t > 0: + hidden_states[:, :t] = hidden_states[:, :t] + add[:, :t] * _scale if self.offload: self.offload_manager.step(compute_stream) @@ -790,11 +1369,123 @@ def block_wrap(args): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) + # Unpatchify: convert from (batch, num_patches, patch_dim) to (batch, channels, height, width) + bs, c, h_padded, w_padded, h_orig, w_orig = self.last_orig_shape + h_len = (h_orig + (self.patch_size // 2)) // self.patch_size + w_len = (w_orig + (self.patch_size // 2)) // self.patch_size + num_embeds = h_len * w_len + + # Reshape to image: (batch, num_patches, patch_dim) -> (batch, channels, height, width) hidden_states = hidden_states[:, :num_embeds].view( - orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2 + bs, h_len, w_len, self.out_channels, self.patch_size, self.patch_size ) hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) - return hidden_states.reshape(orig_shape)[:, :, :, : x.shape[-2], : x.shape[-1]] + # Use padded dimensions for reshape, then crop to original + output = hidden_states.reshape(bs, self.out_channels, h_padded, w_padded)[:, :, :h_orig, :w_orig] + + torch.cuda.empty_cache() + + return (output,) + + def update_lora_params(self, lora_dict: dict, num_loras: int = 1): + """ + Update LoRA parameters for the Qwen Image model. + + This method applies LoRA weights to the model. + For ComfyUI-nunchaku, we use a simplified approach that directly applies + LoRA weights without the complex quantization handling. + + Parameters + ---------- + lora_dict : dict + Dictionary containing LoRA weights in Diffusers or Nunchaku format. + num_loras : int, optional + Number of LoRAs that were composed. If > 1, this is a composed LoRA. + Used to determine whether to merge with base model. + """ + import logging + + logger = logging.getLogger(__name__) + + # Import necessary functions + from nunchaku.lora.qwenimage import is_nunchaku_format, to_nunchaku + + # Convert to nunchaku format if needed + if not is_nunchaku_format(lora_dict): + logger.debug("Converting LoRA to Nunchaku format") + + # Check if this is a composed LoRA + is_composed = num_loras > 1 + + # Always use skip_base_merge=False (Qwen Image requires base model low-rank branches) + if is_composed: + logger.debug(f"Detected composed LoRA ({num_loras} LoRAs)") + else: + logger.debug("Single LoRA detected") + + lora_dict = to_nunchaku(lora_dict, base_sd=self._quantized_part_sd, skip_base_merge=False) + logger.debug(f"Converted LoRA to Nunchaku format: {len(lora_dict)} keys") + else: + logger.debug("LoRA already in Nunchaku format") + + # Apply LoRA to transformer blocks + blocks_updated = 0 + for i, block in enumerate(self.transformer_blocks): + # Extract LoRA weights for this block + block_lora = {} + for k, v in lora_dict.items(): + if f"transformer_blocks.{i}." in k or f"blocks.{i}." in k: + # Remove all prefixes to get relative key + parts = k.split(f".{i}.") + if len(parts) > 1: + relative_key = parts[-1] + block_lora[relative_key] = v + + # Apply LoRA to this block if it has any weights + if block_lora: + # Disabled detailed logging - only show final summary + # if i == 0: # Only log first block to reduce noise + # logger.info(f" Block {i}: {len(block_lora)} LoRA keys") + if hasattr(block, "update_lora_params"): + block.update_lora_params(block_lora) + blocks_updated += 1 + + logger.info(f"LoRA applied to {blocks_updated}/{len(self.transformer_blocks)} blocks") + + def restore_original_params(self): + """ + Restore original parameters for all transformer blocks. + This method should be called when LoRA is no longer needed. + """ + import logging + + logger = logging.getLogger(__name__) + + logger.info("🔄 Restoring original model parameters...") + blocks_restored = 0 + for block in self.transformer_blocks: + if hasattr(block, "restore_original_params"): + block.restore_original_params() + blocks_restored += 1 + + logger.info(f"Restored original parameters for {blocks_restored}/{len(self.transformer_blocks)} blocks") + + def reset_lora(self): + """ + Reset LoRA parameters to remove all LoRA effects. + """ + # Import the nunchaku library's transformer model + from nunchaku.models.transformers.transformer_qwenimage import ( + NunchakuQwenImageTransformer2DModel as NunchakuQwenImageTransformer2DModelLib, + ) + + # Check if the nunchaku library's model has the reset_lora method + if hasattr(NunchakuQwenImageTransformer2DModelLib, "reset_lora"): + NunchakuQwenImageTransformer2DModelLib.reset_lora(self) + else: + # Fallback: clear LoRA lists + self.comfy_lora_meta_list = [] + self.comfy_lora_sd_list = [] def set_offload(self, offload: bool, **kwargs): """ @@ -836,3 +1527,38 @@ def set_offload(self, offload: bool, **kwargs): self.offload_manager = None gc.collect() torch.cuda.empty_cache() + + def set_lora_strength(self, strength: float): + """ + Sets the LoRA scaling strength for the model. + + This method allows dynamic adjustment of LoRA strength, similar to Flux's setLoraScale. + The strength is applied only to the LoRA part (ranks beyond original_rank), while + the original low-rank branches remain at strength 1.0. + + Parameters + ---------- + strength : float, optional + LoRA scaling strength (default: 1). + + Note: This function will change the strength of all the LoRAs. So only use it when you only have a single LoRA. + """ + # Set LoRA strength for all SVDQW4A4Linear layers in transformer blocks + from nunchaku.models.linear import SVDQW4A4Linear + + for block in self.transformer_blocks: + # Set strength for all SVDQW4A4Linear layers in this block + for module in block.modules(): + if isinstance(module, SVDQW4A4Linear): + module.set_lora_strength(strength) + + # Handle unquantized part (similar to Flux implementation) + if len(self._unquantized_part_loras) > 0: + self._update_unquantized_part_lora_params(strength) + if len(self._quantized_part_vectors) > 0: + from nunchaku.lora.qwenimage.utils import fuse_vectors + + vector_dict = fuse_vectors(self._quantized_part_vectors, self._quantized_part_sd, strength) + for block in self.transformer_blocks: + if hasattr(block, "update_lora_params"): + block.update_lora_params(vector_dict) diff --git a/nodes/lora/qwenimage.py b/nodes/lora/qwenimage.py new file mode 100644 index 00000000..73a804ff --- /dev/null +++ b/nodes/lora/qwenimage.py @@ -0,0 +1,321 @@ +""" +This module provides the :class:`NunchakuQwenImageLoraLoader` node +for applying LoRA weights to Nunchaku Qwen Image models within ComfyUI. +""" + +import copy +import logging +import os + +import folder_paths + +# from nunchaku.lora.qwenimage import compose_lora, to_diffusers # Not needed here + +# from ...wrappers.qwenimage import ComfyQwenImageWrapper # Not needed - working directly with transformer + +# Get log level from environment variable (default to INFO) +log_level = os.getenv("LOG_LEVEL", "INFO").upper() + +# Configure logging +logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class NunchakuQwenImageLoraLoader: + """ + Node for loading and applying a LoRA to a Nunchaku Qwen Image model. + + This implementation follows the Flux LoRA Loader design pattern: + - LoRAs are stored as metadata (path + strength) without immediate application + - Actual composition and application happens lazily in the wrapper's forward pass + - This allows flexible strength adjustment and avoids redundant conversions + + Attributes + ---------- + RETURN_TYPES : tuple + The return type of the node ("MODEL",). + OUTPUT_TOOLTIPS : tuple + Tooltip for the output. + FUNCTION : str + The function to call ("load_lora"). + TITLE : str + Node title. + CATEGORY : str + Node category. + DESCRIPTION : str + Node description. + """ + + @classmethod + def INPUT_TYPES(s): + """ + Defines the input types and tooltips for the node. + + Returns + ------- + dict + A dictionary specifying the required inputs and their descriptions for the node interface. + """ + return { + "required": { + "model": ( + "MODEL", + { + "tooltip": "The diffusion model the LoRA will be applied to. " + "Make sure the model is loaded by `Nunchaku Qwen Image DiT Loader`." + }, + ), + "lora_name": ( + folder_paths.get_filename_list("loras"), + {"tooltip": "The file name of the LoRA."}, + ), + "lora_strength": ( + "FLOAT", + { + "default": 1.0, + "min": -100.0, + "max": 100.0, + "step": 0.01, + "tooltip": "How strongly to modify the diffusion model. This value can be negative.", + }, + ), + } + } + + RETURN_TYPES = ("MODEL",) + OUTPUT_TOOLTIPS = ("The modified diffusion model.",) + FUNCTION = "load_lora" + TITLE = "Nunchaku Qwen Image LoRA Loader" + + CATEGORY = "Nunchaku" + DESCRIPTION = ( + "LoRAs are used to modify the diffusion model, " + "altering the way in which latents are denoised such as applying styles. " + "You can link multiple LoRA nodes." + ) + + def load_lora(self, model, lora_name: str, lora_strength: float): + """ + Apply a LoRA to a Nunchaku Qwen Image diffusion model. + + Following Flux's design pattern, this method stores LoRA metadata in the wrapper. + The actual composition and application happens lazily in the wrapper's forward pass. + + Parameters + ---------- + model : object + The diffusion model to modify. + lora_name : str + The name of the LoRA to apply. + lora_strength : float + The strength with which to apply the LoRA. + + Returns + ------- + tuple + A tuple containing the modified diffusion model. + """ + if abs(lora_strength) < 1e-5: + return (model,) # If the strength is too small, return the original model + + model_wrapper = model.model.diffusion_model + + # Check if this is a ComfyQwenImageWrapper + from ...wrappers.qwenimage import ComfyQwenImageWrapper + + if not isinstance(model_wrapper, ComfyQwenImageWrapper): + logger.error("❌ Model type mismatch!") + logger.error(" Expected: ComfyQwenImageWrapper (Nunchaku Qwen Image model)") + logger.error(f" Got: {type(model_wrapper).__name__}") + logger.error(" Please make sure you're using 'Nunchaku Qwen Image DiT Loader' to load the model.") + raise TypeError( + f"This LoRA loader only works with Nunchaku Qwen Image models. " + f"Got {type(model_wrapper).__name__} instead. " + f"Please use 'Nunchaku Qwen Image DiT Loader' to load your model." + ) + + transformer = model_wrapper.model + + # Flux-style deepcopy: temporarily remove transformer to avoid copying it + model_wrapper.model = None + ret_model = copy.deepcopy(model) # copy everything except the model + ret_model_wrapper = ret_model.model.diffusion_model + + if not isinstance(ret_model_wrapper, ComfyQwenImageWrapper): + raise TypeError(f"Model wrapper type changed after deepcopy: {type(ret_model_wrapper).__name__}") + + model_wrapper.model = transformer + ret_model_wrapper.model = transformer # Share the same transformer + + lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) + ret_model_wrapper.loras.append((lora_path, lora_strength)) + + logger.info(f"LoRA added: {lora_name} (strength={lora_strength})") + logger.debug(f"Total LoRAs: {len(ret_model_wrapper.loras)}") + + return (ret_model,) + + +class NunchakuQwenImageLoraStack: + """ + Node for loading and applying multiple LoRAs to a Nunchaku Qwen Image model with dynamic input. + + This node allows you to configure multiple LoRAs with their respective strengths + in a single node, providing the same effect as chaining multiple LoRA nodes. + + Following Flux's design pattern, this method only stores LoRA metadata. + The actual composition and application happens lazily in the wrapper's forward pass. + + Attributes + ---------- + RETURN_TYPES : tuple + The return type of the node ("MODEL",). + OUTPUT_TOOLTIPS : tuple + Tooltip for the output. + FUNCTION : str + The function to call ("load_lora_stack"). + TITLE : str + Node title. + CATEGORY : str + Node category. + DESCRIPTION : str + Node description. + """ + + @classmethod + def INPUT_TYPES(s): + """ + Defines the input types for the LoRA stack node. + + Returns + ------- + dict + A dictionary specifying the required inputs and optional LoRA inputs. + """ + # Base inputs + inputs = { + "required": { + "model": ( + "MODEL", + { + "tooltip": "The diffusion model the LoRAs will be applied to. " + "Make sure the model is loaded by `Nunchaku Qwen Image DiT Loader`." + }, + ), + }, + "optional": {}, + } + + # Add fixed number of LoRA inputs (15 slots) + for i in range(1, 16): # Support up to 15 LoRAs + inputs["optional"][f"lora_name_{i}"] = ( + ["None"] + folder_paths.get_filename_list("loras"), + {"tooltip": f"The file name of LoRA {i}. Select 'None' to skip this slot."}, + ) + inputs["optional"][f"lora_strength_{i}"] = ( + "FLOAT", + { + "default": 1.0, + "min": -100.0, + "max": 100.0, + "step": 0.01, + "tooltip": f"Strength for LoRA {i}. This value can be negative.", + }, + ) + + return inputs + + RETURN_TYPES = ("MODEL",) + OUTPUT_TOOLTIPS = ("The modified diffusion model with all LoRAs applied.",) + FUNCTION = "load_lora_stack" + TITLE = "Nunchaku Qwen Image LoRA Stack" + + CATEGORY = "Nunchaku" + DESCRIPTION = ( + "Apply multiple LoRAs to a diffusion model in a single node. " + "Equivalent to chaining multiple LoRA nodes but more convenient for managing many LoRAs. " + "Supports up to 15 LoRAs simultaneously. Set unused slots to 'None' to skip them." + ) + + def load_lora_stack(self, model, **kwargs): + """ + Apply multiple LoRAs to a Nunchaku Qwen Image diffusion model. + + Following Flux's design pattern, this method uses shared transformer instances + to avoid memory overhead and enable efficient LoRA caching. + + Parameters + ---------- + model : object + The diffusion model to modify. + **kwargs + Dynamic LoRA name and strength parameters. + + Returns + ------- + tuple + A tuple containing the modified diffusion model. + """ + # Collect LoRA information to apply + loras_to_apply = [] + + for i in range(1, 16): # Check all 15 LoRA slots + lora_name = kwargs.get(f"lora_name_{i}") + lora_strength = kwargs.get(f"lora_strength_{i}", 1.0) + + # Skip unset or None LoRAs + if lora_name is None or lora_name == "None" or lora_name == "": + continue + + # Skip LoRAs with zero strength + if abs(lora_strength) < 1e-5: + continue + + loras_to_apply.append((lora_name, lora_strength)) + + # If no LoRAs need to be applied, return the original model + if not loras_to_apply: + return (model,) + + model_wrapper = model.model.diffusion_model + + # Check if this is a ComfyQwenImageWrapper + from ...wrappers.qwenimage import ComfyQwenImageWrapper + + if not isinstance(model_wrapper, ComfyQwenImageWrapper): + logger.error("❌ Model type mismatch!") + logger.error(" Expected: ComfyQwenImageWrapper (Nunchaku Qwen Image model)") + logger.error(f" Got: {type(model_wrapper).__name__}") + logger.error(" Please make sure you're using 'Nunchaku Qwen Image DiT Loader' to load the model.") + raise TypeError( + f"This LoRA loader only works with Nunchaku Qwen Image models. " + f"Got {type(model_wrapper).__name__} instead. " + f"Please use 'Nunchaku Qwen Image DiT Loader' to load your model." + ) + + transformer = model_wrapper.model + + # Flux-style deepcopy: temporarily remove transformer to avoid copying it + model_wrapper.model = None + ret_model = copy.deepcopy(model) # copy everything except the model + ret_model_wrapper = ret_model.model.diffusion_model + + if not isinstance(ret_model_wrapper, ComfyQwenImageWrapper): + raise TypeError(f"Model wrapper type changed after deepcopy: {type(ret_model_wrapper).__name__}") + + model_wrapper.model = transformer + ret_model_wrapper.model = transformer # Share the same transformer + + # Clear existing LoRA list + ret_model_wrapper.loras = [] + + # Add all LoRAs + for lora_name, lora_strength in loras_to_apply: + lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) + ret_model_wrapper.loras.append((lora_path, lora_strength)) + + logger.debug(f"LoRA added to stack: {lora_name} (strength={lora_strength})") + + logger.info(f"Total LoRAs in stack: {len(ret_model_wrapper.loras)}") + + return (ret_model,) diff --git a/nodes/models/qwenimage.py b/nodes/models/qwenimage.py index c76c6f3b..343ad1b3 100644 --- a/nodes/models/qwenimage.py +++ b/nodes/models/qwenimage.py @@ -214,4 +214,17 @@ def load_model( cpu_offload_enabled, num_blocks_on_gpu=num_blocks_on_gpu, use_pin_memory=use_pin_memory == "enable" ) + # Wrap transformer in ComfyQwenImageWrapper for LoRA support (Flux-style) + from ...models.qwenimage import NunchakuQwenImageTransformer2DModel + from ...wrappers.qwenimage import ComfyQwenImageWrapper + + if isinstance(model.model.diffusion_model, NunchakuQwenImageTransformer2DModel): + # Only wrap if not already wrapped + if not isinstance(model.model.diffusion_model, ComfyQwenImageWrapper): + wrapper = ComfyQwenImageWrapper( + model=model.model.diffusion_model, config=model.model.model_config.unet_config + ) + model.model.diffusion_model = wrapper + logger.debug("Wrapped transformer in ComfyQwenImageWrapper for LoRA support") + return (model,) diff --git a/wrappers/qwenimage.py b/wrappers/qwenimage.py new file mode 100644 index 00000000..ae595b3f --- /dev/null +++ b/wrappers/qwenimage.py @@ -0,0 +1,377 @@ +""" +This module provides a wrapper for the :class:`~nunchaku.models.transformers.transformer_qwenimage.NunchakuQwenImageTransformer2DModel`, +enabling integration with ComfyUI forward, LoRA composition, and advanced caching strategies. +""" + +from typing import Callable + +import torch +from comfy.ldm.common_dit import pad_to_patch_size +from einops import rearrange, repeat +from torch import nn + +from nunchaku import NunchakuQwenImageTransformer2DModel +from nunchaku.caching.fbcache import cache_context, create_cache_context +from nunchaku.lora.qwenimage.compose import compose_lora +from nunchaku.utils import load_state_dict_in_safetensors + + +class ComfyQwenImageWrapper(nn.Module): + """ + Wrapper for :class:`~nunchaku.models.transformers.transformer_qwenimage.NunchakuQwenImageTransformer2DModel` + to support ComfyUI workflows, LoRA composition, and caching. + + Parameters + ---------- + model : :class:`~nunchaku.models.transformers.transformer_qwenimage.NunchakuQwenImageTransformer2DModel` + The underlying Nunchaku model to wrap. + config : dict + Model configuration dictionary. + customized_forward : Callable, optional + Optional custom forward function. + forward_kwargs : dict, optional + Additional keyword arguments for the forward pass. + + Attributes + ---------- + model : :class:`~nunchaku.models.transformers.transformer_qwenimage.NunchakuQwenImageTransformer2DModel` + The wrapped model. + dtype : torch.dtype + Data type of the model parameters. + config : dict + Model configuration. + loras : list + List of LoRA metadata for composition. + customized_forward : Callable or None + Custom forward function if provided. + forward_kwargs : dict + Additional arguments for the forward pass. + """ + + def __init__( + self, + model: NunchakuQwenImageTransformer2DModel, + config, + customized_forward: Callable = None, + forward_kwargs: dict | None = {}, + ): + super(ComfyQwenImageWrapper, self).__init__() + self.model = model + self.dtype = next(model.parameters()).dtype + self.config = config + self.loras = [] + + self.customized_forward = customized_forward + self.forward_kwargs = {} if forward_kwargs is None else forward_kwargs + + self._prev_timestep = None # for first-block cache + self._cache_context = None + + def to_safely(self, device): + """ + Safely move the model to the specified device. + Required by NunchakuModelPatcher for device management. + """ + if hasattr(self.model, "to_safely"): + return self.model.to_safely(device) + else: + return self.model.to(device) + + def process_img(self, x, index=0, h_offset=0, w_offset=0): + """ + Preprocess an input image tensor for the model. + + Pads and rearranges the image into patches and generates corresponding image IDs. + + Parameters + ---------- + x : torch.Tensor + Input image tensor of shape (batch, channels, height, width). + index : int, optional + Index for image ID encoding. + h_offset : int, optional + Height offset for patch IDs. + w_offset : int, optional + Width offset for patch IDs. + + Returns + ------- + img : torch.Tensor + Rearranged image tensor of shape (batch, num_patches, patch_dim). + img_ids : torch.Tensor + Image ID tensor of shape (batch, num_patches, 3). + """ + bs, c, h, w = x.shape + patch_size = self.config.get("patch_size", 2) + x = pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + h_len = (h + (patch_size // 2)) // patch_size + w_len = (w + (patch_size // 2)) // patch_size + + h_offset = (h_offset + (patch_size // 2)) // patch_size + w_offset = (w_offset + (patch_size // 2)) // patch_size + + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace( + h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype + ).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace( + w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype + ).unsqueeze(0) + return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) + + def forward( + self, + x, + timestep, + context=None, + y=None, + guidance=None, + control=None, + transformer_options={}, + **kwargs, + ): + """ + Forward pass for the wrapped model. + + Handles LoRA composition, caching, and dual-stream processing for Qwen Image. + + Parameters + ---------- + x : torch.Tensor + Input image tensor. + timestep : float or torch.Tensor + Diffusion timestep. + context : torch.Tensor + Context tensor (e.g., text embeddings). + y : torch.Tensor + Pooled projections or additional conditioning. + guidance : torch.Tensor + Guidance embedding or value. + control : dict, optional + ControlNet input and output samples. + transformer_options : dict, optional + Additional transformer options. + **kwargs + Additional keyword arguments, e.g., 'ref_latents'. + + Returns + ------- + out : torch.Tensor + Output tensor of the same spatial size as the input. + """ + if isinstance(timestep, torch.Tensor): + if timestep.numel() == 1: + timestep_float = timestep.item() + else: + timestep_float = timestep.flatten()[0].item() + else: + assert isinstance(timestep, float) + timestep_float = timestep + + model = self.model + if model is None: + raise ValueError("Wrapped model is None!") + if not ( + type(model).__name__ == "NunchakuQwenImageTransformer2DModel" + or isinstance(model, NunchakuQwenImageTransformer2DModel) + ): + raise TypeError(f"Expected NunchakuQwenImageTransformer2DModel, got {type(model).__name__}") + + # Check if x is already processed or needs processing + input_is_5d = False + if x.ndim == 5: + # x is (batch, channels, 1, height, width) - squeeze the middle dimension + input_is_5d = True + x = x.squeeze(2) # Now (batch, channels, height, width) + + # Keep x in 4D format and let model's _forward handle process_img + if x.ndim != 4: + raise ValueError(f"Unexpected input shape: {x.shape}, expected 4D tensor") + + # load and compose LoRA + if self.loras != model.comfy_lora_meta_list: + from nunchaku.lora.qwenimage import is_nunchaku_format + + lora_to_be_composed = [] + nunchaku_lora_count = 0 + + for _ in range(max(0, len(model.comfy_lora_meta_list) - len(self.loras))): + model.comfy_lora_meta_list.pop() + model.comfy_lora_sd_list.pop() + + for i in range(len(self.loras)): + meta = self.loras[i] + if i >= len(model.comfy_lora_meta_list): + sd = load_state_dict_in_safetensors(meta[0]) + model.comfy_lora_meta_list.append(meta) + model.comfy_lora_sd_list.append(sd) + elif model.comfy_lora_meta_list[i] != meta: + if meta[0] != model.comfy_lora_meta_list[i][0]: + sd = load_state_dict_in_safetensors(meta[0]) + model.comfy_lora_sd_list[i] = sd + model.comfy_lora_meta_list[i] = meta + + # Check if this LoRA is already in Nunchaku format + sd_to_compose = model.comfy_lora_sd_list[i] + if is_nunchaku_format(sd_to_compose): + nunchaku_lora_count += 1 + # Convert back to Diffusers format for composition + from nunchaku.lora.qwenimage import to_diffusers + + sd_to_compose = to_diffusers(sd_to_compose) + # Update the cache with Diffusers version + model.comfy_lora_sd_list[i] = sd_to_compose + + lora_to_be_composed.append(({k: v for k, v in sd_to_compose.items()}, meta[1])) + + # Now all LoRAs are in Diffusers format, can safely compose + composed_lora = compose_lora(lora_to_be_composed) + + if len(composed_lora) == 0: + # CRITICAL: Manually restore original proj_down/proj_up weights + import torch.nn as nn + + from nunchaku.models.linear import SVDQW4A4Linear + + restored_count = 0 + for name, module in model.named_modules(): + if isinstance(module, SVDQW4A4Linear): + proj_down_key = f"{name}.proj_down" + proj_up_key = f"{name}.proj_up" + if proj_down_key in model._quantized_part_sd: + original_proj_down = model._quantized_part_sd[proj_down_key] + module.proj_down = nn.Parameter( + original_proj_down.clone().to( + device=module.proj_down.device, dtype=module.proj_down.dtype + ), + requires_grad=False, + ) + restored_count += 1 + if proj_up_key in model._quantized_part_sd: + original_proj_up = model._quantized_part_sd[proj_up_key] + module.proj_up = nn.Parameter( + original_proj_up.clone().to(device=module.proj_up.device, dtype=module.proj_up.dtype), + requires_grad=False, + ) + restored_count += 1 + if proj_down_key in model._quantized_part_sd: + original_rank = model._quantized_part_sd[proj_down_key].shape[1] + module.rank = original_rank + if not hasattr(module, "original_rank"): + module.original_rank = original_rank + module.lora_strength = 0.0 + + model.reset_lora() + else: + # Pass number of LoRAs to help detect composed LoRAs + model.update_lora_params(composed_lora, num_loras=len(self.loras)) + + # CRITICAL: For composed LoRAs, strength is already baked by compose_lora + # Setting lora_strength to a uniform value will destroy the individual strength differences + # SOLUTION: Calculate average strength or use 1.0 as neutral value + from nunchaku.models.linear import SVDQW4A4Linear + + # Calculate weighted average strength for composed LoRAs + if len(self.loras) > 1: + avg_strength = sum(s for _, s in self.loras) / len(self.loras) + else: + avg_strength = 1.0 + + for block in model.transformer_blocks: + for module in block.modules(): + if isinstance(module, SVDQW4A4Linear): + module.lora_strength = avg_strength + + # Note: nunchaku's attention processor doesn't accept 'wrappers' and other ComfyUI-specific keys + # We handle this by not passing transformer_options to the underlying model + + if getattr(model, "residual_diff_threshold_multi", 0) != 0 or getattr(model, "_is_cached", False): + # A more robust caching strategy + cache_invalid = False + + # Check if timestamps have changed or are out of valid range + if self._prev_timestep is None: + cache_invalid = True + elif self._prev_timestep < timestep_float + 1e-5: # allow a small tolerance to reuse the cache + cache_invalid = True + + if cache_invalid: + self._cache_context = create_cache_context() + + # Update the previous timestamp + self._prev_timestep = timestep_float + with cache_context(self._cache_context): + if self.customized_forward is None: + out = model( + hidden_states=x, + encoder_hidden_states=context, + encoder_hidden_states_mask=None, # Qwen Image doesn't use mask in ComfyUI + timestep=timestep, + ref_latents=kwargs.get("ref_latents"), + guidance=guidance if self.config.get("guidance_embed", False) else None, + control=control, + transformer_options=transformer_options, + ) + else: + out = self.customized_forward( + model, + hidden_states=x, + encoder_hidden_states=context, + encoder_hidden_states_mask=None, + timestep=timestep, + ref_latents=kwargs.get("ref_latents"), + guidance=guidance if self.config.get("guidance_embed", False) else None, + control=control, + transformer_options=transformer_options, + **self.forward_kwargs, + ) + else: + if self.customized_forward is None: + # Pass original 4D x to model, let model handle process_img + # Model's forward will convert parameters and call _forward + out = model( + hidden_states=x, # Pass original 4D x + encoder_hidden_states=context, + encoder_hidden_states_mask=None, + timestep=timestep, + ref_latents=kwargs.get("ref_latents"), + guidance=guidance if self.config.get("guidance_embed", False) else None, + control=control, + transformer_options=transformer_options, + ) + else: + out = self.customized_forward( + model, + hidden_states=x, # Pass original 4D x + encoder_hidden_states=context, + encoder_hidden_states_mask=None, + timestep=timestep, + ref_latents=kwargs.get("ref_latents"), + guidance=guidance if self.config.get("guidance_embed", False) else None, + control=control, + transformer_options=transformer_options, + **self.forward_kwargs, + ) + + # Model returns a tuple (output,), unpack it + if isinstance(out, tuple): + out = out[0] + + # Model already returns unpatchified output (4D) + # out = out[:, :img_tokens] + # out = rearrange( + # out, + # "b (h w) (c ph pw) -> b c (h ph) (w pw)", + # h=h_len, + # w=w_len, + # ph=patch_size, + # pw=patch_size, + # ) + + # If input was 5D, unsqueeze output back to 5D + if input_is_5d: + out = out.unsqueeze(2) # (batch, channels, height, width) -> (batch, channels, 1, height, width) + return out