@@ -30,11 +30,10 @@ class GatherDescriptor(Structure):
3030
3131infiniopGatherDescriptor_t = POINTER (GatherDescriptor )
3232
33- def gather (input , indices , axis ):
34- np_input = input .numpy ()
35- np_indices = indices .numpy ()
36- np_output = np .take (np_input , np_indices , axis = axis )
37- return torch .from_numpy (np_output )
33+ def gather (x , indices , axis = 0 ):
34+ idx = [slice (None )] * x .ndim
35+ idx [axis ] = indices
36+ return x [tuple (idx )]
3837
3938def tuple_to_void_p (py_tuple : Tuple ):
4039 array = ctypes .c_int64 * len (py_tuple )
@@ -55,16 +54,19 @@ def test(
5554 tensor_dtype = torch .float16
5655):
5756 print (
58- f"Testing clip on { torch_device } with x_shape:{ x_shape } dtype:{ tensor_dtype } "
57+ f"Testing gather on { torch_device } with x_shape:{ x_shape } dtype:{ tensor_dtype } "
5958 )
6059 x = torch .randn (x_shape , dtype = tensor_dtype , device = torch_device )
61- if len (x .shape ) == 2 :
62- indices = torch .tensor (2 , dtype = torch .int64 , device = torch_device )
63- elif len (x .shape ) == 3 :
64- indices = torch .tensor ([[0 , 1 ], [1 , 2 ]], dtype = torch .int64 , device = torch_device )
60+ if isinstance (indices_shape , int ):
61+ indices_shape_tuple = (indices_shape ,)
62+ else :
63+ indices_shape_tuple = tuple (indices_shape )
64+ indices = torch .randint (0 , x .shape [axis ], indices_shape_tuple ,
65+ device = torch_device ).type (torch .int64 )
6566 dst = torch .randn (inferShape (x_shape , indices .shape , axis ), dtype = tensor_dtype , device = torch_device )
67+
6668 ans = gather (x , indices , axis )
67- axis = axis
69+
6870 x_tensor = to_tensor (x , lib )
6971 indices_tensor = to_tensor (indices , lib )
7072 dst_tensor = to_tensor (dst , lib )
@@ -106,25 +108,35 @@ def test(
106108 )
107109 elapsed = (time .time () - start_time ) / NUM_ITERATIONS
108110 print (f"lib time: { elapsed :10f} " )
109- print (f"pytorch ans: { ans } " )
110- print (f"lib ans: { dst } " )
111+ ans = ans .to (torch_device )
111112 assert torch .allclose (dst , ans , atol = 0 , rtol = 0 )
112113 check_error (lib .infiniopDestroyGatherDescriptor (descriptor ))
113114
114115def test_cpu (lib , test_cases ):
115116 device = DeviceEnum .DEVICE_CPU
116117 handle = create_handle (lib , device )
117- for x_shape , indices_shape , axis in test_cases :
118- test (lib , handle , "cpu" , x_shape , indices_shape , axis , tensor_dtype = torch .float16 )
119- print ("\n " )
120- #test(lib, handle, "cpu", x_shape, axes, tensor_dtype=torch.float32)
118+ for x_shape , indices_shape , axis , tensor_dtype in test_cases :
119+ test (lib , handle , "cpu" , x_shape , indices_shape , axis , tensor_dtype = tensor_dtype )
120+ destroy_handle (lib , handle )
121+
122+ def test_cuda (lib , test_cases ):
123+ device = DeviceEnum .DEVICE_CUDA
124+ handle = create_handle (lib , device )
125+ for x_shape , indices_shape , axis , tensor_dtype in test_cases :
126+ test (lib , handle , "cuda" , x_shape , indices_shape , axis , tensor_dtype = tensor_dtype )
121127 destroy_handle (lib , handle )
122128
123129
124130if __name__ == "__main__" :
125131 test_cases = [
126- ((3 , 4 ), (2 ), 0 ),
127- ((2 , 3 , 4 ), (2 , 2 ), 1 ),
132+ ((3 , 4 ), (2 ), 0 , torch .float32 ),
133+ ((64 , 64 ), (64 , 64 ), 0 , torch .float32 ),
134+ ((64 , 64 ), (64 , 64 ), 1 , torch .float32 ),
135+ ((2 , 3 , 4 ), (2 , 2 ), 1 , torch .float32 ),
136+ ((64 , 64 ), (64 , 64 ), 0 , torch .float16 ),
137+ ((64 , 64 ), (64 , 64 ), 1 , torch .float16 ),
138+ ((8 , 8 , 8 , 8 , 8 ), (8 , 8 ), 0 , torch .float16 ),
139+ ((8 , 8 , 8 , 8 , 8 ), (8 , 8 ), 2 , torch .float16 ),
128140 ]
129141 args = get_args ()
130142 lib = open_lib ()
@@ -144,5 +156,8 @@ def test_cpu(lib, test_cases):
144156 ]
145157 lib .infiniopDestroyGatherDescriptor .restype = c_int32
146158 lib .infiniopDestroyGatherDescriptor .argtypes = [infiniopGatherDescriptor_t ]
147- test_cpu (lib , test_cases )
159+ if args .cuda :
160+ test_cuda (lib , test_cases )
161+ if args .cpu :
162+ test_cpu (lib , test_cases )
148163 print ("All tests passed!" )
0 commit comments