Skip to content

Commit 14147f6

Browse files
Test fix
1 parent 941681d commit 14147f6

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

tests/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
def get_available_devices(no_cpu=False):
2222
if "BNB_TEST_DEVICE" in os.environ:
2323
# If the environment variable is set, use it directly.
24-
return [d for d in os.environ["BNB_TEST_DEVICE"] if d.lower() != "cpu"]
24+
device = os.environ["BNB_TEST_DEVICE"]
25+
return [] if no_cpu and device == "cpu" else [device]
2526

2627
devices = [] if HIP_ENVIRONMENT else ["cpu"] if not no_cpu else []
2728

tests/test_optim.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def rm_path(path):
170170
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
171171
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
172172
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
173+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
173174
def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
174175
if optim_name.startswith("paged_") and sys.platform == "win32":
175176
pytest.skip("Paged optimizers can have issues on Windows.")
@@ -250,6 +251,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
250251
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
251252
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
252253
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
254+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
253255
def test_global_config(dim1, dim2, gtype, device):
254256
if dim1 == 1 and dim2 == 1:
255257
return
@@ -306,6 +308,7 @@ def test_global_config(dim1, dim2, gtype, device):
306308
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
307309
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
308310
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
311+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
309312
def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
310313
torch.set_printoptions(precision=6)
311314

0 commit comments

Comments
 (0)