Skip to content

Commit

Permalink
feat: SDXS & YOSO
Browse files Browse the repository at this point in the history
  • Loading branch information
zweifisch committed Mar 27, 2024
1 parent f386386 commit 8b8ebc1
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 19 deletions.
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ generating images from the command line and web UI

![web](https://github.com/zweifisch/sd-tools/assets/447862/3855dbd1-65ba-4721-ad44-0af6d79eb0c0)


https://github.com/zweifisch/sd-tools/assets/447862/e39bee8d-f7b4-41d9-b1ec-640f341835b5

```shell
Expand Down Expand Up @@ -390,3 +389,19 @@ sd 'Arnold Schwarzenegger' \
sd 'ethereal fantasy concept art of cat, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy' \
--sketch sketch.png -i -o output/preview.webp
```

### SDXS

[SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions](https://arxiv.org/abs/2403.16627)

```shell
sd --model zweifisch/sdxs-512-0.9-fp16 'chihuahua' --steps 1 --count 10
```

### YOSO

[You Only Sample Once](https://www.arxiv.org/abs/2403.12931)

```shell
sd --model 'SG161222/Realistic_Vision_V6.0_B1_noVAE' 'chihuahua' --steps 2 --no-fp16 --no-safetensor --yoso 1 --count 10 --cfg 1
```
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "sd_tools"
version = "1.2.0"
version = "1.3.0"
authors = [{ name = "Feng Zhou", email = "[email protected]" }]
description = "command line tool for stable diffusion"
license = { file = "LICENSE" }
Expand Down Expand Up @@ -38,3 +38,4 @@ sdxl = "sd_tools.sdxl:main"
sd = "sd_tools.sd:main"
pt2st = "sd_tools.misc:pt2st"
st-inspect = "sd_tools.misc:st_inspect"
sd-to-fp16 = "sd_tools.misc:to_fp16"
12 changes: 12 additions & 0 deletions src/sd_tools/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,15 @@ def st_inspect():
for k, v in files.items():
if not v.get('shape'): continue
print(k, reduce(op.mul, v['shape'], 1), 'x'.join(map(str, v['shape'])))

def to_fp16():

parser = ArgumentParser('save model as fp16')
parser.add_argument('model', type=str)
parser.add_argument('output', type=str)
args = parser.parse_args()

from diffusers import AutoPipelineForText2Image
pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.float16)

pipe.save_pretrained(args.output, variant='fp16')
17 changes: 17 additions & 0 deletions src/sd_tools/plugins/dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from argparse import ArgumentParser
from .base import PluginBase

class PluginDPO(PluginBase):

def setup_args(self, parser: ArgumentParser):
parser.add_argument('--dpo', type=float)

def setup_pipe(self):
if not self.ctx.args.dpo:
return

self.ctx.pipe.load_lora_weights(
'radames/sdxl-turbo-DPO-LoRA',
adapter_name="dpo",
)
self.ctx.loras.append(("dpo", self.ctx.args.dpo))
29 changes: 26 additions & 3 deletions src/sd_tools/plugins/http/assets/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@
<label for="size" class="text-sm font-medium text-gray-700"
>size:</label
>
<div class="block flex w-full flex-wrap items-center gap-4">
<div class="flex w-full flex-wrap items-center gap-4">
<input
type="number"
step="8"
name="width"
min="8"
max="3200"
value="512"
class="flex grow rounded-md border border-gray-300 px-3 px-3 py-2 py-2 shadow-sm focus:border-black focus:outline-none sm:text-sm"
class="flex grow rounded-md border border-gray-300 px-3 py-2 shadow-sm focus:border-black focus:outline-none sm:text-sm"
/>
X
<input
Expand All @@ -80,7 +80,7 @@
min="8"
max="3200"
value="512"
class="flex grow rounded-md border border-gray-300 px-3 px-3 py-2 py-2 shadow-sm focus:border-black focus:outline-none sm:text-sm"
class="flex grow rounded-md border border-gray-300 px-3 py-2 shadow-sm focus:border-black focus:outline-none sm:text-sm"
/>
<div class="block">
<span
Expand Down Expand Up @@ -114,6 +114,29 @@
</div>
</div>
</div>
<div>
<label
for="scheduler"
class="block text-sm font-medium text-gray-700"
>scheduler:</label
>
<select name="scheduler">
<option value=""></option>
<option value="DPM++ 2M">DPM++ 2M</option>
<option value="DPM++ 2M Karras">DPM++ 2M Karras</option>
<option value="DPM++ 2M SDE">DPM++ 2M SDE</option>
<option value="DPM++ 2M SDE Karras">
DPM++ 2M SDE Karras
</option>
<option value="DPM++ SDE">DPM++ SDE</option>
<option value="DPM++ SDE Karras">
DPM++ SDE Karras
</option>
<option value="Euler">Euler</option>
<option value="Euler a">Euler a</option>
<option value="UniPC">UniPC</option>
</select>
</div>
<div>
<label
for="prompt"
Expand Down
14 changes: 0 additions & 14 deletions src/sd_tools/plugins/http/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,3 @@ def do_GET(req): self.dispatch(req)
host, port = interface.split(':') if ':' in interface else ('127.0.0.1', interface)
server = HTTPServer((host, int(port)), Handler)
server.serve_forever()

class HttpHandler(BaseHTTPRequestHandler):

handler: List[Tuple[str, str, Callable]] = []

def GET(self, path: str):
def wrapper(handler):
self.handler.append(('GET', path, handler))
return wrapper

def POST(self, path: str):
def wrapper(handler):
self.handler.append(('POST', path, handler))
return wrapper
7 changes: 7 additions & 0 deletions src/sd_tools/plugins/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,10 @@ def setup_pipe(self):

(Scheduler, config) = scheduler_alias[scheduler]
self.ctx.pipe.scheduler = Scheduler.from_config(self.ctx.pipe.scheduler.config, **config)

def pre_pipe(self):
if not 'scheduler' in self.ctx.pipe_opts_otg:
return

(Scheduler, config) = scheduler_alias[self.ctx.pipe_opts_otg.pop('scheduler', None)]
self.ctx.pipe.scheduler = Scheduler.from_config(self.ctx.pipe.scheduler.config, **config)
18 changes: 18 additions & 0 deletions src/sd_tools/plugins/wrong.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from argparse import ArgumentParser
from .base import PluginBase

class PluginWrong(PluginBase):

def setup_args(self, parser: ArgumentParser):
parser.add_argument('--wrong', type=float)

def setup_pipe(self):
if not self.ctx.args.wrong:
return

self.ctx.pipe.load_lora_weights(
'minimaxir/sdxl-wrong-lora',
adapter_name="wrong",
)
self.ctx.loras.append(("wrong", self.ctx.args.wrong))
self.ctx.pipe_opts.negative_prompt = 'wrong,' + self.ctx.pipe_opts.negative_prompt
19 changes: 19 additions & 0 deletions src/sd_tools/plugins/yoso.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from argparse import ArgumentParser
from diffusers import LCMScheduler
from .base import PluginBase

class PluginYOSO(PluginBase):

def setup_args(self, parser: ArgumentParser):
parser.add_argument('--yoso', type=float)

def setup_pipe(self):
if not self.ctx.args.yoso:
return

self.ctx.pipe.scheduler = LCMScheduler.from_config(self.ctx.pipe.scheduler.config)
self.ctx.pipe.load_lora_weights(
'Luo-Yihong/yoso_sd1.5_lora',
adapter_name="yoso",
)
self.ctx.loras.append(("yoso", self.ctx.args.yoso))
2 changes: 2 additions & 0 deletions src/sd_tools/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .plugins.ipa_faceid_portrait import PluginIPAdaptorFaceIDPortrait
from .plugins.device import PluginDevice
from .plugins.lcm import PluginLCM
from .plugins.yoso import PluginYOSO
from .plugins.offline import PluginOffline
from .plugins.safetensor import PluginSafetensor
from .plugins.http import PluginHTTP
Expand Down Expand Up @@ -55,6 +56,7 @@ def main():
PluginSteps(ctx),
PluginSeed(ctx),
PluginLCM(ctx),
PluginYOSO(ctx),
PluginOutput(ctx),
# PluginCanny(ctx),
# PluginPose(ctx),
Expand Down
4 changes: 4 additions & 0 deletions src/sd_tools/sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from .plugins.cfg import PluginCFG
from .plugins.size import PluginSize
from .plugins.prompt import PluginPrompt
from .plugins.wrong import PluginWrong
from .plugins.modelXL import PluginModelXL
from .plugins.pipe import PluginPipe
from .plugins.inpainting import PluginInpainting
from .plugins.ip_composition_xl import PluginIPCompositionAdapterXL
from .plugins.tcd import PluginTCD
from .plugins.lcm import PluginLCM
from .plugins.dpo import PluginDPO
from .plugins.steps import PluginSteps
from .plugins.output import PluginOutput
from .plugins.lora import PluginLora
Expand Down Expand Up @@ -65,6 +67,8 @@ def main():
PluginPrompt(ctx),
PluginSteps(ctx),
PluginSeed(ctx),
PluginWrong(ctx),
PluginDPO(ctx),
PluginTCD(ctx),
PluginLCM(ctx),
PluginLightning(ctx),
Expand Down
9 changes: 9 additions & 0 deletions tailwind.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/** @type {import('tailwindcss').Config} */
module.exports = {
content: [],
theme: {
extend: {},
},
plugins: [],
}

0 comments on commit 8b8ebc1

Please sign in to comment.