diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index af2fa3c74ee..80c7c2a51df 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -28,6 +28,7 @@ from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ +from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig from torchao.utils import unwrap_tensor_subclass @@ -77,7 +78,7 @@ def main() -> None: parser.add_argument( "--coreml-quantize", default=None, - choices=["b4w", "c4w"], + choices=["b4w", "c4w", "custom",], help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)", ) parser.add_argument( @@ -118,6 +119,8 @@ def main() -> None: model.eval() model.to(float_dtype) + print("MODEL", model) + if export_args.target_split_size is not None: replace_linear_with_split_linear( model, @@ -163,6 +166,34 @@ def main() -> None: granularity=PerAxis(0), ), ) + elif export_args.coreml_quantize == "custom": + replace_linear_with_split_linear( + model, + out_target_split_size=2048, + out_max_splits=4, + in_target_split_size=1, + in_max_splits=1, + fqn_filer=lambda fqn: any(fqn.endswith(suffix) for suffix in ["w1", "w3"]) + ) + replace_linear_with_split_linear( + model, + out_target_split_size=2048, + out_max_splits=1, + in_target_split_size=2048, + in_max_splits=4, + fqn_filer=lambda fqn: any(fqn.endswith(suffix) for suffix in ["w2"]) + ) + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerAxis(0), + ), + lambda m, fqn: ( + isinstance(m, torch.nn.Linear) + and any(fqn.endswith(suffix) for suffix in ["wq", "wk", "wv", "wo", "output", "w1", "w3", "w2"]) + ), + ) compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] minimum_deployment_target=ct.target.iOS18, @@ -199,6 +230,10 @@ def main() -> None: print("Exported program") print(ep) + # ep = ep.run_decompositions({}) + # mlprogram = ct.convert(ep, minimum_deployment_target=ct.target.iOS18) + # mlprogram.save("model.mlpackage") + edge_manager = to_edge_transform_and_lower( ep, partitioner=[partitioner], diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index ae98c327b45..7133b4079d9 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -167,6 +167,40 @@ def forward(self, x): output = self._norm(x) return output * self.weight +class CoreMLRMSNormV2(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + + return torch.nn.functional.rms_norm(x, normalized_shape=[self.dim], weight=self.weight, eps=None) + +_RMS_NORM = CoreMLRMSNorm class Rope(torch.nn.Module): def __init__(self, params: ModelArgs): @@ -327,8 +361,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): if self.use_qk_norm: q_norm_dim = self.head_dim k_norm_dim = self.head_dim - self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) - self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) + self.q_norm_fn = _RMS_NORM(q_norm_dim, eps=args.norm_eps) + self.k_norm_fn = _RMS_NORM(k_norm_dim, eps=args.norm_eps) def forward( self, @@ -364,6 +398,7 @@ def forward( k = torch.concat([k_cache, k], dim=2) v = torch.concat([v_cache, v], dim=2) + # TODO: I'm pretty sure the MB version of SDPA does not require this repeat_interleave, # grouped multiquery attention: expand out keys and values if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1) @@ -388,8 +423,8 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): self.block_sparse_moe = MOEFeedForward(args) else: self.feed_forward = FeedForward(args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = _RMS_NORM(args.dim, eps=args.norm_eps) + self.ffn_norm = _RMS_NORM(args.dim, eps=args.norm_eps) def forward( self, @@ -422,7 +457,7 @@ def __init__(self, params: ModelArgs): self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params, self.rope)) - self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.norm = _RMS_NORM(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.generate_full_logits = params.generate_full_logits self.max_seq_len = params.max_seq_len diff --git a/examples/apple/coreml/llama/utils.py b/examples/apple/coreml/llama/utils.py index 1e5a842fed5..06dd858ce9c 100644 --- a/examples/apple/coreml/llama/utils.py +++ b/examples/apple/coreml/llama/utils.py @@ -91,10 +91,15 @@ def forward(self, x): def replace_linear_with_split_linear( - model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1 + model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1, fqn_filer=None, ): + if fqn_filer is None: + fqn_filer = lambda fqn: True + for name, module in model.named_children(): - if isinstance(module, torch.nn.Linear): + should_split = isinstance(module, torch.nn.Linear) and fqn_filer(name) + print("TESTING", name, "WILL SPLIT", should_split) + if should_split: assert module.bias is None, "SplitLinearModule does not support bias" new_module = SplitLinearModule( module.in_features, @@ -113,4 +118,5 @@ def replace_linear_with_split_linear( out_max_splits, in_target_split_size, in_max_splits, + fqn_filer, ) diff --git a/examples/apple/coreml/scripts/extract_coreml_models.py b/examples/apple/coreml/scripts/extract_coreml_models.py index b3778a22625..593a270186b 100644 --- a/examples/apple/coreml/scripts/extract_coreml_models.py +++ b/examples/apple/coreml/scripts/extract_coreml_models.py @@ -21,7 +21,7 @@ def extract_coreml_models(pte_data: bytes): - program = deserialize_pte_binary(pte_data) + program = deserialize_pte_binary(pte_data).program delegates: List[BackendDelegate] = sum( [execution_plan.delegates for execution_plan in program.execution_plan], [] )