@@ -73,9 +73,9 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
7373 # verify that if data.shape is (M, K) then scale.shape is (M, K // block_size)
7474 prev_dims , K = data_hp .shape [:- 1 ], data_hp .shape [- 1 ]
7575 if elem_dtype is torch .float4_e2m1fn_x2 :
76- assert data_mx ._data .shape == (* prev_dims , K // 2 )
76+ assert data_mx .qdata .shape == (* prev_dims , K // 2 )
7777 else :
78- assert data_mx ._data .shape == (* prev_dims , K )
78+ assert data_mx .qdata .shape == (* prev_dims , K )
7979 assert data_mx ._scale_e8m0 .shape == (* prev_dims , K // block_size )
8080
8181
@@ -148,8 +148,8 @@ def test_to_mx_rceil():
148148 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
149149 )
150150 torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
151- assert torch .isnan (data_mx ._data [0 ])
152- assert torch .all (data_mx ._data [1 :] == 0 )
151+ assert torch .isnan (data_mx .qdata [0 ])
152+ assert torch .all (data_mx .qdata [1 :] == 0 )
153153 # fp32 denorm
154154 # fmt: off
155155 data_hp = torch .tensor (
@@ -170,7 +170,7 @@ def test_to_mx_rceil():
170170 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
171171 )
172172 torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
173- torch .testing .assert_close (data_mx ._data , ground_truth_fp8 )
173+ torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
174174 # bf16 denorm
175175 # fmt: off
176176 data_hp = torch .tensor (
@@ -191,7 +191,7 @@ def test_to_mx_rceil():
191191 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
192192 )
193193 torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
194- torch .testing .assert_close (data_mx ._data , ground_truth_fp8 )
194+ torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
195195 # fp32 some denorm
196196 # fmt: off
197197 data_hp = torch .tensor (
@@ -222,7 +222,7 @@ def test_to_mx_rceil():
222222 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
223223 )
224224 torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
225- torch .testing .assert_close (data_mx ._data , ground_truth_fp8 )
225+ torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
226226 # bf16 some denorm
227227 # fmt: off
228228 data_hp = torch .tensor (
@@ -253,7 +253,7 @@ def test_to_mx_rceil():
253253 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
254254 )
255255 torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
256- torch .testing .assert_close (data_mx ._data , ground_truth_fp8 )
256+ torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
257257 # zero
258258 data_hp = torch .tensor ([0 ] * 32 , dtype = torch .uint32 ).view (torch .float32 )
259259 ground_truth_scale = torch .tensor ([0 ], dtype = torch .uint8 ).view (torch .float8_e8m0fnu )
@@ -264,7 +264,7 @@ def test_to_mx_rceil():
264264 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
265265 )
266266 torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
267- torch .testing .assert_close (data_mx ._data , ground_truth_fp8 )
267+ torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
268268 # fp32 normal
269269 # fmt: off
270270 data_hp = torch .tensor (
@@ -295,7 +295,7 @@ def test_to_mx_rceil():
295295 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
296296 )
297297 torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
298- torch .testing .assert_close (data_mx ._data , ground_truth_fp8 )
298+ torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
299299 # bf16 normal
300300 # fmt: off
301301 data_hp = torch .tensor (
@@ -326,7 +326,7 @@ def test_to_mx_rceil():
326326 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
327327 )
328328 torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
329- torch .testing .assert_close (data_mx ._data , ground_truth_fp8 )
329+ torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
330330
331331
332332@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -382,8 +382,8 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
382382 block_size = 4
383383 use_fp4_custom_triton_dequant_kernel = False
384384 tensor_mx = MXTensor (
385- scale_e8m0 ,
386385 data_bits ,
386+ scale_e8m0 ,
387387 elem_dtype ,
388388 block_size ,
389389 torch .float ,
@@ -473,7 +473,7 @@ def test_fp6_packing(elem_dtype, pack_fp6):
473473 else :
474474 expected_packed_shape = x .shape
475475
476- assert x_mx ._data .shape == expected_packed_shape
476+ assert x_mx .qdata .shape == expected_packed_shape
477477
478478
479479@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -505,14 +505,14 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
505505 atol = 0 ,
506506 rtol = 0 ,
507507 )
508- torch .testing .assert_close (x_mx ._data , x_mx_c ._data , atol = 0 , rtol = 0 )
508+ torch .testing .assert_close (x_mx .qdata , x_mx_c .qdata , atol = 0 , rtol = 0 )
509509
510510 to_dtype_c = torch .compile (to_dtype , fullgraph = True )
511511
512512 use_fp4_custom_triton_dequant_kernel = False
513513 pack_fp6 = False
514514 x_mx_dq = to_dtype (
515- x_mx ._data ,
515+ x_mx .qdata ,
516516 x_mx ._scale_e8m0 ,
517517 x_mx ._elem_dtype ,
518518 x_mx ._block_size ,
@@ -521,7 +521,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
521521 pack_fp6 ,
522522 )
523523 x_mx_c_dq = to_dtype_c (
524- x_mx_c ._data ,
524+ x_mx_c .qdata ,
525525 x_mx_c ._scale_e8m0 ,
526526 x_mx_c ._elem_dtype ,
527527 x_mx_c ._block_size ,
0 commit comments