Skip to content

Commit 73dc0ca

Browse files
committed
add tests for minor functions
1 parent 6566b00 commit 73dc0ca

File tree

2 files changed

+57
-7
lines changed

2 files changed

+57
-7
lines changed

pytorch_sparse_utils/misc.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import re
33
from typing import TYPE_CHECKING, overload
4+
import warnings
45

56
import torch
67
from torch import Tensor
@@ -95,9 +96,12 @@ def unpack_sparse_tensors(batch: dict[str, Tensor]) -> dict[str, Tensor]:
9596
sparse torch.Tensor format
9697
"""
9798
if _pytorch_atleast_2_5:
98-
raise DeprecationWarning(
99-
"`unpack_sparse_tensors` is no longer needed as of Pytorch 2.5",
100-
"which added native support for pinned_memory=True for sparse tensors",
99+
warnings.warn(
100+
(
101+
"`unpack_sparse_tensors` is no longer needed as of Pytorch 2.5"
102+
"which added native support for pinned_memory=True for sparse tensors"
103+
),
104+
DeprecationWarning,
101105
)
102106
prefixes_indices = [
103107
match[0]

tests/test_misc.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
import pytest
22
import torch
3+
from torch import Tensor
4+
from contextlib import ExitStack
5+
6+
from pytorch_sparse_utils.misc import prod, unpack_sparse_tensors, _pytorch_atleast_2_5
7+
from pytorch_sparse_utils.validation import (
8+
validate_atleast_nd,
9+
validate_dim_size,
10+
validate_nd,
11+
)
12+
from . import random_sparse_tensor
313

4-
from pytorch_sparse_utils.validation import validate_nd, validate_dim_size, validate_atleast_nd
514

615
@pytest.mark.cpu_and_cuda
716
class TestValidate:
@@ -10,7 +19,7 @@ def test_validate_nd(self, device):
1019
validate_nd(tensor, 3)
1120
with pytest.raises(
1221
(ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType]
13-
match="Expected tensor to be 4D"
22+
match="Expected tensor to be 4D",
1423
):
1524
validate_nd(tensor, 4)
1625

@@ -19,7 +28,7 @@ def test_validate_at_least_nd(self, device):
1928
validate_atleast_nd(tensor, 3)
2029
with pytest.raises(
2130
(ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType]
22-
match="Expected tensor to have at least"
31+
match="Expected tensor to have at least",
2332
):
2433
validate_atleast_nd(tensor, 4)
2534

@@ -28,6 +37,43 @@ def test_validate_dim_size(self, device):
2837
validate_dim_size(tensor, dim=0, expected_size=3)
2938
with pytest.raises(
3039
(ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType]
31-
match=r"Expected tensor to have shape\[0\]=4"
40+
match=r"Expected tensor to have shape\[0\]=4",
3241
):
3342
validate_dim_size(tensor, dim=0, expected_size=4)
43+
44+
45+
@pytest.mark.cpu_and_cuda
46+
def test_prod(device):
47+
test_list = [1, 2, 3]
48+
result_list = prod(test_list)
49+
assert result_list == 6
50+
assert isinstance(result_list, int)
51+
52+
test_tensor = torch.tensor(test_list, device=device)
53+
result_tensor = prod(test_tensor)
54+
assert isinstance(result_tensor, Tensor)
55+
assert result_tensor == 6
56+
57+
58+
@pytest.mark.cpu_and_cuda
59+
def test_unpack_sparse_tensors(device):
60+
sparse_tensor = random_sparse_tensor(
61+
[8, 16, 16], [4, 32], 0.5, seed=0, device=device
62+
).coalesce()
63+
64+
batch_dict = {
65+
"X_indices": sparse_tensor.indices(),
66+
"X_values": sparse_tensor.values(),
67+
"X_shape": torch.tensor(sparse_tensor.shape, device=device),
68+
}
69+
70+
if _pytorch_atleast_2_5:
71+
with pytest.warns(DeprecationWarning, match="is no longer needed"):
72+
unpacked = unpack_sparse_tensors(batch_dict)
73+
else:
74+
unpacked = unpack_sparse_tensors(batch_dict)
75+
76+
X = unpacked["X"]
77+
assert torch.equal(X.indices(), sparse_tensor.indices())
78+
assert torch.equal(X.values(), sparse_tensor.values())
79+
assert X.shape == sparse_tensor.shape

0 commit comments

Comments
 (0)