|
| 1 | +import math |
1 | 2 | from contextlib import ExitStack |
2 | | -from typing import Callable, Iterator, Optional, Tuple |
| 3 | +from typing import Callable, ClassVar, Iterator, Optional, Tuple |
3 | 4 |
|
4 | 5 | import torch |
5 | 6 | import torchvision.transforms as tv_transforms |
@@ -176,6 +177,72 @@ def _unpack_latents(latents: torch.Tensor, height: int, width: int) -> torch.Ten |
176 | 177 | latents = latents.reshape(batch_size, channels // 4, h, w) |
177 | 178 | return latents |
178 | 179 |
|
| 180 | + @staticmethod |
| 181 | + def _align_ref_latent_dims(rh: int, rw: int) -> tuple[int, int]: |
| 182 | + """Trim reference latent spatial dims to even values for 2x2 packing. |
| 183 | +
|
| 184 | + Raises ValueError if the aligned dims would be < 2 (i.e., the reference |
| 185 | + latent is too small to produce any valid tokens). |
| 186 | + """ |
| 187 | + rh_aligned = rh - (rh % 2) |
| 188 | + rw_aligned = rw - (rw % 2) |
| 189 | + if rh_aligned < 2 or rw_aligned < 2: |
| 190 | + raise ValueError( |
| 191 | + f"Reference latent spatial dims must be >= 2 after even alignment; " |
| 192 | + f"got ({rh_aligned}, {rw_aligned}) from input shape ({rh}, {rw}). " |
| 193 | + "Ensure the reference image is at least 16 pixels in each dimension." |
| 194 | + ) |
| 195 | + return rh_aligned, rw_aligned |
| 196 | + |
| 197 | + @staticmethod |
| 198 | + def _build_img_shapes( |
| 199 | + latent_height: int, |
| 200 | + latent_width: int, |
| 201 | + ref_latent_height: int | None = None, |
| 202 | + ref_latent_width: int | None = None, |
| 203 | + ) -> list[list[tuple[int, int, int]]]: |
| 204 | + """Build the img_shapes argument for the transformer. |
| 205 | +
|
| 206 | + The reference segment (if present) must use its own dims so QwenEmbedRope's |
| 207 | + spatial frequencies position ref tokens distinctly from noisy tokens — |
| 208 | + otherwise reference content bleeds into the generation as a ghost. |
| 209 | + """ |
| 210 | + shapes: list[tuple[int, int, int]] = [(1, latent_height // 2, latent_width // 2)] |
| 211 | + if ref_latent_height is not None and ref_latent_width is not None: |
| 212 | + shapes.append((1, ref_latent_height // 2, ref_latent_width // 2)) |
| 213 | + return [shapes] |
| 214 | + |
| 215 | + # diffusers' QwenImageEdit(Plus)Pipeline VAE_IMAGE_SIZE = 1024 * 1024 pixels; |
| 216 | + # ref images are resized to this area (preserving aspect, snapped to multiples |
| 217 | + # of 32) before VAE encoding. We mirror this clamp in latent space so direct |
| 218 | + # backend callers — whose i2l may not pass explicit width/height — don't feed |
| 219 | + # the transformer an out-of-distribution reference sequence length (which |
| 220 | + # also causes a VRAM spike for large inputs). |
| 221 | + _REF_TARGET_PIXEL_AREA: ClassVar[int] = 1024 * 1024 |
| 222 | + _VAE_SCALE_FACTOR: ClassVar[int] = 8 |
| 223 | + |
| 224 | + @classmethod |
| 225 | + def _maybe_clamp_ref_latent_size(cls, ref_latents: torch.Tensor) -> torch.Tensor: |
| 226 | + """Bilinear-downscale the reference latent if it exceeds diffusers' |
| 227 | + VAE_IMAGE_SIZE budget. |
| 228 | +
|
| 229 | + Returns the latent unchanged if it's already within budget. |
| 230 | + """ |
| 231 | + _, _, rh, rw = ref_latents.shape |
| 232 | + target_cells = cls._REF_TARGET_PIXEL_AREA // (cls._VAE_SCALE_FACTOR**2) |
| 233 | + if rh * rw <= target_cells: |
| 234 | + return ref_latents |
| 235 | + aspect = rw / rh |
| 236 | + target_w_px = math.sqrt(cls._REF_TARGET_PIXEL_AREA * aspect) |
| 237 | + target_h_px = target_w_px / aspect |
| 238 | + target_w_px = max(32, round(target_w_px / 32) * 32) |
| 239 | + target_h_px = max(32, round(target_h_px / 32) * 32) |
| 240 | + target_rh = target_h_px // cls._VAE_SCALE_FACTOR |
| 241 | + target_rw = target_w_px // cls._VAE_SCALE_FACTOR |
| 242 | + return torch.nn.functional.interpolate( |
| 243 | + ref_latents, size=(target_rh, target_rw), mode="bilinear", antialias=False |
| 244 | + ) |
| 245 | + |
179 | 246 | def _run_diffusion(self, context: InvocationContext): |
180 | 247 | inference_dtype = torch.bfloat16 |
181 | 248 | device = TorchDevice.choose_torch_device() |
@@ -332,35 +399,37 @@ def _run_diffusion(self, context: InvocationContext): |
332 | 399 | use_ref_latents = has_zero_cond_t |
333 | 400 |
|
334 | 401 | ref_latents_packed = None |
| 402 | + ref_latent_height = latent_height |
| 403 | + ref_latent_width = latent_width |
335 | 404 | if use_ref_latents: |
336 | 405 | if ref_latents is not None: |
337 | | - _, ref_ch, rh, rw = ref_latents.shape |
338 | | - if rh != latent_height or rw != latent_width: |
339 | | - ref_latents = torch.nn.functional.interpolate( |
340 | | - ref_latents, size=(latent_height, latent_width), mode="bilinear" |
341 | | - ) |
| 406 | + # Defense-in-depth: backend callers (direct API, older graph JSON) |
| 407 | + # may wire qwen_image_i2l without explicit width/height, producing |
| 408 | + # a native-resolution reference latent. Clamp here so the |
| 409 | + # transformer always sees an in-distribution sequence length. |
| 410 | + ref_latents = self._maybe_clamp_ref_latent_size(ref_latents) |
| 411 | + _, _, rh, rw = ref_latents.shape |
| 412 | + ref_latent_height, ref_latent_width = self._align_ref_latent_dims(rh, rw) |
| 413 | + if ref_latent_height != rh or ref_latent_width != rw: |
| 414 | + ref_latents = ref_latents[..., :ref_latent_height, :ref_latent_width] |
342 | 415 | else: |
343 | 416 | # No reference image provided — use zeros so the model still gets the |
344 | 417 | # expected sequence layout. |
345 | 418 | ref_latents = torch.zeros( |
346 | 419 | 1, out_channels, latent_height, latent_width, device=device, dtype=inference_dtype |
347 | 420 | ) |
348 | | - ref_latents_packed = self._pack_latents(ref_latents, 1, out_channels, latent_height, latent_width) |
349 | | - |
350 | | - # img_shapes tells the transformer the spatial layout of patches. |
| 421 | + ref_latents_packed = self._pack_latents(ref_latents, 1, out_channels, ref_latent_height, ref_latent_width) |
| 422 | + |
| 423 | + # img_shapes tells the transformer the spatial layout of patches. The reference |
| 424 | + # segment must use the reference latent's own dimensions so RoPE positions it |
| 425 | + # distinctly from the noisy latent — otherwise the two segments share spatial |
| 426 | + # positional encoding and the model can't disentangle them, producing a |
| 427 | + # ghost/doubling artifact across the whole frame. Matches diffusers' |
| 428 | + # QwenImageEditPipeline / QwenImageEditPlusPipeline. |
351 | 429 | if use_ref_latents: |
352 | | - img_shapes = [ |
353 | | - [ |
354 | | - (1, latent_height // 2, latent_width // 2), |
355 | | - (1, latent_height // 2, latent_width // 2), |
356 | | - ] |
357 | | - ] |
| 430 | + img_shapes = self._build_img_shapes(latent_height, latent_width, ref_latent_height, ref_latent_width) |
358 | 431 | else: |
359 | | - img_shapes = [ |
360 | | - [ |
361 | | - (1, latent_height // 2, latent_width // 2), |
362 | | - ] |
363 | | - ] |
| 432 | + img_shapes = self._build_img_shapes(latent_height, latent_width) |
364 | 433 |
|
365 | 434 | # Prepare inpaint extension (operates in 4D space, so unpack/repack around it) |
366 | 435 | inpaint_mask = self._prep_inpaint_mask(context, noise) # noise has the right 4D shape |
|
0 commit comments