|
17 | 17 | * Zeros: N/A |
18 | 18 | """ |
19 | 19 |
|
20 | | -from typing import Callable, Dict, Union |
| 20 | +from typing import Union |
21 | 21 |
|
22 | 22 | import torch |
23 | 23 | from torch.distributed._tensor import DTensor |
|
57 | 57 | triton_f6_e3m2_to_scaled_bf16, |
58 | 58 | unpack_uint4, |
59 | 59 | ) |
| 60 | +from torchao.utils import TorchAOBaseTensor |
60 | 61 |
|
61 | 62 | # TODO(later): read from somewhere else? |
62 | 63 | SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 |
@@ -448,7 +449,17 @@ def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous): |
448 | 449 | return new_size |
449 | 450 |
|
450 | 451 |
|
451 | | -class MXTensor(torch.Tensor): |
| 452 | +class MXTensor(TorchAOBaseTensor): |
| 453 | + tensor_data_names = ["qdata", "_scale_e8m0"] |
| 454 | + tensor_attribute_names = [ |
| 455 | + "_elem_dtype", |
| 456 | + "_block_size", |
| 457 | + "_orig_dtype", |
| 458 | + "_use_fp4_custom_triton_dequant_kernel", |
| 459 | + "_gemm_kernel_choice", |
| 460 | + "_pack_fp6", |
| 461 | + ] |
| 462 | + |
452 | 463 | def __new__( |
453 | 464 | cls, |
454 | 465 | qdata, |
@@ -610,97 +621,5 @@ def to_mx( |
610 | 621 | pack_fp6, |
611 | 622 | ) |
612 | 623 |
|
613 | | - def __tensor_flatten__(self): |
614 | | - ctx = { |
615 | | - "_elem_dtype": self._elem_dtype, |
616 | | - "_block_size": self._block_size, |
617 | | - "_orig_dtype": self._orig_dtype, |
618 | | - "_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel, |
619 | | - "_gemm_kernel_choice": self._gemm_kernel_choice, |
620 | | - "_pack_fp6": self._pack_fp6, |
621 | | - } |
622 | | - return ["qdata", "_scale_e8m0"], ctx |
623 | | - |
624 | | - @staticmethod |
625 | | - def __tensor_unflatten__( |
626 | | - inner_tensors: Dict, |
627 | | - metadata, |
628 | | - outer_size, |
629 | | - outer_stride, |
630 | | - ): |
631 | | - return MXTensor( |
632 | | - inner_tensors["qdata"], |
633 | | - inner_tensors["_scale_e8m0"], |
634 | | - metadata["_elem_dtype"], |
635 | | - metadata["_block_size"], |
636 | | - metadata["_orig_dtype"], |
637 | | - metadata["_use_fp4_custom_triton_dequant_kernel"], |
638 | | - metadata["_gemm_kernel_choice"], |
639 | | - metadata["_pack_fp6"], |
640 | | - ) |
641 | | - |
642 | | - def _apply_fn_to_data(self, fn: Callable): |
643 | | - """Applies a fn to all tensor components stored on this class""" |
644 | | - tensor_names, ctx = self.__tensor_flatten__() |
645 | | - |
646 | | - # Apply the function to each tensor component |
647 | | - new_tensors = {} |
648 | | - for name in tensor_names: |
649 | | - new_tensors[name] = fn(getattr(self, name)) |
650 | | - |
651 | | - return self.__class__.__tensor_unflatten__( |
652 | | - new_tensors, |
653 | | - ctx, |
654 | | - None, # outer_size parameter |
655 | | - None, # outer_stride parameter |
656 | | - ) |
657 | | - |
658 | 624 | # Do not force the MXTensor type on the returned tensor |
659 | 625 | __torch_function__ = torch._C._disabled_torch_function_impl |
660 | | - |
661 | | - @classmethod |
662 | | - def _same_metadata(cls, self: "MXTensor", src: "MXTensor") -> bool: |
663 | | - checks = [ |
664 | | - (isinstance(self, MXTensor), "self is not MXTensor"), |
665 | | - (isinstance(src, MXTensor), "src is not MXTensor"), |
666 | | - ( |
667 | | - self._elem_dtype == src._elem_dtype, |
668 | | - f"elem_dtype: {self._elem_dtype} != {src._elem_dtype}", |
669 | | - ), |
670 | | - ( |
671 | | - self._block_size == src._block_size, |
672 | | - f"block_size: {self._block_size} != {src._block_size}", |
673 | | - ), |
674 | | - ( |
675 | | - self._orig_dtype == src._orig_dtype, |
676 | | - f"orig_dtype: {self._orig_dtype} != {src._orig_dtype}", |
677 | | - ), |
678 | | - ( |
679 | | - self._use_fp4_custom_triton_dequant_kernel |
680 | | - == src._use_fp4_custom_triton_dequant_kernel, |
681 | | - "use_fp4_custom_triton_dequant_kernel mismatch", |
682 | | - ), |
683 | | - ( |
684 | | - self._gemm_kernel_choice == src._gemm_kernel_choice, |
685 | | - f"gemm_kernel_choice: {self._gemm_kernel_choice} != {src._gemm_kernel_choice}", |
686 | | - ), |
687 | | - ( |
688 | | - self._pack_fp6 == src._pack_fp6, |
689 | | - f"pack_fp6: {self._pack_fp6} != {src._pack_fp6}", |
690 | | - ), |
691 | | - ( |
692 | | - self._scale_e8m0.shape == src._scale_e8m0.shape, |
693 | | - f"scale_e8m0.shape: {self._scale_e8m0.shape} != {src._scale_e8m0.shape}", |
694 | | - ), |
695 | | - ( |
696 | | - self.qdata.shape == src.qdata.shape, |
697 | | - f"data.shape: {self.qdata.shape} != {src.qdata.shape}", |
698 | | - ), |
699 | | - ] |
700 | | - |
701 | | - for condition, error_msg in checks: |
702 | | - if not condition: |
703 | | - raise ValueError(f"Metadata mismatch: {error_msg}") |
704 | | - return False |
705 | | - |
706 | | - return True |
0 commit comments