Skip to content

fixing torchao rocm ci test #2789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torchao.quantization.utils import compute_error
from torchao.sparsity.sparse_api import apply_fake_sparsity
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
)
Expand All @@ -38,6 +39,7 @@ class TestInt4MarlinSparseTensor(TestCase):
def setUp(self):
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

@skip_if_rocm("ROCm enablement in progress")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it also fail in test_to_device? Maybe need to add it there too? (or move the skip to the class itself)

@parametrize("config", [BF16_ACT_CONFIG])
@parametrize(
"sizes",
Expand Down Expand Up @@ -65,6 +67,7 @@ def test_linear(self, config, sizes):
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)

@skip_if_rocm("ROCm enablement in progress")
@unittest.skip("Fix later")
@parametrize("config", [BF16_ACT_CONFIG])
def test_to_device(self, config):
Expand All @@ -81,6 +84,7 @@ def test_to_device(self, config):
quantize_(linear, config)
linear.to(device)

@skip_if_rocm("ROCm enablement in progress")
@parametrize("config", [BF16_ACT_CONFIG])
def test_module_path(self, config):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
Expand Down
Loading