Skip to content

Commit 6f01bc8

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 432877e commit 6f01bc8

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

tests/pytorch/test_group_gemm.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def generate_inputs_outputs(self) -> Tuple[List[torch.Tensor], List[torch.Tensor
6363
ref_C_list.append(ref_C)
6464
return A_list, B_list, C_list, ref_C_list
6565

66-
def test_grouped_gemm(
67-
self, atol=1e-2, rtol=1e-2, check_accuracy=True, check_performance=False
68-
):
66+
def test_grouped_gemm(self, atol=1e-2, rtol=1e-2, check_accuracy=True, check_performance=False):
6967

7068
WARM_ITERS = 10
7169
ITERS = 1000
@@ -91,11 +89,14 @@ def test_grouped_gemm(
9189
get_multi_stream_cublas_workspace(),
9290
layout=layout,
9391
m_splits=self.m_splits,
94-
accumulate=self.accumulate
92+
accumulate=self.accumulate,
9593
)
9694
torch.cuda.synchronize()
9795

98-
print(f'\n=== Accuracy Testing with Layout:{layout} GemmType:{os.getenv("NVTE_USE_CUTLASS_GROUPGEMM", "0")}')
96+
print(
97+
"\n=== Accuracy Testing with"
98+
f" Layout:{layout} GemmType:{os.getenv('NVTE_USE_CUTLASS_GROUPGEMM', '0')}"
99+
)
99100
if check_accuracy:
100101

101102
alpha = 1.0
@@ -140,7 +141,7 @@ def test_grouped_gemm(
140141
get_multi_stream_cublas_workspace(),
141142
layout=layout,
142143
m_splits=self.m_splits,
143-
accumulate=self.accumulate
144+
accumulate=self.accumulate,
144145
)
145146

146147
torch.cuda.synchronize()
@@ -154,7 +155,7 @@ def test_grouped_gemm(
154155
get_multi_stream_cublas_workspace(),
155156
layout=layout,
156157
m_splits=self.m_splits,
157-
accumulate=self.accumulate
158+
accumulate=self.accumulate,
158159
)
159160
torch.cuda.synchronize()
160161
end_time = time.perf_counter()
@@ -193,12 +194,12 @@ def run_grouped_gemm(group_config, check_performance, transa, transb, accumulate
193194
[4096, 768, 2048],
194195
[4096, 768, 2048],
195196
[4096, 768, 2048],
196-
[4096, 768, 2048]
197+
[4096, 768, 2048],
197198
],
198199
"accumulate": False,
199200
"check_performance": True,
200201
"transa": False,
201-
"transb": True
202+
"transb": True,
202203
},
203204
{
204205
"group_config": [
@@ -217,16 +218,16 @@ def run_grouped_gemm(group_config, check_performance, transa, transb, accumulate
217218
[2048, 768, 2048],
218219
[2048, 768, 2048],
219220
[2048, 768, 2048],
220-
[2048, 768, 2048]
221+
[2048, 768, 2048],
221222
],
222223
"accumulate": False,
223224
"check_performance": True,
224225
"transa": False,
225-
"transb": True
226-
}
226+
"transb": True,
227+
},
227228
]
228229
}
229-
230+
230231
for i, case in enumerate(config_data["configs"]):
231232
group_config = [tuple(x) for x in case["group_config"]]
232233
accumulate = case.get("accumulate", False)

transformer_engine/pytorch/cpp_extensions/gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def general_grouped_gemm(
133133
use_bias: bool = False,
134134
use_split_accumulator: bool = False,
135135
D_dtype: Optional[tex.DType] = None,
136-
single_output=False
136+
single_output=False,
137137
) -> Tuple[List[torch.Tensor], ...]:
138138
"""
139139
TN layout Grouped GEMM with fp8 inputs.

0 commit comments

Comments
 (0)