@@ -318,7 +318,7 @@ class Tensor4DPermuteBMM0213RowMajor : public PermuteBase {
318
318
LongIndex operator ()(MatrixCoord coord) const {
319
319
320
320
// The batch index for BMM
321
- Index BMM_batch_idx = blockIdx.z ;
321
+ auto BMM_batch_idx = Index ( blockIdx.z ) ;
322
322
323
323
// [i,j,k,l] -> [i,k,j,l]
324
324
Index l = coord.column ();
@@ -381,7 +381,7 @@ class Tensor4DPermuteBMM0213RowMajorInverse : public PermuteBase {
381
381
LongIndex operator ()(MatrixCoord coord) const {
382
382
383
383
// The batch index for BMM
384
- Index BMM_batch_idx = blockIdx.z ;
384
+ auto BMM_batch_idx = Index ( blockIdx.z ) ;
385
385
386
386
// The following assumes grouping [(D0)->batch, (D2)->row, (D1,D3)->col]
387
387
Index l = coord.column () % D3_;
@@ -453,7 +453,7 @@ class Tensor4DPermuteBMM0321ColumnMajor : public PermuteBase {
453
453
CUTLASS_HOST_DEVICE
454
454
LongIndex operator ()(MatrixCoord coord) const {
455
455
456
- Index BMM_batch_idx = blockIdx.z ;
456
+ auto BMM_batch_idx = Index ( blockIdx.z ) ;
457
457
458
458
// [i,j,k,l] -> [i,k,j,l]
459
459
Index l = coord.column ();
@@ -514,7 +514,7 @@ class Tensor4DPermuteBMM0321ColumnMajorInverse : public PermuteBase {
514
514
CUTLASS_HOST_DEVICE
515
515
LongIndex operator ()(MatrixCoord coord) const {
516
516
517
- Index BMM_batch_idx = blockIdx.z ;
517
+ auto BMM_batch_idx = Index ( blockIdx.z ) ;
518
518
519
519
// The following assumes grouping [(D0)->batch, (D1,D2)->row, (D3)->col]
520
520
Index l = coord.column ();
0 commit comments