6
6
import torch
7
7
8
8
import PIL
9
- from diffusers . configuration_utils import FrozenDict
9
+ from diffusers import SchedulerMixin , StableDiffusionPipeline
10
10
from diffusers .models import AutoencoderKL , UNet2DConditionModel
11
- from diffusers .pipeline_utils import DiffusionPipeline
12
- from diffusers .pipelines .stable_diffusion import StableDiffusionPipelineOutput
13
- from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
14
- from diffusers .schedulers import DDIMScheduler , LMSDiscreteScheduler , PNDMScheduler
15
- from diffusers .utils import deprecate , is_accelerate_available , logging
16
-
17
- # TODO: remove and import from diffusers.utils when the new version of diffusers is released
18
- from packaging import version
11
+ from diffusers .pipelines .stable_diffusion import StableDiffusionPipelineOutput , StableDiffusionSafetyChecker
12
+ from diffusers .utils import PIL_INTERPOLATION , deprecate , logging
19
13
from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
20
14
21
15
22
- if version .parse (version .parse (PIL .__version__ ).base_version ) >= version .parse ("9.1.0" ):
23
- PIL_INTERPOLATION = {
24
- "linear" : PIL .Image .Resampling .BILINEAR ,
25
- "bilinear" : PIL .Image .Resampling .BILINEAR ,
26
- "bicubic" : PIL .Image .Resampling .BICUBIC ,
27
- "lanczos" : PIL .Image .Resampling .LANCZOS ,
28
- "nearest" : PIL .Image .Resampling .NEAREST ,
29
- }
30
- else :
31
- PIL_INTERPOLATION = {
32
- "linear" : PIL .Image .LINEAR ,
33
- "bilinear" : PIL .Image .BILINEAR ,
34
- "bicubic" : PIL .Image .BICUBIC ,
35
- "lanczos" : PIL .Image .LANCZOS ,
36
- "nearest" : PIL .Image .NEAREST ,
37
- }
38
- # ------------------------------------------------------------------------------
39
-
40
-
41
16
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
42
17
43
18
re_attention = re .compile (
@@ -146,7 +121,7 @@ def multiply_range(start_position, multiplier):
146
121
return res
147
122
148
123
149
- def get_prompts_with_weights (pipe : DiffusionPipeline , prompt : List [str ], max_length : int ):
124
+ def get_prompts_with_weights (pipe : StableDiffusionPipeline , prompt : List [str ], max_length : int ):
150
125
r"""
151
126
Tokenize a list of prompts and return its tokens with weights of each token.
152
127
@@ -207,7 +182,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
207
182
208
183
209
184
def get_unweighted_text_embeddings (
210
- pipe : DiffusionPipeline ,
185
+ pipe : StableDiffusionPipeline ,
211
186
text_input : torch .Tensor ,
212
187
chunk_length : int ,
213
188
no_boseos_middle : Optional [bool ] = True ,
@@ -247,10 +222,10 @@ def get_unweighted_text_embeddings(
247
222
248
223
249
224
def get_weighted_text_embeddings (
250
- pipe : DiffusionPipeline ,
225
+ pipe : StableDiffusionPipeline ,
251
226
prompt : Union [str , List [str ]],
252
227
uncond_prompt : Optional [Union [str , List [str ]]] = None ,
253
- max_embeddings_multiples : Optional [int ] = 1 ,
228
+ max_embeddings_multiples : Optional [int ] = 3 ,
254
229
no_boseos_middle : Optional [bool ] = False ,
255
230
skip_parsing : Optional [bool ] = False ,
256
231
skip_weighting : Optional [bool ] = False ,
@@ -264,14 +239,14 @@ def get_weighted_text_embeddings(
264
239
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
265
240
266
241
Args:
267
- pipe (`DiffusionPipeline `):
242
+ pipe (`StableDiffusionPipeline `):
268
243
Pipe to provide access to the tokenizer and the text encoder.
269
244
prompt (`str` or `List[str]`):
270
245
The prompt or prompts to guide the image generation.
271
246
uncond_prompt (`str` or `List[str]`):
272
247
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
273
248
is provided, the embeddings of prompt and uncond_prompt are concatenated.
274
- max_embeddings_multiples (`int`, *optional*, defaults to `1 `):
249
+ max_embeddings_multiples (`int`, *optional*, defaults to `3 `):
275
250
The max multiple length of prompt embeddings compared to the max output length of text encoder.
276
251
no_boseos_middle (`bool`, *optional*, defaults to `False`):
277
252
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
@@ -387,11 +362,11 @@ def preprocess_image(image):
387
362
return 2.0 * image - 1.0
388
363
389
364
390
- def preprocess_mask (mask ):
365
+ def preprocess_mask (mask , scale_factor = 8 ):
391
366
mask = mask .convert ("L" )
392
367
w , h = mask .size
393
368
w , h = map (lambda x : x - x % 32 , (w , h )) # resize to integer multiple of 32
394
- mask = mask .resize ((w // 8 , h // 8 ), resample = PIL_INTERPOLATION ["nearest" ])
369
+ mask = mask .resize ((w // scale_factor , h // scale_factor ), resample = PIL_INTERPOLATION ["nearest" ])
395
370
mask = np .array (mask ).astype (np .float32 ) / 255.0
396
371
mask = np .tile (mask , (4 , 1 , 1 ))
397
372
mask = mask [None ].transpose (0 , 1 , 2 , 3 ) # what does this step do?
@@ -400,7 +375,7 @@ def preprocess_mask(mask):
400
375
return mask
401
376
402
377
403
- class StableDiffusionLongPromptWeightingPipeline (DiffusionPipeline ):
378
+ class StableDiffusionLongPromptWeightingPipeline (StableDiffusionPipeline ):
404
379
r"""
405
380
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
406
381
weighting in prompt.
@@ -435,102 +410,184 @@ def __init__(
435
410
text_encoder : CLIPTextModel ,
436
411
tokenizer : CLIPTokenizer ,
437
412
unet : UNet2DConditionModel ,
438
- scheduler : Union [ DDIMScheduler , PNDMScheduler , LMSDiscreteScheduler ] ,
413
+ scheduler : SchedulerMixin ,
439
414
safety_checker : StableDiffusionSafetyChecker ,
440
415
feature_extractor : CLIPFeatureExtractor ,
416
+ requires_safety_checker : bool = True ,
441
417
):
442
- super ().__init__ ()
443
-
444
- if hasattr (scheduler .config , "steps_offset" ) and scheduler .config .steps_offset != 1 :
445
- deprecation_message = (
446
- f"The configuration file of this scheduler: { scheduler } is outdated. `steps_offset`"
447
- f" should be set to 1 instead of { scheduler .config .steps_offset } . Please make sure "
448
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
449
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
450
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
451
- " file"
452
- )
453
- deprecate ("steps_offset!=1" , "1.0.0" , deprecation_message , standard_warn = False )
454
- new_config = dict (scheduler .config )
455
- new_config ["steps_offset" ] = 1
456
- scheduler ._internal_dict = FrozenDict (new_config )
457
-
458
- if hasattr (scheduler .config , "clip_sample" ) and scheduler .config .clip_sample is True :
459
- deprecation_message = (
460
- f"The configuration file of this scheduler: { scheduler } has not set the configuration `clip_sample`."
461
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
462
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
463
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
464
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
465
- )
466
- deprecate ("clip_sample not set" , "1.0.0" , deprecation_message , standard_warn = False )
467
- new_config = dict (scheduler .config )
468
- new_config ["clip_sample" ] = False
469
- scheduler ._internal_dict = FrozenDict (new_config )
470
-
471
- if safety_checker is None :
472
- logger .warning (
473
- f"You have disabled the safety checker for { self .__class__ } by passing `safety_checker=None`. Ensure"
474
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
475
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
476
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
477
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
478
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
479
- )
480
-
481
- self .register_modules (
418
+ super ().__init__ (
482
419
vae = vae ,
483
420
text_encoder = text_encoder ,
484
421
tokenizer = tokenizer ,
485
422
unet = unet ,
486
423
scheduler = scheduler ,
487
424
safety_checker = safety_checker ,
488
425
feature_extractor = feature_extractor ,
426
+ requires_safety_checker = requires_safety_checker ,
489
427
)
490
428
491
- def enable_attention_slicing (self , slice_size : Optional [Union [str , int ]] = "auto" ):
429
+ def _encode_prompt (
430
+ self ,
431
+ prompt ,
432
+ device ,
433
+ num_images_per_prompt ,
434
+ do_classifier_free_guidance ,
435
+ negative_prompt ,
436
+ max_embeddings_multiples ,
437
+ ):
492
438
r"""
493
- Enable sliced attention computation.
494
-
495
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
496
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
439
+ Encodes the prompt into text encoder hidden states.
497
440
498
441
Args:
499
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
500
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
501
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
502
- `attention_head_dim` must be a multiple of `slice_size`.
442
+ prompt (`str` or `list(int)`):
443
+ prompt to be encoded
444
+ device: (`torch.device`):
445
+ torch device
446
+ num_images_per_prompt (`int`):
447
+ number of images that should be generated per prompt
448
+ do_classifier_free_guidance (`bool`):
449
+ whether to use classifier free guidance or not
450
+ negative_prompt (`str` or `List[str]`):
451
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
452
+ if `guidance_scale` is less than `1`).
453
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
454
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
503
455
"""
504
- if slice_size == "auto" :
505
- # half the attention head size is usually a good trade-off between
506
- # speed and memory
507
- slice_size = self .unet .config .attention_head_dim // 2
508
- self .unet .set_attention_slice (slice_size )
456
+ batch_size = len (prompt ) if isinstance (prompt , list ) else 1
509
457
510
- def disable_attention_slicing (self ):
511
- r"""
512
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
513
- back to computing attention in one step.
514
- """
515
- # set slice_size = `None` to disable `attention slicing`
516
- self .enable_attention_slicing (None )
458
+ if negative_prompt is None :
459
+ negative_prompt = ["" ] * batch_size
460
+ elif isinstance (negative_prompt , str ):
461
+ negative_prompt = [negative_prompt ] * batch_size
462
+ if batch_size != len (negative_prompt ):
463
+ raise ValueError (
464
+ f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
465
+ f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
466
+ " the batch size of `prompt`."
467
+ )
517
468
518
- def enable_sequential_cpu_offload (self ):
519
- r"""
520
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
521
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
522
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
523
- """
524
- if is_accelerate_available ():
525
- from accelerate import cpu_offload
469
+ text_embeddings , uncond_embeddings = get_weighted_text_embeddings (
470
+ pipe = self ,
471
+ prompt = prompt ,
472
+ uncond_prompt = negative_prompt if do_classifier_free_guidance else None ,
473
+ max_embeddings_multiples = max_embeddings_multiples ,
474
+ )
475
+ bs_embed , seq_len , _ = text_embeddings .shape
476
+ text_embeddings = text_embeddings .repeat (1 , num_images_per_prompt , 1 )
477
+ text_embeddings = text_embeddings .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
478
+
479
+ if do_classifier_free_guidance :
480
+ bs_embed , seq_len , _ = uncond_embeddings .shape
481
+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
482
+ uncond_embeddings = uncond_embeddings .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
483
+ text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
484
+
485
+ return text_embeddings
486
+
487
+ def check_inputs (self , prompt , height , width , strength , callback_steps ):
488
+ if not isinstance (prompt , str ) and not isinstance (prompt , list ):
489
+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
490
+
491
+ if strength < 0 or strength > 1 :
492
+ raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
493
+
494
+ if height % 8 != 0 or width % 8 != 0 :
495
+ raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
496
+
497
+ if (callback_steps is None ) or (
498
+ callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
499
+ ):
500
+ raise ValueError (
501
+ f"`callback_steps` has to be a positive integer but is { callback_steps } of type"
502
+ f" { type (callback_steps )} ."
503
+ )
504
+
505
+ def get_timesteps (self , num_inference_steps , strength , device , is_text2img ):
506
+ if is_text2img :
507
+ return self .scheduler .timesteps .to (device ), num_inference_steps
508
+ else :
509
+ # get the original timestep using init_timestep
510
+ offset = self .scheduler .config .get ("steps_offset" , 0 )
511
+ init_timestep = int (num_inference_steps * strength ) + offset
512
+ init_timestep = min (init_timestep , num_inference_steps )
513
+
514
+ t_start = max (num_inference_steps - init_timestep + offset , 0 )
515
+ timesteps = self .scheduler .timesteps [t_start :].to (device )
516
+ return timesteps , num_inference_steps - t_start
517
+
518
+ def run_safety_checker (self , image , device , dtype ):
519
+ if self .safety_checker is not None :
520
+ safety_checker_input = self .feature_extractor (self .numpy_to_pil (image ), return_tensors = "pt" ).to (device )
521
+ image , has_nsfw_concept = self .safety_checker (
522
+ images = image , clip_input = safety_checker_input .pixel_values .to (dtype )
523
+ )
526
524
else :
527
- raise ImportError ("Please install accelerate via `pip install accelerate`" )
525
+ has_nsfw_concept = None
526
+ return image , has_nsfw_concept
527
+
528
+ def decode_latents (self , latents ):
529
+ latents = 1 / 0.18215 * latents
530
+ image = self .vae .decode (latents ).sample
531
+ image = (image / 2 + 0.5 ).clamp (0 , 1 )
532
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
533
+ image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
534
+ return image
535
+
536
+ def prepare_extra_step_kwargs (self , generator , eta ):
537
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
538
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
539
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
540
+ # and should be between [0, 1]
528
541
529
- device = self .device
542
+ accepts_eta = "eta" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
543
+ extra_step_kwargs = {}
544
+ if accepts_eta :
545
+ extra_step_kwargs ["eta" ] = eta
546
+
547
+ # check if the scheduler accepts generator
548
+ accepts_generator = "generator" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
549
+ if accepts_generator :
550
+ extra_step_kwargs ["generator" ] = generator
551
+ return extra_step_kwargs
552
+
553
+ def prepare_latents (self , image , timestep , batch_size , height , width , dtype , device , generator , latents = None ):
554
+ if image is None :
555
+ shape = (
556
+ batch_size ,
557
+ self .unet .in_channels ,
558
+ height // self .vae_scale_factor ,
559
+ width // self .vae_scale_factor ,
560
+ )
561
+
562
+ if latents is None :
563
+ if device .type == "mps" :
564
+ # randn does not work reproducibly on mps
565
+ latents = torch .randn (shape , generator = generator , device = "cpu" , dtype = dtype ).to (device )
566
+ else :
567
+ latents = torch .randn (shape , generator = generator , device = device , dtype = dtype )
568
+ else :
569
+ if latents .shape != shape :
570
+ raise ValueError (f"Unexpected latents shape, got { latents .shape } , expected { shape } " )
571
+ latents = latents .to (device )
572
+
573
+ # scale the initial noise by the standard deviation required by the scheduler
574
+ latents = latents * self .scheduler .init_noise_sigma
575
+ return latents , None , None
576
+ else :
577
+ init_latent_dist = self .vae .encode (image ).latent_dist
578
+ init_latents = init_latent_dist .sample (generator = generator )
579
+ init_latents = 0.18215 * init_latents
580
+ init_latents = torch .cat ([init_latents ] * batch_size , dim = 0 )
581
+ init_latents_orig = init_latents
582
+ shape = init_latents .shape
530
583
531
- for cpu_offloaded_model in [self .unet , self .text_encoder , self .vae , self .safety_checker ]:
532
- if cpu_offloaded_model is not None :
533
- cpu_offload (cpu_offloaded_model , device )
584
+ # add noise to latents using the timesteps
585
+ if device .type == "mps" :
586
+ noise = torch .randn (shape , generator = generator , device = "cpu" , dtype = dtype ).to (device )
587
+ else :
588
+ noise = torch .randn (shape , generator = generator , device = device , dtype = dtype )
589
+ latents = self .scheduler .add_noise (init_latents , noise , timestep )
590
+ return latents , init_latents_orig , noise
534
591
535
592
@torch .no_grad ()
536
593
def __call__ (
@@ -634,221 +691,111 @@ def __call__(
634
691
init_image = deprecate ("init_image" , "0.12.0" , message , take_from = kwargs )
635
692
image = init_image or image
636
693
637
- if isinstance (prompt , str ):
638
- batch_size = 1
639
- prompt = [prompt ]
640
- elif isinstance (prompt , list ):
641
- batch_size = len (prompt )
642
- else :
643
- raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
694
+ # 0. Default height and width to unet
695
+ height = height or self .unet .config .sample_size * self .vae_scale_factor
696
+ width = width or self .unet .config .sample_size * self .vae_scale_factor
644
697
645
- if strength < 0 or strength > 1 :
646
- raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
647
-
648
- if height % 8 != 0 or width % 8 != 0 :
649
- raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
650
-
651
- if (callback_steps is None ) or (
652
- callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
653
- ):
654
- raise ValueError (
655
- f"`callback_steps` has to be a positive integer but is { callback_steps } of type"
656
- f" { type (callback_steps )} ."
657
- )
658
-
659
- # get prompt text embeddings
698
+ # 1. Check inputs. Raise error if not correct
699
+ self .check_inputs (prompt , height , width , strength , callback_steps )
660
700
701
+ # 2. Define call parameters
702
+ batch_size = 1 if isinstance (prompt , str ) else len (prompt )
703
+ device = self ._execution_device
661
704
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
662
705
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
663
706
# corresponds to doing no classifier free guidance.
664
707
do_classifier_free_guidance = guidance_scale > 1.0
665
- # get unconditional embeddings for classifier free guidance
666
- if negative_prompt is None :
667
- negative_prompt = ["" ] * batch_size
668
- elif isinstance (negative_prompt , str ):
669
- negative_prompt = [negative_prompt ] * batch_size
670
- if batch_size != len (negative_prompt ):
671
- raise ValueError (
672
- f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
673
- f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
674
- " the batch size of `prompt`."
675
- )
676
708
677
- text_embeddings , uncond_embeddings = get_weighted_text_embeddings (
678
- pipe = self ,
679
- prompt = prompt ,
680
- uncond_prompt = negative_prompt if do_classifier_free_guidance else None ,
681
- max_embeddings_multiples = max_embeddings_multiples ,
682
- ** kwargs ,
709
+ # 3. Encode input prompt
710
+ text_embeddings = self ._encode_prompt (
711
+ prompt ,
712
+ device ,
713
+ num_images_per_prompt ,
714
+ do_classifier_free_guidance ,
715
+ negative_prompt ,
716
+ max_embeddings_multiples ,
683
717
)
684
- bs_embed , seq_len , _ = text_embeddings .shape
685
- text_embeddings = text_embeddings .repeat (1 , num_images_per_prompt , 1 )
686
- text_embeddings = text_embeddings .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
687
-
688
- if do_classifier_free_guidance :
689
- bs_embed , seq_len , _ = uncond_embeddings .shape
690
- uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
691
- uncond_embeddings = uncond_embeddings .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
692
- text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
693
-
694
- # set timesteps
695
- self .scheduler .set_timesteps (num_inference_steps )
696
-
697
- latents_dtype = text_embeddings .dtype
698
- init_latents_orig = None
699
- mask = None
700
- noise = None
701
-
702
- if image is None :
703
- # get the initial random noise unless the user supplied it
704
-
705
- # Unlike in other pipelines, latents need to be generated in the target device
706
- # for 1-to-1 results reproducibility with the CompVis implementation.
707
- # However this currently doesn't work in `mps`.
708
- latents_shape = (
709
- batch_size * num_images_per_prompt ,
710
- self .unet .in_channels ,
711
- height // 8 ,
712
- width // 8 ,
713
- )
714
-
715
- if latents is None :
716
- if self .device .type == "mps" :
717
- # randn does not exist on mps
718
- latents = torch .randn (
719
- latents_shape ,
720
- generator = generator ,
721
- device = "cpu" ,
722
- dtype = latents_dtype ,
723
- ).to (self .device )
724
- else :
725
- latents = torch .randn (
726
- latents_shape ,
727
- generator = generator ,
728
- device = self .device ,
729
- dtype = latents_dtype ,
730
- )
731
- else :
732
- if latents .shape != latents_shape :
733
- raise ValueError (f"Unexpected latents shape, got { latents .shape } , expected { latents_shape } " )
734
- latents = latents .to (self .device )
735
-
736
- timesteps = self .scheduler .timesteps .to (self .device )
737
-
738
- # scale the initial noise by the standard deviation required by the scheduler
739
- latents = latents * self .scheduler .init_noise_sigma
718
+ dtype = text_embeddings .dtype
719
+
720
+ # 4. Preprocess image and mask
721
+ if isinstance (image , PIL .Image .Image ):
722
+ image = preprocess_image (image )
723
+ if image is not None :
724
+ image = image .to (device = self .device , dtype = dtype )
725
+ if isinstance (mask_image , PIL .Image .Image ):
726
+ mask_image = preprocess_mask (mask_image , self .vae_scale_factor )
727
+ if mask_image is not None :
728
+ mask = mask_image .to (device = self .device , dtype = dtype )
729
+ mask = torch .cat ([mask ] * batch_size * num_images_per_prompt )
740
730
else :
741
- if isinstance (image , PIL .Image .Image ):
742
- image = preprocess_image (image )
743
- # encode the init image into latents and scale the latents
744
- image = image .to (device = self .device , dtype = latents_dtype )
745
- init_latent_dist = self .vae .encode (image ).latent_dist
746
- init_latents = init_latent_dist .sample (generator = generator )
747
- init_latents = 0.18215 * init_latents
748
- init_latents = torch .cat ([init_latents ] * batch_size * num_images_per_prompt , dim = 0 )
749
- init_latents_orig = init_latents
750
-
751
- # preprocess mask
752
- if mask_image is not None :
753
- if isinstance (mask_image , PIL .Image .Image ):
754
- mask_image = preprocess_mask (mask_image )
755
- mask_image = mask_image .to (device = self .device , dtype = latents_dtype )
756
- mask = torch .cat ([mask_image ] * batch_size * num_images_per_prompt )
757
-
758
- # check sizes
759
- if not mask .shape == init_latents .shape :
760
- raise ValueError ("The mask and image should be the same size!" )
761
-
762
- # get the original timestep using init_timestep
763
- offset = self .scheduler .config .get ("steps_offset" , 0 )
764
- init_timestep = int (num_inference_steps * strength ) + offset
765
- init_timestep = min (init_timestep , num_inference_steps )
766
-
767
- timesteps = self .scheduler .timesteps [- init_timestep ]
768
- timesteps = torch .tensor ([timesteps ] * batch_size * num_images_per_prompt , device = self .device )
769
-
770
- # add noise to latents using the timesteps
771
- if self .device .type == "mps" :
772
- # randn does not exist on mps
773
- noise = torch .randn (
774
- init_latents .shape ,
775
- generator = generator ,
776
- device = "cpu" ,
777
- dtype = latents_dtype ,
778
- ).to (self .device )
779
- else :
780
- noise = torch .randn (
781
- init_latents .shape ,
782
- generator = generator ,
783
- device = self .device ,
784
- dtype = latents_dtype ,
785
- )
786
- latents = self .scheduler .add_noise (init_latents , noise , timesteps )
787
-
788
- t_start = max (num_inference_steps - init_timestep + offset , 0 )
789
- timesteps = self .scheduler .timesteps [t_start :].to (self .device )
790
-
791
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
792
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
793
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
794
- # and should be between [0, 1]
795
- accepts_eta = "eta" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
796
- extra_step_kwargs = {}
797
- if accepts_eta :
798
- extra_step_kwargs ["eta" ] = eta
799
-
800
- for i , t in enumerate (self .progress_bar (timesteps )):
801
- # expand the latents if we are doing classifier free guidance
802
- latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
803
- latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
804
-
805
- # predict the noise residual
806
- noise_pred = self .unet (latent_model_input , t , encoder_hidden_states = text_embeddings ).sample
807
-
808
- # perform guidance
809
- if do_classifier_free_guidance :
810
- noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
811
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
812
-
813
- # compute the previous noisy sample x_t -> x_t-1
814
- latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
815
-
816
- if mask is not None :
817
- # masking
818
- init_latents_proper = self .scheduler .add_noise (init_latents_orig , noise , torch .tensor ([t ]))
819
- latents = (init_latents_proper * mask ) + (latents * (1 - mask ))
820
-
821
- # call the callback, if provided
822
- if i % callback_steps == 0 :
823
- if callback is not None :
824
- callback (i , t , latents )
825
- if is_cancelled_callback is not None and is_cancelled_callback ():
826
- return None
827
-
828
- latents = 1 / 0.18215 * latents
829
- image = self .vae .decode (latents ).sample
830
-
831
- image = (image / 2 + 0.5 ).clamp (0 , 1 )
832
-
833
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
834
- image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
835
-
836
- if self .safety_checker is not None :
837
- safety_checker_input = self .feature_extractor (self .numpy_to_pil (image ), return_tensors = "pt" ).to (
838
- self .device
839
- )
840
- image , has_nsfw_concept = self .safety_checker (
841
- images = image ,
842
- clip_input = safety_checker_input .pixel_values .to (text_embeddings .dtype ),
843
- )
844
- else :
845
- has_nsfw_concept = None
731
+ mask = None
732
+
733
+ # 5. set timesteps
734
+ self .scheduler .set_timesteps (num_inference_steps , device = device )
735
+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device , image is None )
736
+ latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
737
+
738
+ # 6. Prepare latent variables
739
+ latents , init_latents_orig , noise = self .prepare_latents (
740
+ image ,
741
+ latent_timestep ,
742
+ batch_size * num_images_per_prompt ,
743
+ height ,
744
+ width ,
745
+ dtype ,
746
+ device ,
747
+ generator ,
748
+ latents ,
749
+ )
846
750
751
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
752
+ extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
753
+
754
+ # 8. Denoising loop
755
+ num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
756
+ with self .progress_bar (total = num_inference_steps ) as progress_bar :
757
+ for i , t in enumerate (timesteps ):
758
+ # expand the latents if we are doing classifier free guidance
759
+ latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
760
+ latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
761
+
762
+ # predict the noise residual
763
+ noise_pred = self .unet (latent_model_input , t , encoder_hidden_states = text_embeddings ).sample
764
+
765
+ # perform guidance
766
+ if do_classifier_free_guidance :
767
+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
768
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
769
+
770
+ # compute the previous noisy sample x_t -> x_t-1
771
+ latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
772
+
773
+ if mask is not None :
774
+ # masking
775
+ init_latents_proper = self .scheduler .add_noise (init_latents_orig , noise , torch .tensor ([t ]))
776
+ latents = (init_latents_proper * mask ) + (latents * (1 - mask ))
777
+
778
+ # call the callback, if provided
779
+ if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
780
+ progress_bar .update ()
781
+ if i % callback_steps == 0 :
782
+ if callback is not None :
783
+ callback (i , t , latents )
784
+ if is_cancelled_callback is not None and is_cancelled_callback ():
785
+ return None
786
+
787
+ # 9. Post-processing
788
+ image = self .decode_latents (latents )
789
+
790
+ # 10. Run safety checker
791
+ image , has_nsfw_concept = self .run_safety_checker (image , device , text_embeddings .dtype )
792
+
793
+ # 11. Convert to PIL
847
794
if output_type == "pil" :
848
795
image = self .numpy_to_pil (image )
849
796
850
797
if not return_dict :
851
- return ( image , has_nsfw_concept )
798
+ return image , has_nsfw_concept
852
799
853
800
return StableDiffusionPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
854
801
@@ -868,6 +815,7 @@ def text2img(
868
815
output_type : Optional [str ] = "pil" ,
869
816
return_dict : bool = True ,
870
817
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
818
+ is_cancelled_callback : Optional [Callable [[], bool ]] = None ,
871
819
callback_steps : Optional [int ] = 1 ,
872
820
** kwargs ,
873
821
):
@@ -915,6 +863,9 @@ def text2img(
915
863
callback (`Callable`, *optional*):
916
864
A function that will be called every `callback_steps` steps during inference. The function will be
917
865
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
866
+ is_cancelled_callback (`Callable`, *optional*):
867
+ A function that will be called every `callback_steps` steps during inference. If the function returns
868
+ `True`, the inference will be cancelled.
918
869
callback_steps (`int`, *optional*, defaults to 1):
919
870
The frequency at which the `callback` function will be called. If not specified, the callback will be
920
871
called at every step.
@@ -940,6 +891,7 @@ def text2img(
940
891
output_type = output_type ,
941
892
return_dict = return_dict ,
942
893
callback = callback ,
894
+ is_cancelled_callback = is_cancelled_callback ,
943
895
callback_steps = callback_steps ,
944
896
** kwargs ,
945
897
)
@@ -959,6 +911,7 @@ def img2img(
959
911
output_type : Optional [str ] = "pil" ,
960
912
return_dict : bool = True ,
961
913
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
914
+ is_cancelled_callback : Optional [Callable [[], bool ]] = None ,
962
915
callback_steps : Optional [int ] = 1 ,
963
916
** kwargs ,
964
917
):
@@ -1007,6 +960,9 @@ def img2img(
1007
960
callback (`Callable`, *optional*):
1008
961
A function that will be called every `callback_steps` steps during inference. The function will be
1009
962
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
963
+ is_cancelled_callback (`Callable`, *optional*):
964
+ A function that will be called every `callback_steps` steps during inference. If the function returns
965
+ `True`, the inference will be cancelled.
1010
966
callback_steps (`int`, *optional*, defaults to 1):
1011
967
The frequency at which the `callback` function will be called. If not specified, the callback will be
1012
968
called at every step.
@@ -1031,6 +987,7 @@ def img2img(
1031
987
output_type = output_type ,
1032
988
return_dict = return_dict ,
1033
989
callback = callback ,
990
+ is_cancelled_callback = is_cancelled_callback ,
1034
991
callback_steps = callback_steps ,
1035
992
** kwargs ,
1036
993
)
@@ -1051,6 +1008,7 @@ def inpaint(
1051
1008
output_type : Optional [str ] = "pil" ,
1052
1009
return_dict : bool = True ,
1053
1010
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
1011
+ is_cancelled_callback : Optional [Callable [[], bool ]] = None ,
1054
1012
callback_steps : Optional [int ] = 1 ,
1055
1013
** kwargs ,
1056
1014
):
@@ -1103,6 +1061,9 @@ def inpaint(
1103
1061
callback (`Callable`, *optional*):
1104
1062
A function that will be called every `callback_steps` steps during inference. The function will be
1105
1063
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1064
+ is_cancelled_callback (`Callable`, *optional*):
1065
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1066
+ `True`, the inference will be cancelled.
1106
1067
callback_steps (`int`, *optional*, defaults to 1):
1107
1068
The frequency at which the `callback` function will be called. If not specified, the callback will be
1108
1069
called at every step.
@@ -1128,6 +1089,7 @@ def inpaint(
1128
1089
output_type = output_type ,
1129
1090
return_dict = return_dict ,
1130
1091
callback = callback ,
1092
+ is_cancelled_callback = is_cancelled_callback ,
1131
1093
callback_steps = callback_steps ,
1132
1094
** kwargs ,
1133
1095
)
0 commit comments