Skip to content

Commit 4eb9ad0

Browse files
authoredDec 7, 2022
[Community Pipeline] fix lpw_stable_diffusion (huggingface#1570)
* fix lpw_stable_diffusion * rollback preprocess_mask resample
1 parent 896c98a commit 4eb9ad0

File tree

2 files changed

+563
-537
lines changed

2 files changed

+563
-537
lines changed
 

‎examples/community/lpw_stable_diffusion.py

+275-313
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,13 @@
66
import torch
77

88
import PIL
9-
from diffusers.configuration_utils import FrozenDict
9+
from diffusers import SchedulerMixin, StableDiffusionPipeline
1010
from diffusers.models import AutoencoderKL, UNet2DConditionModel
11-
from diffusers.pipeline_utils import DiffusionPipeline
12-
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13-
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14-
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15-
from diffusers.utils import deprecate, is_accelerate_available, logging
16-
17-
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
18-
from packaging import version
11+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
12+
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
1913
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
2014

2115

22-
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
23-
PIL_INTERPOLATION = {
24-
"linear": PIL.Image.Resampling.BILINEAR,
25-
"bilinear": PIL.Image.Resampling.BILINEAR,
26-
"bicubic": PIL.Image.Resampling.BICUBIC,
27-
"lanczos": PIL.Image.Resampling.LANCZOS,
28-
"nearest": PIL.Image.Resampling.NEAREST,
29-
}
30-
else:
31-
PIL_INTERPOLATION = {
32-
"linear": PIL.Image.LINEAR,
33-
"bilinear": PIL.Image.BILINEAR,
34-
"bicubic": PIL.Image.BICUBIC,
35-
"lanczos": PIL.Image.LANCZOS,
36-
"nearest": PIL.Image.NEAREST,
37-
}
38-
# ------------------------------------------------------------------------------
39-
40-
4116
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4217

4318
re_attention = re.compile(
@@ -146,7 +121,7 @@ def multiply_range(start_position, multiplier):
146121
return res
147122

148123

149-
def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
124+
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
150125
r"""
151126
Tokenize a list of prompts and return its tokens with weights of each token.
152127
@@ -207,7 +182,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
207182

208183

209184
def get_unweighted_text_embeddings(
210-
pipe: DiffusionPipeline,
185+
pipe: StableDiffusionPipeline,
211186
text_input: torch.Tensor,
212187
chunk_length: int,
213188
no_boseos_middle: Optional[bool] = True,
@@ -247,10 +222,10 @@ def get_unweighted_text_embeddings(
247222

248223

249224
def get_weighted_text_embeddings(
250-
pipe: DiffusionPipeline,
225+
pipe: StableDiffusionPipeline,
251226
prompt: Union[str, List[str]],
252227
uncond_prompt: Optional[Union[str, List[str]]] = None,
253-
max_embeddings_multiples: Optional[int] = 1,
228+
max_embeddings_multiples: Optional[int] = 3,
254229
no_boseos_middle: Optional[bool] = False,
255230
skip_parsing: Optional[bool] = False,
256231
skip_weighting: Optional[bool] = False,
@@ -264,14 +239,14 @@ def get_weighted_text_embeddings(
264239
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
265240
266241
Args:
267-
pipe (`DiffusionPipeline`):
242+
pipe (`StableDiffusionPipeline`):
268243
Pipe to provide access to the tokenizer and the text encoder.
269244
prompt (`str` or `List[str]`):
270245
The prompt or prompts to guide the image generation.
271246
uncond_prompt (`str` or `List[str]`):
272247
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
273248
is provided, the embeddings of prompt and uncond_prompt are concatenated.
274-
max_embeddings_multiples (`int`, *optional*, defaults to `1`):
249+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
275250
The max multiple length of prompt embeddings compared to the max output length of text encoder.
276251
no_boseos_middle (`bool`, *optional*, defaults to `False`):
277252
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
@@ -387,11 +362,11 @@ def preprocess_image(image):
387362
return 2.0 * image - 1.0
388363

389364

390-
def preprocess_mask(mask):
365+
def preprocess_mask(mask, scale_factor=8):
391366
mask = mask.convert("L")
392367
w, h = mask.size
393368
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
394-
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
369+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
395370
mask = np.array(mask).astype(np.float32) / 255.0
396371
mask = np.tile(mask, (4, 1, 1))
397372
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -400,7 +375,7 @@ def preprocess_mask(mask):
400375
return mask
401376

402377

403-
class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
378+
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
404379
r"""
405380
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
406381
weighting in prompt.
@@ -435,102 +410,184 @@ def __init__(
435410
text_encoder: CLIPTextModel,
436411
tokenizer: CLIPTokenizer,
437412
unet: UNet2DConditionModel,
438-
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
413+
scheduler: SchedulerMixin,
439414
safety_checker: StableDiffusionSafetyChecker,
440415
feature_extractor: CLIPFeatureExtractor,
416+
requires_safety_checker: bool = True,
441417
):
442-
super().__init__()
443-
444-
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
445-
deprecation_message = (
446-
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
447-
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
448-
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
449-
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
450-
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
451-
" file"
452-
)
453-
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
454-
new_config = dict(scheduler.config)
455-
new_config["steps_offset"] = 1
456-
scheduler._internal_dict = FrozenDict(new_config)
457-
458-
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
459-
deprecation_message = (
460-
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
461-
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
462-
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
463-
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
464-
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
465-
)
466-
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
467-
new_config = dict(scheduler.config)
468-
new_config["clip_sample"] = False
469-
scheduler._internal_dict = FrozenDict(new_config)
470-
471-
if safety_checker is None:
472-
logger.warning(
473-
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
474-
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
475-
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
476-
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
477-
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
478-
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
479-
)
480-
481-
self.register_modules(
418+
super().__init__(
482419
vae=vae,
483420
text_encoder=text_encoder,
484421
tokenizer=tokenizer,
485422
unet=unet,
486423
scheduler=scheduler,
487424
safety_checker=safety_checker,
488425
feature_extractor=feature_extractor,
426+
requires_safety_checker=requires_safety_checker,
489427
)
490428

491-
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
429+
def _encode_prompt(
430+
self,
431+
prompt,
432+
device,
433+
num_images_per_prompt,
434+
do_classifier_free_guidance,
435+
negative_prompt,
436+
max_embeddings_multiples,
437+
):
492438
r"""
493-
Enable sliced attention computation.
494-
495-
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
496-
in several steps. This is useful to save some memory in exchange for a small speed decrease.
439+
Encodes the prompt into text encoder hidden states.
497440
498441
Args:
499-
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
500-
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
501-
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
502-
`attention_head_dim` must be a multiple of `slice_size`.
442+
prompt (`str` or `list(int)`):
443+
prompt to be encoded
444+
device: (`torch.device`):
445+
torch device
446+
num_images_per_prompt (`int`):
447+
number of images that should be generated per prompt
448+
do_classifier_free_guidance (`bool`):
449+
whether to use classifier free guidance or not
450+
negative_prompt (`str` or `List[str]`):
451+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
452+
if `guidance_scale` is less than `1`).
453+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
454+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
503455
"""
504-
if slice_size == "auto":
505-
# half the attention head size is usually a good trade-off between
506-
# speed and memory
507-
slice_size = self.unet.config.attention_head_dim // 2
508-
self.unet.set_attention_slice(slice_size)
456+
batch_size = len(prompt) if isinstance(prompt, list) else 1
509457

510-
def disable_attention_slicing(self):
511-
r"""
512-
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
513-
back to computing attention in one step.
514-
"""
515-
# set slice_size = `None` to disable `attention slicing`
516-
self.enable_attention_slicing(None)
458+
if negative_prompt is None:
459+
negative_prompt = [""] * batch_size
460+
elif isinstance(negative_prompt, str):
461+
negative_prompt = [negative_prompt] * batch_size
462+
if batch_size != len(negative_prompt):
463+
raise ValueError(
464+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
465+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
466+
" the batch size of `prompt`."
467+
)
517468

518-
def enable_sequential_cpu_offload(self):
519-
r"""
520-
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
521-
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
522-
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
523-
"""
524-
if is_accelerate_available():
525-
from accelerate import cpu_offload
469+
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
470+
pipe=self,
471+
prompt=prompt,
472+
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
473+
max_embeddings_multiples=max_embeddings_multiples,
474+
)
475+
bs_embed, seq_len, _ = text_embeddings.shape
476+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
477+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
478+
479+
if do_classifier_free_guidance:
480+
bs_embed, seq_len, _ = uncond_embeddings.shape
481+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
482+
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
483+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
484+
485+
return text_embeddings
486+
487+
def check_inputs(self, prompt, height, width, strength, callback_steps):
488+
if not isinstance(prompt, str) and not isinstance(prompt, list):
489+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
490+
491+
if strength < 0 or strength > 1:
492+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
493+
494+
if height % 8 != 0 or width % 8 != 0:
495+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
496+
497+
if (callback_steps is None) or (
498+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
499+
):
500+
raise ValueError(
501+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
502+
f" {type(callback_steps)}."
503+
)
504+
505+
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
506+
if is_text2img:
507+
return self.scheduler.timesteps.to(device), num_inference_steps
508+
else:
509+
# get the original timestep using init_timestep
510+
offset = self.scheduler.config.get("steps_offset", 0)
511+
init_timestep = int(num_inference_steps * strength) + offset
512+
init_timestep = min(init_timestep, num_inference_steps)
513+
514+
t_start = max(num_inference_steps - init_timestep + offset, 0)
515+
timesteps = self.scheduler.timesteps[t_start:].to(device)
516+
return timesteps, num_inference_steps - t_start
517+
518+
def run_safety_checker(self, image, device, dtype):
519+
if self.safety_checker is not None:
520+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
521+
image, has_nsfw_concept = self.safety_checker(
522+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
523+
)
526524
else:
527-
raise ImportError("Please install accelerate via `pip install accelerate`")
525+
has_nsfw_concept = None
526+
return image, has_nsfw_concept
527+
528+
def decode_latents(self, latents):
529+
latents = 1 / 0.18215 * latents
530+
image = self.vae.decode(latents).sample
531+
image = (image / 2 + 0.5).clamp(0, 1)
532+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
533+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
534+
return image
535+
536+
def prepare_extra_step_kwargs(self, generator, eta):
537+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
538+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
539+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
540+
# and should be between [0, 1]
528541

529-
device = self.device
542+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
543+
extra_step_kwargs = {}
544+
if accepts_eta:
545+
extra_step_kwargs["eta"] = eta
546+
547+
# check if the scheduler accepts generator
548+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
549+
if accepts_generator:
550+
extra_step_kwargs["generator"] = generator
551+
return extra_step_kwargs
552+
553+
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
554+
if image is None:
555+
shape = (
556+
batch_size,
557+
self.unet.in_channels,
558+
height // self.vae_scale_factor,
559+
width // self.vae_scale_factor,
560+
)
561+
562+
if latents is None:
563+
if device.type == "mps":
564+
# randn does not work reproducibly on mps
565+
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
566+
else:
567+
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
568+
else:
569+
if latents.shape != shape:
570+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
571+
latents = latents.to(device)
572+
573+
# scale the initial noise by the standard deviation required by the scheduler
574+
latents = latents * self.scheduler.init_noise_sigma
575+
return latents, None, None
576+
else:
577+
init_latent_dist = self.vae.encode(image).latent_dist
578+
init_latents = init_latent_dist.sample(generator=generator)
579+
init_latents = 0.18215 * init_latents
580+
init_latents = torch.cat([init_latents] * batch_size, dim=0)
581+
init_latents_orig = init_latents
582+
shape = init_latents.shape
530583

531-
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
532-
if cpu_offloaded_model is not None:
533-
cpu_offload(cpu_offloaded_model, device)
584+
# add noise to latents using the timesteps
585+
if device.type == "mps":
586+
noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
587+
else:
588+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
589+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
590+
return latents, init_latents_orig, noise
534591

535592
@torch.no_grad()
536593
def __call__(
@@ -634,221 +691,111 @@ def __call__(
634691
init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
635692
image = init_image or image
636693

637-
if isinstance(prompt, str):
638-
batch_size = 1
639-
prompt = [prompt]
640-
elif isinstance(prompt, list):
641-
batch_size = len(prompt)
642-
else:
643-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
694+
# 0. Default height and width to unet
695+
height = height or self.unet.config.sample_size * self.vae_scale_factor
696+
width = width or self.unet.config.sample_size * self.vae_scale_factor
644697

645-
if strength < 0 or strength > 1:
646-
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
647-
648-
if height % 8 != 0 or width % 8 != 0:
649-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
650-
651-
if (callback_steps is None) or (
652-
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
653-
):
654-
raise ValueError(
655-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
656-
f" {type(callback_steps)}."
657-
)
658-
659-
# get prompt text embeddings
698+
# 1. Check inputs. Raise error if not correct
699+
self.check_inputs(prompt, height, width, strength, callback_steps)
660700

701+
# 2. Define call parameters
702+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
703+
device = self._execution_device
661704
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
662705
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
663706
# corresponds to doing no classifier free guidance.
664707
do_classifier_free_guidance = guidance_scale > 1.0
665-
# get unconditional embeddings for classifier free guidance
666-
if negative_prompt is None:
667-
negative_prompt = [""] * batch_size
668-
elif isinstance(negative_prompt, str):
669-
negative_prompt = [negative_prompt] * batch_size
670-
if batch_size != len(negative_prompt):
671-
raise ValueError(
672-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
673-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
674-
" the batch size of `prompt`."
675-
)
676708

677-
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
678-
pipe=self,
679-
prompt=prompt,
680-
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
681-
max_embeddings_multiples=max_embeddings_multiples,
682-
**kwargs,
709+
# 3. Encode input prompt
710+
text_embeddings = self._encode_prompt(
711+
prompt,
712+
device,
713+
num_images_per_prompt,
714+
do_classifier_free_guidance,
715+
negative_prompt,
716+
max_embeddings_multiples,
683717
)
684-
bs_embed, seq_len, _ = text_embeddings.shape
685-
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
686-
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
687-
688-
if do_classifier_free_guidance:
689-
bs_embed, seq_len, _ = uncond_embeddings.shape
690-
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
691-
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
692-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
693-
694-
# set timesteps
695-
self.scheduler.set_timesteps(num_inference_steps)
696-
697-
latents_dtype = text_embeddings.dtype
698-
init_latents_orig = None
699-
mask = None
700-
noise = None
701-
702-
if image is None:
703-
# get the initial random noise unless the user supplied it
704-
705-
# Unlike in other pipelines, latents need to be generated in the target device
706-
# for 1-to-1 results reproducibility with the CompVis implementation.
707-
# However this currently doesn't work in `mps`.
708-
latents_shape = (
709-
batch_size * num_images_per_prompt,
710-
self.unet.in_channels,
711-
height // 8,
712-
width // 8,
713-
)
714-
715-
if latents is None:
716-
if self.device.type == "mps":
717-
# randn does not exist on mps
718-
latents = torch.randn(
719-
latents_shape,
720-
generator=generator,
721-
device="cpu",
722-
dtype=latents_dtype,
723-
).to(self.device)
724-
else:
725-
latents = torch.randn(
726-
latents_shape,
727-
generator=generator,
728-
device=self.device,
729-
dtype=latents_dtype,
730-
)
731-
else:
732-
if latents.shape != latents_shape:
733-
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
734-
latents = latents.to(self.device)
735-
736-
timesteps = self.scheduler.timesteps.to(self.device)
737-
738-
# scale the initial noise by the standard deviation required by the scheduler
739-
latents = latents * self.scheduler.init_noise_sigma
718+
dtype = text_embeddings.dtype
719+
720+
# 4. Preprocess image and mask
721+
if isinstance(image, PIL.Image.Image):
722+
image = preprocess_image(image)
723+
if image is not None:
724+
image = image.to(device=self.device, dtype=dtype)
725+
if isinstance(mask_image, PIL.Image.Image):
726+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
727+
if mask_image is not None:
728+
mask = mask_image.to(device=self.device, dtype=dtype)
729+
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
740730
else:
741-
if isinstance(image, PIL.Image.Image):
742-
image = preprocess_image(image)
743-
# encode the init image into latents and scale the latents
744-
image = image.to(device=self.device, dtype=latents_dtype)
745-
init_latent_dist = self.vae.encode(image).latent_dist
746-
init_latents = init_latent_dist.sample(generator=generator)
747-
init_latents = 0.18215 * init_latents
748-
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
749-
init_latents_orig = init_latents
750-
751-
# preprocess mask
752-
if mask_image is not None:
753-
if isinstance(mask_image, PIL.Image.Image):
754-
mask_image = preprocess_mask(mask_image)
755-
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
756-
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
757-
758-
# check sizes
759-
if not mask.shape == init_latents.shape:
760-
raise ValueError("The mask and image should be the same size!")
761-
762-
# get the original timestep using init_timestep
763-
offset = self.scheduler.config.get("steps_offset", 0)
764-
init_timestep = int(num_inference_steps * strength) + offset
765-
init_timestep = min(init_timestep, num_inference_steps)
766-
767-
timesteps = self.scheduler.timesteps[-init_timestep]
768-
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
769-
770-
# add noise to latents using the timesteps
771-
if self.device.type == "mps":
772-
# randn does not exist on mps
773-
noise = torch.randn(
774-
init_latents.shape,
775-
generator=generator,
776-
device="cpu",
777-
dtype=latents_dtype,
778-
).to(self.device)
779-
else:
780-
noise = torch.randn(
781-
init_latents.shape,
782-
generator=generator,
783-
device=self.device,
784-
dtype=latents_dtype,
785-
)
786-
latents = self.scheduler.add_noise(init_latents, noise, timesteps)
787-
788-
t_start = max(num_inference_steps - init_timestep + offset, 0)
789-
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
790-
791-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
792-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
793-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
794-
# and should be between [0, 1]
795-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
796-
extra_step_kwargs = {}
797-
if accepts_eta:
798-
extra_step_kwargs["eta"] = eta
799-
800-
for i, t in enumerate(self.progress_bar(timesteps)):
801-
# expand the latents if we are doing classifier free guidance
802-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
803-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
804-
805-
# predict the noise residual
806-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
807-
808-
# perform guidance
809-
if do_classifier_free_guidance:
810-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
811-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
812-
813-
# compute the previous noisy sample x_t -> x_t-1
814-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
815-
816-
if mask is not None:
817-
# masking
818-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
819-
latents = (init_latents_proper * mask) + (latents * (1 - mask))
820-
821-
# call the callback, if provided
822-
if i % callback_steps == 0:
823-
if callback is not None:
824-
callback(i, t, latents)
825-
if is_cancelled_callback is not None and is_cancelled_callback():
826-
return None
827-
828-
latents = 1 / 0.18215 * latents
829-
image = self.vae.decode(latents).sample
830-
831-
image = (image / 2 + 0.5).clamp(0, 1)
832-
833-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
834-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
835-
836-
if self.safety_checker is not None:
837-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
838-
self.device
839-
)
840-
image, has_nsfw_concept = self.safety_checker(
841-
images=image,
842-
clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
843-
)
844-
else:
845-
has_nsfw_concept = None
731+
mask = None
732+
733+
# 5. set timesteps
734+
self.scheduler.set_timesteps(num_inference_steps, device=device)
735+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
736+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
737+
738+
# 6. Prepare latent variables
739+
latents, init_latents_orig, noise = self.prepare_latents(
740+
image,
741+
latent_timestep,
742+
batch_size * num_images_per_prompt,
743+
height,
744+
width,
745+
dtype,
746+
device,
747+
generator,
748+
latents,
749+
)
846750

751+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
752+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
753+
754+
# 8. Denoising loop
755+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
756+
with self.progress_bar(total=num_inference_steps) as progress_bar:
757+
for i, t in enumerate(timesteps):
758+
# expand the latents if we are doing classifier free guidance
759+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
760+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
761+
762+
# predict the noise residual
763+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
764+
765+
# perform guidance
766+
if do_classifier_free_guidance:
767+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
768+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
769+
770+
# compute the previous noisy sample x_t -> x_t-1
771+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
772+
773+
if mask is not None:
774+
# masking
775+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
776+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
777+
778+
# call the callback, if provided
779+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
780+
progress_bar.update()
781+
if i % callback_steps == 0:
782+
if callback is not None:
783+
callback(i, t, latents)
784+
if is_cancelled_callback is not None and is_cancelled_callback():
785+
return None
786+
787+
# 9. Post-processing
788+
image = self.decode_latents(latents)
789+
790+
# 10. Run safety checker
791+
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
792+
793+
# 11. Convert to PIL
847794
if output_type == "pil":
848795
image = self.numpy_to_pil(image)
849796

850797
if not return_dict:
851-
return (image, has_nsfw_concept)
798+
return image, has_nsfw_concept
852799

853800
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
854801

@@ -868,6 +815,7 @@ def text2img(
868815
output_type: Optional[str] = "pil",
869816
return_dict: bool = True,
870817
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
818+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
871819
callback_steps: Optional[int] = 1,
872820
**kwargs,
873821
):
@@ -915,6 +863,9 @@ def text2img(
915863
callback (`Callable`, *optional*):
916864
A function that will be called every `callback_steps` steps during inference. The function will be
917865
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
866+
is_cancelled_callback (`Callable`, *optional*):
867+
A function that will be called every `callback_steps` steps during inference. If the function returns
868+
`True`, the inference will be cancelled.
918869
callback_steps (`int`, *optional*, defaults to 1):
919870
The frequency at which the `callback` function will be called. If not specified, the callback will be
920871
called at every step.
@@ -940,6 +891,7 @@ def text2img(
940891
output_type=output_type,
941892
return_dict=return_dict,
942893
callback=callback,
894+
is_cancelled_callback=is_cancelled_callback,
943895
callback_steps=callback_steps,
944896
**kwargs,
945897
)
@@ -959,6 +911,7 @@ def img2img(
959911
output_type: Optional[str] = "pil",
960912
return_dict: bool = True,
961913
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
914+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
962915
callback_steps: Optional[int] = 1,
963916
**kwargs,
964917
):
@@ -1007,6 +960,9 @@ def img2img(
1007960
callback (`Callable`, *optional*):
1008961
A function that will be called every `callback_steps` steps during inference. The function will be
1009962
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
963+
is_cancelled_callback (`Callable`, *optional*):
964+
A function that will be called every `callback_steps` steps during inference. If the function returns
965+
`True`, the inference will be cancelled.
1010966
callback_steps (`int`, *optional*, defaults to 1):
1011967
The frequency at which the `callback` function will be called. If not specified, the callback will be
1012968
called at every step.
@@ -1031,6 +987,7 @@ def img2img(
1031987
output_type=output_type,
1032988
return_dict=return_dict,
1033989
callback=callback,
990+
is_cancelled_callback=is_cancelled_callback,
1034991
callback_steps=callback_steps,
1035992
**kwargs,
1036993
)
@@ -1051,6 +1008,7 @@ def inpaint(
10511008
output_type: Optional[str] = "pil",
10521009
return_dict: bool = True,
10531010
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1011+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
10541012
callback_steps: Optional[int] = 1,
10551013
**kwargs,
10561014
):
@@ -1103,6 +1061,9 @@ def inpaint(
11031061
callback (`Callable`, *optional*):
11041062
A function that will be called every `callback_steps` steps during inference. The function will be
11051063
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1064+
is_cancelled_callback (`Callable`, *optional*):
1065+
A function that will be called every `callback_steps` steps during inference. If the function returns
1066+
`True`, the inference will be cancelled.
11061067
callback_steps (`int`, *optional*, defaults to 1):
11071068
The frequency at which the `callback` function will be called. If not specified, the callback will be
11081069
called at every step.
@@ -1128,6 +1089,7 @@ def inpaint(
11281089
output_type=output_type,
11291090
return_dict=return_dict,
11301091
callback=callback,
1092+
is_cancelled_callback=is_cancelled_callback,
11311093
callback_steps=callback_steps,
11321094
**kwargs,
11331095
)

‎examples/community/lpw_stable_diffusion_onnx.py

+288-224
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,13 @@
66
import torch
77

88
import PIL
9-
from diffusers.onnx_utils import OnnxRuntimeModel
10-
from diffusers.pipeline_utils import DiffusionPipeline
9+
from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
10+
from diffusers.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
1111
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
12-
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
13-
from diffusers.utils import deprecate, logging
14-
15-
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
16-
from packaging import version
12+
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
1713
from transformers import CLIPFeatureExtractor, CLIPTokenizer
1814

1915

20-
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
21-
PIL_INTERPOLATION = {
22-
"linear": PIL.Image.Resampling.BILINEAR,
23-
"bilinear": PIL.Image.Resampling.BILINEAR,
24-
"bicubic": PIL.Image.Resampling.BICUBIC,
25-
"lanczos": PIL.Image.Resampling.LANCZOS,
26-
"nearest": PIL.Image.Resampling.NEAREST,
27-
}
28-
else:
29-
PIL_INTERPOLATION = {
30-
"linear": PIL.Image.LINEAR,
31-
"bilinear": PIL.Image.BILINEAR,
32-
"bicubic": PIL.Image.BICUBIC,
33-
"lanczos": PIL.Image.LANCZOS,
34-
"nearest": PIL.Image.NEAREST,
35-
}
36-
# ------------------------------------------------------------------------------
37-
3816
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3917

4018
re_attention = re.compile(
@@ -262,7 +240,7 @@ def get_weighted_text_embeddings(
262240
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
263241
264242
Args:
265-
pipe (`DiffusionPipeline`):
243+
pipe (`OnnxStableDiffusionPipeline`):
266244
Pipe to provide access to the tokenizer and the text encoder.
267245
prompt (`str` or `List[str]`):
268246
The prompt or prompts to guide the image generation.
@@ -392,19 +370,19 @@ def preprocess_image(image):
392370
return 2.0 * image - 1.0
393371

394372

395-
def preprocess_mask(mask):
373+
def preprocess_mask(mask, scale_factor=8):
396374
mask = mask.convert("L")
397375
w, h = mask.size
398376
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
399-
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
377+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
400378
mask = np.array(mask).astype(np.float32) / 255.0
401379
mask = np.tile(mask, (4, 1, 1))
402380
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
403381
mask = 1 - mask # repaint white, keep black
404382
return mask
405383

406384

407-
class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
385+
class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline):
408386
r"""
409387
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
410388
weighting in prompt.
@@ -420,12 +398,12 @@ def __init__(
420398
text_encoder: OnnxRuntimeModel,
421399
tokenizer: CLIPTokenizer,
422400
unet: OnnxRuntimeModel,
423-
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
401+
scheduler: SchedulerMixin,
424402
safety_checker: OnnxRuntimeModel,
425403
feature_extractor: CLIPFeatureExtractor,
404+
requires_safety_checker: bool = True,
426405
):
427-
super().__init__()
428-
self.register_modules(
406+
super().__init__(
429407
vae_encoder=vae_encoder,
430408
vae_decoder=vae_decoder,
431409
text_encoder=text_encoder,
@@ -434,8 +412,171 @@ def __init__(
434412
scheduler=scheduler,
435413
safety_checker=safety_checker,
436414
feature_extractor=feature_extractor,
415+
requires_safety_checker=requires_safety_checker,
416+
)
417+
self.unet_in_channels = 4
418+
self.vae_scale_factor = 8
419+
420+
def _encode_prompt(
421+
self,
422+
prompt,
423+
num_images_per_prompt,
424+
do_classifier_free_guidance,
425+
negative_prompt,
426+
max_embeddings_multiples,
427+
):
428+
r"""
429+
Encodes the prompt into text encoder hidden states.
430+
431+
Args:
432+
prompt (`str` or `list(int)`):
433+
prompt to be encoded
434+
num_images_per_prompt (`int`):
435+
number of images that should be generated per prompt
436+
do_classifier_free_guidance (`bool`):
437+
whether to use classifier free guidance or not
438+
negative_prompt (`str` or `List[str]`):
439+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
440+
if `guidance_scale` is less than `1`).
441+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
442+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
443+
"""
444+
batch_size = len(prompt) if isinstance(prompt, list) else 1
445+
446+
if negative_prompt is None:
447+
negative_prompt = [""] * batch_size
448+
elif isinstance(negative_prompt, str):
449+
negative_prompt = [negative_prompt] * batch_size
450+
if batch_size != len(negative_prompt):
451+
raise ValueError(
452+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
453+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
454+
" the batch size of `prompt`."
455+
)
456+
457+
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
458+
pipe=self,
459+
prompt=prompt,
460+
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
461+
max_embeddings_multiples=max_embeddings_multiples,
437462
)
438463

464+
text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
465+
if do_classifier_free_guidance:
466+
uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)
467+
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
468+
469+
return text_embeddings
470+
471+
def check_inputs(self, prompt, height, width, strength, callback_steps):
472+
if not isinstance(prompt, str) and not isinstance(prompt, list):
473+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
474+
475+
if strength < 0 or strength > 1:
476+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
477+
478+
if height % 8 != 0 or width % 8 != 0:
479+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
480+
481+
if (callback_steps is None) or (
482+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
483+
):
484+
raise ValueError(
485+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
486+
f" {type(callback_steps)}."
487+
)
488+
489+
def get_timesteps(self, num_inference_steps, strength, is_text2img):
490+
if is_text2img:
491+
return self.scheduler.timesteps, num_inference_steps
492+
else:
493+
# get the original timestep using init_timestep
494+
offset = self.scheduler.config.get("steps_offset", 0)
495+
init_timestep = int(num_inference_steps * strength) + offset
496+
init_timestep = min(init_timestep, num_inference_steps)
497+
498+
t_start = max(num_inference_steps - init_timestep + offset, 0)
499+
timesteps = self.scheduler.timesteps[t_start:]
500+
return timesteps, num_inference_steps - t_start
501+
502+
def run_safety_checker(self, image):
503+
if self.safety_checker is not None:
504+
safety_checker_input = self.feature_extractor(
505+
self.numpy_to_pil(image), return_tensors="np"
506+
).pixel_values.astype(image.dtype)
507+
# There will throw an error if use safety_checker directly and batchsize>1
508+
images, has_nsfw_concept = [], []
509+
for i in range(image.shape[0]):
510+
image_i, has_nsfw_concept_i = self.safety_checker(
511+
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
512+
)
513+
images.append(image_i)
514+
has_nsfw_concept.append(has_nsfw_concept_i[0])
515+
image = np.concatenate(images)
516+
else:
517+
has_nsfw_concept = None
518+
return image, has_nsfw_concept
519+
520+
def decode_latents(self, latents):
521+
latents = 1 / 0.18215 * latents
522+
# image = self.vae_decoder(latent_sample=latents)[0]
523+
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
524+
image = np.concatenate(
525+
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
526+
)
527+
image = np.clip(image / 2 + 0.5, 0, 1)
528+
image = image.transpose((0, 2, 3, 1))
529+
return image
530+
531+
def prepare_extra_step_kwargs(self, generator, eta):
532+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
533+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
534+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
535+
# and should be between [0, 1]
536+
537+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
538+
extra_step_kwargs = {}
539+
if accepts_eta:
540+
extra_step_kwargs["eta"] = eta
541+
542+
# check if the scheduler accepts generator
543+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
544+
if accepts_generator:
545+
extra_step_kwargs["generator"] = generator
546+
return extra_step_kwargs
547+
548+
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, generator, latents=None):
549+
if image is None:
550+
shape = (
551+
batch_size,
552+
self.unet_in_channels,
553+
height // self.vae_scale_factor,
554+
width // self.vae_scale_factor,
555+
)
556+
557+
if latents is None:
558+
latents = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
559+
else:
560+
if latents.shape != shape:
561+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
562+
563+
# scale the initial noise by the standard deviation required by the scheduler
564+
latents = (torch.from_numpy(latents) * self.scheduler.init_noise_sigma).numpy()
565+
return latents, None, None
566+
else:
567+
init_latents = self.vae_encoder(sample=image)[0]
568+
init_latents = 0.18215 * init_latents
569+
init_latents = np.concatenate([init_latents] * batch_size, axis=0)
570+
init_latents_orig = init_latents
571+
shape = init_latents.shape
572+
573+
# add noise to latents using the timesteps
574+
noise = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
575+
latents = self.scheduler.add_noise(
576+
torch.from_numpy(init_latents), torch.from_numpy(noise), timestep
577+
).numpy()
578+
return latents, init_latents_orig, noise
579+
439580
@torch.no_grad()
440581
def __call__(
441582
self,
@@ -450,7 +591,7 @@ def __call__(
450591
strength: float = 0.8,
451592
num_images_per_prompt: Optional[int] = 1,
452593
eta: float = 0.0,
453-
generator: Optional[np.random.RandomState] = None,
594+
generator: Optional[torch.Generator] = None,
454595
latents: Optional[np.ndarray] = None,
455596
max_embeddings_multiples: Optional[int] = 3,
456597
output_type: Optional[str] = "pil",
@@ -501,8 +642,9 @@ def __call__(
501642
eta (`float`, *optional*, defaults to 0.0):
502643
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
503644
[`schedulers.DDIMScheduler`], will be ignored for others.
504-
generator (`np.random.RandomState`, *optional*):
505-
A np.random.RandomState to make generation deterministic.
645+
generator (`torch.Generator`, *optional*):
646+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
647+
deterministic.
506648
latents (`np.ndarray`, *optional*):
507649
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
508650
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
@@ -537,204 +679,123 @@ def __call__(
537679
init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
538680
image = init_image or image
539681

540-
if isinstance(prompt, str):
541-
batch_size = 1
542-
prompt = [prompt]
543-
elif isinstance(prompt, list):
544-
batch_size = len(prompt)
545-
else:
546-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
547-
548-
if strength < 0 or strength > 1:
549-
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
682+
# 0. Default height and width to unet
683+
height = height or self.unet.config.sample_size * self.vae_scale_factor
684+
width = width or self.unet.config.sample_size * self.vae_scale_factor
550685

551-
if height % 8 != 0 or width % 8 != 0:
552-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
553-
554-
if (callback_steps is None) or (
555-
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
556-
):
557-
raise ValueError(
558-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
559-
f" {type(callback_steps)}."
560-
)
561-
562-
# get prompt text embeddings
686+
# 1. Check inputs. Raise error if not correct
687+
self.check_inputs(prompt, height, width, strength, callback_steps)
563688

689+
# 2. Define call parameters
690+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
564691
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
565692
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
566693
# corresponds to doing no classifier free guidance.
567694
do_classifier_free_guidance = guidance_scale > 1.0
568-
# get unconditional embeddings for classifier free guidance
569-
if negative_prompt is None:
570-
negative_prompt = [""] * batch_size
571-
elif isinstance(negative_prompt, str):
572-
negative_prompt = [negative_prompt] * batch_size
573-
if batch_size != len(negative_prompt):
574-
raise ValueError(
575-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
576-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
577-
" the batch size of `prompt`."
578-
)
579695

580-
if generator is None:
581-
generator = np.random
582-
583-
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
584-
pipe=self,
585-
prompt=prompt,
586-
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
587-
max_embeddings_multiples=max_embeddings_multiples,
588-
**kwargs,
696+
# 3. Encode input prompt
697+
text_embeddings = self._encode_prompt(
698+
prompt,
699+
num_images_per_prompt,
700+
do_classifier_free_guidance,
701+
negative_prompt,
702+
max_embeddings_multiples,
589703
)
590-
591-
text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
592-
if do_classifier_free_guidance:
593-
uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)
594-
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
595-
596-
# set timesteps
597-
self.scheduler.set_timesteps(num_inference_steps)
598-
599-
latents_dtype = text_embeddings.dtype
600-
init_latents_orig = None
601-
mask = None
602-
noise = None
603-
604-
if image is None:
605-
latents_shape = (
606-
batch_size * num_images_per_prompt,
607-
4,
608-
height // 8,
609-
width // 8,
610-
)
611-
612-
if latents is None:
613-
latents = generator.randn(*latents_shape).astype(latents_dtype)
614-
elif latents.shape != latents_shape:
615-
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
616-
617-
timesteps = self.scheduler.timesteps.to(self.device)
618-
619-
# scale the initial noise by the standard deviation required by the scheduler
620-
latents = latents * self.scheduler.init_noise_sigma
704+
dtype = text_embeddings.dtype
705+
706+
# 4. Preprocess image and mask
707+
if isinstance(image, PIL.Image.Image):
708+
image = preprocess_image(image)
709+
if image is not None:
710+
image = image.astype(dtype)
711+
if isinstance(mask_image, PIL.Image.Image):
712+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
713+
if mask_image is not None:
714+
mask = mask_image.astype(dtype)
715+
mask = np.concatenate([mask] * batch_size * num_images_per_prompt)
621716
else:
622-
if isinstance(image, PIL.Image.Image):
623-
image = preprocess_image(image)
624-
# encode the init image into latents and scale the latents
625-
image = image.astype(latents_dtype)
626-
init_latents = self.vae_encoder(sample=image)[0]
627-
init_latents = 0.18215 * init_latents
628-
init_latents = np.concatenate([init_latents] * batch_size * num_images_per_prompt)
629-
init_latents_orig = init_latents
630-
631-
# preprocess mask
632-
if mask_image is not None:
633-
if isinstance(mask_image, PIL.Image.Image):
634-
mask_image = preprocess_mask(mask_image)
635-
mask_image = mask_image.astype(latents_dtype)
636-
mask = np.concatenate([mask_image] * batch_size * num_images_per_prompt)
637-
638-
# check sizes
639-
if not mask.shape == init_latents.shape:
640-
print(mask.shape, init_latents.shape)
641-
raise ValueError("The mask and image should be the same size!")
642-
643-
# get the original timestep using init_timestep
644-
offset = self.scheduler.config.get("steps_offset", 0)
645-
init_timestep = int(num_inference_steps * strength) + offset
646-
init_timestep = min(init_timestep, num_inference_steps)
717+
mask = None
647718

648-
timesteps = self.scheduler.timesteps[-init_timestep]
649-
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt)
650-
651-
# add noise to latents using the timesteps
652-
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
653-
latents = self.scheduler.add_noise(
654-
torch.from_numpy(init_latents), torch.from_numpy(noise), timesteps
655-
).numpy()
656-
657-
t_start = max(num_inference_steps - init_timestep + offset, 0)
658-
timesteps = self.scheduler.timesteps[t_start:]
659-
660-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
661-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
662-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
663-
# and should be between [0, 1]
664-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
665-
extra_step_kwargs = {}
666-
if accepts_eta:
667-
extra_step_kwargs["eta"] = eta
668-
669-
for i, t in enumerate(self.progress_bar(timesteps)):
670-
# expand the latents if we are doing classifier free guidance
671-
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
672-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
673-
674-
# predict the noise residual
675-
noise_pred = self.unet(
676-
sample=latent_model_input,
677-
timestep=np.array([t]),
678-
encoder_hidden_states=text_embeddings,
679-
)
680-
noise_pred = noise_pred[0]
681-
682-
# perform guidance
683-
if do_classifier_free_guidance:
684-
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
685-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
686-
687-
# compute the previous noisy sample x_t -> x_t-1
688-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample.numpy()
689-
690-
if mask is not None:
691-
# masking
692-
init_latents_proper = self.scheduler.add_noise(
693-
torch.from_numpy(init_latents_orig),
694-
torch.from_numpy(noise),
695-
torch.tensor([t]),
696-
).numpy()
697-
latents = (init_latents_proper * mask) + (latents * (1 - mask))
698-
699-
# call the callback, if provided
700-
if i % callback_steps == 0:
701-
if callback is not None:
702-
callback(i, t, latents)
703-
if is_cancelled_callback is not None and is_cancelled_callback():
704-
return None
719+
# 5. set timesteps
720+
self.scheduler.set_timesteps(num_inference_steps)
721+
timestep_dtype = next(
722+
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
723+
)
724+
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
725+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, image is None)
726+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
727+
728+
# 6. Prepare latent variables
729+
latents, init_latents_orig, noise = self.prepare_latents(
730+
image,
731+
latent_timestep,
732+
batch_size * num_images_per_prompt,
733+
height,
734+
width,
735+
dtype,
736+
generator,
737+
latents,
738+
)
705739

706-
latents = 1 / 0.18215 * latents
707-
# image = self.vae_decoder(latent_sample=latents)[0]
708-
# it seems likes there is a problem for using half-precision vae decoder if batchsize>1
709-
image = []
710-
for i in range(latents.shape[0]):
711-
image.append(self.vae_decoder(latent_sample=latents[i : i + 1])[0])
712-
image = np.concatenate(image)
740+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
741+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
742+
743+
# 8. Denoising loop
744+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
745+
with self.progress_bar(total=num_inference_steps) as progress_bar:
746+
for i, t in enumerate(timesteps):
747+
# expand the latents if we are doing classifier free guidance
748+
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
749+
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
750+
latent_model_input = latent_model_input.numpy()
751+
752+
# predict the noise residual
753+
noise_pred = self.unet(
754+
sample=latent_model_input,
755+
timestep=np.array([t], dtype=timestep_dtype),
756+
encoder_hidden_states=text_embeddings,
757+
)
758+
noise_pred = noise_pred[0]
713759

714-
image = np.clip(image / 2 + 0.5, 0, 1)
715-
image = image.transpose((0, 2, 3, 1))
760+
# perform guidance
761+
if do_classifier_free_guidance:
762+
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
763+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
716764

717-
if self.safety_checker is not None:
718-
safety_checker_input = self.feature_extractor(
719-
self.numpy_to_pil(image), return_tensors="np"
720-
).pixel_values.astype(image.dtype)
721-
# There will throw an error if use safety_checker directly and batchsize>1
722-
images, has_nsfw_concept = [], []
723-
for i in range(image.shape[0]):
724-
image_i, has_nsfw_concept_i = self.safety_checker(
725-
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
765+
# compute the previous noisy sample x_t -> x_t-1
766+
scheduler_output = self.scheduler.step(
767+
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
726768
)
727-
images.append(image_i)
728-
has_nsfw_concept.append(has_nsfw_concept_i[0])
729-
image = np.concatenate(images)
730-
else:
731-
has_nsfw_concept = None
732-
769+
latents = scheduler_output.prev_sample.numpy()
770+
771+
if mask is not None:
772+
# masking
773+
init_latents_proper = self.scheduler.add_noise(
774+
torch.from_numpy(init_latents_orig),
775+
torch.from_numpy(noise),
776+
t,
777+
).numpy()
778+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
779+
780+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
781+
progress_bar.update()
782+
if i % callback_steps == 0:
783+
if callback is not None:
784+
callback(i, t, latents)
785+
if is_cancelled_callback is not None and is_cancelled_callback():
786+
return None
787+
# 9. Post-processing
788+
image = self.decode_latents(latents)
789+
790+
# 10. Run safety checker
791+
image, has_nsfw_concept = self.run_safety_checker(image)
792+
793+
# 11. Convert to PIL
733794
if output_type == "pil":
734795
image = self.numpy_to_pil(image)
735796

736797
if not return_dict:
737-
return (image, has_nsfw_concept)
798+
return image, has_nsfw_concept
738799

739800
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
740801

@@ -748,7 +809,7 @@ def text2img(
748809
guidance_scale: float = 7.5,
749810
num_images_per_prompt: Optional[int] = 1,
750811
eta: float = 0.0,
751-
generator: Optional[np.random.RandomState] = None,
812+
generator: Optional[torch.Generator] = None,
752813
latents: Optional[np.ndarray] = None,
753814
max_embeddings_multiples: Optional[int] = 3,
754815
output_type: Optional[str] = "pil",
@@ -783,8 +844,9 @@ def text2img(
783844
eta (`float`, *optional*, defaults to 0.0):
784845
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
785846
[`schedulers.DDIMScheduler`], will be ignored for others.
786-
generator (`np.random.RandomState`, *optional*):
787-
A np.random.RandomState to make generation deterministic.
847+
generator (`torch.Generator`, *optional*):
848+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
849+
deterministic.
788850
latents (`np.ndarray`, *optional*):
789851
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
790852
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
@@ -839,7 +901,7 @@ def img2img(
839901
guidance_scale: Optional[float] = 7.5,
840902
num_images_per_prompt: Optional[int] = 1,
841903
eta: Optional[float] = 0.0,
842-
generator: Optional[np.random.RandomState] = None,
904+
generator: Optional[torch.Generator] = None,
843905
max_embeddings_multiples: Optional[int] = 3,
844906
output_type: Optional[str] = "pil",
845907
return_dict: bool = True,
@@ -878,8 +940,9 @@ def img2img(
878940
eta (`float`, *optional*, defaults to 0.0):
879941
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
880942
[`schedulers.DDIMScheduler`], will be ignored for others.
881-
generator (`np.random.RandomState`, *optional*):
882-
A np.random.RandomState to make generation deterministic.
943+
generator (`torch.Generator`, *optional*):
944+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
945+
deterministic.
883946
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
884947
The max multiple length of prompt embeddings compared to the max output length of text encoder.
885948
output_type (`str`, *optional*, defaults to `"pil"`):
@@ -930,7 +993,7 @@ def inpaint(
930993
guidance_scale: Optional[float] = 7.5,
931994
num_images_per_prompt: Optional[int] = 1,
932995
eta: Optional[float] = 0.0,
933-
generator: Optional[np.random.RandomState] = None,
996+
generator: Optional[torch.Generator] = None,
934997
max_embeddings_multiples: Optional[int] = 3,
935998
output_type: Optional[str] = "pil",
936999
return_dict: bool = True,
@@ -973,8 +1036,9 @@ def inpaint(
9731036
eta (`float`, *optional*, defaults to 0.0):
9741037
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
9751038
[`schedulers.DDIMScheduler`], will be ignored for others.
976-
generator (`np.random.RandomState`, *optional*):
977-
A np.random.RandomState to make generation deterministic.
1039+
generator (`torch.Generator`, *optional*):
1040+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1041+
deterministic.
9781042
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
9791043
The max multiple length of prompt embeddings compared to the max output length of text encoder.
9801044
output_type (`str`, *optional*, defaults to `"pil"`):

0 commit comments

Comments
 (0)
Please sign in to comment.