11import pytest
22import torch
3+ from torch import Tensor
4+ from contextlib import ExitStack
5+
6+ from pytorch_sparse_utils .misc import prod , unpack_sparse_tensors , _pytorch_atleast_2_5
7+ from pytorch_sparse_utils .validation import (
8+ validate_atleast_nd ,
9+ validate_dim_size ,
10+ validate_nd ,
11+ )
12+ from . import random_sparse_tensor
313
4- from pytorch_sparse_utils .validation import validate_nd , validate_dim_size , validate_atleast_nd
514
615@pytest .mark .cpu_and_cuda
716class TestValidate :
@@ -10,7 +19,7 @@ def test_validate_nd(self, device):
1019 validate_nd (tensor , 3 )
1120 with pytest .raises (
1221 (ValueError , torch .jit .Error ), # pyright: ignore[reportArgumentType]
13- match = "Expected tensor to be 4D"
22+ match = "Expected tensor to be 4D" ,
1423 ):
1524 validate_nd (tensor , 4 )
1625
@@ -19,7 +28,7 @@ def test_validate_at_least_nd(self, device):
1928 validate_atleast_nd (tensor , 3 )
2029 with pytest .raises (
2130 (ValueError , torch .jit .Error ), # pyright: ignore[reportArgumentType]
22- match = "Expected tensor to have at least"
31+ match = "Expected tensor to have at least" ,
2332 ):
2433 validate_atleast_nd (tensor , 4 )
2534
@@ -28,6 +37,43 @@ def test_validate_dim_size(self, device):
2837 validate_dim_size (tensor , dim = 0 , expected_size = 3 )
2938 with pytest .raises (
3039 (ValueError , torch .jit .Error ), # pyright: ignore[reportArgumentType]
31- match = r"Expected tensor to have shape\[0\]=4"
40+ match = r"Expected tensor to have shape\[0\]=4" ,
3241 ):
3342 validate_dim_size (tensor , dim = 0 , expected_size = 4 )
43+
44+
45+ @pytest .mark .cpu_and_cuda
46+ def test_prod (device ):
47+ test_list = [1 , 2 , 3 ]
48+ result_list = prod (test_list )
49+ assert result_list == 6
50+ assert isinstance (result_list , int )
51+
52+ test_tensor = torch .tensor (test_list , device = device )
53+ result_tensor = prod (test_tensor )
54+ assert isinstance (result_tensor , Tensor )
55+ assert result_tensor == 6
56+
57+
58+ @pytest .mark .cpu_and_cuda
59+ def test_unpack_sparse_tensors (device ):
60+ sparse_tensor = random_sparse_tensor (
61+ [8 , 16 , 16 ], [4 , 32 ], 0.5 , seed = 0 , device = device
62+ ).coalesce ()
63+
64+ batch_dict = {
65+ "X_indices" : sparse_tensor .indices (),
66+ "X_values" : sparse_tensor .values (),
67+ "X_shape" : torch .tensor (sparse_tensor .shape , device = device ),
68+ }
69+
70+ if _pytorch_atleast_2_5 :
71+ with pytest .warns (DeprecationWarning , match = "is no longer needed" ):
72+ unpacked = unpack_sparse_tensors (batch_dict )
73+ else :
74+ unpacked = unpack_sparse_tensors (batch_dict )
75+
76+ X = unpacked ["X" ]
77+ assert torch .equal (X .indices (), sparse_tensor .indices ())
78+ assert torch .equal (X .values (), sparse_tensor .values ())
79+ assert X .shape == sparse_tensor .shape
0 commit comments