Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 79 additions & 140 deletions litert_torch/generative/export_hf/core/mu/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ai_edge_litert.tools.model_utils.dialect import tfl


def fp32_predicate(op):
def default_fp32_predicate(op):
"""Returns true if the op should be kept in fp32."""
if isinstance(op, stablehlo.CompositeOp):
if "odml.rms_norm" == op.composite_name:
Expand Down Expand Up @@ -54,9 +54,17 @@ def fp32_predicate(op):
return False


def is_float(value: irdl.SSAValue):
if not isinstance(value.type, mlir.RankedTensorType):
return False
return value.type.elty in ("f16", "f32")


def convert_model_to_fp16(
path: str | pathlib.Path,
fp32_op_predicate: Callable[[irdl.Operation], bool] | None = None,
fp32_op_predicate: (
Callable[[irdl.Operation], bool] | None
) = default_fp32_predicate,
) -> bytes:
if isinstance(path, str):
path = pathlib.Path(path)
Expand All @@ -69,157 +77,88 @@ def convert_model_to_fp16(

def convert_to_fp16(
module: mlir.ModuleOp,
fp32_op_predicate: Callable[[irdl.Operation], bool] | None = None,
fp32_op_predicate: (
Callable[[irdl.Operation], bool] | None
) = default_fp32_predicate,
) -> None:
"""Converts the model to fp16."""
args_to_cast = []
args_to_update = []
ops_to_cast = []
ops_to_update = []
funcs_to_update = set()
fp32_ops = set()
visited = set()

def _walk(original_op):
def is_nested_fp32_op(op):
while op:
if isinstance(op, irdl.Operation) and op in fp32_ops:
return True
op = op.parent
return False

def collect_fp32_ops(original_op):
orig_fp32_ops_len = len(fp32_ops)

for op in original_op.walk():
if op not in visited:
visited.add(op)
else:
continue

if (
op.parent
and isinstance(op.parent, irdl.Block)
and op.parent.parent
and isinstance(op.parent.parent, irdl.Region)
and op.parent.parent.parent
and isinstance(op.parent.parent.parent, func.FuncOp)
and op.parent.parent.parent in fp32_ops
):
continue

if op == original_op:
continue

if isinstance(op, func.ReturnOp):
continue

if fp32_op_predicate and fp32_op_predicate(op):
if is_nested_fp32_op(op):
fp32_ops.add(op)
elif fp32_op_predicate and fp32_op_predicate(op):
fp32_ops.add(op)

if isinstance(op, stablehlo.CompositeOp):
fp32_ops.add(op.decomposition_func)
elif isinstance(op, tfl.SelectV2Op):
if isinstance(op.operands[2].op, tfl.ConstOp):
fp32_ops.add(op.operands[2].op)
continue

if op in fp32_ops:
continue

if isinstance(op, func.FuncOp):
funcs_to_update.add(op)

for arg in op.body.block.args:
if not isinstance(arg.type, mlir.RankedTensorType):
continue

if arg.type.elty != "f32":
continue

args_to_cast.append(arg)

_walk(op)

elif isinstance(op, tfl.ConstOp):
should_add = False
for result in op.results:
if (
isinstance(result.type, mlir.RankedTensorType)
and result.type.elty == "f32"
):
should_add = True
break
if should_add:
ops_to_cast.append(op)

elif isinstance(op, stablehlo.CompositeOp):
funcs_to_update.add(op.decomposition_func)

for arg in op.decomposition_func.body.block.args:
if not isinstance(arg.type, mlir.RankedTensorType):
continue

if arg.type.elty != "f32":
continue

args_to_update.append(arg)

_walk(op.decomposition_func)
ops_to_update.append(op)

else:
ops_to_update.append(op)

_walk(module)

for arg in args_to_cast:
arg.type = mlir.RankedTensorType(arg.type.shape, "f16")
for use in arg.uses.copy():
if use.operation in fp32_ops:
with mu.OpBuildingContext(use.operation, insert_before=True):
cast = tfl.cast(arg, "f32")
use.operation.operands[use.index] = cast

for arg in args_to_update:
arg.type = mlir.RankedTensorType(arg.type.shape, "f16")

for op in ops_to_cast:
for result in op.results:
for use in result.uses.copy():
# Skip if the use is in a fp32 op. Used for constant tensors.
if use.operation in fp32_ops:
continue
with mu.OpBuildingContext(use.operation, insert_before=True):
cast = tfl.cast(result, "f16")
use.operation.operands[use.index] = cast

for op in ops_to_update:
for result in op.results:
if not isinstance(result.type, mlir.RankedTensorType):
continue
if result.type.elty != "f32":
continue

result.type = mlir.RankedTensorType(result.type.shape, "f16")

for func_op in funcs_to_update:
func_op.update_function_type()

for op in fp32_ops:
for i, operand in enumerate(op.operands):
if (
isinstance(operand.type, mlir.RankedTensorType)
and operand.type.elty == "f16"
):
with mu.OpBuildingContext(op, insert_before=True):
cast = tfl.cast(operand, "f32")
op.operands[i] = cast

for result in op.results:
if (
not isinstance(result.type, mlir.RankedTensorType)
or result.type.elty != "f32"
):
continue

for use in result.uses.copy():
if use.operation not in fp32_ops:
with mu.OpBuildingContext(use.operation, insert_before=True):
cast = tfl.cast(result, "f16")
use.operation.operands[use.index] = cast

return orig_fp32_ops_len != len(fp32_ops)

# Recursively collect fp32 ops until convergence.
while collect_fp32_ops(module):
continue

for op in module.walk():
# Do not change cast ops.
if isinstance(op, tfl.CastOp):
continue

# fp32 op
if op in fp32_ops:
for i, x in enumerate(op.operands):
if is_float(x):
with mu.OpBuildingContext(op, insert_before=True):
op.operands[i] = tfl.cast(x, "f32")
continue

# fp16 op
if isinstance(op, func.FuncOp):
for arg in op.body.block.args:
if is_float(arg):
arg.type = mlir.RankedTensorType(arg.type.shape, "f16")
else:
has_float_operand = False
for i, x in enumerate(op.operands):
if is_float(x):
has_float_operand = True
with mu.OpBuildingContext(op, insert_before=True):
op.operands[i] = tfl.cast(x, "f16")

for x in op.results:
if is_float(x):
if not has_float_operand:
# Assumption: if the op has no float operands but has a float
# result, the result type is determined by the op semantic or
# attributes. In such case, we need to insert a cast op to convert
# the result to fp16 and rely on cleanups to propagate the type
# change.
with mu.OpBuildingContext(op, insert_after=True):
cast = tfl.cast(x, "f16")
x.replace_by(cast)
cast.owner.operands[0] = x
else:
# Otherwise, the result type is determined by the input operand
# types. We can directly update the result type to fp16.
x.type = mlir.RankedTensorType(x.type.shape, "f16")

# Update function types with new argument/result types.
for op in module.walk():
if isinstance(op, func.FuncOp):
op.update_function_type()

module.cleanup()
# Canonicalize, CSE, constant folding, etc.
module.cleanup()
2 changes: 1 addition & 1 deletion litert_torch/generative/export_hf/core/mu/mu_pass_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,6 @@ def apply_mixed_precision(
mu_module, ctx = _litert_model_to_model_utils(model)
with ctx:
print("Applying mixed precision to model...")
mixed_precision.convert_to_fp16(mu_module, mixed_precision.fp32_predicate)
mixed_precision.convert_to_fp16(mu_module)

return _model_utils_to_litert_model(mu_module, ctx)
Loading