-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
259 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ | |
.DS_Store | ||
*.safetensors | ||
/dist/ | ||
*.pkl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" | |
|
||
[project] | ||
name = "sd_tools" | ||
version = "1.0.1" | ||
version = "1.0.2" | ||
authors = [{ name = "Feng Zhou", email = "[email protected]" }] | ||
description = "command line tool for stable diffusion" | ||
license = { file = "LICENSE" } | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
import torch | ||
from argparse import ArgumentParser | ||
from PIL import Image | ||
from torchvision import transforms | ||
import torchvision.transforms.functional as F | ||
from .pix2pix_turbo import Pix2Pix_Turbo | ||
from ..utils import obj, to8 | ||
from ..base import PluginBase | ||
|
||
|
||
class PluginSketchToImage(PluginBase): | ||
|
||
def setup_args(self, parser: ArgumentParser): | ||
group = parser.add_argument_group("Sketch to Image") | ||
group.add_argument("--sketch", type=str, help='Path to sketch image') | ||
group.add_argument('--sketch-gamma', type=float, default=0.4, help='The sketch interpolation guidance amount') | ||
|
||
def setup_pipe(self): | ||
if not self.ctx.args.sketch: | ||
return | ||
|
||
if not os.path.exists(self.ctx.args.sketch): | ||
image = Image.new("RGB", (self.ctx.pipe_opts.width, self.ctx.pipe_opts.height), (255,255,255)) | ||
image.save(self.ctx.args.sketch) | ||
input(f"Open {self.ctx.args.sketch}, draw, close image eidtor, then press Enter to continue ") | ||
|
||
model = Pix2Pix_Turbo('sketch_to_image_stochastic', device=self.ctx.device) | ||
|
||
def run(**kwargs): | ||
with torch.no_grad(): | ||
image_t = F.to_tensor(to8(Image.open(self.ctx.args.sketch).convert('RGB'))) < 0.5 | ||
c_t = image_t.unsqueeze(0).to(self.ctx.device).float() | ||
B, C, H, W = c_t.shape | ||
images = model( | ||
c_t, | ||
kwargs['prompt'], | ||
deterministic=False, | ||
r=self.ctx.args.sketch_gamma, | ||
noise_map=torch.randn((1, 4, H // 8, W // 8), device=c_t.device), | ||
) | ||
|
||
return obj(images=[transforms.ToPILImage()(images[0].cpu() * 0.5 + 0.5)]) | ||
|
||
self.ctx.pipe = run |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# adapted from https://github.com/GaParmar/img2img-turbo/blob/b1add3ec17f59ec94f6839b39b9eb17984c5dd65/src/model.py | ||
from diffusers import DDPMScheduler | ||
|
||
|
||
def make_1step_sched(device): | ||
noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler") | ||
noise_scheduler_1step.set_timesteps(1, device=device) | ||
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.to(device) | ||
return noise_scheduler_1step | ||
|
||
|
||
def my_vae_encoder_fwd(self, sample): | ||
sample = self.conv_in(sample) | ||
l_blocks = [] | ||
# down | ||
for down_block in self.down_blocks: | ||
l_blocks.append(sample) | ||
sample = down_block(sample) | ||
# middle | ||
sample = self.mid_block(sample) | ||
sample = self.conv_norm_out(sample) | ||
sample = self.conv_act(sample) | ||
sample = self.conv_out(sample) | ||
self.current_down_blocks = l_blocks | ||
return sample | ||
|
||
|
||
def my_vae_decoder_fwd(self, sample, latent_embeds=None): | ||
sample = self.conv_in(sample) | ||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype | ||
# middle | ||
sample = self.mid_block(sample, latent_embeds) | ||
sample = sample.to(upscale_dtype) | ||
if not self.ignore_skip: | ||
skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4] | ||
# up | ||
for idx, up_block in enumerate(self.up_blocks): | ||
skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma) | ||
# add skip | ||
sample = sample + skip_in | ||
sample = up_block(sample, latent_embeds) | ||
else: | ||
for idx, up_block in enumerate(self.up_blocks): | ||
sample = up_block(sample, latent_embeds) | ||
# post-process | ||
if latent_embeds is None: | ||
sample = self.conv_norm_out(sample) | ||
else: | ||
sample = self.conv_norm_out(sample, latent_embeds) | ||
sample = self.conv_act(sample) | ||
sample = self.conv_out(sample) | ||
return sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# adapted from https://github.com/GaParmar/img2img-turbo/blob/b1add3ec17f59ec94f6839b39b9eb17984c5dd65/src/pix2pix_turbo.py | ||
import os | ||
import requests | ||
import sys | ||
import pdb | ||
import copy | ||
from tqdm import tqdm | ||
import torch | ||
from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel | ||
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler | ||
from diffusers.loaders import PeftAdapterMixin | ||
from diffusers.utils.peft_utils import set_weights_and_activate_adapters | ||
from peft import LoraConfig | ||
from .model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd | ||
|
||
class TwinConv(torch.nn.Module): | ||
def __init__(self, convin_pretrained, convin_curr): | ||
super(TwinConv, self).__init__() | ||
self.conv_in_pretrained = copy.deepcopy(convin_pretrained) | ||
self.conv_in_curr = copy.deepcopy(convin_curr) | ||
self.r = None | ||
def forward(self, x): | ||
x1 = self.conv_in_pretrained(x).detach() | ||
x2 = self.conv_in_curr(x) | ||
return x1 * (1 - self.r) + x2 * self.r | ||
|
||
class AutoencoderKLPeft(AutoencoderKL, PeftAdapterMixin): | ||
pass | ||
|
||
|
||
class Pix2Pix_Turbo(torch.nn.Module): | ||
def __init__(self, name, ckpt_folder="checkpoints", device=None): | ||
super().__init__() | ||
self._device = device | ||
self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo",subfolder="tokenizer") | ||
self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder", variant='fp16').to(self._device) | ||
self.sched = make_1step_sched(device) | ||
|
||
vae = AutoencoderKLPeft.from_pretrained("stabilityai/sd-turbo", subfolder="vae", variant='fp16') | ||
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet", variant='fp16') | ||
|
||
if name=="edge_to_image": | ||
url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl" | ||
os.makedirs(ckpt_folder, exist_ok=True) | ||
outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl") | ||
if not os.path.exists(outf): | ||
print(f"Downloading checkpoint to {outf}") | ||
response = requests.get(url, stream=True) | ||
total_size_in_bytes= int(response.headers.get('content-length', 0)) | ||
block_size = 1024 # 1 Kibibyte | ||
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | ||
with open(outf, 'wb') as file: | ||
for data in response.iter_content(block_size): | ||
progress_bar.update(len(data)) | ||
file.write(data) | ||
progress_bar.close() | ||
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | ||
print("ERROR, something went wrong") | ||
print(f"Downloaded successfully to {outf}") | ||
p_ckpt = outf | ||
sd = torch.load(p_ckpt, map_location=self._device) | ||
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"]) | ||
|
||
if name=="sketch_to_image_stochastic": | ||
# download from url | ||
url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl" | ||
os.makedirs(ckpt_folder, exist_ok=True) | ||
outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl") | ||
if not os.path.exists(outf): | ||
print(f"Downloading checkpoint to {outf}") | ||
response = requests.get(url, stream=True) | ||
total_size_in_bytes= int(response.headers.get('content-length', 0)) | ||
block_size = 1024 # 1 Kibibyte | ||
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | ||
with open(outf, 'wb') as file: | ||
for data in response.iter_content(block_size): | ||
progress_bar.update(len(data)) | ||
file.write(data) | ||
progress_bar.close() | ||
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | ||
print("ERROR, something went wrong") | ||
print(f"Downloaded successfully to {outf}") | ||
p_ckpt = outf | ||
sd = torch.load(p_ckpt, map_location=self._device) | ||
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"]) | ||
convin_pretrained = copy.deepcopy(unet.conv_in) | ||
unet.conv_in = TwinConv(convin_pretrained, unet.conv_in) | ||
|
||
vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__) | ||
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__) | ||
# add the skip connection convs | ||
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).to(self._device) | ||
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).to(self._device) | ||
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).to(self._device) | ||
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).to(self._device) | ||
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"]) | ||
vae.decoder.ignore_skip = False | ||
vae.add_adapter(vae_lora_config, adapter_name="vae_skip") | ||
unet.add_adapter(unet_lora_config) | ||
_sd_unet = unet.state_dict() | ||
for k in sd["state_dict_unet"]: _sd_unet[k] = sd["state_dict_unet"][k] | ||
unet.load_state_dict(_sd_unet) | ||
# unet.enable_xformers_memory_efficient_attention() | ||
_sd_vae = vae.state_dict() | ||
for k in sd["state_dict_vae"]: _sd_vae[k] = sd["state_dict_vae"][k] | ||
vae.load_state_dict(_sd_vae) | ||
unet.to(self._device) | ||
vae.to(self._device) | ||
unet.eval() | ||
vae.eval() | ||
self.unet, self.vae = unet, vae | ||
self.vae.decoder.gamma = 1 | ||
self.timesteps = torch.tensor([999], device=self._device).long() | ||
|
||
|
||
def forward(self, c_t, prompt, deterministic=True, r=1.0, noise_map=None): | ||
# encode the text prompt | ||
caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length, | ||
padding="max_length", truncation=True, return_tensors="pt").input_ids.to(self._device) | ||
caption_enc = self.text_encoder(caption_tokens)[0] | ||
if deterministic: | ||
encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor | ||
model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc,).sample | ||
x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample | ||
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks | ||
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1) | ||
else: | ||
# scale the lora weights based on the r value | ||
self.unet.set_adapters(["default"], weights=[r]) | ||
set_weights_and_activate_adapters(self.vae, ["vae_skip"], [r]) | ||
encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor | ||
# combine the input and noise | ||
unet_input = encoded_control*r + noise_map*(1-r) | ||
self.unet.conv_in.r = r | ||
unet_output = self.unet(unet_input, self.timesteps, encoder_hidden_states=caption_enc,).sample | ||
self.unet.conv_in.r = None | ||
x_denoised = self.sched.step(unet_output, self.timesteps, unet_input, return_dict=True).prev_sample | ||
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks | ||
self.vae.decoder.gamma = r | ||
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1) | ||
return output_image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters