55from loadModel import loadModel , MODEL_IDS
66from diffusers import AutoencoderKL , UNet2DConditionModel , DDPMScheduler
77from transformers import CLIPTextModel , CLIPTokenizer
8- from precision import revision
8+ from precision import PRECISION , revision_from_precision , torch_dtype_from_precision
99from utils import Storage
1010import subprocess
1111from pathlib import Path
@@ -24,12 +24,30 @@ def send(type: str, status: str, payload: dict = {}):
2424 _send (type , status , payload )
2525
2626
27- def download_model (model_url = None , model_id = None ):
27+ def normalize_model_id (model_id : str , model_revision ):
28+ normalized_model_id = "models--" + model_id .replace ("/" , "--" )
29+ if model_revision :
30+ normalized_model_id += "--" + model_revision
31+ return normalized_model_id
32+
33+
34+ def download_model (model_url = None , model_id = None , model_revision = None ):
35+ print (
36+ "download_model" ,
37+ {
38+ "model_url" : model_url ,
39+ "model_id" : model_id ,
40+ "model_revision" : model_revision ,
41+ },
42+ )
2843 id = model_id or MODEL_ID
2944 url = model_url or MODEL_URL
45+ revision = model_revision or revision_from_precision ()
46+ normalized_model_id = id
3047
3148 if url != "" :
32- normalized_model_id = "models--" + model_id .replace ("/" , "--" )
49+ normalized_model_id = normalize_model_id (model_id , model_revision )
50+ print ({"normalized_model_id" : normalized_model_id })
3351 filename = url .split ("/" ).pop ()
3452 if not filename :
3553 filename = normalized_model_id + ".tar.zst"
@@ -38,17 +56,31 @@ def download_model(model_url=None, model_id=None):
3856 if exists :
3957 storage .download_file (filename )
4058 # os.mkdir(id)
41- Path (id ).mkdir (parents = True , exist_ok = False )
59+ # Path(id).mkdir(parents=True, exist_ok=False)
60+ os .mkdir (normalized_model_id )
4261 subprocess .run (
43- ["tar" , "--use-compress-program=unzstd" , "-C" , id , "-xvf" , filename ],
62+ [
63+ "tar" ,
64+ "--use-compress-program=unzstd" ,
65+ "-C" ,
66+ normalized_model_id ,
67+ "-xvf" ,
68+ filename ,
69+ ],
4470 check = True ,
4571 )
4672 subprocess .run (["ls" , "-l" ])
4773 else :
4874 print ("Does not exist, let's try find it on huggingface" )
49- download_model (model_id = model_id )
50- model = loadModel (model_id , True )
51- dir = "models--" + model_id .replace ("/" , "--" ) + "--dda"
75+ print ("precision = " , {"model_revision" : model_revision })
76+ # This would be quicker to just model.to("cuda") afterwards, but
77+ # this conveniently logs all the timings (and doesn't happen often)
78+ print ("download" )
79+ model = loadModel (model_id , False , precision = model_revision ) # download
80+ print ("load" )
81+ model = loadModel (model_id , True , precision = model_revision ) # load
82+ # dir = "models--" + model_id.replace("/", "--") + "--dda"
83+ dir = normalized_model_id
5284 model .save_pretrained (dir , safe_serialization = True )
5385
5486 # This is all duped from train_dreambooth, need to refactor TODO XXX
@@ -67,7 +99,10 @@ def download_model(model_url=None, model_id=None):
6799 send ("upload" , "done" )
68100 print (upload_result )
69101 os .remove (filename )
70- shutil .rmtree (dir )
102+
103+ # leave model dir for future loads... make configurable?
104+ # shutil.rmtree(dir)
105+
71106 # TODO, swap directories, inside HF's cache structure.
72107
73108 return
@@ -76,9 +111,9 @@ def download_model(model_url=None, model_id=None):
76111 # For local dev & preview deploys, download all the models (terrible for serverless deploys)
77112 if MODEL_ID == "ALL" :
78113 for MODEL_I in MODEL_IDS :
79- loadModel (MODEL_I , False )
114+ loadModel (MODEL_I , False , precision = model_revision )
80115 else :
81- loadModel (MODEL_ID , False )
116+ loadModel (normalized_model_id , False , precision = model_revision )
82117
83118 # if USE_DREAMBOOTH:
84119 # Actually we can re-use these from the above loaded model
0 commit comments