Skip to content

Commit 92b282f

Browse files
committed
stablediffusion 3.5, deps update
1 parent b97aebb commit 92b282f

File tree

6 files changed

+1409
-1238
lines changed

6 files changed

+1409
-1238
lines changed

app/llms/ollamamultimodal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ async def astream_chat(
222222

223223

224224

225-
class OllamaMultiModal2(OllamaMultiModal):
225+
class OllamaMultiModalInternal(OllamaMultiModal):
226226
system: str = Field(
227227
default="", description="Default system message to send to the model."
228228
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import base64
2+
import io
3+
from diffusers import BitsAndBytesConfig, SD3Transformer2DModel
4+
from diffusers import StableDiffusion3Pipeline
5+
import torch
6+
from transformers import T5EncoderModel
7+
8+
from app.config import RESTAI_DEFAULT_DEVICE
9+
10+
11+
def worker(prompt, sharedmem):
12+
model_id = "stabilityai/stable-diffusion-3.5-large-turbo"
13+
14+
nf4_config = BitsAndBytesConfig(
15+
load_in_4bit=True,
16+
bnb_4bit_quant_type="nf4",
17+
bnb_4bit_compute_dtype=torch.bfloat16
18+
)
19+
model_nf4 = SD3Transformer2DModel.from_pretrained(
20+
model_id,
21+
subfolder="transformer",
22+
quantization_config=nf4_config,
23+
torch_dtype=torch.bfloat16
24+
)
25+
26+
t5_nf4 = T5EncoderModel.from_pretrained("diffusers/t5-nf4", torch_dtype=torch.bfloat16)
27+
28+
pipeline = StableDiffusion3Pipeline.from_pretrained(
29+
model_id,
30+
transformer=model_nf4,
31+
text_encoder_3=t5_nf4,
32+
torch_dtype=torch.bfloat16
33+
)
34+
pipeline.enable_model_cpu_offload()
35+
36+
image = pipeline(
37+
prompt=prompt,
38+
num_inference_steps=4,
39+
guidance_scale=0.0,
40+
max_sequence_length=512,
41+
).images[0]
42+
43+
44+
image_data = io.BytesIO()
45+
image.save(image_data, format="JPEG")
46+
image_base64 = base64.b64encode(image_data.getvalue()).decode('utf-8')
47+
48+
sharedmem["image"] = image_base64

app/llms/workers/stablediffusion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#from app.llms.workers.children.stablediffusion import worker
44
#from app.llms.workers.children.sdxl_lightning import worker
5-
from app.llms.workers.children.stablediffusion3 import worker
5+
from app.llms.workers.children.stablediffusion35 import worker
66

77
try:
88
set_start_method('spawn')

app/tools.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ def get_llm_class(llm_class_name):
2929
if llm_class_name == "Ollama":
3030
from app.llms.ollama import Ollama
3131
return Ollama, {}
32-
elif llm_class_name == "OllamaMultiModal2":
33-
from app.llms.ollamamultimodal import OllamaMultiModal2
34-
return OllamaMultiModal2, {}
32+
elif llm_class_name == "OllamaMultiModal":
33+
from llama_index.multi_modal_llms.ollama import OllamaMultiModal
34+
return OllamaMultiModal, {}
35+
elif llm_class_name == "OllamaMultiModalInternal" or llm_class_name == "OllamaMultiModal2":
36+
from app.llms.ollamamultimodal import OllamaMultiModalInternal
37+
return OllamaMultiModalInternal, {}
3538
elif llm_class_name == "OpenAI":
3639
from llama_index.llms.openai import OpenAI
3740
return OpenAI, {}

0 commit comments

Comments
 (0)