Skip to content

Commit 3d82b51

Browse files
committed
flug 2x24gb vram parallel strategy
1 parent 92b282f commit 3d82b51

File tree

3 files changed

+149
-0
lines changed

3 files changed

+149
-0
lines changed

app/llms/workers/children/flux1.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import base64
2+
import io
3+
import torch
4+
from diffusers import FluxPipeline
5+
import gc
6+
from diffusers import FluxTransformer2DModel
7+
from diffusers import AutoencoderKL
8+
from diffusers.image_processor import VaeImageProcessor
9+
10+
from app.config import RESTAI_DEFAULT_DEVICE
11+
12+
def flush():
13+
gc.collect()
14+
torch.cuda.empty_cache()
15+
torch.cuda.reset_max_memory_allocated()
16+
torch.cuda.reset_peak_memory_stats()
17+
18+
def worker(prompt, sharedmem):
19+
20+
pipeline = FluxPipeline.from_pretrained(
21+
"black-forest-labs/FLUX.1-dev",
22+
transformer=None,
23+
vae=None,
24+
device_map="balanced",
25+
max_memory={0: "24GB", 1: "24GB"},
26+
torch_dtype=torch.bfloat16
27+
)
28+
with torch.no_grad():
29+
print("Encoding prompts.")
30+
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
31+
prompt=prompt, prompt_2=None, max_sequence_length=512
32+
)
33+
34+
del pipeline.text_encoder
35+
del pipeline.text_encoder_2
36+
del pipeline.tokenizer
37+
del pipeline.tokenizer_2
38+
del pipeline
39+
40+
flush()
41+
42+
transformer = FluxTransformer2DModel.from_pretrained(
43+
"black-forest-labs/FLUX.1-dev",
44+
subfolder="transformer",
45+
device_map="auto",
46+
torch_dtype=torch.bfloat16
47+
)
48+
49+
pipeline = FluxPipeline.from_pretrained(
50+
"black-forest-labs/FLUX.1-dev",
51+
text_encoder=None,
52+
text_encoder_2=None,
53+
tokenizer=None,
54+
tokenizer_2=None,
55+
vae=None,
56+
transformer=transformer,
57+
torch_dtype=torch.bfloat16
58+
)
59+
60+
print("Running denoising.")
61+
height, width = 768, 1360
62+
latents = pipeline(
63+
prompt_embeds=prompt_embeds,
64+
pooled_prompt_embeds=pooled_prompt_embeds,
65+
num_inference_steps=50,
66+
guidance_scale=3.5,
67+
height=height,
68+
width=width,
69+
output_type="latent",
70+
).images
71+
72+
del pipeline.transformer
73+
del pipeline
74+
75+
flush()
76+
77+
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
78+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
79+
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
80+
81+
with torch.no_grad():
82+
print("Running decoding.")
83+
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
84+
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
85+
86+
image = vae.decode(latents, return_dict=False)[0]
87+
image = image_processor.postprocess(image, output_type="pil")
88+
89+
image_data = io.BytesIO()
90+
image[0].save(image_data, format="JPEG")
91+
image_base64 = base64.b64encode(image_data.getvalue()).decode('utf-8')
92+
93+
sharedmem["image"] = image_base64

app/llms/workers/flux.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from torch.multiprocessing import Process, set_start_method, Manager
2+
3+
from app.llms.workers.children.flux1 import worker
4+
5+
try:
6+
set_start_method('spawn')
7+
except RuntimeError:
8+
pass
9+
from langchain.tools import BaseTool
10+
from langchain.chains import LLMChain
11+
from langchain_community.chat_models import ChatOpenAI
12+
from langchain.prompts import PromptTemplate
13+
14+
from typing import Optional
15+
from langchain.callbacks.manager import (
16+
CallbackManagerForToolRun,
17+
)
18+
from ilock import ILock, ILockException
19+
20+
21+
class FluxImage(BaseTool):
22+
name = "Flux Image Generator"
23+
description = "use this tool when you need to generate an image using Flux."
24+
return_direct = True
25+
26+
def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
27+
if run_manager.tags[0].boost == True:
28+
llm = ChatOpenAI(temperature=0.9, model_name="gpt-3.5-turbo")
29+
prompt = PromptTemplate(
30+
input_variables=["image_desc"],
31+
template="Generate a detailed prompt to generate an image based on the following description: {image_desc}",
32+
)
33+
chain = LLMChain(llm=llm, prompt=prompt)
34+
35+
fprompt = chain.run(query)
36+
else:
37+
fprompt = run_manager.tags[0].question
38+
39+
manager = Manager()
40+
sharedmem = manager.dict()
41+
42+
with ILock('flux', timeout=180):
43+
p = Process(target=worker, args=(fprompt, sharedmem))
44+
p.start()
45+
p.join()
46+
p.kill()
47+
48+
if "image" not in sharedmem or not sharedmem["image"]:
49+
raise Exception("An error occurred while processing the image. Please try again.")
50+
51+
return {"type": "flux", "image": sharedmem["image"], "prompt": fprompt}
52+
53+
async def _arun(self, query: str) -> str:
54+
raise NotImplementedError("N/A")

app/projects/vision.py

+2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def question(self, project: Project, questionModel: QuestionModel, user: User, d
5252
from app.llms.workers.stablediffusion import StableDiffusionImage
5353
from app.llms.workers.describeimage import DescribeImage
5454
from app.llms.workers.instantid import InstantID
55+
from app.llms.workers.flux import FluxImage
5556
tools.append(StableDiffusionImage())
57+
tools.append(FluxImage())
5658
tools.append(DescribeImage())
5759
tools.append(InstantID())
5860

0 commit comments

Comments
 (0)