Skip to content

Commit 81c03d9

Browse files
committed
[major] add Nunchaku convert scripts; update SVDQuant evaluation results; update SpinQuant results;
1 parent aa3d84d commit 81c03d9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+100516
-98164
lines changed

README.md

+31-22
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
</p>
1818

1919
## News
20-
- [2025/01] 🎉 [**SVDQuant**](https://arxiv.org/abs/2411.05007) has been accepted to ICLR 2025!
20+
- [2025/02] 🎉 [**QServe**](https://arxiv.org/abs/2405.04532) has been accepted to MLSys 2025!
21+
- [2025/01] 🎉 [**SVDQuant**](https://arxiv.org/abs/2411.05007) has been accepted to ICLR 2025 (Spotlight)!
22+
- [2024/12] 🎉 [**QServe**](https://github.com/mit-han-lab/qserve) has been integratedd into NVIDIA [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama)!
2123
- [2024/11] 🔥 Our latest **W4A4** diffusion model quantization work [**SVDQuant**](https://arxiv.org/abs/2411.05007) algorithm and [**Nunchaku**](https://github.com/mit-han-lab/nunchaku) system is publicly released! Check our [paper](http://arxiv.org/abs/2411.05007)!
2224
- [2024/05] 🔥 Our latest **W4A8KV4** LLM quantization work **QoQ** algorithm and **QServe** system is publicly released! **QoQ** is short for *quattuor-octō-quattuor* which is 4-8-4 in latin. Check our [paper](https://arxiv.org/abs/2405.04532)!
2325

@@ -72,24 +74,30 @@ Diffusion models have been proven highly effective at generating high-quality im
7274

7375
Below is the quality and similarity evaluated with 5000 samples from MJHQ-30K dataset. IR means ImageReward. Our 4-bit results outperform other 4-bit baselines, effectively preserving the visual quality of 16-bit models.
7476

75-
| Model | Precision | Method | FID ($\downarrow$) | IR ($\uparrow$) | LPIPS ($\downarrow$) | PSNR( $\uparrow$) |
76-
|----------------------------|-----------|---------|--------------------|-----------------|----------------------|-------------------|
77-
| FLUX.1-dev (50 Steps) | BF16 | -- | 20.3 | 0.953 | -- | -- |
78-
| | INT W8A8 | Ours | 20.4 | 0.948 | 0.089 | 27.0 |
79-
| | W4A16 | NF4 | 20.6 | 0.910 | 0.272 | 19.5 |
80-
| | INT W4A4 | Ours | **19.86** | 0.932 | 0.254 | 20.1 |
81-
| | FP W4A4 | Ours | 21.0 | **0.933** | **0.247** | **20.2** |
82-
| FLUX.1-schnell (4 Steps) | BF16 | -- | 19.2 | 0.938 | -- | -- |
83-
| | INT W8A8 | Ours | 19.2 | 0.966 | 0.120 | 22.9 |
84-
| | W4A16 | NF4 | 18.9 | 0.943 | 0.257 | 18.2 |
85-
| | INT W4A4 | Ours | **18.4** | **0.969** | 0.292 | 17.5 |
86-
| | FP W4A4 | Ours | 19.9 | 0.956 | 0.279 | 17.5 |
87-
| | FP16 | -- | 16.6 | 0.944 | -- | -- |
88-
| PixArt-Sigma (20 Steps) | INT W8A8 | ViDiT-Q | 15.7 | 0.944 | 0.137 | 22.5 |
89-
| | INT W8A8 | Ours | 16.3 | **0.955** | **0.109** | **23.7** |
90-
| | INT W4A8 | ViDiT-Q | 37.3 | 0.573 | 0.611 | 12.0 |
91-
| | INT W4A4 | Ours | 20.1 | 0.898 | 0.394 | 16.2 |
92-
| | FP W4A4 | Ours | **18.3** | **0.946** | **0.326** | **17.4** |
77+
| Model | Precision | Method | FID ($\downarrow$) | IR ($\uparrow$) | LPIPS ($\downarrow$) | PSNR( $\uparrow$) |
78+
|----------------------------|-----------|-----------|--------------------|-----------------|----------------------|-------------------|
79+
| FLUX.1-dev (50 Steps) | BF16 | -- | 20.3 | 0.953 | -- | -- |
80+
| | W4A16 | NF4 | 20.6 | 0.910 | 0.272 | 19.5 |
81+
| | INT W4A4 | | 20.2 | 0.908 | 0.322 | 18.5 |
82+
| | INT W4A4 | Ours | 19.9 | 0.935 | 0.223 | 21.0 |
83+
| | NVFP4 | | 20.3 | 0.961 | 0.345 | 16.3 |
84+
| | NVFP4 | Ours | 20.3 | 0.942 | 0.205 | 21.5 |
85+
| FLUX.1-schnell (4 Steps) | BF16 | -- | 19.2 | 0.938 | -- | -- |
86+
| | W4A16 | NF4 | 18.9 | 0.943 | 0.257 | 18.2 |
87+
| | INT W4A4 | | 18.1 | 0.962 | 0.345 | 16.3 |
88+
| | INT W4A4 | Ours | 18.3 | 0.951 | 0.257 | 18.3 |
89+
| | NVFP4 | | 19.0 | 0.952 | 0.276 | 17.6 |
90+
| | NVFP4 | Ours | 18.9 | 0.964 | 0.229 | 19.0 |
91+
| SANA-1.6b (20 Steps) | BF16 | -- | 20.6 | 0.952 | -- | -- |
92+
| | INT W4A4 | | 20.5 | 0.894 | 0.339 | 15.3 |
93+
| | INT W4A4 | Ours | 19.3 | 0.935 | 0.220 | 17.8 |
94+
| | NVFP4 | | 19.7 | 0.929 | 0.236 | 17.4 |
95+
| | NVFP4 | Ours | 20.2 | 0.941 | 0.176 | 19.0 |
96+
| PixArt-Sigma (20 Steps) | FP16 | -- | 16.6 | 0.944 | -- | -- |
97+
| | INT W4A8 | ViDiT-Q | 37.3 | 0.573 | 0.611 | 12.0 |
98+
| | INT W4A4 | Ours | 19.2 | 0.878 | 0.323 | 17.6 |
99+
| | NVFP4 | | 31.8 | 0.660 | 0.517 | 14.8 |
100+
| | NVFP4 | Ours | 16.6 | 0.940 | 0.271 | 18.5 |
93101

94102
### QServe: W4A8KV4 Quantization for Efficient LLM Serving
95103

@@ -111,10 +119,11 @@ Below is the WikiText2 perplexity evaluated with 2048 sequence length. The lower
111119
| SmoothQuant | W8A8 | 3.23 | 6.38 | 3.14 | 6.28 | 5.54 | 4.95 | 3.36 | 5.73 | 5.13 | 4.23 | 5.29 | 4.69 |
112120
| GPTQ-R | W4A16 g128 | 3.46 | 6.64 | 3.42 | 6.56 | 5.63 | 4.99 | 3.43 | 5.83 | 5.20 | 4.22 | 5.39 | 4.68 |
113121
| AWQ | W4A16 g128 | 3.22 | 6.60 | 3.20 | 6.54 | 5.60 | 4.97 | 3.41 | 5.78 | 5.19 | 4.21 | 5.37 | 4.67 |
114-
| QuaRot | W4A4 | 5.97 | 8.32 | 6.75 | 8.33 | 6.19 | 5.45 | 3.83 | 6.34 | 5.58 | 4.64 | 5.77 | NaN |
122+
| QuaRot | W4A4 | 5.97 | 8.32 | 6.75 | 8.33 | 6.19 | 5.45 | 3.83 | 6.34 | 5.58 | 4.64 | 5.77 | - |
123+
| SpinQuant | W4A4 | 4.80 | 7.42 | 6.27 | 7.37 | 5.96 | 5.24 | 3.71 | 6.14 | 5.39 | 4.56 | - | - |
115124
| Atom | W4A4 g128 | - | - | 4.33 | 7.78 | 6.12 | 5.31 | 3.73 | 6.25 | 5.52 | 4.61 | 5.76 | 4.97 |
116-
| QoQ | W4A8KV4 | 3.69 | 6.91 | 3.65 | 6.84 | 5.75 | 5.11 | 3.51 | 5.92 | 5.27 | 4.32 | 5.45 | 4.73 |
117-
| QoQ | W4A8KV4 g128 | 3.54 | 6.80 | 3.51 | 6.73 | 5.68 | 5.05 | 3.46 | 5.88 | 5.23 | 4.27 | 5.41 | 4.73 |
125+
| QoQ | W4A8KV4 | 3.68 | 6.87 | 3.65 | 6.81 | 5.75 | 5.11 | 3.50 | 5.92 | 5.27 | 4.31 | 5.44 | 4.73 |
126+
| QoQ | W4A8KV4 g128 | 3.51 | 6.77 | 3.50 | 6.70 | 5.67 | 5.06 | 3.46 | 5.88 | 5.23 | 4.27 | 5.41 | 4.73 |
118127

119128
\* SmoothQuant is evaluated with per-tensor static KV cache quantization.
120129

deepcompressor/app/diffusion/dataset/calib.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,19 @@
1111
import torch.utils.data
1212
from diffusers.models.attention import JointTransformerBlock
1313
from diffusers.models.attention_processor import Attention
14-
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
14+
from diffusers.models.transformers.transformer_flux import (
15+
FluxSingleTransformerBlock,
16+
FluxTransformerBlock,
17+
)
1518
from omniconfig import configclass
1619

17-
from deepcompressor.data.cache import IOTensorsCache, ModuleForwardInput, TensorCache, TensorsCache
18-
from deepcompressor.data.utils.reshape import AttentionInputReshapeFn, LinearReshapeFn, ReshapeFn
20+
from deepcompressor.data.cache import (
21+
IOTensorsCache,
22+
ModuleForwardInput,
23+
TensorCache,
24+
TensorsCache,
25+
)
26+
from deepcompressor.data.utils.reshape import AttentionInputReshapeFn, LinearReshapeFn
1927
from deepcompressor.dataset.action import CacheAction, ConcatCacheAction
2028
from deepcompressor.dataset.cache import BaseCalibCacheLoader
2129
from deepcompressor.dataset.config import BaseDataLoaderConfig
@@ -113,9 +121,6 @@ def info(
113121
encoder_hidden_states_cache.reshape = AttentionInputReshapeFn(encoder_channels_dim)
114122
else:
115123
assert encoder_hidden_states_cache.channels_dim == encoder_channels_dim
116-
if tensors["image_rotary_emb"] is None:
117-
tensors.pop("image_rotary_emb")
118-
cache.tensors.pop("image_rotary_emb")
119124
hidden_states, hidden_states_cache = tensors["hidden_states"], cache.tensors["hidden_states"]
120125
channels_dim = 1 if hidden_states.dim() == 4 else -1
121126
if hidden_states_cache.channels_dim is None:
@@ -163,7 +168,6 @@ def _init_cache(self, name: str, module: nn.Module) -> IOTensorsCache:
163168
OrderedDict(
164169
hidden_states=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
165170
temb=TensorCache(channels_dim=1, reshape=LinearReshapeFn()),
166-
image_rotary_emb=TensorCache(channels_dim=1, reshape=ReshapeFn()),
167171
)
168172
),
169173
outputs=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
@@ -174,7 +178,6 @@ def _init_cache(self, name: str, module: nn.Module) -> IOTensorsCache:
174178
OrderedDict(
175179
hidden_states=TensorCache(channels_dim=None, reshape=None),
176180
encoder_hidden_states=TensorCache(channels_dim=None, reshape=None),
177-
image_rotary_emb=TensorCache(channels_dim=1, reshape=ReshapeFn()),
178181
),
179182
),
180183
outputs=TensorCache(channels_dim=None, reshape=None),
@@ -211,7 +214,6 @@ def _convert_layer_inputs(
211214
kwargs = {k: v for k, v in kwargs.items()} # noqa: C416
212215
if "res_hidden_states_tuple" in kwargs:
213216
kwargs["res_hidden_states_tuple"] = None
214-
# tree_map(lambda x: x.detach().cpu(), kwargs["res_hidden_states_tuple"])
215217
if "hidden_states" in kwargs:
216218
hidden_states = kwargs.pop("hidden_states")
217219
assert len(args) == 0, f"Invalid args: {args}"
@@ -333,7 +335,6 @@ def iter_layer_activations( # noqa: C901
333335
layer_kwargs.pop("hidden_states", None)
334336
layer_kwargs.pop("encoder_hidden_states", None)
335337
layer_kwargs.pop("temb", None)
336-
layer_kwargs.pop("image_rotary_emb", None)
337338
layer_struct = layer_structs[layer_idx]
338339
if isinstance(layer_struct, DiffusionBlockStruct):
339340
assert layer_struct.name == layer_name

deepcompressor/app/diffusion/dataset/collect/calib.py

+52-28
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22
"""Collect calibration dataset."""
33

44
import os
5-
import random
65
from dataclasses import dataclass
76

7+
import datasets
88
import torch
9-
import torch.nn as nn
10-
import yaml
119
from omniconfig import configclass
12-
from tqdm import trange
10+
from torch import nn
11+
from tqdm import tqdm
1312

1413
from deepcompressor.app.diffusion.config import DiffusionPtqRunConfig
1514
from deepcompressor.utils.common import hash_str_to_int, tree_map
1615

16+
from ...utils import get_control
17+
from ..data import get_dataset
1718
from .utils import CollectHook
1819

1920

@@ -22,7 +23,7 @@ def process(x: torch.Tensor) -> torch.Tensor:
2223
return torch.from_numpy(x.float().numpy()).to(dtype)
2324

2425

25-
def collect(config: DiffusionPtqRunConfig, filenames: list[str], dataset: dict[str, str]):
26+
def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
2627
samples_dirpath = os.path.join(config.output.root, "samples")
2728
caches_dirpath = os.path.join(config.output.root, "caches")
2829
os.makedirs(samples_dirpath, exist_ok=True)
@@ -35,25 +36,48 @@ def collect(config: DiffusionPtqRunConfig, filenames: list[str], dataset: dict[s
3536
model.register_forward_hook(CollectHook(caches=caches), with_kwargs=True)
3637

3738
batch_size = config.eval.batch_size
38-
print(f"In total {len(filenames)} samples")
39+
print(f"In total {len(dataset)} samples")
3940
print(f"Evaluating with batch size {batch_size}")
4041
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
41-
num_batches = (len(filenames) + batch_size - 1) // batch_size
42-
for i in trange(num_batches, desc="Images", leave=False, dynamic_ncols=True, position=0):
43-
batch = filenames[i * batch_size : (i + 1) * batch_size]
44-
prompts = [dataset[name] for name in batch]
45-
seeds = [hash_str_to_int(name) for name in batch]
42+
for batch in tqdm(
43+
dataset.iter(batch_size=batch_size, drop_last_batch=False),
44+
desc="Data",
45+
leave=False,
46+
dynamic_ncols=True,
47+
total=(len(dataset) + batch_size - 1) // batch_size,
48+
):
49+
filenames = batch["filename"]
50+
prompts = batch["prompt"]
51+
seeds = [hash_str_to_int(name) for name in filenames]
4652
generators = [torch.Generator(device=pipeline.device).manual_seed(seed) for seed in seeds]
47-
images = pipeline(prompts, generator=generators, **config.eval.get_pipeline_kwargs()).images
48-
if len(caches) == batch_size * config.eval.num_steps:
49-
num_guidances = 1
50-
elif len(caches) == 2 * batch_size * config.eval.num_steps:
51-
num_guidances = 2
52-
else:
53-
raise ValueError(f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps}")
54-
for j, (filename, image) in enumerate(zip(batch, images, strict=True)):
53+
pipeline_kwargs = config.eval.get_pipeline_kwargs()
54+
55+
task = config.pipeline.task
56+
control_root = config.eval.control_root
57+
if task in ["canny-to-image", "depth-to-image", "inpainting"]:
58+
controls = get_control(
59+
task,
60+
batch["image"],
61+
names=batch["filename"],
62+
data_root=os.path.join(
63+
control_root, collect_config.dataset_name, f"{dataset.config_name}-{config.eval.num_samples}"
64+
),
65+
)
66+
if task == "inpainting":
67+
pipeline_kwargs["image"] = controls[0]
68+
pipeline_kwargs["mask_image"] = controls[1]
69+
else:
70+
pipeline_kwargs["control_image"] = controls
71+
72+
result_images = pipeline(prompts, generator=generators, **pipeline_kwargs).images
73+
num_guidances = (len(caches) // batch_size) // config.eval.num_steps
74+
num_steps = len(caches) // (batch_size * num_guidances)
75+
assert (
76+
len(caches) == batch_size * num_steps * num_guidances
77+
), f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps} * {num_guidances}"
78+
for j, (filename, image) in enumerate(zip(filenames, result_images, strict=True)):
5579
image.save(os.path.join(samples_dirpath, f"{filename}.png"))
56-
for s in range(config.eval.num_steps):
80+
for s in range(num_steps):
5781
for g in range(num_guidances):
5882
c = caches[s * batch_size * num_guidances + g * batch_size + j]
5983
c["filename"] = filename
@@ -82,7 +106,7 @@ class CollectConfig:
82106

83107
root: str = "datasets"
84108
dataset_name: str = "qdiff"
85-
prompt_path: str = "prompts/qdiff.yaml"
109+
data_path: str = "prompts/qdiff.yaml"
86110
num_samples: int = 128
87111

88112

@@ -109,13 +133,13 @@ class CollectConfig:
109133
)
110134
print(f"Saving caches to {collect_dirpath}")
111135

112-
dataset = yaml.safe_load(open(collect_config.prompt_path, "r"))
113-
filenames = list(dataset.keys())
114-
if collect_config.num_samples > 0:
115-
random.Random(0).shuffle(filenames)
116-
filenames = filenames[: collect_config.num_samples]
117-
filenames = sorted(filenames)
136+
dataset = get_dataset(
137+
collect_config.data_path,
138+
max_dataset_size=collect_config.num_samples,
139+
return_gt=ptq_config.pipeline.task in ["canny-to-image"],
140+
repeat=1,
141+
)
118142

119143
ptq_config.output.root = collect_dirpath
120144
os.makedirs(ptq_config.output.root, exist_ok=True)
121-
collect(ptq_config, filenames=filenames, dataset=dataset)
145+
collect(ptq_config, dataset=dataset)

deepcompressor/app/diffusion/dataset/collect/utils.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
import torch
88
import torch.nn as nn
9-
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
10-
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
9+
from diffusers.models.transformers import (
10+
FluxTransformer2DModel,
11+
PixArtTransformer2DModel,
12+
SanaTransformer2DModel,
13+
)
1114
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
1215

1316
from deepcompressor.utils.common import tree_map, tree_split
@@ -51,7 +54,7 @@ def __call__(
5154
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
5255
timesteps = timesteps.expand(sample.shape[0])
5356
input_kwargs["timestep"] = timesteps
54-
elif isinstance(module, PixArtTransformer2DModel):
57+
elif isinstance(module, (PixArtTransformer2DModel, SanaTransformer2DModel)):
5558
new_args.append(input_kwargs.pop("hidden_states"))
5659
elif isinstance(module, FluxTransformer2DModel):
5760
new_args.append(input_kwargs.pop("hidden_states"))

0 commit comments

Comments
 (0)