From 878a7f6c99ef5c8d2f8c0af8cf023a5162ef1bc8 Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Mon, 24 Mar 2025 07:27:02 -0700 Subject: [PATCH 1/2] Refactor LLMEdgeManager's to_dtype --- examples/models/llama/export_llama_lib.py | 20 ++------------------ extension/llm/export/builder.py | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 37a4e6952d8..ce1eff93b70 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -588,25 +588,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: ) # At this point, the model is loaded in the default fp32. - - # Checkpoint dtype should be lower or equal precision to the dtype override. + # override dtype checkpoint_dtype = edge_manager.model.checkpoint_dtype - if not ( - checkpoint_dtype == dtype_override.to_torch_dtype() - or ( - checkpoint_dtype == torch.float16 - and dtype_override.to_torch_dtype() == torch.float32 - ) - or ( - checkpoint_dtype == torch.bfloat16 - and dtype_override.to_torch_dtype() == torch.float32 - ) - ): - logging.warning( - f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}." - ) - - edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) + edge_manager.to_dtype(dtype_override) # We want to quantize (in the source transforms) the weights of the model # in the checkpoint dtype. diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 751e2d16175..155a38287f8 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -147,7 +147,25 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager": assert not dtype_override or isinstance( dtype_override, DType ), "Override dtype needs to be of type " - if dtype_override is not None and dtype_override != self.dtype: + + # Checkpoint dtype should be lower or equal precision to the dtype override. + checkpoint_dtype = self.model.checkpoint_dtype + if not ( + checkpoint_dtype == dtype_override.to_torch_dtype() + or ( + checkpoint_dtype == torch.float16 + and dtype_override.to_torch_dtype() == torch.float32 + ) + or ( + checkpoint_dtype == torch.bfloat16 + and dtype_override.to_torch_dtype() == torch.float32 + ) + ): + logging.warning( + f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}." + ) + + if dtype_override != self.dtype: torch_dtype = dtype_override.to_torch_dtype() logging.info(f"model.to {torch_dtype}") self.model = self.model.to(dtype=torch_dtype) From 8b22cadd4aaf330751c18142dd84dbd69060fc66 Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Sun, 23 Mar 2025 17:31:22 -0700 Subject: [PATCH 2/2] Apply hybrid quantization on Mimi --- examples/models/llama/export_llama_lib.py | 4 +- examples/models/llava/export_llava.py | 4 +- examples/models/moshi/mimi/test_mimi.py | 135 +++++++++++++++++++++- 3 files changed, 140 insertions(+), 3 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index ce1eff93b70..335d522d7ac 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1060,7 +1060,9 @@ def _load_llama_model( model.vocab_size, metadata_str, ), - args=args, + qnn = args.qnn, + export_only=args.export_only, + output_name=args.output_name, ) diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 64def112908..7ff94866323 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -92,7 +92,9 @@ def forward(self, input_pos, embeddings): use_kv_cache=True, example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings), dynamic_shapes=dynamic_shapes, - args=llava.text_model_args, + qnn = llava.text_model_args.qnn, + export_only=llava.text_model_args.export_only, + output_name=llava.text_model_args.output_name, ) dtype_override = DType.fp32 diff --git a/examples/models/moshi/mimi/test_mimi.py b/examples/models/moshi/mimi/test_mimi.py index 8160b5df79c..609f1c53320 100644 --- a/examples/models/moshi/mimi/test_mimi.py +++ b/examples/models/moshi/mimi/test_mimi.py @@ -2,6 +2,8 @@ import os import random import unittest +from functools import partial +from executorch.examples.models.llama.source_transformation.quantize import quantize import numpy as np import requests @@ -22,13 +24,23 @@ from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export, ExportedProgram from torch.utils._pytree import tree_flatten +from executorch.extension.llm.export.builder import DType, LLMEdgeManager +from omegaconf import OmegaConf +from pathlib import Path +from executorch.examples.models.llama.export_llama_lib import ( + _get_source_transforms, + get_quantizer_and_quant_params, +) +from executorch.extension.llm.export.partitioner_lib import ( + get_xnnpack_partitioner, +) +import logging proxies = { "http": "http://fwdproxy:8080", "https": "http://fwdproxy:8080", } - def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float: assert x.shape == y.shape, "Tensor shapes do not match" x = x.float() @@ -219,6 +231,127 @@ def forward(self, x): print(f"SQNR: {sqnr}") torch.testing.assert_close(eager_res, res[0], atol=1e-3, rtol=1e-3) + def test_exported_decoder_hybrid_quant(self): + class MimiDecode(nn.Module): + def __init__(self, mimi: nn.Module): + super().__init__() + self.mimi_model = mimi + + def forward(self, x): + return self.mimi_model.decode(x) + + sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None] + pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate) + chunk = sample_pcm[..., 0:pcm_chunk_size] + input = self.mimi.encode(chunk) + + mimi_decode = MimiDecode(self.mimi) + eager_output = mimi_decode(input) + + config_dict = { + "model": "llama3", + "checkpoint": None, + "max_seq_length": 128, + "dtype_override": "fp32", + "use_kv_cache": False, + "generate_full_logits": False, + "enable_dynamic_shape": False, + "verbose": True, + "qnn": False, + "export_only": False, + "output_name": "output", + "output_dir": "/tmp", + "xnnpack_extended_ops": True, + "quantization": { + "mode": "8da4w", + "group_size": 64, + } + } + + # Create a DictConfig object from the dictionary + config = OmegaConf.create(config_dict) + + edge_manager = LLMEdgeManager( + model=mimi_decode, + modelname=config.model, + max_seq_len=config.max_seq_length, + dtype=config.dtype_override, + use_kv_cache=config.use_kv_cache, + generate_full_logits=config.generate_full_logits, + example_inputs=(input,), + example_kwarg_inputs=None, + dynamic_shapes=None, + enable_dynamic_shape=config.enable_dynamic_shape, + verbose=config.verbose, + ) + + dtype_override = DType[config.dtype_override] + edge_manager.to_dtype(dtype_override) + + transforms = [] + + # Linear 8da4w. + # TODO: look into decode_latent as an "embedding" layer + if config.quantization.mode: + quant_args = { + "group_size": config.quantization.group_size, + } + transforms.append( + partial( + quantize, + **quant_args, + qmode=config.quantization.mode, + computation_dtype=dtype_override, + checkpoint_dtype=None, + checkpoint_path=(Path(path) if (path := config.checkpoint) is not None else None), + ) + ) + + llm_manager = edge_manager.source_transform(transforms) + # llm_manager = edge_manager + builder_exported = llm_manager.export() + builder_exported.run_canonical_optimizations() + + + # Lower to xnnpack + partitioners = [] + + # Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled + partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)) + + + if config.xnnpack_extended_ops: + partitioners.append( + get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) + ) + + for partitioner in partitioners: + logging.info(f"--> {partitioner.__class__.__name__}") + + # builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( + builder = builder_exported.to_edge_transform_and_lower( + partitioners + ) + print_delegation_info(builder.edge_manager.exported_program().graph_module) + + builder = builder.to_executorch() + + builder.save_to_pte("mimi_4bit") + + # + # + # llm_manager = edge_manager.source_transform( + # _get_source_transforms( + # modelname=config.model, dtype_override=dtype_override, args=args + # ) + # ) + # builder_exported = edge_manager.set_output_dir(args.output_dir).export() + # builder_exported.run_canonical_optimizations() + # + # pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + # args + # ) + if __name__ == "__main__": unittest.main()