-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgradient_mask_extensions.py
executable file
·326 lines (274 loc) · 15.4 KB
/
gradient_mask_extensions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
from typing import Literal, Optional, Union, List
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageFilter
from torchvision.transforms.functional import resize as tv_resize
import einops
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
ImageField,
Input,
InputField,
OutputField,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from torchvision.transforms.functional import resize as tv_resize
from invokeai.backend.util.devices import TorchDevice
from torch import Tensor
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
from .extension_classes import GuidanceField, base_guidance_extension, GuidanceDataOutput
@invocation_output("gradient_mask_extension_output")
class GradientMaskExtensionOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
mask_extension: GuidanceField = OutputField(
description="Guidance Extension for masked denoise",
)
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@base_guidance_extension("InpaintMaskGuidance")
class InpaintMaskGuidance(InpaintExt):
def __init__(
self,
context: InvocationContext,
mask_name: str,
is_gradient_mask: bool,
):
"""Initialize InpaintExt.
This override is purely to adapt the Invoke internal extension to accept the mask_name as a string.
"""
super(InpaintExt,self).__init__() # skip the super call to the InvokeAI version
self._mask = context.tensors.load(mask_name)
self._is_gradient_mask = is_gradient_mask
self._noise: Optional[torch.Tensor] = None
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def init_tensors(self, ctx: DenoiseContext):
self._mask = tv_resize(self._mask, ctx.latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
super().init_tensors(ctx)
#debug: print out the min and max of each layer in the mask
for i in range(4):
print(f"Layer {i}: Min: {self._mask[0, i, :, :].min()}, Max: {self._mask[0, i, :, :].max()}")
print(f"Mask shape: {self._mask.shape}")
@invocation(
"gradient_mask_extension",
title="Gradient Mask [Extension]",
tags=["mask", "denoise", "extension"],
category="extension",
version="1.4.0",
)
class GradientMaskExtensionInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
image: Optional[ImageField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] Image",
ui_order=6,
)
unet: Optional[UNetField] = InputField(
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
default=None,
input=Input.Connection,
title="[OPTIONAL] UNet",
ui_order=5,
)
vae: Optional[VAEField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] VAE",
input=Input.Connection,
ui_order=7,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField(
default=False,
description=FieldDescriptions.fp32,
ui_order=9,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskExtensionOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
blur_tensor[blur_tensor < 0] = 0.0
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
else:
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
resized_expanded_mask = tv_resize(expanded_mask, (expanded_mask.shape[-2] // 8, expanded_mask.shape[-1] // 8), T.InterpolationMode.BILINEAR, antialias=False)
expanded_mask = torch.where((resized_expanded_mask < 1), 0, 1)
upscaled_expanded_mask = tv_resize(expanded_mask, (expanded_mask.shape[-2] * 8, expanded_mask.shape[-1] * 8), T.InterpolationMode.NEAREST, antialias=False)
expanded_mask_image = Image.fromarray((upscaled_expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
masked_latents_name = None
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
context.util.signal_progress("Running VAE encoder")
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)
masked_latents_name = context.tensors.save(tensor=masked_latents)
return GradientMaskExtensionOutput(
mask_extension=GuidanceField(guidance_name="InpaintMaskGuidance", extension_kwargs={"mask_name": mask_name, "is_gradient_mask": True}),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)
# @invocation(
# "create_gradient_mask_v2",
# title="Gradient Mask V2",
# tags=["mask", "denoise"],
# category="latents",
# version="2.0.0",
# )
# class CreateGradientMaskV2Invocation(BaseInvocation):
# """Creates mask for denoising model run."""
# mask: Union[ImageField, List[ImageField]] = InputField(default=None, description="Image which will be masked", ui_order=1)
# max_mask_expansion: int = InputField(
# default=24, ge=0, multiple_of=8, description="How far to expand the edges of the mask", ui_order=2
# )
# minimum_denoise: float = InputField(
# default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
# )
# latent_scale: bool = InputField(default=True, description="Scale the mask to the latent size before processing", ui_order=5, ui_hidden=True)
# process_on_device: bool = InputField(default=False, description="Process the mask on the same device as inference (GPU, typically)", ui_order=6)
# @torch.no_grad()
# def invoke(self, context: InvocationContext) -> GradientMaskOutput:
# if not isinstance(self.mask, list):
# mask = [self.mask]
# else:
# mask = self.mask
# mask_images = [context.images.get_pil(m.image_name, mode="L") for m in mask]
# #convert to tensors and combine, keeping the lowest value of each
# tensor_images = [image_resized_to_grid_as_tensor(m, normalize=False) for m in mask_images]
# tensor_images = [tv_resize(m, tensor_images[0].shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) for m in tensor_images]
# mask_tensor = torch.stack(tensor_images, dim=0).min(dim=0)
# mask_tensor = mask_tensor.values / 255.0
# #downscale by a factor fo 8 to match the latent size
# if self.latent_scale:
# mask_tensor = tv_resize(mask_tensor.values, [s // LATENT_SCALE_FACTOR for s in mask_tensor.shape[-2:]], T.InterpolationMode.BILINEAR, antialias=False)
# expansion_count = self.max_mask_expansion // LATENT_SCALE_FACTOR
# else:
# expansion_count = self.max_mask_expansion
# #expansion steps are linearly spaced between 0 and 1
# expansion_steps = torch.linspace(0, 1, expansion_count + 1).flip(0)
# #investigating for speed improvement
# if self.process_on_device:
# mask_tensor = mask_tensor.to(TorchDevice.choose_torch_device())
# expansion_steps = expansion_steps.to(TorchDevice.choose_torch_device())
# device = mask_tensor.device
# #expand the mask
# # We are using a convolution interaction to expand darker areas of the mask to the lighter areas
# # The input mask(s) may not be binary, and could already be gradients.
# # We start by inverting the mask so white is full denoise and black is fully preserved.
# # We split the convolution into multiple steps on binned values so that we can apply a different expansion factor to each step
# # This allows us to have a more gradual expansion of the mask
# mask_tensor = 1 - mask_tensor
# combine_mask_bool_tensor = (mask_tensor >= expansion_steps[0]) # catches 100% regions
# for i in range(expansion_count):
# #create a boolean mask for the current expansion_step values bin
# mask_bin_bool_tensor = (mask_tensor >= expansion_steps[i+1]) & (mask_tensor < expansion_steps[i])
# mask_bin_tensor = torch.where(mask_bin_bool_tensor, mask_tensor, torch.zeros_like(mask_tensor))
# mask_bin_tensor = mask_bin_tensor.unsqueeze(0).unsqueeze(0)
# #apply the convolution, dilate by 1
# expanded_mask_bin_tensor = torch.nn.functional.conv2d(mask_bin_tensor, torch.ones(1, 1, 3, 3).to(device), padding=1)
# #set newly expanded regions to be the next expansion step
# mask_tensor = torch.where(mask_bin_bool_tensor, expanded_mask_bin_tensor, mask_tensor)
# masked_latents_name = context.tensors.save(tensor=masked_latents)
# return GradientMaskOutput(
# denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
# expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
# )
@base_guidance_extension("InpaintChannelMaskGuidance")
class InpaintChannelMaskGuidance(InpaintMaskGuidance):
def __init__(
self,
context: InvocationContext,
mask_name: str,
is_gradient_mask: bool,
channel_mask: list[float],
):
"""Initialize InpaintExt.
This override is purely to adapt the Invoke internal extension to accept the mask_name as a string.
"""
super(InpaintExt,self).__init__() # skip the super call to the InvokeAI version
mask = context.tensors.load(mask_name)
for i in range(4):
mask[:, i, :, :] = mask[:, i, :, :] * channel_mask[i]
self._mask = mask
self._is_gradient_mask = is_gradient_mask
self._noise: Optional[torch.Tensor] = None
@invocation(
"channel_mask_extension",
title="Channel Mask [Extension]",
tags=["mask", "denoise", "extension"],
category="extension",
version="1.0.0",
)
class GradientChannelMaskExtensionInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
width: int = InputField(default=1024, ge=64, multiple_of=8, description="Width of the channel mask", ui_order=10)
height: int = InputField(default=1024, ge=64, multiple_of=8, description="Height of the channel mask", ui_order=11)
channel_0: bool = InputField(default=True, description="Mask the first channel", ui_order=12)
channel_1: bool = InputField(default=True, description="Mask the second channel", ui_order=13)
channel_2: bool = InputField(default=False, description="Mask the third channel", ui_order=14)
channel_3: bool = InputField(default=False, description="Mask the fourth channel", ui_order=15)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GuidanceDataOutput:
mask_name = context.tensors.save(tensor=torch.ones(1, 4, self.height//8, self.width//8))
channel_mask = [float(self.channel_0), float(self.channel_1), float(self.channel_2), float(self.channel_3)]
print(channel_mask)
kwargs = {
"mask_name": mask_name,
"is_gradient_mask": False,
"channel_mask": channel_mask
}
return GuidanceDataOutput(
guidance_data_output=GuidanceField(
guidance_name="InpaintChannelMaskGuidance",
extension_kwargs=kwargs
)
)