Skip to content

Commit 3f8e82b

Browse files
committed
trying to add quantization to Flux
1 parent 9e390da commit 3f8e82b

File tree

4 files changed

+171
-6
lines changed

4 files changed

+171
-6
lines changed

Diff for: examples/apps/flux-quantization.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# %%
2+
# Import the following libraries
3+
# -----------------------------
4+
import modelopt.torch.opt as mto
5+
import modelopt.torch.quantization as mtq
6+
import torch
7+
import torch_tensorrt
8+
from diffusers import FluxPipeline
9+
from modelopt.torch.quantization.utils import export_torch_mode
10+
from torch.export._trace import _export
11+
from transformers import AutoModelForCausalLM
12+
13+
# Load the ModelOpt-modified model architecture and weights using Huggingface APIs
14+
# model = AutoModelForCausalLM.from_pretrained("/home/other/quantization/quantized_flux.pt")
15+
16+
# %%
17+
DEVICE = "cuda:0"
18+
pipe = FluxPipeline.from_pretrained(
19+
"black-forest-labs/FLUX.1-dev",
20+
torch_dtype=torch.float32,
21+
)
22+
pipe.to(DEVICE).to(torch.float32)
23+
# Store the config and transformer backbone
24+
config = pipe.transformer.config
25+
# global backbone
26+
backbone = pipe.transformer
27+
backbone.eval()
28+
29+
30+
def generate_image(pipe, prompt, image_name):
31+
seed = 42
32+
image = pipe(
33+
prompt,
34+
output_type="pil",
35+
num_inference_steps=20,
36+
generator=torch.Generator("cuda").manual_seed(seed),
37+
).images[0]
38+
image.save(f"{image_name}.png")
39+
print(f"Image generated using {image_name} model saved as {image_name}.png")
40+
41+
42+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
43+
44+
# %%
45+
# Quantization
46+
47+
48+
def do_calibrate(
49+
pipe,
50+
prompt: str,
51+
) -> None:
52+
"""
53+
Run calibration steps on the pipeline using the given prompts.
54+
"""
55+
image = pipe(
56+
prompt,
57+
output_type="pil",
58+
num_inference_steps=20,
59+
generator=torch.Generator("cuda").manual_seed(0),
60+
).images[0]
61+
62+
63+
def forward_loop(mod):
64+
# Switch the pipeline's backbone, run calibration
65+
pipe.transformer = mod
66+
do_calibrate(
67+
pipe=pipe,
68+
prompt="test",
69+
)
70+
71+
72+
# mto.restore(backbone, "/home/other/quantization/quantized_flux.pt")
73+
ptq_config = mtq.INT8_DEFAULT_CFG
74+
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
75+
76+
77+
# %%
78+
# Export the backbone using torch.export
79+
# --------------------------------------------------
80+
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2``
81+
# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
82+
83+
batch_size = 2
84+
BATCH = torch.export.Dim("batch", min=1, max=2)
85+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
86+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
87+
# To see this recommendation, you can try exporting using min=1, max=4096
88+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
89+
dynamic_shapes = {
90+
"hidden_states": {0: BATCH},
91+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
92+
"pooled_projections": {0: BATCH},
93+
"timestep": {0: BATCH},
94+
"txt_ids": {0: SEQ_LEN},
95+
"img_ids": {0: IMG_ID},
96+
"guidance": {0: BATCH},
97+
"joint_attention_kwargs": {},
98+
"return_dict": None,
99+
}
100+
# The guidance factor is of type torch.float32
101+
dummy_inputs = {
102+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
103+
DEVICE
104+
),
105+
"encoder_hidden_states": torch.randn(
106+
(batch_size, 512, 4096), dtype=torch.float16
107+
).to(DEVICE),
108+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
109+
DEVICE
110+
),
111+
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(DEVICE),
112+
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
113+
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
114+
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
115+
"joint_attention_kwargs": {},
116+
"return_dict": False,
117+
}
118+
119+
# This will create an exported program which is going to be compiled with Torch-TensorRT
120+
with export_torch_mode():
121+
ep = _export(
122+
backbone,
123+
args=(),
124+
kwargs=dummy_inputs,
125+
# dynamic_shapes=dynamic_shapes,
126+
strict=False,
127+
allow_complex_guards_as_runtime_asserts=True,
128+
)
129+
130+
131+
trt_gm = torch_tensorrt.dynamo.compile(
132+
ep,
133+
inputs=dummy_inputs,
134+
enabled_precisions={torch.int8},
135+
truncate_double=True,
136+
min_block_size=1,
137+
# use_fp32_acc=True,
138+
# use_explicit_typing=True,
139+
debug=False,
140+
use_python_runtime=True,
141+
immutable_weights=True,
142+
offload_module_to_cpu=False,
143+
)
144+
145+
146+
del ep
147+
pipe.transformer = trt_gm
148+
pipe.transformer.config = config
149+
150+
151+
# %%
152+
trt_gm.device = torch.device(DEVICE)
153+
# Function which generates images from the flux pipeline
154+
155+
for _ in range(2):
156+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")

Diff for: py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,9 @@ def aten_ops_neg(
597597
)
598598
else:
599599

600-
@dynamo_tensorrt_converter(torch.ops.tensorrt.quantize_op.default)
600+
@dynamo_tensorrt_converter(
601+
torch.ops.tensorrt.quantize_op.default, supports_dynamic_shapes=True
602+
)
601603
def aten_ops_quantize_op(
602604
ctx: ConversionContext,
603605
target: Target,

Diff for: py/torch_tensorrt/dynamo/conversion/impl/quantize.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,19 @@ def quantize(
4848
scale = torch.divide(amax, max_bound)
4949
scale = get_trt_tensor(ctx, scale, name + "_scale")
5050
# Add Q node
51-
quantize_layer = ctx.net.add_quantize(input_tensor, scale)
5251
if num_bits == 8 and exponent_bits == 0:
53-
quantize_layer.set_output_type(0, trt.DataType.INT8)
52+
dtype = trt.DataType.INT8
5453
elif num_bits == 8 and exponent_bits == 4:
55-
quantize_layer.set_output_type(0, trt.DataType.FP8)
54+
dtype = trt.DataType.FP8
55+
56+
quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype)
5657

5758
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
5859
q_output = quantize_layer.get_output(0)
5960
# Add DQ node
60-
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
61+
dequantize_layer = ctx.net.add_dequantize(
62+
q_output, scale, output_type=input_tensor.dtype
63+
)
6164
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
6265
if num_bits == 8 and exponent_bits == 0:
6366
dequantize_layer.precision = trt.DataType.INT8

Diff for: py/torch_tensorrt/dynamo/utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,9 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -
419419
"""
420420
Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64
421421
"""
422-
if isinstance(tensor, (torch.Tensor, FakeTensor, int, float, bool)):
422+
if isinstance(tensor, (torch.Tensor, FakeTensor)):
423+
return tensor.dtype
424+
elif isinstance(tensor, (int, float, bool)):
423425
return torch.tensor(tensor).dtype
424426
elif isinstance(tensor, torch.SymInt):
425427
return torch.int64
@@ -791,6 +793,8 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
791793
output_dtypes.append(dtype.float32)
792794
else:
793795
output_dtypes.append(dtype._from(output_meta.dtype))
796+
elif isinstance(output_meta, torch.SymInt):
797+
output_dtypes.append(dtype.int64)
794798
elif "tensor_meta" in output.meta:
795799
output_meta = output.meta["tensor_meta"]
796800
output_dtypes.append(dtype._from(output_meta.dtype))

0 commit comments

Comments
 (0)