Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions litert_torch/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,16 @@ def layout_optimize_partitioner(self, value: str) -> None:
os.environ["LAYOUT_OPTIMIZE_PARTITIONER"] = str(value).upper()

@property
def lazy_constant_numel_threshold(self) -> int:
def resource_constant_numel_threshold(self) -> int:
"""The threshold for the number of elements in a constant to be eligible to be lazily loaded during lightweight conversion."""
default = 1024 * 1024 # 1MB
return _get_int_env_var("LAZY_CONSTANT_NUMEL_THRESHOLD", default=default)
return _get_int_env_var(
"RESOURCE_CONSTANT_NUMEL_THRESHOLD", default=default
)

@lazy_constant_numel_threshold.setter
def lazy_constant_numel_threshold(self, value: int) -> None:
os.environ["LAZY_CONSTANT_NUMEL_THRESHOLD"] = str(value)

@property
def lazy_constant_getter_chunk_size(self) -> int:
"""The chunk size for the lazy constant getter during lightweight conversion."""
default = 32 * 1024 * 1024 # 32MB
return _get_int_env_var("LAZY_CONSTANT_GETTER_CHUNK_SIZE", default=default)

@lazy_constant_getter_chunk_size.setter
def lazy_constant_getter_chunk_size(self, value: int) -> None:
os.environ["LAZY_CONSTANT_GETTER_CHUNK_SIZE"] = str(value)
@resource_constant_numel_threshold.setter
def resource_constant_numel_threshold(self, value: int) -> None:
os.environ["RESOURCE_CONSTANT_NUMEL_THRESHOLD"] = str(value)

@property
def show_progress(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion litert_torch/_convert/litert_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def exported_programs_to_flatbuffer(

ir_context = backend.export_utils.create_ir_context()
cross_program_inline_consts_ctx = inline_consts_lib.InlineConstsContext(
enable_lazy_constants=lightweight_conversion,
enable_resource_constants=lightweight_conversion,
)

lowered_programs = []
Expand Down
84 changes: 20 additions & 64 deletions litert_torch/backend/inline_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ai_edge_litert.mlir.dialects import arith
import numpy as np
import torch
from ai_edge_litert.mlir._mlir_libs import converter_api_ext

config = _config.config

Expand All @@ -33,7 +32,7 @@
class InlineConstsContext(lowerings.context.LoweringContextPlugin):
"""The context object for inlining constants."""

enable_lazy_constants: bool = False
enable_resource_constants: bool = False
constant_cache: dict[int, ir.Attribute] = dataclasses.field(
default_factory=dict
)
Expand Down Expand Up @@ -63,37 +62,12 @@ def _tensor_fingerprint(tensor: torch.Tensor) -> int:
def _tensor_to_mlir_compatible_array(tensor: torch.Tensor) -> np.ndarray:
"""Converts a tensor to a numpy array that is compatible with MLIR contiguity and endianness."""
if hasattr(tensor, 'detach'):
arr = tensor.detach().cpu().numpy()
arr = tensor.contiguous().detach().cpu().numpy()
else:
arr = np.array(tensor)

if arr.dtype == bool or arr.dtype == np.bool_:
# packbits returns uint8; bitorder='little' is crucial for MLIR
packed = np.packbits(arr, axis=None, bitorder='little')
return packed

target_dtype = {
# Floating point
np.float16: '<f2',
np.float32: '<f4',
np.float64: '<f8',
# Signed Integers
np.int8: '<i1',
np.int16: '<i2',
np.int32: '<i4',
np.int64: '<i8',
# Unsigned Integers
np.uint8: '<u1',
np.uint16: '<u2',
np.uint32: '<u4',
np.uint64: '<u8',
}.get(arr.dtype.type)

if target_dtype is None:
raise TypeError(f'Unsupported dtype for MLIR conversion: {arr.dtype}')

# Ensure C-contiguity and the specific bit-width/endianness
return np.ascontiguousarray(arr, dtype=target_dtype)
# Ensure C-contiguity
return np.ascontiguousarray(arr)


def _get_tensor_uniform_value(tensor: torch.Tensor):
Expand Down Expand Up @@ -134,7 +108,7 @@ def _clamp_inf_values(tensor: torch.Tensor):
"""Clamps a tensor to the min/max value for float tensors."""
if torch.is_floating_point(tensor):
info = torch.finfo(tensor.dtype)
tensor = torch.clamp(tensor, info.min, info.max)
tensor.clamp_(min=info.min, max=info.max)
return tensor


Expand All @@ -161,6 +135,7 @@ def tensor_lowering_placeholder_lowering(
):
"""Lower the placeholder function to a constant op."""
const_ctx = InlineConstsContext.get(lctx)
x = x.contiguous().detach().cpu()

x_fingerprint = _tensor_fingerprint(x)
elty = lowering_utils.torch_dtype_to_ir_element_type(x.dtype)
Expand All @@ -171,52 +146,33 @@ def tensor_lowering_placeholder_lowering(
if cached_attr is not None:
return _build_const(cached_attr, tensor_type)

use_lazy_attr = const_ctx.enable_lazy_constants
if x.dtype not in [torch.float32]:
use_lazy_attr = False
use_resource_attr = const_ctx.enable_resource_constants
if x.dtype not in [torch.float32, torch.int32]:
use_resource_attr = False

# If the tensor is too small, just use a dense elements attr.
if x.numel() * x.element_size() < config.lazy_constant_numel_threshold:
use_lazy_attr = False
if x.numel() * x.element_size() < config.resource_constant_numel_threshold:
use_resource_attr = False

# If not using lazy attr, clamp inf values to the min/max value of the
# tensor's dtype. Otherwise, rely on the bytes getter to clamp values
# lazily.
if not use_lazy_attr:
x = _clamp_inf_values(x)
x = _clamp_inf_values(x)

# If the tensor is uniform, use a splat constant.
uniform_value = _get_tensor_uniform_value(x)
if uniform_value is not None:
use_lazy_attr = False
use_resource_attr = False

if uniform_value is not None:
attr = lowering_utils.splat_attr(
uniform_value,
tensor_type.element_type,
tensor_type.shape,
)
elif use_lazy_attr:

def chunk_iterator_factory():
nonlocal x
element_size = x.element_size()
elements_per_chunk = (
config.lazy_constant_getter_chunk_size // element_size
)

# x.view(-1) is a metadata-only operation (0 bytes allocated)
flat_x = x.view(-1)
numel = flat_x.numel()

for i in range(0, numel, elements_per_chunk):
chunk = flat_x[i : i + elements_per_chunk]
chunk = _clamp_inf_values(chunk)
chunk_data = _tensor_to_mlir_compatible_array(chunk).tobytes()
yield chunk_data

attr = converter_api_ext.get_py_chunked_callback_resource_attr(
tensor_type, chunk_iterator_factory
elif use_resource_attr:
arr = _tensor_to_mlir_compatible_array(x)
attr = ir.DenseResourceElementsAttr.get_from_buffer(
memoryview(arr),
f'TENSOR_{x_fingerprint}',
tensor_type,
)
else:
arr = _tensor_to_mlir_compatible_array(x)
Expand All @@ -227,7 +183,7 @@ def chunk_iterator_factory():


def inline_consts(exported_program: torch.export.ExportedProgram) -> None:
"""Inlines exported program's constant inputs by replacing with lazy_tensor_placeholder."""
"""Inlines exported program's constant inputs by replacing with resource_tensor_placeholder."""
flat_user_inputs, _ = exported_program._get_flat_args_with_check(
*exported_program.example_inputs
)
Expand Down
2 changes: 1 addition & 1 deletion litert_torch/generative/export_hf/core/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def export_text_prefill_decode_model(
start_time = time.perf_counter()

print('Converting model...')
lrt_model = converter.convert(strict_export=False)
lrt_model = converter.convert(lightweight_conversion=True, strict_export=False)
print('Converting model done.')

lrt_model = mu_pass_lib.update_model(lrt_model)
Expand Down
Loading