Skip to content

Commit fa0fa95

Browse files
committed
update flux pipeline
1 parent 41ea2f8 commit fa0fa95

File tree

6 files changed

+32
-22
lines changed

6 files changed

+32
-22
lines changed

diffsynth/models/flux_dit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def forward(
364364

365365
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
366366
if self.guidance_embedder is not None:
367+
guidance = guidance * 1000
367368
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
368369
prompt_emb = self.context_embedder(prompt_emb)
369370
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))

diffsynth/pipelines/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
6565
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
6666
return prompt, local_prompts, masks, mask_scales
6767

68+
6869
def enable_cpu_offload(self):
6970
self.cpu_offload = True
7071

72+
7173
def load_models_to_device(self, loadmodel_names=[]):
7274
# only load models to device if cpu_offload is enabled
7375
if not self.cpu_offload:
@@ -85,3 +87,9 @@ def load_models_to_device(self, loadmodel_names=[]):
8587
model.to(self.device)
8688
# fresh the cuda cache
8789
torch.cuda.empty_cache()
90+
91+
92+
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
93+
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
94+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
95+
return noise

diffsynth/pipelines/flux_image.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
5858
return image
5959

6060

61-
def encode_prompt(self, prompt, positive=True, t5_sequence_length=256):
61+
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
6262
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
6363
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
6464
)
@@ -80,7 +80,7 @@ def __call__(
8080
mask_scales= None,
8181
negative_prompt="",
8282
cfg_scale=1.0,
83-
embedded_guidance=1.0,
83+
embedded_guidance=3.5,
8484
input_image=None,
8585
denoising_strength=1.0,
8686
height=1024,
@@ -90,6 +90,7 @@ def __call__(
9090
tiled=False,
9191
tile_size=128,
9292
tile_stride=64,
93+
seed=None,
9394
progress_bar_cmd=tqdm,
9495
progress_bar_st=None,
9596
):
@@ -104,10 +105,10 @@ def __call__(
104105
self.load_models_to_device(['vae_encoder'])
105106
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
106107
latents = self.encode_image(image, **tiler_kwargs)
107-
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
108+
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
108109
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
109110
else:
110-
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
111+
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
111112

112113
# Extend prompt
113114
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])

diffsynth/prompters/flux_prompter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def encode_prompt(
5757
prompt,
5858
positive=True,
5959
device="cuda",
60-
t5_sequence_length=256,
60+
t5_sequence_length=512,
6161
):
6262
prompt = self.process_prompt(prompt, positive=positive)
6363

examples/image_synthesis/flux_text_to_image.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,30 @@
1212
])
1313
pipe = FluxImagePipeline.from_model_manager(model_manager)
1414

15-
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
16-
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
15+
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
16+
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
1717

1818
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
19-
torch.manual_seed(6)
19+
torch.manual_seed(9)
2020
image = pipe(
2121
prompt=prompt,
22-
num_inference_steps=30, embedded_guidance=3.5
22+
num_inference_steps=50, embedded_guidance=3.5
2323
)
2424
image.save("image_1024.jpg")
2525

2626
# Enable classifier-free guidance
27-
torch.manual_seed(6)
27+
torch.manual_seed(9)
2828
image = pipe(
2929
prompt=prompt, negative_prompt=negative_prompt,
30-
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5
30+
num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5
3131
)
3232
image.save("image_1024_cfg.jpg")
3333

3434
# Highres-fix
35-
torch.manual_seed(7)
35+
torch.manual_seed(10)
3636
image = pipe(
3737
prompt=prompt,
38-
num_inference_steps=30, embedded_guidance=3.5,
38+
num_inference_steps=50, embedded_guidance=3.5,
3939
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
4040
)
4141
image.save("image_2048_highres.jpg")

examples/image_synthesis/flux_text_to_image_low_vram.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,30 @@
2222
pipe.enable_cpu_offload()
2323
pipe.dit.quantize()
2424

25-
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
26-
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
25+
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
26+
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
2727

2828
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
29-
torch.manual_seed(6)
29+
torch.manual_seed(9)
3030
image = pipe(
3131
prompt=prompt,
32-
num_inference_steps=30, embedded_guidance=3.5
32+
num_inference_steps=50, embedded_guidance=3.5
3333
)
3434
image.save("image_1024.jpg")
3535

3636
# Enable classifier-free guidance
37-
torch.manual_seed(6)
37+
torch.manual_seed(9)
3838
image = pipe(
3939
prompt=prompt, negative_prompt=negative_prompt,
40-
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5
40+
num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5
4141
)
4242
image.save("image_1024_cfg.jpg")
4343

4444
# Highres-fix
45-
torch.manual_seed(7)
45+
torch.manual_seed(10)
4646
image = pipe(
4747
prompt=prompt,
48-
num_inference_steps=30, embedded_guidance=3.5,
48+
num_inference_steps=50, embedded_guidance=3.5,
4949
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
5050
)
51-
image.save("image_2048_highres.jpg")
51+
image.save("image_2048_highres.jpg")

0 commit comments

Comments
 (0)