Skip to content

Commit 249d95b

Browse files
authored
mxtensor: make data argument first and rename to qdata (#2804)
Update [ghstack-poisoned]
1 parent af2cf1e commit 249d95b

File tree

6 files changed

+56
-58
lines changed

6 files changed

+56
-58
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def run_around_tests():
5555
"ROCm float4 gemm require gfx950"
5656
) # TODO(future): deploy gfx950 in ROCM CI
5757
@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required")
58-
def test_inference_workflow(elem_dtype, bias: bool, compile: bool):
58+
def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool):
5959
"""
6060
Smoke test for inference compile
6161
"""

test/prototype/mx_formats/test_mx_mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
3838
a_mx = MXTensor.to_mx(a, fmt, 32)
3939
b_mx = MXTensor.to_mx(b, fmt, 32)
4040

41-
a_data = a_mx._data
42-
b_data = b_mx._data
41+
a_data = a_mx.qdata
42+
b_data = b_mx.qdata
4343
assert b_data.is_contiguous()
4444
b_data = b_data.transpose(-1, -2)
4545

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
7373
# verify that if data.shape is (M, K) then scale.shape is (M, K // block_size)
7474
prev_dims, K = data_hp.shape[:-1], data_hp.shape[-1]
7575
if elem_dtype is torch.float4_e2m1fn_x2:
76-
assert data_mx._data.shape == (*prev_dims, K // 2)
76+
assert data_mx.qdata.shape == (*prev_dims, K // 2)
7777
else:
78-
assert data_mx._data.shape == (*prev_dims, K)
78+
assert data_mx.qdata.shape == (*prev_dims, K)
7979
assert data_mx._scale_e8m0.shape == (*prev_dims, K // block_size)
8080

8181

@@ -148,8 +148,8 @@ def test_to_mx_rceil():
148148
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
149149
)
150150
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
151-
assert torch.isnan(data_mx._data[0])
152-
assert torch.all(data_mx._data[1:] == 0)
151+
assert torch.isnan(data_mx.qdata[0])
152+
assert torch.all(data_mx.qdata[1:] == 0)
153153
# fp32 denorm
154154
# fmt: off
155155
data_hp = torch.tensor(
@@ -170,7 +170,7 @@ def test_to_mx_rceil():
170170
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
171171
)
172172
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
173-
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
173+
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
174174
# bf16 denorm
175175
# fmt: off
176176
data_hp = torch.tensor(
@@ -191,7 +191,7 @@ def test_to_mx_rceil():
191191
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
192192
)
193193
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
194-
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
194+
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
195195
# fp32 some denorm
196196
# fmt: off
197197
data_hp = torch.tensor(
@@ -222,7 +222,7 @@ def test_to_mx_rceil():
222222
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
223223
)
224224
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
225-
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
225+
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
226226
# bf16 some denorm
227227
# fmt: off
228228
data_hp = torch.tensor(
@@ -253,7 +253,7 @@ def test_to_mx_rceil():
253253
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
254254
)
255255
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
256-
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
256+
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
257257
# zero
258258
data_hp = torch.tensor([0] * 32, dtype=torch.uint32).view(torch.float32)
259259
ground_truth_scale = torch.tensor([0], dtype=torch.uint8).view(torch.float8_e8m0fnu)
@@ -264,7 +264,7 @@ def test_to_mx_rceil():
264264
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
265265
)
266266
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
267-
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
267+
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
268268
# fp32 normal
269269
# fmt: off
270270
data_hp = torch.tensor(
@@ -295,7 +295,7 @@ def test_to_mx_rceil():
295295
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
296296
)
297297
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
298-
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
298+
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
299299
# bf16 normal
300300
# fmt: off
301301
data_hp = torch.tensor(
@@ -326,7 +326,7 @@ def test_to_mx_rceil():
326326
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
327327
)
328328
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
329-
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
329+
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
330330

331331

332332
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -382,8 +382,8 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
382382
block_size = 4
383383
use_fp4_custom_triton_dequant_kernel = False
384384
tensor_mx = MXTensor(
385-
scale_e8m0,
386385
data_bits,
386+
scale_e8m0,
387387
elem_dtype,
388388
block_size,
389389
torch.float,
@@ -473,7 +473,7 @@ def test_fp6_packing(elem_dtype, pack_fp6):
473473
else:
474474
expected_packed_shape = x.shape
475475

476-
assert x_mx._data.shape == expected_packed_shape
476+
assert x_mx.qdata.shape == expected_packed_shape
477477

478478

479479
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -505,14 +505,14 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
505505
atol=0,
506506
rtol=0,
507507
)
508-
torch.testing.assert_close(x_mx._data, x_mx_c._data, atol=0, rtol=0)
508+
torch.testing.assert_close(x_mx.qdata, x_mx_c.qdata, atol=0, rtol=0)
509509

510510
to_dtype_c = torch.compile(to_dtype, fullgraph=True)
511511

512512
use_fp4_custom_triton_dequant_kernel = False
513513
pack_fp6 = False
514514
x_mx_dq = to_dtype(
515-
x_mx._data,
515+
x_mx.qdata,
516516
x_mx._scale_e8m0,
517517
x_mx._elem_dtype,
518518
x_mx._block_size,
@@ -521,7 +521,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
521521
pack_fp6,
522522
)
523523
x_mx_c_dq = to_dtype_c(
524-
x_mx_c._data,
524+
x_mx_c.qdata,
525525
x_mx_c._scale_e8m0,
526526
x_mx_c._elem_dtype,
527527
x_mx_c._block_size,

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def _to_mxfp8_dim1_kernel_wrapper(
6060
a_data_local = a_data.to_local()
6161
a_scale_local = a_scale.to_local()
6262
inner = MXTensor(
63-
a_scale_local,
6463
a_data_local.t(),
64+
a_scale_local,
6565
elem_dtype,
6666
block_size,
6767
hp_dtype,
@@ -79,8 +79,8 @@ def _to_mxfp8_dim1_kernel_wrapper(
7979
)
8080
else:
8181
mx_tensor = MXTensor(
82-
a_scale,
8382
a_data.t(),
83+
a_scale,
8484
elem_dtype,
8585
block_size,
8686
hp_dtype,

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def _addmm_mx_dispatch(
9191
if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
9292
# real MX gemm backed by torchao's CUTLASS kernels
9393
M, K, N = a.shape[0], a.shape[1], b.shape[1]
94-
assert a._data.is_contiguous()
95-
assert b._data.t().is_contiguous()
94+
assert a.qdata.is_contiguous()
95+
assert b.qdata.t().is_contiguous()
9696
assert a._block_size == 32, f"Invalid block size {a._block_size}"
9797
assert b._block_size == 32, f"Invalid block size {b._block_size}"
9898

@@ -108,8 +108,8 @@ def _addmm_mx_dispatch(
108108
)
109109

110110
res = torch._scaled_mm(
111-
a._data,
112-
b._data,
111+
a.qdata,
112+
b.qdata,
113113
a_scale_block.view(torch.float8_e8m0fnu),
114114
b_scale_block.view(torch.float8_e8m0fnu),
115115
bias=bias,
@@ -121,7 +121,7 @@ def _addmm_mx_dispatch(
121121
assert gemm_choice is MXGemmKernelChoice.CUTLASS, "unsupported"
122122
# FP4 operations
123123
res = torchao.ops.mx_fp4_bf16(
124-
a._data, b._data, a_scale_block, b_scale_block
124+
a.qdata, b.qdata, a_scale_block, b_scale_block
125125
)
126126
# TODO add optional bias to kernel
127127
if bias is not None:
@@ -171,8 +171,8 @@ def mx_t(func, types, args, kwargs):
171171
# For now, only transpose(input, 0, 1) is supported.
172172
old = args[0]
173173
new = MXTensor(
174+
old.qdata.t(),
174175
old._scale_e8m0,
175-
old._data.t(),
176176
old._elem_dtype,
177177
old._block_size,
178178
old._orig_dtype,
@@ -205,7 +205,7 @@ def unwrap(x):
205205

206206
@implements([aten.view.default])
207207
def mx_view_op(func, types, args, kwargs):
208-
data = args[0]._data
208+
data = args[0].qdata
209209
new_size = args[1]
210210
if args[0]._elem_dtype == torch.float4_e2m1fn_x2:
211211
# special case fp4 as we pack two elements per byte
@@ -215,8 +215,8 @@ def mx_view_op(func, types, args, kwargs):
215215
new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous())
216216
new_data = func(data, new_size, *args[2:], **kwargs)
217217
return MXTensor(
218-
args[0]._scale_e8m0,
219218
new_data,
219+
args[0]._scale_e8m0,
220220
args[0]._elem_dtype,
221221
args[0]._block_size,
222222
args[0]._orig_dtype,
@@ -241,7 +241,7 @@ def mx_slice(func, types, args, kwargs):
241241
if dim == 0:
242242
# Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now
243243
sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step)
244-
sliced_data = aten.slice.Tensor(x._data, dim, start, end, step).unsqueeze(-1)
244+
sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step).unsqueeze(-1)
245245
elif dim == 1:
246246
# Slicing along reduciton dim
247247
if start is not None:
@@ -256,7 +256,7 @@ def mx_slice(func, types, args, kwargs):
256256
f"End index {end} must be a multiple of block_size {x._block_size}"
257257
)
258258

259-
sliced_data = aten.slice.Tensor(x._data, dim, start, end, step)
259+
sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step)
260260

261261
# Calculate which scale elements to keep
262262
start_block = 0 if start is None else start // x._block_size
@@ -276,8 +276,8 @@ def mx_slice(func, types, args, kwargs):
276276
args,
277277
kwargs,
278278
MXTensor(
279-
sliced_scale,
280279
sliced_data,
280+
sliced_scale,
281281
x._elem_dtype,
282282
x._block_size,
283283
x._orig_dtype,
@@ -330,8 +330,8 @@ def autocast_to_copy(func, types, args, kwargs):
330330
# If dtype is specified, create a new MXTensor with the requested dtype
331331
if dtype is not None:
332332
res = MXTensor(
333+
tensor.qdata,
333334
tensor._scale_e8m0,
334-
tensor._data,
335335
tensor._elem_dtype,
336336
tensor._block_size,
337337
dtype,

0 commit comments

Comments
 (0)