We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent eff088b commit f406b88Copy full SHA for f406b88
advanced_source/semi_structured_sparse.py
@@ -54,6 +54,8 @@
54
from torch.utils.benchmark import Timer
55
SparseSemiStructuredTensor._FORCE_CUTLASS = True
56
57
+torch.set_default_device("cuda:0")
58
+
59
# mask Linear weight to be 2:4 sparse
60
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
61
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
0 commit comments