Skip to content

Commit b42385d

Browse files
author
Min Yang
committed
feat: add cutlass group gemm support
1 parent cae1c43 commit b42385d

File tree

15 files changed

+1380
-24
lines changed

15 files changed

+1380
-24
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
[submodule "3rdparty/cudnn-frontend"]
55
path = 3rdparty/cudnn-frontend
66
url = https://github.com/NVIDIA/cudnn-frontend.git
7+
[submodule "transformer_engine/common/gemm/cutlass"]
8+
path = transformer_engine/common/gemm/cutlass
9+
url = https://github.com/NVIDIA/cutlass.git

configs/tests_new_nt_cutlass.json

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
{
2+
"configs": [
3+
{
4+
"group_config": [
5+
[
6+
4096,
7+
768,
8+
2048
9+
],
10+
[
11+
4096,
12+
768,
13+
2048
14+
],
15+
[
16+
4096,
17+
768,
18+
2048
19+
],
20+
[
21+
4096,
22+
768,
23+
2048
24+
],
25+
[
26+
4096,
27+
768,
28+
2048
29+
],
30+
[
31+
4096,
32+
768,
33+
2048
34+
],
35+
[
36+
4096,
37+
768,
38+
2048
39+
],
40+
[
41+
4096,
42+
768,
43+
2048
44+
]
45+
],
46+
"gemm_type": "cutlass",
47+
"check_performance": true,
48+
"transa": true,
49+
"transb": false
50+
},
51+
{
52+
"group_config": [
53+
[
54+
2048,
55+
768,
56+
2048
57+
],
58+
[
59+
2048,
60+
768,
61+
2048
62+
],
63+
[
64+
2048,
65+
768,
66+
2048
67+
],
68+
[
69+
2048,
70+
768,
71+
2048
72+
],
73+
[
74+
2048,
75+
768,
76+
2048
77+
],
78+
[
79+
2048,
80+
768,
81+
2048
82+
],
83+
[
84+
2048,
85+
768,
86+
2048
87+
],
88+
[
89+
2048,
90+
768,
91+
2048
92+
],
93+
[
94+
2048,
95+
768,
96+
2048
97+
],
98+
[
99+
2048,
100+
768,
101+
2048
102+
],
103+
[
104+
2048,
105+
768,
106+
2048
107+
],
108+
[
109+
2048,
110+
768,
111+
2048
112+
],
113+
[
114+
2048,
115+
768,
116+
2048
117+
],
118+
[
119+
2048,
120+
768,
121+
2048
122+
],
123+
[
124+
2048,
125+
768,
126+
2048
127+
],
128+
[
129+
2048,
130+
768,
131+
2048
132+
]
133+
],
134+
"gemm_type": "cutlass",
135+
"check_performance": true,
136+
"transa": true,
137+
"transb": false
138+
}
139+
]
140+
}

configs/tests_new_nt_te.json

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
{
2+
"configs": [
3+
{
4+
"group_config": [
5+
[
6+
4096,
7+
768,
8+
2048
9+
],
10+
[
11+
4096,
12+
768,
13+
2048
14+
],
15+
[
16+
4096,
17+
768,
18+
2048
19+
],
20+
[
21+
4096,
22+
768,
23+
2048
24+
],
25+
[
26+
4096,
27+
768,
28+
2048
29+
],
30+
[
31+
4096,
32+
768,
33+
2048
34+
],
35+
[
36+
4096,
37+
768,
38+
2048
39+
],
40+
[
41+
4096,
42+
768,
43+
2048
44+
]
45+
],
46+
"gemm_type": "te",
47+
"check_performance": true,
48+
"transa": true,
49+
"transb": false
50+
},
51+
{
52+
"group_config": [
53+
[
54+
2048,
55+
768,
56+
2048
57+
],
58+
[
59+
2048,
60+
768,
61+
2048
62+
],
63+
[
64+
2048,
65+
768,
66+
2048
67+
],
68+
[
69+
2048,
70+
768,
71+
2048
72+
],
73+
[
74+
2048,
75+
768,
76+
2048
77+
],
78+
[
79+
2048,
80+
768,
81+
2048
82+
],
83+
[
84+
2048,
85+
768,
86+
2048
87+
],
88+
[
89+
2048,
90+
768,
91+
2048
92+
],
93+
[
94+
2048,
95+
768,
96+
2048
97+
],
98+
[
99+
2048,
100+
768,
101+
2048
102+
],
103+
[
104+
2048,
105+
768,
106+
2048
107+
],
108+
[
109+
2048,
110+
768,
111+
2048
112+
],
113+
[
114+
2048,
115+
768,
116+
2048
117+
],
118+
[
119+
2048,
120+
768,
121+
2048
122+
],
123+
[
124+
2048,
125+
768,
126+
2048
127+
],
128+
[
129+
2048,
130+
768,
131+
2048
132+
]
133+
],
134+
"gemm_type": "te",
135+
"check_performance": true,
136+
"transa": true,
137+
"transb": false
138+
}
139+
]
140+
}

0 commit comments

Comments
 (0)