Skip to content

Commit f3be04a

Browse files
authored
update_test (#2156)
* update_test * update * update
1 parent 00291e1 commit f3be04a

File tree

3 files changed

+55
-119
lines changed

3 files changed

+55
-119
lines changed

op_tests/test_moe.py

Lines changed: 28 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def calculateTensorsSize(*args):
367367
"--token",
368368
type=int,
369369
nargs="*",
370-
default=None,
370+
default=[128],
371371
help="""Token Num.
372372
e.g.: -m 128""",
373373
)
@@ -376,7 +376,7 @@ def calculateTensorsSize(*args):
376376
"--hidden_dim",
377377
type=int,
378378
nargs="*",
379-
default=None,
379+
default=[4096],
380380
help="""Hidden states dim.
381381
e.g.: -hd 4096""",
382382
)
@@ -385,7 +385,7 @@ def calculateTensorsSize(*args):
385385
"--inter_dim",
386386
type=int,
387387
nargs="*",
388-
default=None,
388+
default=[1024],
389389
help="""Intermediate dim.
390390
e.g.: -id 1024""",
391391
)
@@ -410,7 +410,7 @@ def calculateTensorsSize(*args):
410410
parser.add_argument(
411411
"-a",
412412
"--activation",
413-
type=str,
413+
type=dtypes.str2ActivationType,
414414
choices=[
415415
"silu",
416416
"gelu",
@@ -424,21 +424,16 @@ def calculateTensorsSize(*args):
424424

425425
args = parser.parse_args()
426426

427-
args.activation = {"gelu": ActivationType.Gelu, "silu": ActivationType.Silu}[
428-
args.activation
429-
]
430427

431428
for test in args.test:
432429
print(f"\nRunning test: {test}")
433430
if test == "test_fmoe_16_bit":
434431
print("test test_fmoe 16 bit")
435432
print("\ng1u0 no quant")
436433
for dtype in args.dtype:
437-
for m in [128, 256] if args.token is None else args.token:
438-
for hdim in (
439-
[4096, 8192] if args.hidden_dim is None else args.hidden_dim
440-
):
441-
for idim in [1024] if args.inter_dim is None else args.inter_dim:
434+
for m in args.token:
435+
for hdim in args.hidden_dim:
436+
for idim in args.inter_dim:
442437
expert = 32 if args.expert is None else args.expert
443438
topk = 5 if args.topk is None else args.topk
444439
test_fmoe(
@@ -453,11 +448,9 @@ def calculateTensorsSize(*args):
453448
)
454449
elif test == "g1u1_no_quant":
455450
for dtype in args.dtype:
456-
for m in [128, 256] if args.token is None else args.token:
457-
for hdim in (
458-
[4096, 8192] if args.hidden_dim is None else args.hidden_dim
459-
):
460-
for idim in [1024] if args.inter_dim is None else args.inter_dim:
451+
for m in args.token:
452+
for hdim in args.hidden_dim:
453+
for idim in args.inter_dim:
461454
expert = 32 if args.expert is None else args.expert
462455
topk = 5 if args.topk is None else args.topk
463456
test_fmoe(
@@ -473,11 +466,9 @@ def calculateTensorsSize(*args):
473466
)
474467
elif test == "g1u1_int8quant":
475468
for dtype in args.dtype:
476-
for m in [128, 256] if args.token is None else args.token:
477-
for hdim in (
478-
[4096, 8192] if args.hidden_dim is None else args.hidden_dim
479-
):
480-
for idim in [1024] if args.inter_dim is None else args.inter_dim:
469+
for m in args.token:
470+
for hdim in args.hidden_dim:
471+
for idim in args.inter_dim:
481472
expert = 32 if args.expert is None else args.expert
482473
topk = 5 if args.topk is None else args.topk
483474
test_fmoe(
@@ -495,11 +486,9 @@ def calculateTensorsSize(*args):
495486

496487
elif test == "g1u1_fp8quant":
497488
for dtype in args.dtype:
498-
for m in [128, 256] if args.token is None else args.token:
499-
for hdim in (
500-
[4096, 8192] if args.hidden_dim is None else args.hidden_dim
501-
):
502-
for idim in [1024] if args.inter_dim is None else args.inter_dim:
489+
for m in args.token:
490+
for hdim in args.hidden_dim:
491+
for idim in args.inter_dim:
503492
expert = 32 if args.expert is None else args.expert
504493
topk = 5 if args.topk is None else args.topk
505494
test_fmoe(
@@ -518,13 +507,9 @@ def calculateTensorsSize(*args):
518507

519508
elif test == "g1u0_int8smoothquant":
520509
for dtype in args.dtype:
521-
for m in [128] if args.token is None else args.token:
522-
for hdim in (
523-
[4096, 6144, 8192] if args.hidden_dim is None else args.hidden_dim
524-
):
525-
for idim in (
526-
[512, 1024] if args.inter_dim is None else args.inter_dim
527-
):
510+
for m in args.token:
511+
for hdim in args.hidden_dim:
512+
for idim in args.inter_dim:
528513
expert = 32 if args.expert is None else args.expert
529514
topk = 5 if args.topk is None else args.topk
530515
test_fmoe(
@@ -541,15 +526,9 @@ def calculateTensorsSize(*args):
541526

542527
elif test == "g1u1_int8smoothquant":
543528
for dtype in args.dtype:
544-
for m in [128] if args.token is None else args.token:
545-
for hdim in (
546-
[4096, 6144, 8192] if args.hidden_dim is None else args.hidden_dim
547-
):
548-
for idim in (
549-
[512, 1024, 1280, 1536]
550-
if args.inter_dim is None
551-
else args.inter_dim
552-
):
529+
for m in args.token:
530+
for hdim in args.hidden_dim:
531+
for idim in args.inter_dim:
553532
expert = 32 if args.expert is None else args.expert
554533
topk = 5 if args.topk is None else args.topk
555534
test_fmoe(
@@ -566,13 +545,9 @@ def calculateTensorsSize(*args):
566545

567546
elif test == "g1u1_fp8smoothquant":
568547
for dtype in args.dtype:
569-
for m in [128] if args.token is None else args.token:
570-
for hdim in (
571-
[4096, 6144, 8192] if args.hidden_dim is None else args.hidden_dim
572-
):
573-
for idim in (
574-
[512, 1024, 1280] if args.inter_dim is None else args.inter_dim
575-
):
548+
for m in args.token:
549+
for hdim in args.hidden_dim:
550+
for idim in args.inter_dim:
576551
expert = 32 if args.expert is None else args.expert
577552
topk = 5 if args.topk is None else args.topk
578553
test_fmoe(
@@ -588,13 +563,9 @@ def calculateTensorsSize(*args):
588563
)
589564
elif test == "g1u1_int4":
590565
for dtype in args.dtype:
591-
for m in [32, 128] if args.token is None else args.token:
592-
for hdim in (
593-
[4096, 6144] if args.hidden_dim is None else args.hidden_dim
594-
):
595-
for idim in (
596-
[1024, 4096] if args.inter_dim is None else args.inter_dim
597-
):
566+
for m in args.token:
567+
for hdim in args.hidden_dim:
568+
for idim in args.inter_dim:
598569
expert = 8 if args.expert is None else args.expert
599570
topk = 3 if args.topk is None else args.topk
600571
test_fmoe(

op_tests/test_pa.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -974,9 +974,6 @@ def test_paged_attention(
974974

975975

976976
df = []
977-
l_num_heads = [(4, 1), (8, 1), (32, 8)]
978-
l_ctx_len = [7, 26, 57, 66, 109, 128, 257, 282, 4097]
979-
l_dtype = ["fp16", "bf16"]
980977

981978
parser = argparse.ArgumentParser(
982979
formatter_class=argparse.RawTextHelpFormatter,
@@ -985,53 +982,35 @@ def test_paged_attention(
985982
parser.add_argument(
986983
"-d",
987984
"--dtype",
988-
type=str,
989-
choices=l_dtype,
990-
nargs="?",
991-
const=None,
992-
default=None,
985+
type=dtypes.str2Dtype,
986+
nargs="*",
987+
default=[dtypes.d_dtypes["fp16"], dtypes.d_dtypes["bf16"]],
993988
help="""Data type.
994989
e.g.: -d bf16""",
995990
)
996-
997991
parser.add_argument(
998992
"-n",
999993
"--num_heads",
1000994
type=dtypes.str2tuple,
1001-
choices=l_num_heads,
1002-
nargs="?",
1003-
const=None,
1004-
default=None,
995+
nargs="*",
996+
default=[(4, 1), (8, 1), (32, 8)],
1005997
help="""Number of heads (num_query_heads, num_kv_heads)
1006998
e.g.: -n 4,1""",
1007999
)
1008-
10091000
parser.add_argument(
10101001
"-c",
10111002
"--ctx_len",
10121003
type=int,
1013-
choices=l_ctx_len,
1014-
nargs="?",
1015-
const=None,
1016-
default=None,
1004+
nargs="*",
1005+
default=[7, 26, 57, 66, 109, 128, 257, 282, 4097],
10171006
help="""Context length.
10181007
e.g. -c 128""",
10191008
)
1020-
10211009
args = parser.parse_args()
1022-
if args.dtype is None:
1023-
l_dtype = [dtypes.d_dtypes[key] for key in l_dtype]
1024-
else:
1025-
l_dtype = [dtypes.d_dtypes[args.dtype]]
1026-
if args.num_heads is not None:
1027-
l_num_heads = [args.num_heads]
1028-
if args.ctx_len is not None:
1029-
l_ctx_len = [args.ctx_len]
1030-
1031-
1032-
for num_heads in l_num_heads:
1033-
for ctx_len in l_ctx_len:
1034-
for dtype in l_dtype:
1010+
1011+
for num_heads in args.num_heads:
1012+
for ctx_len in args.ctx_len:
1013+
for dtype in args.dtype:
10351014
ret = test_paged_attention(
10361015
ctx_len, 128, num_heads, 128, False, 16, dtype, "auto", 0, "cuda:0"
10371016
)

op_tests/test_quant.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from aiter import get_hip_quant, get_torch_quant, get_triton_quant
1313
import itertools
1414
import argparse
15+
import pandas as pd
1516

1617
torch.set_default_device("cuda")
1718

@@ -64,30 +65,16 @@ def test_quant(m, n, q_type, q_dtype, h_dtype):
6465
return ret
6566

6667

67-
d_quant = {
68-
"fp8_tensor": (aiter.QuantType.per_Tensor, dtypes.fp8),
69-
"fp8_token": (aiter.QuantType.per_Token, dtypes.fp8),
70-
"fp8_1x128": (aiter.QuantType.per_1x128, dtypes.fp8),
71-
"i8_token": (aiter.QuantType.per_Token, dtypes.i8),
72-
# 'fp4x2-1x32': (aiter.QuantType.per_1x32, dtypes.fp4x2),
73-
}
74-
list_dtype = ["fp16", "bf16"]
75-
l_n = [4096, 8192]
76-
l_m = [1, 2, 16, 32, 64, 128, 192, 256, 512, 1024, 16384, 163840]
77-
import pandas as pd
78-
7968
parser = argparse.ArgumentParser(
8069
formatter_class=argparse.RawTextHelpFormatter,
8170
description="config input of test",
8271
)
8372
parser.add_argument(
8473
"-d",
8574
"--dtype",
86-
type=str,
87-
choices=list_dtype,
88-
nargs="?",
89-
const=None,
90-
default=None,
75+
type=dtypes.str2Dtype,
76+
nargs="*",
77+
default=[dtypes.d_dtypes["fp16"], dtypes.d_dtypes["bf16"]],
9178
help="""Data type.
9279
e.g.: -d bf16""",
9380
)
@@ -96,7 +83,7 @@ def test_quant(m, n, q_type, q_dtype, h_dtype):
9683
"--n",
9784
type=int,
9885
nargs="*",
99-
default=None,
86+
default=[4096, 8192],
10087
help="""N of mnk.
10188
e.g.: -n 1024""",
10289
)
@@ -105,10 +92,17 @@ def test_quant(m, n, q_type, q_dtype, h_dtype):
10592
"--m",
10693
type=int,
10794
nargs="*",
108-
default=None,
95+
default=[1, 2, 16, 32, 64, 128, 192, 256, 512, 1024, 16384, 163840],
10996
help="""M of mnk.
11097
e.g.: -m 32""",
11198
)
99+
d_quant = {
100+
"fp8_tensor": (aiter.QuantType.per_Tensor, dtypes.fp8),
101+
"fp8_token": (aiter.QuantType.per_Token, dtypes.fp8),
102+
"fp8_1x128": (aiter.QuantType.per_1x128, dtypes.fp8),
103+
"i8_token": (aiter.QuantType.per_Token, dtypes.i8),
104+
# 'fp4x2-1x32': (aiter.QuantType.per_1x32, dtypes.fp4x2),
105+
}
112106
parser.add_argument(
113107
"-q",
114108
"--quant",
@@ -121,23 +115,15 @@ def test_quant(m, n, q_type, q_dtype, h_dtype):
121115
)
122116

123117
args = parser.parse_args()
124-
if args.dtype is None:
125-
list_dtype = [dtypes.d_dtypes[key] for key in list_dtype]
126-
else:
127-
list_dtype = [dtypes.d_dtypes[args.dtype]]
128118
list_quant = [d_quant[key] for key in args.quant]
129-
if args.n is not None:
130-
l_n = args.n
131-
if args.m is not None:
132-
l_m = args.m
133119

134120
for (
135121
(q_type, q_dtype),
136122
h_dtype,
137-
) in itertools.product(list_quant, list_dtype):
123+
) in itertools.product(list_quant, args.dtype):
138124
df = []
139-
for n in l_n:
140-
for m in l_m:
125+
for n in args.n:
126+
for m in args.m:
141127
ret = test_quant(m, n, q_type, q_dtype, h_dtype)
142128
df.append(ret)
143129
df = pd.DataFrame(df)

0 commit comments

Comments
 (0)