| 
 | 1 | +from __future__ import annotations  | 
 | 2 | + | 
 | 3 | +from typing import TYPE_CHECKING, Optional  | 
 | 4 | + | 
 | 5 | +import einops  | 
 | 6 | +import torch  | 
 | 7 | +from diffusers import UNet2DConditionModel  | 
 | 8 | + | 
 | 9 | +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType  | 
 | 10 | +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback  | 
 | 11 | + | 
 | 12 | +if TYPE_CHECKING:  | 
 | 13 | +    from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext  | 
 | 14 | + | 
 | 15 | + | 
 | 16 | +class InpaintExt(ExtensionBase):  | 
 | 17 | +    """An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting  | 
 | 18 | +    models.  | 
 | 19 | +    """  | 
 | 20 | + | 
 | 21 | +    def __init__(  | 
 | 22 | +        self,  | 
 | 23 | +        mask: torch.Tensor,  | 
 | 24 | +        is_gradient_mask: bool,  | 
 | 25 | +    ):  | 
 | 26 | +        """Initialize InpaintExt.  | 
 | 27 | +        Args:  | 
 | 28 | +            mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are  | 
 | 29 | +                expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be  | 
 | 30 | +                inpainted.  | 
 | 31 | +            is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range  | 
 | 32 | +                from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or  | 
 | 33 | +                1.  | 
 | 34 | +        """  | 
 | 35 | +        super().__init__()  | 
 | 36 | +        self._mask = mask  | 
 | 37 | +        self._is_gradient_mask = is_gradient_mask  | 
 | 38 | + | 
 | 39 | +        # Noise, which used to noisify unmasked part of image  | 
 | 40 | +        # if noise provided to context, then it will be used  | 
 | 41 | +        # if no noise provided, then noise will be generated based on seed  | 
 | 42 | +        self._noise: Optional[torch.Tensor] = None  | 
 | 43 | + | 
 | 44 | +    @staticmethod  | 
 | 45 | +    def _is_normal_model(unet: UNet2DConditionModel):  | 
 | 46 | +        """Checks if the provided UNet belongs to a regular model.  | 
 | 47 | +        The `in_channels` of a UNet vary depending on model type:  | 
 | 48 | +        - normal - 4  | 
 | 49 | +        - depth - 5  | 
 | 50 | +        - inpaint - 9  | 
 | 51 | +        """  | 
 | 52 | +        return unet.conv_in.in_channels == 4  | 
 | 53 | + | 
 | 54 | +    def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:  | 
 | 55 | +        batch_size = latents.size(0)  | 
 | 56 | +        mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)  | 
 | 57 | +        if t.dim() == 0:  | 
 | 58 | +            # some schedulers expect t to be one-dimensional.  | 
 | 59 | +            # TODO: file diffusers bug about inconsistency?  | 
 | 60 | +            t = einops.repeat(t, "-> batch", batch=batch_size)  | 
 | 61 | +        # Noise shouldn't be re-randomized between steps here. The multistep schedulers  | 
 | 62 | +        # get very confused about what is happening from step to step when we do that.  | 
 | 63 | +        mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)  | 
 | 64 | +        # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?  | 
 | 65 | +        # mask_latents = self.scheduler.scale_model_input(mask_latents, t)  | 
 | 66 | +        mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)  | 
 | 67 | +        if self._is_gradient_mask:  | 
 | 68 | +            threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps  | 
 | 69 | +            mask_bool = mask < 1 - threshold  | 
 | 70 | +            masked_input = torch.where(mask_bool, latents, mask_latents)  | 
 | 71 | +        else:  | 
 | 72 | +            masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))  | 
 | 73 | +        return masked_input  | 
 | 74 | + | 
 | 75 | +    @callback(ExtensionCallbackType.PRE_DENOISE_LOOP)  | 
 | 76 | +    def init_tensors(self, ctx: DenoiseContext):  | 
 | 77 | +        if not self._is_normal_model(ctx.unet):  | 
 | 78 | +            raise ValueError(  | 
 | 79 | +                "InpaintExt should be used only on normal (non-inpainting) models. This could be caused by an "  | 
 | 80 | +                "inpainting model that was incorrectly marked as a non-inpainting model. In some cases, this can be "  | 
 | 81 | +                "fixed by removing and re-adding the model (so that it gets re-probed)."  | 
 | 82 | +            )  | 
 | 83 | + | 
 | 84 | +        self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)  | 
 | 85 | + | 
 | 86 | +        self._noise = ctx.inputs.noise  | 
 | 87 | +        # 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).  | 
 | 88 | +        # We still need noise for inpainting, so we generate it from the seed here.  | 
 | 89 | +        if self._noise is None:  | 
 | 90 | +            self._noise = torch.randn(  | 
 | 91 | +                ctx.latents.shape,  | 
 | 92 | +                dtype=torch.float32,  | 
 | 93 | +                device="cpu",  | 
 | 94 | +                generator=torch.Generator(device="cpu").manual_seed(ctx.seed),  | 
 | 95 | +            ).to(device=ctx.latents.device, dtype=ctx.latents.dtype)  | 
 | 96 | + | 
 | 97 | +    # Use negative order to make extensions with default order work with patched latents  | 
 | 98 | +    @callback(ExtensionCallbackType.PRE_STEP, order=-100)  | 
 | 99 | +    def apply_mask_to_initial_latents(self, ctx: DenoiseContext):  | 
 | 100 | +        ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)  | 
 | 101 | + | 
 | 102 | +    # TODO: redo this with preview events rewrite  | 
 | 103 | +    # Use negative order to make extensions with default order work with patched latents  | 
 | 104 | +    @callback(ExtensionCallbackType.POST_STEP, order=-100)  | 
 | 105 | +    def apply_mask_to_step_output(self, ctx: DenoiseContext):  | 
 | 106 | +        timestep = ctx.scheduler.timesteps[-1]  | 
 | 107 | +        if hasattr(ctx.step_output, "denoised"):  | 
 | 108 | +            ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)  | 
 | 109 | +        elif hasattr(ctx.step_output, "pred_original_sample"):  | 
 | 110 | +            ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)  | 
 | 111 | +        else:  | 
 | 112 | +            ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)  | 
 | 113 | + | 
 | 114 | +    # Restore unmasked part after the last step is completed  | 
 | 115 | +    @callback(ExtensionCallbackType.POST_DENOISE_LOOP)  | 
 | 116 | +    def restore_unmasked(self, ctx: DenoiseContext):  | 
 | 117 | +        if self._is_gradient_mask:  | 
 | 118 | +            ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)  | 
 | 119 | +        else:  | 
 | 120 | +            ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)  | 
0 commit comments