Skip to content

Commit 1a20585

Browse files
authored
mxtensor: inherit from TorchAOBaseTensor (#2805)
* Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 249d95b commit 1a20585

File tree

2 files changed

+14
-94
lines changed

2 files changed

+14
-94
lines changed

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 13 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
* Zeros: N/A
1818
"""
1919

20-
from typing import Callable, Dict, Union
20+
from typing import Union
2121

2222
import torch
2323
from torch.distributed._tensor import DTensor
@@ -57,6 +57,7 @@
5757
triton_f6_e3m2_to_scaled_bf16,
5858
unpack_uint4,
5959
)
60+
from torchao.utils import TorchAOBaseTensor
6061

6162
# TODO(later): read from somewhere else?
6263
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23
@@ -448,7 +449,17 @@ def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous):
448449
return new_size
449450

450451

451-
class MXTensor(torch.Tensor):
452+
class MXTensor(TorchAOBaseTensor):
453+
tensor_data_names = ["qdata", "_scale_e8m0"]
454+
tensor_attribute_names = [
455+
"_elem_dtype",
456+
"_block_size",
457+
"_orig_dtype",
458+
"_use_fp4_custom_triton_dequant_kernel",
459+
"_gemm_kernel_choice",
460+
"_pack_fp6",
461+
]
462+
452463
def __new__(
453464
cls,
454465
qdata,
@@ -610,97 +621,5 @@ def to_mx(
610621
pack_fp6,
611622
)
612623

613-
def __tensor_flatten__(self):
614-
ctx = {
615-
"_elem_dtype": self._elem_dtype,
616-
"_block_size": self._block_size,
617-
"_orig_dtype": self._orig_dtype,
618-
"_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel,
619-
"_gemm_kernel_choice": self._gemm_kernel_choice,
620-
"_pack_fp6": self._pack_fp6,
621-
}
622-
return ["qdata", "_scale_e8m0"], ctx
623-
624-
@staticmethod
625-
def __tensor_unflatten__(
626-
inner_tensors: Dict,
627-
metadata,
628-
outer_size,
629-
outer_stride,
630-
):
631-
return MXTensor(
632-
inner_tensors["qdata"],
633-
inner_tensors["_scale_e8m0"],
634-
metadata["_elem_dtype"],
635-
metadata["_block_size"],
636-
metadata["_orig_dtype"],
637-
metadata["_use_fp4_custom_triton_dequant_kernel"],
638-
metadata["_gemm_kernel_choice"],
639-
metadata["_pack_fp6"],
640-
)
641-
642-
def _apply_fn_to_data(self, fn: Callable):
643-
"""Applies a fn to all tensor components stored on this class"""
644-
tensor_names, ctx = self.__tensor_flatten__()
645-
646-
# Apply the function to each tensor component
647-
new_tensors = {}
648-
for name in tensor_names:
649-
new_tensors[name] = fn(getattr(self, name))
650-
651-
return self.__class__.__tensor_unflatten__(
652-
new_tensors,
653-
ctx,
654-
None, # outer_size parameter
655-
None, # outer_stride parameter
656-
)
657-
658624
# Do not force the MXTensor type on the returned tensor
659625
__torch_function__ = torch._C._disabled_torch_function_impl
660-
661-
@classmethod
662-
def _same_metadata(cls, self: "MXTensor", src: "MXTensor") -> bool:
663-
checks = [
664-
(isinstance(self, MXTensor), "self is not MXTensor"),
665-
(isinstance(src, MXTensor), "src is not MXTensor"),
666-
(
667-
self._elem_dtype == src._elem_dtype,
668-
f"elem_dtype: {self._elem_dtype} != {src._elem_dtype}",
669-
),
670-
(
671-
self._block_size == src._block_size,
672-
f"block_size: {self._block_size} != {src._block_size}",
673-
),
674-
(
675-
self._orig_dtype == src._orig_dtype,
676-
f"orig_dtype: {self._orig_dtype} != {src._orig_dtype}",
677-
),
678-
(
679-
self._use_fp4_custom_triton_dequant_kernel
680-
== src._use_fp4_custom_triton_dequant_kernel,
681-
"use_fp4_custom_triton_dequant_kernel mismatch",
682-
),
683-
(
684-
self._gemm_kernel_choice == src._gemm_kernel_choice,
685-
f"gemm_kernel_choice: {self._gemm_kernel_choice} != {src._gemm_kernel_choice}",
686-
),
687-
(
688-
self._pack_fp6 == src._pack_fp6,
689-
f"pack_fp6: {self._pack_fp6} != {src._pack_fp6}",
690-
),
691-
(
692-
self._scale_e8m0.shape == src._scale_e8m0.shape,
693-
f"scale_e8m0.shape: {self._scale_e8m0.shape} != {src._scale_e8m0.shape}",
694-
),
695-
(
696-
self.qdata.shape == src.qdata.shape,
697-
f"data.shape: {self.qdata.shape} != {src.qdata.shape}",
698-
),
699-
]
700-
701-
for condition, error_msg in checks:
702-
if not condition:
703-
raise ValueError(f"Metadata mismatch: {error_msg}")
704-
return False
705-
706-
return True

torchao/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,7 @@ class variables to define to simplify implmentation of tensor subclasses:
738738
`_apply_fn_to_data`: takes a function (Tensor -> Tensor), applies function to all tensor data and
739739
recreate a new subclassed Tensor with the transformed tensor data
740740
`__repr__`: the string representation of the subclassed tensor instance
741+
`_same_metadata`: returns whether the metadata is the same between two instances of cls
741742
torch ops: torch.Tensor.contiguous
742743
aten ops: aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to)
743744

0 commit comments

Comments
 (0)