Skip to content

Commit f406b88

Browse files
committed
Fixes #3303
1 parent eff088b commit f406b88

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

advanced_source/semi_structured_sparse.py

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
from torch.utils.benchmark import Timer
5555
SparseSemiStructuredTensor._FORCE_CUTLASS = True
5656

57+
torch.set_default_device("cuda:0")
58+
5759
# mask Linear weight to be 2:4 sparse
5860
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
5961
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()

0 commit comments

Comments
 (0)