@@ -367,7 +367,7 @@ def calculateTensorsSize(*args):
367367 "--token" ,
368368 type = int ,
369369 nargs = "*" ,
370- default = None ,
370+ default = [ 128 ] ,
371371 help = """Token Num.
372372 e.g.: -m 128""" ,
373373)
@@ -376,7 +376,7 @@ def calculateTensorsSize(*args):
376376 "--hidden_dim" ,
377377 type = int ,
378378 nargs = "*" ,
379- default = None ,
379+ default = [ 4096 ] ,
380380 help = """Hidden states dim.
381381 e.g.: -hd 4096""" ,
382382)
@@ -385,7 +385,7 @@ def calculateTensorsSize(*args):
385385 "--inter_dim" ,
386386 type = int ,
387387 nargs = "*" ,
388- default = None ,
388+ default = [ 1024 ] ,
389389 help = """Intermediate dim.
390390 e.g.: -id 1024""" ,
391391)
@@ -410,7 +410,7 @@ def calculateTensorsSize(*args):
410410parser .add_argument (
411411 "-a" ,
412412 "--activation" ,
413- type = str ,
413+ type = dtypes . str2ActivationType ,
414414 choices = [
415415 "silu" ,
416416 "gelu" ,
@@ -424,21 +424,16 @@ def calculateTensorsSize(*args):
424424
425425args = parser .parse_args ()
426426
427- args .activation = {"gelu" : ActivationType .Gelu , "silu" : ActivationType .Silu }[
428- args .activation
429- ]
430427
431428for test in args .test :
432429 print (f"\n Running test: { test } " )
433430 if test == "test_fmoe_16_bit" :
434431 print ("test test_fmoe 16 bit" )
435432 print ("\n g1u0 no quant" )
436433 for dtype in args .dtype :
437- for m in [128 , 256 ] if args .token is None else args .token :
438- for hdim in (
439- [4096 , 8192 ] if args .hidden_dim is None else args .hidden_dim
440- ):
441- for idim in [1024 ] if args .inter_dim is None else args .inter_dim :
434+ for m in args .token :
435+ for hdim in args .hidden_dim :
436+ for idim in args .inter_dim :
442437 expert = 32 if args .expert is None else args .expert
443438 topk = 5 if args .topk is None else args .topk
444439 test_fmoe (
@@ -453,11 +448,9 @@ def calculateTensorsSize(*args):
453448 )
454449 elif test == "g1u1_no_quant" :
455450 for dtype in args .dtype :
456- for m in [128 , 256 ] if args .token is None else args .token :
457- for hdim in (
458- [4096 , 8192 ] if args .hidden_dim is None else args .hidden_dim
459- ):
460- for idim in [1024 ] if args .inter_dim is None else args .inter_dim :
451+ for m in args .token :
452+ for hdim in args .hidden_dim :
453+ for idim in args .inter_dim :
461454 expert = 32 if args .expert is None else args .expert
462455 topk = 5 if args .topk is None else args .topk
463456 test_fmoe (
@@ -473,11 +466,9 @@ def calculateTensorsSize(*args):
473466 )
474467 elif test == "g1u1_int8quant" :
475468 for dtype in args .dtype :
476- for m in [128 , 256 ] if args .token is None else args .token :
477- for hdim in (
478- [4096 , 8192 ] if args .hidden_dim is None else args .hidden_dim
479- ):
480- for idim in [1024 ] if args .inter_dim is None else args .inter_dim :
469+ for m in args .token :
470+ for hdim in args .hidden_dim :
471+ for idim in args .inter_dim :
481472 expert = 32 if args .expert is None else args .expert
482473 topk = 5 if args .topk is None else args .topk
483474 test_fmoe (
@@ -495,11 +486,9 @@ def calculateTensorsSize(*args):
495486
496487 elif test == "g1u1_fp8quant" :
497488 for dtype in args .dtype :
498- for m in [128 , 256 ] if args .token is None else args .token :
499- for hdim in (
500- [4096 , 8192 ] if args .hidden_dim is None else args .hidden_dim
501- ):
502- for idim in [1024 ] if args .inter_dim is None else args .inter_dim :
489+ for m in args .token :
490+ for hdim in args .hidden_dim :
491+ for idim in args .inter_dim :
503492 expert = 32 if args .expert is None else args .expert
504493 topk = 5 if args .topk is None else args .topk
505494 test_fmoe (
@@ -518,13 +507,9 @@ def calculateTensorsSize(*args):
518507
519508 elif test == "g1u0_int8smoothquant" :
520509 for dtype in args .dtype :
521- for m in [128 ] if args .token is None else args .token :
522- for hdim in (
523- [4096 , 6144 , 8192 ] if args .hidden_dim is None else args .hidden_dim
524- ):
525- for idim in (
526- [512 , 1024 ] if args .inter_dim is None else args .inter_dim
527- ):
510+ for m in args .token :
511+ for hdim in args .hidden_dim :
512+ for idim in args .inter_dim :
528513 expert = 32 if args .expert is None else args .expert
529514 topk = 5 if args .topk is None else args .topk
530515 test_fmoe (
@@ -541,15 +526,9 @@ def calculateTensorsSize(*args):
541526
542527 elif test == "g1u1_int8smoothquant" :
543528 for dtype in args .dtype :
544- for m in [128 ] if args .token is None else args .token :
545- for hdim in (
546- [4096 , 6144 , 8192 ] if args .hidden_dim is None else args .hidden_dim
547- ):
548- for idim in (
549- [512 , 1024 , 1280 , 1536 ]
550- if args .inter_dim is None
551- else args .inter_dim
552- ):
529+ for m in args .token :
530+ for hdim in args .hidden_dim :
531+ for idim in args .inter_dim :
553532 expert = 32 if args .expert is None else args .expert
554533 topk = 5 if args .topk is None else args .topk
555534 test_fmoe (
@@ -566,13 +545,9 @@ def calculateTensorsSize(*args):
566545
567546 elif test == "g1u1_fp8smoothquant" :
568547 for dtype in args .dtype :
569- for m in [128 ] if args .token is None else args .token :
570- for hdim in (
571- [4096 , 6144 , 8192 ] if args .hidden_dim is None else args .hidden_dim
572- ):
573- for idim in (
574- [512 , 1024 , 1280 ] if args .inter_dim is None else args .inter_dim
575- ):
548+ for m in args .token :
549+ for hdim in args .hidden_dim :
550+ for idim in args .inter_dim :
576551 expert = 32 if args .expert is None else args .expert
577552 topk = 5 if args .topk is None else args .topk
578553 test_fmoe (
@@ -588,13 +563,9 @@ def calculateTensorsSize(*args):
588563 )
589564 elif test == "g1u1_int4" :
590565 for dtype in args .dtype :
591- for m in [32 , 128 ] if args .token is None else args .token :
592- for hdim in (
593- [4096 , 6144 ] if args .hidden_dim is None else args .hidden_dim
594- ):
595- for idim in (
596- [1024 , 4096 ] if args .inter_dim is None else args .inter_dim
597- ):
566+ for m in args .token :
567+ for hdim in args .hidden_dim :
568+ for idim in args .inter_dim :
598569 expert = 8 if args .expert is None else args .expert
599570 topk = 3 if args .topk is None else args .topk
600571 test_fmoe (
0 commit comments