Skip to content

[WIP] Mimi 4-bit quant on transformer and 8-bit on conv #9882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,25 +588,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
)

# At this point, the model is loaded in the default fp32.

# Checkpoint dtype should be lower or equal precision to the dtype override.
# override dtype
checkpoint_dtype = edge_manager.model.checkpoint_dtype
if not (
checkpoint_dtype == dtype_override.to_torch_dtype()
or (
checkpoint_dtype == torch.float16
and dtype_override.to_torch_dtype() == torch.float32
)
or (
checkpoint_dtype == torch.bfloat16
and dtype_override.to_torch_dtype() == torch.float32
)
):
logging.warning(
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
)

edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
edge_manager.to_dtype(dtype_override)

# We want to quantize (in the source transforms) the weights of the model
# in the checkpoint dtype.
Expand Down Expand Up @@ -1076,7 +1060,9 @@ def _load_llama_model(
model.vocab_size,
metadata_str,
),
args=args,
qnn = args.qnn,
export_only=args.export_only,
output_name=args.output_name,
)


Expand Down
4 changes: 3 additions & 1 deletion examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def forward(self, input_pos, embeddings):
use_kv_cache=True,
example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings),
dynamic_shapes=dynamic_shapes,
args=llava.text_model_args,
qnn = llava.text_model_args.qnn,
export_only=llava.text_model_args.export_only,
output_name=llava.text_model_args.output_name,
)

dtype_override = DType.fp32
Expand Down
135 changes: 134 additions & 1 deletion examples/models/moshi/mimi/test_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import random
import unittest
from functools import partial
from executorch.examples.models.llama.source_transformation.quantize import quantize

import numpy as np
import requests
Expand All @@ -22,13 +24,23 @@
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.export import export, ExportedProgram
from torch.utils._pytree import tree_flatten
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
from omegaconf import OmegaConf
from pathlib import Path
from executorch.examples.models.llama.export_llama_lib import (
_get_source_transforms,
get_quantizer_and_quant_params,
)
from executorch.extension.llm.export.partitioner_lib import (
get_xnnpack_partitioner,
)
import logging

proxies = {
"http": "http://fwdproxy:8080",
"https": "http://fwdproxy:8080",
}


def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
assert x.shape == y.shape, "Tensor shapes do not match"
x = x.float()
Expand Down Expand Up @@ -219,6 +231,127 @@ def forward(self, x):
print(f"SQNR: {sqnr}")
torch.testing.assert_close(eager_res, res[0], atol=1e-3, rtol=1e-3)

def test_exported_decoder_hybrid_quant(self):
class MimiDecode(nn.Module):
def __init__(self, mimi: nn.Module):
super().__init__()
self.mimi_model = mimi

def forward(self, x):
return self.mimi_model.decode(x)

sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
chunk = sample_pcm[..., 0:pcm_chunk_size]
input = self.mimi.encode(chunk)

mimi_decode = MimiDecode(self.mimi)
eager_output = mimi_decode(input)

config_dict = {
"model": "llama3",
"checkpoint": None,
"max_seq_length": 128,
"dtype_override": "fp32",
"use_kv_cache": False,
"generate_full_logits": False,
"enable_dynamic_shape": False,
"verbose": True,
"qnn": False,
"export_only": False,
"output_name": "output",
"output_dir": "/tmp",
"xnnpack_extended_ops": True,
"quantization": {
"mode": "8da4w",
"group_size": 64,
}
}

# Create a DictConfig object from the dictionary
config = OmegaConf.create(config_dict)

edge_manager = LLMEdgeManager(
model=mimi_decode,
modelname=config.model,
max_seq_len=config.max_seq_length,
dtype=config.dtype_override,
use_kv_cache=config.use_kv_cache,
generate_full_logits=config.generate_full_logits,
example_inputs=(input,),
example_kwarg_inputs=None,
dynamic_shapes=None,
enable_dynamic_shape=config.enable_dynamic_shape,
verbose=config.verbose,
)

dtype_override = DType[config.dtype_override]
edge_manager.to_dtype(dtype_override)

transforms = []

# Linear 8da4w.
# TODO: look into decode_latent as an "embedding" layer
if config.quantization.mode:
quant_args = {
"group_size": config.quantization.group_size,
}
transforms.append(
partial(
quantize,
**quant_args,
qmode=config.quantization.mode,
computation_dtype=dtype_override,
checkpoint_dtype=None,
checkpoint_path=(Path(path) if (path := config.checkpoint) is not None else None),
)
)

llm_manager = edge_manager.source_transform(transforms)
# llm_manager = edge_manager
builder_exported = llm_manager.export()
builder_exported.run_canonical_optimizations()


# Lower to xnnpack
partitioners = []

# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))


if config.xnnpack_extended_ops:
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
)

for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

# builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
builder = builder_exported.to_edge_transform_and_lower(
partitioners
)
print_delegation_info(builder.edge_manager.exported_program().graph_module)

builder = builder.to_executorch()

builder.save_to_pte("mimi_4bit")

#
#
# llm_manager = edge_manager.source_transform(
# _get_source_transforms(
# modelname=config.model, dtype_override=dtype_override, args=args
# )
# )
# builder_exported = edge_manager.set_output_dir(args.output_dir).export()
# builder_exported.run_canonical_optimizations()
#
# pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
# args
# )


if __name__ == "__main__":
unittest.main()
20 changes: 19 additions & 1 deletion extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,25 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager":
assert not dtype_override or isinstance(
dtype_override, DType
), "Override dtype needs to be of type <DType>"
if dtype_override is not None and dtype_override != self.dtype:

# Checkpoint dtype should be lower or equal precision to the dtype override.
checkpoint_dtype = self.model.checkpoint_dtype
if not (
checkpoint_dtype == dtype_override.to_torch_dtype()
or (
checkpoint_dtype == torch.float16
and dtype_override.to_torch_dtype() == torch.float32
)
or (
checkpoint_dtype == torch.bfloat16
and dtype_override.to_torch_dtype() == torch.float32
)
):
logging.warning(
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
)

if dtype_override != self.dtype:
torch_dtype = dtype_override.to_torch_dtype()
logging.info(f"model.to {torch_dtype}")
self.model = self.model.to(dtype=torch_dtype)
Expand Down
Loading