Skip to content

Commit

Permalink
feat: sketch to image
Browse files Browse the repository at this point in the history
  • Loading branch information
zweifisch committed Mar 20, 2024
1 parent 78b7b9c commit 82bea60
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
.DS_Store
*.safetensors
/dist/
*.pkl
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "sd_tools"
version = "1.0.1"
version = "1.0.2"
authors = [{ name = "Feng Zhou", email = "[email protected]" }]
description = "command line tool for stable diffusion"
license = { file = "LICENSE" }
Expand Down
45 changes: 45 additions & 0 deletions src/sd_tools/plugins/pix2pix/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import torch
from argparse import ArgumentParser
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as F
from .pix2pix_turbo import Pix2Pix_Turbo
from ..utils import obj, to8
from ..base import PluginBase


class PluginSketchToImage(PluginBase):

def setup_args(self, parser: ArgumentParser):
group = parser.add_argument_group("Sketch to Image")
group.add_argument("--sketch", type=str, help='Path to sketch image')
group.add_argument('--sketch-gamma', type=float, default=0.4, help='The sketch interpolation guidance amount')

def setup_pipe(self):
if not self.ctx.args.sketch:
return

if not os.path.exists(self.ctx.args.sketch):
image = Image.new("RGB", (self.ctx.pipe_opts.width, self.ctx.pipe_opts.height), (255,255,255))
image.save(self.ctx.args.sketch)
input(f"Open {self.ctx.args.sketch}, draw, close image eidtor, then press Enter to continue ")

model = Pix2Pix_Turbo('sketch_to_image_stochastic', device=self.ctx.device)

def run(**kwargs):
with torch.no_grad():
image_t = F.to_tensor(to8(Image.open(self.ctx.args.sketch).convert('RGB'))) < 0.5
c_t = image_t.unsqueeze(0).to(self.ctx.device).float()
B, C, H, W = c_t.shape
images = model(
c_t,
kwargs['prompt'],
deterministic=False,
r=self.ctx.args.sketch_gamma,
noise_map=torch.randn((1, 4, H // 8, W // 8), device=c_t.device),
)

return obj(images=[transforms.ToPILImage()(images[0].cpu() * 0.5 + 0.5)])

self.ctx.pipe = run
52 changes: 52 additions & 0 deletions src/sd_tools/plugins/pix2pix/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# adapted from https://github.com/GaParmar/img2img-turbo/blob/b1add3ec17f59ec94f6839b39b9eb17984c5dd65/src/model.py
from diffusers import DDPMScheduler


def make_1step_sched(device):
noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler")
noise_scheduler_1step.set_timesteps(1, device=device)
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.to(device)
return noise_scheduler_1step


def my_vae_encoder_fwd(self, sample):
sample = self.conv_in(sample)
l_blocks = []
# down
for down_block in self.down_blocks:
l_blocks.append(sample)
sample = down_block(sample)
# middle
sample = self.mid_block(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
self.current_down_blocks = l_blocks
return sample


def my_vae_decoder_fwd(self, sample, latent_embeds=None):
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
# middle
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
if not self.ignore_skip:
skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
# up
for idx, up_block in enumerate(self.up_blocks):
skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma)
# add skip
sample = sample + skip_in
sample = up_block(sample, latent_embeds)
else:
for idx, up_block in enumerate(self.up_blocks):
sample = up_block(sample, latent_embeds)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
141 changes: 141 additions & 0 deletions src/sd_tools/plugins/pix2pix/pix2pix_turbo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# adapted from https://github.com/GaParmar/img2img-turbo/blob/b1add3ec17f59ec94f6839b39b9eb17984c5dd65/src/pix2pix_turbo.py
import os
import requests
import sys
import pdb
import copy
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from diffusers.loaders import PeftAdapterMixin
from diffusers.utils.peft_utils import set_weights_and_activate_adapters
from peft import LoraConfig
from .model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd

class TwinConv(torch.nn.Module):
def __init__(self, convin_pretrained, convin_curr):
super(TwinConv, self).__init__()
self.conv_in_pretrained = copy.deepcopy(convin_pretrained)
self.conv_in_curr = copy.deepcopy(convin_curr)
self.r = None
def forward(self, x):
x1 = self.conv_in_pretrained(x).detach()
x2 = self.conv_in_curr(x)
return x1 * (1 - self.r) + x2 * self.r

class AutoencoderKLPeft(AutoencoderKL, PeftAdapterMixin):
pass


class Pix2Pix_Turbo(torch.nn.Module):
def __init__(self, name, ckpt_folder="checkpoints", device=None):
super().__init__()
self._device = device
self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo",subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder", variant='fp16').to(self._device)
self.sched = make_1step_sched(device)

vae = AutoencoderKLPeft.from_pretrained("stabilityai/sd-turbo", subfolder="vae", variant='fp16')
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet", variant='fp16')

if name=="edge_to_image":
url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl"
os.makedirs(ckpt_folder, exist_ok=True)
outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl")
if not os.path.exists(outf):
print(f"Downloading checkpoint to {outf}")
response = requests.get(url, stream=True)
total_size_in_bytes= int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(outf, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
print(f"Downloaded successfully to {outf}")
p_ckpt = outf
sd = torch.load(p_ckpt, map_location=self._device)
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])

if name=="sketch_to_image_stochastic":
# download from url
url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl"
os.makedirs(ckpt_folder, exist_ok=True)
outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl")
if not os.path.exists(outf):
print(f"Downloading checkpoint to {outf}")
response = requests.get(url, stream=True)
total_size_in_bytes= int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(outf, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
print(f"Downloaded successfully to {outf}")
p_ckpt = outf
sd = torch.load(p_ckpt, map_location=self._device)
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
convin_pretrained = copy.deepcopy(unet.conv_in)
unet.conv_in = TwinConv(convin_pretrained, unet.conv_in)

vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
# add the skip connection convs
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).to(self._device)
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).to(self._device)
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).to(self._device)
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).to(self._device)
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
vae.decoder.ignore_skip = False
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
unet.add_adapter(unet_lora_config)
_sd_unet = unet.state_dict()
for k in sd["state_dict_unet"]: _sd_unet[k] = sd["state_dict_unet"][k]
unet.load_state_dict(_sd_unet)
# unet.enable_xformers_memory_efficient_attention()
_sd_vae = vae.state_dict()
for k in sd["state_dict_vae"]: _sd_vae[k] = sd["state_dict_vae"][k]
vae.load_state_dict(_sd_vae)
unet.to(self._device)
vae.to(self._device)
unet.eval()
vae.eval()
self.unet, self.vae = unet, vae
self.vae.decoder.gamma = 1
self.timesteps = torch.tensor([999], device=self._device).long()


def forward(self, c_t, prompt, deterministic=True, r=1.0, noise_map=None):
# encode the text prompt
caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length,
padding="max_length", truncation=True, return_tensors="pt").input_ids.to(self._device)
caption_enc = self.text_encoder(caption_tokens)[0]
if deterministic:
encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc,).sample
x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1)
else:
# scale the lora weights based on the r value
self.unet.set_adapters(["default"], weights=[r])
set_weights_and_activate_adapters(self.vae, ["vae_skip"], [r])
encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
# combine the input and noise
unet_input = encoded_control*r + noise_map*(1-r)
self.unet.conv_in.r = r
unet_output = self.unet(unet_input, self.timesteps, encoder_hidden_states=caption_enc,).sample
self.unet.conv_in.r = None
x_denoised = self.sched.step(unet_output, self.timesteps, unet_input, return_dict=True).prev_sample
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
self.vae.decoder.gamma = r
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1)
return output_image
17 changes: 17 additions & 0 deletions src/sd_tools/plugins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,20 @@ def load_images(locations: List[str]) -> List[Image.Image]:

def remove_none(input_dict):
return {k: v for k, v in input_dict.items() if v is not None}

def to8(image: Image.Image) -> Image.Image:
return image.resize((image.width - image.width % 8, image.height - image.height % 8), Image.LANCZOS)

def canny_from_pil(image, low_threshold=100, high_threshold=200):
import cv2
import numpy as np
image = cv2.Canny(np.array(image), low_threshold, high_threshold)[:, :, None]
return Image.fromarray(np.concatenate([image, image, image], axis=2))

class Object:
def __init__(self, kvs):
for key, value in kvs.items():
setattr(self, key, value)

def obj(**kvs):
return Object(kvs)
2 changes: 2 additions & 0 deletions src/sd_tools/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .plugins.model import PluginModel
from .plugins.pipe import PluginPipe
from .plugins.inpainting import PluginInpainting
from .plugins.pix2pix import PluginSketchToImage
from .plugins.steps import PluginSteps
from .plugins.output import PluginOutput
from .plugins.lora import PluginLora
Expand Down Expand Up @@ -65,6 +66,7 @@ def main():
PluginIPAdaptorFaceIDPortrait(ctx),
PluginIPCompositionAdapter(ctx),
# PluginImage(ctx),
PluginSketchToImage(ctx),
PluginScheduler(ctx),
PluginLora(ctx),
PluginHTTP(ctx),
Expand Down

0 comments on commit 82bea60

Please sign in to comment.