Skip to content

Commit ce6182f

Browse files
committed
[CUTLASS] Fix AOT
1 parent 32cf9e7 commit ce6182f

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

src/runtime/contrib/cutlass/fp16_group_gemm.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <cuda_fp16.h>
2121
#include <float.h>
22+
#include <tvm/ffi/extra/c_env_api.h>
2223
#include <tvm/ffi/function.h>
2324
#include <tvm/runtime/ndarray.h>
2425

@@ -36,7 +37,8 @@ void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight, NDArray indptr, NDAr
3637
NDArray out) {
3738
// Workspace is used for storing device-side group gemm arguments and cutlass internal workspace.
3839
// Recommened size is 4MB.
39-
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id));
40+
cudaStream_t stream =
41+
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id));
4042
CHECK_EQ(x->ndim, 2);
4143
CHECK_EQ(weight->ndim, 3);
4244
CHECK_EQ(indptr->ndim, 1);
@@ -47,7 +49,6 @@ void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight, NDArray indptr, NDAr
4749
int k = weight->shape[2];
4850
float alpha = 1.0f;
4951
float beta = 0.0f;
50-
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
5152

5253
if (DataType(x->dtype) == DataType::Float(16)) {
5354
CHECK(DataType(weight->dtype) == DataType::Float(16));

src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <cuda_fp16.h>
2121
#include <float.h>
22+
#include <tvm/ffi/extra/c_env_api.h>
2223
#include <tvm/ffi/function.h>
2324
#include <tvm/ffi/reflection/registry.h>
2425
#include <tvm/runtime/ndarray.h>

0 commit comments

Comments
 (0)