Skip to content

Commit ad1b2ef

Browse files
committed
feat(cloud_cache): normalize model_id and include precision
1 parent c2cb368 commit ad1b2ef

5 files changed

Lines changed: 119 additions & 28 deletions

File tree

app.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from getPipeline import getPipelineForModel, listAvailablePipelines, clearPipelines
1919
import re
2020
import requests
21-
from download import download_model
21+
from download import download_model, normalize_model_id
2222
import traceback
2323

2424
RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
@@ -130,11 +130,14 @@ def inference(all_inputs: dict) -> dict:
130130
if not model_id:
131131
model_id = MODEL_ID
132132
result["$meta"].update({"MODEL_ID": MODEL_ID})
133+
normalized_model_id = model_id
133134

134135
if RUNTIME_DOWNLOADS:
135136
global downloaded_models
136-
if last_model_id != model_id:
137-
if not downloaded_models.get(model_id, None):
137+
model_precision = call_inputs.get("MODEL_PRECISION", None)
138+
normalized_model_id = normalize_model_id(model_id, model_precision)
139+
if last_model_id != normalized_model_id:
140+
if not downloaded_models.get(normalized_model_id, None):
138141
model_url = call_inputs.get("MODEL_URL", None)
139142
if not model_url:
140143
return {
@@ -143,18 +146,22 @@ def inference(all_inputs: dict) -> dict:
143146
"message": "Currently RUNTIME_DOWNOADS requires a MODEL_URL callInput",
144147
}
145148
}
146-
download_model(model_id=model_id, model_url=model_url)
147-
downloaded_models.update({model_id: True})
148-
model = loadModel(model_id)
149+
download_model(
150+
model_id=model_id,
151+
model_url=model_url,
152+
model_revision=model_precision,
153+
)
154+
downloaded_models.update({normalized_model_id: True})
155+
model = loadModel(normalized_model_id)
149156
if PIPELINE == "ALL":
150157
clearPipelines()
151-
last_model_id = model_id
158+
last_model_id = normalized_model_id
152159

153160
if MODEL_ID == "ALL":
154-
if last_model_id != model_id:
155-
model = loadModel(model_id)
161+
if last_model_id != normalized_model_id:
162+
model = loadModel(normalized_model_id)
156163
clearPipelines()
157-
last_model_id = model_id
164+
last_model_id = normalized_model_id
158165
else:
159166
if model_id != MODEL_ID and not RUNTIME_DOWNLOADS:
160167
return {
@@ -172,7 +179,7 @@ def inference(all_inputs: dict) -> dict:
172179
pipeline_name = "StableDiffusionPipeline"
173180
result["$meta"].update({"PIPELINE": pipeline_name})
174181

175-
pipeline = getPipelineForModel(pipeline_name, model, model_id)
182+
pipeline = getPipelineForModel(pipeline_name, model, normalized_model_id)
176183
if not pipeline:
177184
return {
178185
"$error": {
@@ -190,7 +197,7 @@ def inference(all_inputs: dict) -> dict:
190197
scheduler_name = "DPMSolverMultistepScheduler"
191198
result["$meta"].update({"SCHEDULER": scheduler_name})
192199

193-
pipeline.scheduler = getScheduler(model_id, scheduler_name)
200+
pipeline.scheduler = getScheduler(normalized_model_id, scheduler_name)
194201
if pipeline.scheduler == None:
195202
return {
196203
"$error": {
@@ -289,7 +296,9 @@ def inference(all_inputs: dict) -> dict:
289296
}
290297
}
291298
torch.set_grad_enabled(True)
292-
result = result | TrainDreamBooth(model_id, pipeline, model_inputs, call_inputs)
299+
result = result | TrainDreamBooth(
300+
normalized_model_id, pipeline, model_inputs, call_inputs
301+
)
293302
torch.set_grad_enabled(False)
294303
send("inference", "done", {"startRequestId": startRequestId})
295304
result.update({"$timings": getTimings()})

docs/internal_safetensor_cache_flow.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,32 @@ e.g. stabilityai/stable-diffusion-2-1-base
1717
1. Run inference with HF model.
1818

1919
FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/diffusers/models--stabilityai--stable-diffusion-2-1-base/refs/main'
20+
21+
22+
NVIDIA RTX Quadro 5000
23+
24+
NO SAFETENSORS
25+
Downloaded in 462557 ms
26+
Loading model: stabilityai/stable-diffusion-2-1 (fp32)
27+
Loaded from disk in 3113 ms, to gpu in 1644 ms
28+
29+
SAFETENSORS_FAST_GPU=0
30+
Loaded from disk in 2741 ms, to gpu in 557 ms
31+
32+
SAFETENSORS_FAST_GPU=1
33+
Loaded from disk in 1153 ms, to gpu in 1495 ms
34+
35+
36+
37+
NVIDIA RTX Quadro 5000 (fp16)
38+
39+
NO SAFETENSORS
40+
Downloaded in 462557 ms
41+
Loading model: stabilityai/stable-diffusion-2-1-base (fp16)
42+
Loaded from disk in 2043 ms, to gpu in 1539 ms
43+
44+
SAFETENSORS_FAST_GPU=0
45+
46+
47+
SAFETENSORS_FAST_GPU=1
48+
Loaded from disk in 1134 ms, to gpu in 1184 ms

download.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from loadModel import loadModel, MODEL_IDS
66
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
77
from transformers import CLIPTextModel, CLIPTokenizer
8-
from precision import revision
8+
from precision import PRECISION, revision_from_precision, torch_dtype_from_precision
99
from utils import Storage
1010
import subprocess
1111
from pathlib import Path
@@ -24,12 +24,30 @@ def send(type: str, status: str, payload: dict = {}):
2424
_send(type, status, payload)
2525

2626

27-
def download_model(model_url=None, model_id=None):
27+
def normalize_model_id(model_id: str, model_revision):
28+
normalized_model_id = "models--" + model_id.replace("/", "--")
29+
if model_revision:
30+
normalized_model_id += "--" + model_revision
31+
return normalized_model_id
32+
33+
34+
def download_model(model_url=None, model_id=None, model_revision=None):
35+
print(
36+
"download_model",
37+
{
38+
"model_url": model_url,
39+
"model_id": model_id,
40+
"model_revision": model_revision,
41+
},
42+
)
2843
id = model_id or MODEL_ID
2944
url = model_url or MODEL_URL
45+
revision = model_revision or revision_from_precision()
46+
normalized_model_id = id
3047

3148
if url != "":
32-
normalized_model_id = "models--" + model_id.replace("/", "--")
49+
normalized_model_id = normalize_model_id(model_id, model_revision)
50+
print({"normalized_model_id": normalized_model_id})
3351
filename = url.split("/").pop()
3452
if not filename:
3553
filename = normalized_model_id + ".tar.zst"
@@ -38,17 +56,31 @@ def download_model(model_url=None, model_id=None):
3856
if exists:
3957
storage.download_file(filename)
4058
# os.mkdir(id)
41-
Path(id).mkdir(parents=True, exist_ok=False)
59+
# Path(id).mkdir(parents=True, exist_ok=False)
60+
os.mkdir(normalized_model_id)
4261
subprocess.run(
43-
["tar", "--use-compress-program=unzstd", "-C", id, "-xvf", filename],
62+
[
63+
"tar",
64+
"--use-compress-program=unzstd",
65+
"-C",
66+
normalized_model_id,
67+
"-xvf",
68+
filename,
69+
],
4470
check=True,
4571
)
4672
subprocess.run(["ls", "-l"])
4773
else:
4874
print("Does not exist, let's try find it on huggingface")
49-
download_model(model_id=model_id)
50-
model = loadModel(model_id, True)
51-
dir = "models--" + model_id.replace("/", "--") + "--dda"
75+
print("precision = ", {"model_revision": model_revision})
76+
# This would be quicker to just model.to("cuda") afterwards, but
77+
# this conveniently logs all the timings (and doesn't happen often)
78+
print("download")
79+
model = loadModel(model_id, False, precision=model_revision) # download
80+
print("load")
81+
model = loadModel(model_id, True, precision=model_revision) # load
82+
# dir = "models--" + model_id.replace("/", "--") + "--dda"
83+
dir = normalized_model_id
5284
model.save_pretrained(dir, safe_serialization=True)
5385

5486
# This is all duped from train_dreambooth, need to refactor TODO XXX
@@ -67,7 +99,10 @@ def download_model(model_url=None, model_id=None):
6799
send("upload", "done")
68100
print(upload_result)
69101
os.remove(filename)
70-
shutil.rmtree(dir)
102+
103+
# leave model dir for future loads... make configurable?
104+
# shutil.rmtree(dir)
105+
71106
# TODO, swap directories, inside HF's cache structure.
72107

73108
return
@@ -76,9 +111,9 @@ def download_model(model_url=None, model_id=None):
76111
# For local dev & preview deploys, download all the models (terrible for serverless deploys)
77112
if MODEL_ID == "ALL":
78113
for MODEL_I in MODEL_IDS:
79-
loadModel(MODEL_I, False)
114+
loadModel(MODEL_I, False, precision=model_revision)
80115
else:
81-
loadModel(MODEL_ID, False)
116+
loadModel(normalized_model_id, False, precision=model_revision)
82117

83118
# if USE_DREAMBOOTH:
84119
# Actually we can re-use these from the above loaded model

loadModel.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from diffusers import pipelines as _pipelines, StableDiffusionPipeline
44
from getScheduler import getScheduler, DEFAULT_SCHEDULER
5-
from precision import revision, torch_dtype
5+
from precision import revision_from_precision, torch_dtype_from_precision
66
import time
77

88
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
@@ -21,8 +21,16 @@
2121
]
2222

2323

24-
def loadModel(model_id: str, load=True):
25-
print(("Loading" if load else "Downloading") + " model: " + model_id)
24+
def loadModel(model_id: str, load=True, precision=None):
25+
print("loadModel", {"model_id": model_id, "load": load, "precision": precision})
26+
revision = revision_from_precision(precision)
27+
torch_dtype = torch_dtype_from_precision(precision)
28+
print(
29+
("Loading" if load else "Downloading")
30+
+ " model: "
31+
+ model_id
32+
+ (f" ({revision})" if revision else "")
33+
)
2634

2735
pipeline = (
2836
StableDiffusionPipeline if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE)
@@ -51,4 +59,4 @@ def loadModel(model_id: str, load=True):
5159
else:
5260
print(f"Downloaded in {from_pretrained} ms")
5361

54-
return model.to("cuda") if load else None
62+
return model if load else None

precision.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,13 @@
55

66
revision = None if PRECISION == "" else PRECISION
77
torch_dtype = None if PRECISION == "" else torch.float16
8+
9+
10+
def revision_from_precision(precision=PRECISION):
11+
return precision if precision else None
12+
13+
14+
def torch_dtype_from_precision(precision=PRECISION):
15+
if precision == "fp16":
16+
return torch.float16
17+
return None

0 commit comments

Comments
 (0)