diff --git a/litert_torch/generative/export_hf/core/mu/mixed_precision.py b/litert_torch/generative/export_hf/core/mu/mixed_precision.py index 65ab65ee..d0925fb7 100644 --- a/litert_torch/generative/export_hf/core/mu/mixed_precision.py +++ b/litert_torch/generative/export_hf/core/mu/mixed_precision.py @@ -26,7 +26,7 @@ from ai_edge_litert.tools.model_utils.dialect import tfl -def fp32_predicate(op): +def default_fp32_predicate(op): """Returns true if the op should be kept in fp32.""" if isinstance(op, stablehlo.CompositeOp): if "odml.rms_norm" == op.composite_name: @@ -54,9 +54,17 @@ def fp32_predicate(op): return False +def is_float(value: irdl.SSAValue): + if not isinstance(value.type, mlir.RankedTensorType): + return False + return value.type.elty in ("f16", "f32") + + def convert_model_to_fp16( path: str | pathlib.Path, - fp32_op_predicate: Callable[[irdl.Operation], bool] | None = None, + fp32_op_predicate: ( + Callable[[irdl.Operation], bool] | None + ) = default_fp32_predicate, ) -> bytes: if isinstance(path, str): path = pathlib.Path(path) @@ -69,157 +77,88 @@ def convert_model_to_fp16( def convert_to_fp16( module: mlir.ModuleOp, - fp32_op_predicate: Callable[[irdl.Operation], bool] | None = None, + fp32_op_predicate: ( + Callable[[irdl.Operation], bool] | None + ) = default_fp32_predicate, ) -> None: """Converts the model to fp16.""" - args_to_cast = [] - args_to_update = [] - ops_to_cast = [] - ops_to_update = [] - funcs_to_update = set() fp32_ops = set() - visited = set() - def _walk(original_op): + def is_nested_fp32_op(op): + while op: + if isinstance(op, irdl.Operation) and op in fp32_ops: + return True + op = op.parent + return False + + def collect_fp32_ops(original_op): + orig_fp32_ops_len = len(fp32_ops) for op in original_op.walk(): - if op not in visited: - visited.add(op) - else: - continue - - if ( - op.parent - and isinstance(op.parent, irdl.Block) - and op.parent.parent - and isinstance(op.parent.parent, irdl.Region) - and op.parent.parent.parent - and isinstance(op.parent.parent.parent, func.FuncOp) - and op.parent.parent.parent in fp32_ops - ): - continue - - if op == original_op: - continue - - if isinstance(op, func.ReturnOp): - continue - - if fp32_op_predicate and fp32_op_predicate(op): + if is_nested_fp32_op(op): fp32_ops.add(op) + elif fp32_op_predicate and fp32_op_predicate(op): + fp32_ops.add(op) + if isinstance(op, stablehlo.CompositeOp): fp32_ops.add(op.decomposition_func) elif isinstance(op, tfl.SelectV2Op): if isinstance(op.operands[2].op, tfl.ConstOp): fp32_ops.add(op.operands[2].op) - continue - - if op in fp32_ops: - continue - - if isinstance(op, func.FuncOp): - funcs_to_update.add(op) - - for arg in op.body.block.args: - if not isinstance(arg.type, mlir.RankedTensorType): - continue - - if arg.type.elty != "f32": - continue - - args_to_cast.append(arg) - - _walk(op) - - elif isinstance(op, tfl.ConstOp): - should_add = False - for result in op.results: - if ( - isinstance(result.type, mlir.RankedTensorType) - and result.type.elty == "f32" - ): - should_add = True - break - if should_add: - ops_to_cast.append(op) - - elif isinstance(op, stablehlo.CompositeOp): - funcs_to_update.add(op.decomposition_func) - - for arg in op.decomposition_func.body.block.args: - if not isinstance(arg.type, mlir.RankedTensorType): - continue - - if arg.type.elty != "f32": - continue - - args_to_update.append(arg) - - _walk(op.decomposition_func) - ops_to_update.append(op) - - else: - ops_to_update.append(op) - - _walk(module) - - for arg in args_to_cast: - arg.type = mlir.RankedTensorType(arg.type.shape, "f16") - for use in arg.uses.copy(): - if use.operation in fp32_ops: - with mu.OpBuildingContext(use.operation, insert_before=True): - cast = tfl.cast(arg, "f32") - use.operation.operands[use.index] = cast - - for arg in args_to_update: - arg.type = mlir.RankedTensorType(arg.type.shape, "f16") - - for op in ops_to_cast: - for result in op.results: - for use in result.uses.copy(): - # Skip if the use is in a fp32 op. Used for constant tensors. - if use.operation in fp32_ops: - continue - with mu.OpBuildingContext(use.operation, insert_before=True): - cast = tfl.cast(result, "f16") - use.operation.operands[use.index] = cast - - for op in ops_to_update: - for result in op.results: - if not isinstance(result.type, mlir.RankedTensorType): - continue - if result.type.elty != "f32": - continue - - result.type = mlir.RankedTensorType(result.type.shape, "f16") - - for func_op in funcs_to_update: - func_op.update_function_type() - - for op in fp32_ops: - for i, operand in enumerate(op.operands): - if ( - isinstance(operand.type, mlir.RankedTensorType) - and operand.type.elty == "f16" - ): - with mu.OpBuildingContext(op, insert_before=True): - cast = tfl.cast(operand, "f32") - op.operands[i] = cast - - for result in op.results: - if ( - not isinstance(result.type, mlir.RankedTensorType) - or result.type.elty != "f32" - ): - continue - - for use in result.uses.copy(): - if use.operation not in fp32_ops: - with mu.OpBuildingContext(use.operation, insert_before=True): - cast = tfl.cast(result, "f16") - use.operation.operands[use.index] = cast + return orig_fp32_ops_len != len(fp32_ops) + + # Recursively collect fp32 ops until convergence. + while collect_fp32_ops(module): + continue + + for op in module.walk(): + # Do not change cast ops. + if isinstance(op, tfl.CastOp): + continue + + # fp32 op + if op in fp32_ops: + for i, x in enumerate(op.operands): + if is_float(x): + with mu.OpBuildingContext(op, insert_before=True): + op.operands[i] = tfl.cast(x, "f32") + continue + + # fp16 op + if isinstance(op, func.FuncOp): + for arg in op.body.block.args: + if is_float(arg): + arg.type = mlir.RankedTensorType(arg.type.shape, "f16") + else: + has_float_operand = False + for i, x in enumerate(op.operands): + if is_float(x): + has_float_operand = True + with mu.OpBuildingContext(op, insert_before=True): + op.operands[i] = tfl.cast(x, "f16") + + for x in op.results: + if is_float(x): + if not has_float_operand: + # Assumption: if the op has no float operands but has a float + # result, the result type is determined by the op semantic or + # attributes. In such case, we need to insert a cast op to convert + # the result to fp16 and rely on cleanups to propagate the type + # change. + with mu.OpBuildingContext(op, insert_after=True): + cast = tfl.cast(x, "f16") + x.replace_by(cast) + cast.owner.operands[0] = x + else: + # Otherwise, the result type is determined by the input operand + # types. We can directly update the result type to fp16. + x.type = mlir.RankedTensorType(x.type.shape, "f16") + + # Update function types with new argument/result types. + for op in module.walk(): if isinstance(op, func.FuncOp): op.update_function_type() - module.cleanup() + # Canonicalize, CSE, constant folding, etc. + module.cleanup() diff --git a/litert_torch/generative/export_hf/core/mu/mu_pass_lib.py b/litert_torch/generative/export_hf/core/mu/mu_pass_lib.py index fad92223..5f020891 100644 --- a/litert_torch/generative/export_hf/core/mu/mu_pass_lib.py +++ b/litert_torch/generative/export_hf/core/mu/mu_pass_lib.py @@ -146,6 +146,6 @@ def apply_mixed_precision( mu_module, ctx = _litert_model_to_model_utils(model) with ctx: print("Applying mixed precision to model...") - mixed_precision.convert_to_fp16(mu_module, mixed_precision.fp32_predicate) + mixed_precision.convert_to_fp16(mu_module) return _model_utils_to_litert_model(mu_module, ctx)