diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000..f84627cb0d Binary files /dev/null and b/.DS_Store differ diff --git a/README.md b/README.md index e526089c9d..a59d31bbe3 100644 --- a/README.md +++ b/README.md @@ -60,9 +60,16 @@ Note that the way we connect layers is computational efficient. The original SD First create a new conda environment +#### CUDA, CPU + conda env create -f environment.yaml conda activate control +#### MPS + + conda env create -f environment-mps.yaml + conda activate control + All models and detectors can be downloaded from [our Hugging Face page](https://huggingface.co/lllyasviel/ControlNet). Make sure that SD models are put in "ControlNet/models" and detectors are put in "ControlNet/annotator/ckpts". Make sure that you download all necessary pretrained weights and detector models from that Hugging Face page, including HED edge detection model, Midas depth estimation model, Openpose, and so on. We provide 9 Gradio apps with these models. @@ -73,8 +80,14 @@ All test images can be found at the folder "test_imgs". Stable Diffusion 1.5 + ControlNet (using simple Canny edge detection) +##### CUDA, CPU + python gradio_canny2image.py +##### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_canny2image.py + The Gradio app also allows you to change the Canny edge thresholds. Just try it for more details. Prompt: "bird" @@ -87,8 +100,14 @@ Prompt: "cute dog" Stable Diffusion 1.5 + ControlNet (using simple M-LSD straight line detection) +##### CUDA, CPU + python gradio_hough2image.py +#### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_hough2image.py + The Gradio app also allows you to change the M-LSD thresholds. Just try it for more details. Prompt: "room" @@ -101,8 +120,14 @@ Prompt: "building" Stable Diffusion 1.5 + ControlNet (using soft HED Boundary) +#### CUDA, CPU + python gradio_hed2image.py +#### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_hed2image.py + The soft HED Boundary will preserve many details in input images, making this app suitable for recoloring and stylizing. Just try it for more details. Prompt: "oil painting of handsome old man, masterpiece" @@ -115,8 +140,14 @@ Prompt: "Cyberpunk robot" Stable Diffusion 1.5 + ControlNet (using Scribbles) +#### CUDA, CPU + python gradio_scribble2image.py +#### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_scribble2image.py + Note that the UI is based on Gradio, and Gradio is somewhat difficult to customize. Right now you need to draw scribbles outside the UI (using your favorite drawing software, for example, MS Paint) and then import the scribble image to Gradio. Prompt: "turtle" @@ -129,8 +160,14 @@ Prompt: "hot air balloon" We actually provide an interactive interface +#### CUDA, CPU + python gradio_scribble2image_interactive.py +#### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_scribble2image.py + ~~However, because gradio is very [buggy](https://github.com/gradio-app/gradio/issues/3166) and difficult to customize, right now, user need to first set canvas width and heights and then click "Open drawing canvas" to get a drawing area. Please do not upload image to that drawing canvas. Also, the drawing area is very small; it should be bigger. But I failed to find out how to make it larger. Again, gradio is really buggy.~~ (Now fixed, will update asap) The below dog sketch is drawn by me. Perhaps we should draw a better dog for showcase. @@ -142,8 +179,14 @@ Prompt: "dog in a room" Stable Diffusion 1.5 + ControlNet (using fake scribbles) +#### CUDA, CPU + python gradio_fake_scribble2image.py +#### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_fake_scribble2image.py + Sometimes we are lazy, and we do not want to draw scribbles. This script use the exactly same scribble-based model but use a simple algorithm to synthesize scribbles from input images. Prompt: "bag" @@ -156,8 +199,13 @@ Prompt: "shose" (Note that "shose" is a typo; it should be "shoes". But it still Stable Diffusion 1.5 + ControlNet (using human pose) +#### CUDA, CPU + python gradio_pose2image.py +#### MPS + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_pose2image.py + Apparently, this model deserves a better UI to directly manipulate pose skeleton. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then the Openpose will detect the pose for you. Prompt: "Chief in the kitchen" @@ -170,8 +218,13 @@ Prompt: "An astronaut on the moon" Stable Diffusion 1.5 + ControlNet (using semantic segmentation) +#### CUDA, CPU + python gradio_seg2image.py +#### MPS + Not Supported (Reason:aten::_slow_conv2d_forward is currently not supported by mps.) + This model use ADE20K's segmentation protocol. Again, this model deserves a better UI to directly draw the segmentations. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then a model called Uniformer will detect the segmentations for you. Just try it for more details. Prompt: "House" @@ -184,8 +237,14 @@ Prompt: "River" Stable Diffusion 1.5 + ControlNet (using depth map) +#### CUDA, CPU + python gradio_depth2image.py +### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_depth2image.py + Great! Now SD 1.5 also have a depth control. FINALLY. So many possibilities (considering SD1.5 has much more community models than SD2). Note that different from Stability's model, the ControlNet receive the full 512×512 depth map, rather than 64×64 depth. Note that Stability's SD2 depth model use 64*64 depth maps. This means that the ControlNet will preserve more details in the depth map. @@ -199,8 +258,14 @@ Prompt: "Stormtrooper's lecture" Stable Diffusion 1.5 + ControlNet (using normal map) +#### CUDA, CPU + python gradio_normal2image.py +#### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_normal2image.py + This model use normal map. Rightnow in the APP, the normal is computed from the midas depth map and a user threshold (to determine how many area is background with identity normal face to viewer, tune the "Normal background threshold" in the gradio app to get a feeling). Prompt: "Cute toy" diff --git a/annotator/hed/__init__.py b/annotator/hed/__init__.py index 56532c374d..f882115f15 100644 --- a/annotator/hed/__init__.py +++ b/annotator/hed/__init__.py @@ -93,20 +93,21 @@ def forward(self, tenInput): return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1)) -class HEDdetector: - def __init__(self): +class HEDdetector(): + def __init__(self, device): + self.device = device remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth" modelpath = os.path.join(annotator_ckpts_path, "network-bsds500.pth") if not os.path.exists(modelpath): from basicsr.utils.download_util import load_file_from_url load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) - self.netNetwork = Network(modelpath).cuda().eval() + self.netNetwork = Network(modelpath).to(device).eval() def __call__(self, input_image): assert input_image.ndim == 3 input_image = input_image[:, :, ::-1].copy() with torch.no_grad(): - image_hed = torch.from_numpy(input_image).float().cuda() + image_hed = torch.from_numpy(input_image).float().to(self.device) image_hed = image_hed / 255.0 image_hed = rearrange(image_hed, 'h w c -> 1 c h w') edge = self.netNetwork(image_hed)[0] diff --git a/annotator/midas/__init__.py b/annotator/midas/__init__.py index dc5ac03eea..7aa15e5dbb 100644 --- a/annotator/midas/__init__.py +++ b/annotator/midas/__init__.py @@ -7,14 +7,15 @@ class MidasDetector: - def __init__(self): - self.model = MiDaSInference(model_type="dpt_hybrid").cuda() + def __init__(self, device): + self.device = device + self.model = MiDaSInference(model_type="dpt_hybrid").to(device) def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): assert input_image.ndim == 3 image_depth = input_image with torch.no_grad(): - image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = torch.from_numpy(image_depth).float().to(self.device) image_depth = image_depth / 127.5 - 1.0 image_depth = rearrange(image_depth, 'h w c -> 1 c h w') depth = self.model(image_depth)[0] diff --git a/annotator/mlsd/__init__.py b/annotator/mlsd/__init__.py index 42af28c682..3cf840d464 100644 --- a/annotator/mlsd/__init__.py +++ b/annotator/mlsd/__init__.py @@ -15,14 +15,15 @@ class MLSDdetector: - def __init__(self): + def __init__(self, device): model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth") if not os.path.exists(model_path): from basicsr.utils.download_util import load_file_from_url load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) model = MobileV2_MLSD_Large() model.load_state_dict(torch.load(model_path), strict=True) - self.model = model.cuda().eval() + self.model = model.to(device).eval() + self.device = device def __call__(self, input_image, thr_v, thr_d): assert input_image.ndim == 3 @@ -30,7 +31,7 @@ def __call__(self, input_image, thr_v, thr_d): img_output = np.zeros_like(img) try: with torch.no_grad(): - lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) + lines = pred_lines(img, self.model, self.device, [img.shape[0], img.shape[1]], thr_v, thr_d) for line in lines: x_start, y_start, x_end, y_end = [int(val) for val in line] cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) diff --git a/annotator/mlsd/utils.py b/annotator/mlsd/utils.py index ae3cf9420a..b3ecdfb8b8 100644 --- a/annotator/mlsd/utils.py +++ b/annotator/mlsd/utils.py @@ -44,7 +44,7 @@ def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): return ptss, scores, displacement -def pred_lines(image, model, +def pred_lines(image, model, device, input_shape=[512, 512], score_thr=0.10, dist_thr=20.0): @@ -58,7 +58,7 @@ def pred_lines(image, model, batch_image = np.expand_dims(resized_image, axis=0).astype('float32') batch_image = (batch_image / 127.5) - 1.0 - batch_image = torch.from_numpy(batch_image).float().cuda() + batch_image = torch.from_numpy(batch_image).float().to(device) outputs = model(batch_image) pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) start = vmap[:, :, :2] diff --git a/annotator/uniformer/__init__.py b/annotator/uniformer/__init__.py index 6be429542e..d1f2d00bca 100644 --- a/annotator/uniformer/__init__.py +++ b/annotator/uniformer/__init__.py @@ -9,13 +9,13 @@ class UniformerDetector: - def __init__(self): + def __init__(self, device): modelpath = os.path.join(annotator_ckpts_path, "upernet_global_small.pth") if not os.path.exists(modelpath): from basicsr.utils.download_util import load_file_from_url load_file_from_url(checkpoint_file, model_dir=annotator_ckpts_path) config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py") - self.model = init_segmentor(config_file, modelpath).cuda() + self.model = init_segmentor(config_file, modelpath).to(device) def __call__(self, img): result = inference_segmentor(self.model, img) diff --git a/cldm/ddim_hacked.py b/cldm/ddim_hacked.py index 25b1bc9472..76ba15bfbc 100644 --- a/cldm/ddim_hacked.py +++ b/cldm/ddim_hacked.py @@ -8,16 +8,20 @@ class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): + def __init__(self, model, device, schedule="linear", **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.device): + if str(self.device) == 'mps': + attr = attr.to(self.device, torch.float32) + else: + attr = attr.to(self.device) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/cldm/model.py b/cldm/model.py index fed3c31ac1..387938c101 100644 --- a/cldm/model.py +++ b/cldm/model.py @@ -9,8 +9,10 @@ def get_state_dict(d): return d.get('state_dict', d) -def load_state_dict(ckpt_path, location='cpu'): +def load_state_dict(ckpt_path, location): _, extension = os.path.splitext(ckpt_path) + if str(location) == "mps": + location = "cpu" if extension.lower() == ".safetensors": import safetensors.torch state_dict = safetensors.torch.load_file(ckpt_path, device=location) diff --git a/environment-mps.yaml b/environment-mps.yaml new file mode 100644 index 0000000000..ef87ae47b6 --- /dev/null +++ b/environment-mps.yaml @@ -0,0 +1,34 @@ +name: control +channels: + - pytorch + - defaults +dependencies: + - python=3.8 + - pip + - pytorch=1.12.1 + - torchvision=0.13.1 + - numpy=1.23.1 + - pip: + - gradio==3.16.2 + - albumentations==1.3.0 + - opencv-contrib-python + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.5.0 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit==1.12.1 + - einops==0.3.0 + - transformers==4.19.2 + - webdataset==0.2.5 + - kornia==0.6 + - open_clip_torch==2.0.2 + - invisible-watermark>=0.1.5 + - streamlit-drawable-canvas==0.8.0 + - torchmetrics==0.6.0 + - timm==0.6.12 + - addict==2.4.0 + - yapf==0.32.0 + - prettytable==3.6.0 + - safetensors==0.2.7 + - basicsr==1.4.2 diff --git a/gradio_annotator.py b/gradio_annotator.py index 2b1a29ebbe..2d33b52919 100644 --- a/gradio_annotator.py +++ b/gradio_annotator.py @@ -1,9 +1,20 @@ import gradio as gr +import torch from annotator.util import resize_image, HWC3 +def get_device(): + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + + model_canny = None +device = get_device() def canny(img, res, l, h): @@ -24,7 +35,7 @@ def hed(img, res): global model_hed if model_hed is None: from annotator.hed import HEDdetector - model_hed = HEDdetector() + model_hed = HEDdetector(device) result = model_hed(img) return [result] @@ -37,7 +48,7 @@ def mlsd(img, res, thr_v, thr_d): global model_mlsd if model_mlsd is None: from annotator.mlsd import MLSDdetector - model_mlsd = MLSDdetector() + model_mlsd = MLSDdetector(device) result = model_mlsd(img, thr_v, thr_d) return [result] @@ -50,7 +61,7 @@ def midas(img, res, a): global model_midas if model_midas is None: from annotator.midas import MidasDetector - model_midas = MidasDetector() + model_midas = MidasDetector(device) results = model_midas(img, a) return results @@ -76,7 +87,7 @@ def uniformer(img, res): global model_uniformer if model_uniformer is None: from annotator.uniformer import UniformerDetector - model_uniformer = UniformerDetector() + model_uniformer = UniformerDetector(device) result = model_uniformer(img) return [result] diff --git a/gradio_canny2image.py b/gradio_canny2image.py index 9866cac5b3..d8f5a31be7 100644 --- a/gradio_canny2image.py +++ b/gradio_canny2image.py @@ -15,12 +15,22 @@ from cldm.ddim_hacked import DDIMSampler +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + apply_canny = CannyDetector() +device = get_device() model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold): @@ -31,7 +41,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = apply_canny(img, low_threshold, high_threshold) detected_map = HWC3(detected_map) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_depth2image.py b/gradio_depth2image.py index ee678999ae..7289be078e 100644 --- a/gradio_depth2image.py +++ b/gradio_depth2image.py @@ -15,12 +15,22 @@ from cldm.ddim_hacked import DDIMSampler -apply_midas = MidasDetector() +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() +apply_midas = MidasDetector(device) model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): @@ -33,7 +43,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_fake_scribble2image.py b/gradio_fake_scribble2image.py index a7cd375f75..117a8bf716 100644 --- a/gradio_fake_scribble2image.py +++ b/gradio_fake_scribble2image.py @@ -15,12 +15,23 @@ from cldm.ddim_hacked import DDIMSampler -apply_hed = HEDdetector() +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() +apply_hed = HEDdetector(device) + model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): @@ -37,7 +48,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map[detected_map > 4] = 255 detected_map[detected_map < 255] = 0 - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_hed2image.py b/gradio_hed2image.py index 1ceff67969..4befea21fb 100644 --- a/gradio_hed2image.py +++ b/gradio_hed2image.py @@ -15,12 +15,22 @@ from cldm.ddim_hacked import DDIMSampler -apply_hed = HEDdetector() +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() +apply_hed = HEDdetector(device) model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_hed.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_hed.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): @@ -33,7 +43,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_hough2image.py b/gradio_hough2image.py index 6095eeb676..bdefaeecff 100644 --- a/gradio_hough2image.py +++ b/gradio_hough2image.py @@ -15,12 +15,22 @@ from cldm.ddim_hacked import DDIMSampler -apply_mlsd = MLSDdetector() +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() +apply_mlsd = MLSDdetector(device) model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_mlsd.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_mlsd.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold): @@ -33,7 +43,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_normal2image.py b/gradio_normal2image.py index 30aea2f8d4..2f92f56a81 100644 --- a/gradio_normal2image.py +++ b/gradio_normal2image.py @@ -15,12 +15,22 @@ from cldm.ddim_hacked import DDIMSampler -apply_midas = MidasDetector() +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() +apply_midas = MidasDetector(device) model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_normal.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_normal.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, bg_threshold): @@ -33,7 +43,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) - control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_pose2image.py b/gradio_pose2image.py index 700973bfab..cbf2b7a8df 100644 --- a/gradio_pose2image.py +++ b/gradio_pose2image.py @@ -15,12 +15,22 @@ from cldm.ddim_hacked import DDIMSampler +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() apply_openpose = OpenposeDetector() model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): @@ -33,7 +43,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_scribble2image.py b/gradio_scribble2image.py index 8abbc25bde..9d3a43edd8 100644 --- a/gradio_scribble2image.py +++ b/gradio_scribble2image.py @@ -14,10 +14,19 @@ from cldm.ddim_hacked import DDIMSampler +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + +device = get_device() model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): @@ -28,7 +37,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = np.zeros_like(img, dtype=np.uint8) detected_map[np.min(img, axis=2) < 127] = 255 - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_scribble2image_interactive.py b/gradio_scribble2image_interactive.py index 7308bcc1bb..4e4908d7fd 100644 --- a/gradio_scribble2image_interactive.py +++ b/gradio_scribble2image_interactive.py @@ -14,10 +14,20 @@ from cldm.ddim_hacked import DDIMSampler +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): @@ -28,7 +38,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = np.zeros_like(img, dtype=np.uint8) detected_map[np.min(img, axis=2) > 127] = 255 - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 27ead0ea91..09c0c7f893 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -8,16 +8,19 @@ class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): + def __init__(self, model, device, schedule="linear", **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if str(self.device) == 'mps': + attr = attr.to(self.device, torch.float32) + else: + attr = attr.to(self.device) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8cf3..f0509b8a35 100644 --- a/ldm/models/diffusion/dpm_solver/sampler.py +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -11,16 +11,20 @@ class DPMSolverSampler(object): - def __init__(self, model, **kwargs): + def __init__(self, model, device, **kwargs): super().__init__() self.model = model to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.device): + if str(self.device) == 'mps': + attr = attr.to(self.device, torch.float32) + else: + attr = attr.to(self.device) setattr(self, name, attr) @torch.no_grad() diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 7002a365d2..20be051b7b 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -9,17 +9,21 @@ from ldm.models.diffusion.sampling_util import norm_thresholding -class PLMSSampler(object): +class PLMSSampler(object, device): def __init__(self, model, schedule="linear", **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.device): + if str(self.device) == 'mps': + attr = attr.to(self.device, torch.float32) + else: + attr = attr.to(self.device) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 4edd5496b9..2cd4e1dc9f 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -8,6 +8,15 @@ from ldm.util import default, count_params +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + class AbstractEncoder(nn.Module): def __init__(self): super().__init__() @@ -42,7 +51,7 @@ def forward(self, batch, key=None, disable_dropout=False): c = self.embedding(c) return c - def get_unconditional_conditioning(self, bs, device="cuda"): + def get_unconditional_conditioning(self, bs, device=get_device()): uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) uc = torch.ones((bs,), device=device) * uc_class uc = {self.key: uc} @@ -57,7 +66,7 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + def __init__(self, version="google/t5-v1_1-large", device=get_device(), max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) @@ -92,7 +101,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): "pooled", "hidden" ] - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + def __init__(self, version="openai/clip-vit-large-patch14", device=get_device(), max_length=77, freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -140,7 +149,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): "last", "penultimate" ] - def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77, freeze=True, layer="last"): super().__init__() assert layer in self.LAYERS @@ -194,7 +203,7 @@ def encode(self, text): class FrozenCLIPT5Encoder(AbstractEncoder): - def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=get_device(), clip_max_length=77, t5_max_length=77): super().__init__() self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)