1- from ctypes import POINTER , Structure , c_int32 , c_void_p , c_uint64
1+ from ctypes import POINTER , Structure , c_int32 , c_void_p , c_uint64 , c_int64
22import ctypes
33import sys
44import os
55
6- # 调整路径以导入 operatorspy 模块
76sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." , ".." )))
87from operatorspy import (
98 open_lib ,
2322
2423class Inplace (Enum ):
2524 OUT_OF_PLACE = auto ()
26- # 对于 concat 算子,通常不支持 in-place 操作,因此这里只保留 OUT_OF_PLACE
27- # 你可以根据实际需求扩展其他选项
28- # INPLACE_A = auto()
29- # INPLACE_B = auto()
30-
3125
3226class ConcatDescriptor (Structure ):
3327 _fields_ = [("device" , c_int32 ),]
@@ -37,7 +31,6 @@ class ConcatDescriptor(Structure):
3731
3832
3933def concat_py (* tensors , dim = 0 ):
40- """使用 PyTorch 进行拼接的辅助函数"""
4134 return torch .cat (tensors , dim = dim )
4235
4336
@@ -58,72 +51,52 @@ def test(
5851 f"Testing Concat on { torch_device } with output_shape:{ c_shape } , input_shapes:{ input_shapes } , axis:{ axis } , dtype:{ tensor_dtype } , inplace: { inplace .name } "
5952 )
6053
61- # 创建输入张量
6254 inputs = [torch .rand (shape , dtype = tensor_dtype ).to (torch_device ) for shape in input_shapes ]
6355
6456 for idx , tensor in enumerate (inputs ):
6557 print (f"Input { idx } :" )
6658 print (tensor )
6759 print ("-" * 50 )
6860
69- # 创建输出张量
7061 if inplace == Inplace .OUT_OF_PLACE :
7162 c = torch .zeros (c_shape , dtype = tensor_dtype ).to (torch_device )
7263 else :
73- # 对于 concat,通常不支持 in-place 操作,因此这里简化为 OUT_OF_PLACE
7464 c = torch .zeros (c_shape , dtype = tensor_dtype ).to (torch_device )
7565
76- # 使用 PyTorch 进行拼接,作为参考答案
7766 ans = concat_py (* inputs , dim = axis )
78-
79- print ("ans:" ,ans )
80- print ("-" * 50 )
8167
82- # 将张量转换为 infiniop 所需的格式
8368 input_tensors = [to_tensor (t , lib ) for t in inputs ]
8469 c_tensor = to_tensor (c , lib ) if inplace == Inplace .OUT_OF_PLACE else to_tensor (c , lib )
8570
86- # 创建 Concat 描述符
8771 descriptor = infiniopConcatDescriptor_t ()
88-
89- # 准备输入描述符数组
72+
9073 num_inputs = len (input_tensors )
9174 input_desc_array_type = infiniopTensorDescriptor_t * num_inputs
9275 input_desc_array = input_desc_array_type (* [t .descriptor for t in input_tensors ])
9376
94- # 创建描述符
9577 check_error (
9678 lib .infiniopCreateConcatDescriptor (
9779 handle ,
9880 ctypes .byref (descriptor ),
99- c_tensor .descriptor , # 使用 c_tensor 的描述符
100- input_desc_array , # 输入张量描述符数组
81+ c_tensor .descriptor ,
82+ input_desc_array ,
10183 c_uint64 (num_inputs ),
102- c_uint64 (axis ),
84+ c_int64 (axis ),
10385 )
10486 )
10587
106- print ("c1:" ,c )
107- print ("-" * 50 )
108-
109- # 执行拼接操作
11088 input_data_ptrs = (c_void_p * num_inputs )(* [t .data for t in input_tensors ])
11189 check_error (
11290 lib .infiniopConcat (
11391 descriptor ,
11492 c_tensor .data ,
11593 ctypes .cast (input_data_ptrs , POINTER (c_void_p )),
116- None # 假设不需要流
94+ None
11795 )
11896 )
119-
120- print ("c2:" ,c )
121- print ("-" * 50 )
12297
123- # 验证结果
124- assert torch .allclose (c , ans , atol = 0 , rtol = 1e-5 ), "Concat result does not match PyTorch's result."
98+ assert torch .allclose (c , ans , atol = 0 , rtol = 0 ), "Concat result does not match PyTorch's result."
12599
126- # 销毁描述符
127100 check_error (lib .infiniopDestroyConcatDescriptor (descriptor ))
128101
129102
@@ -154,53 +127,85 @@ def test_bang(lib, test_cases):
154127
155128
156129if __name__ == "__main__" :
157- # 定义测试用例
130+
158131 test_cases = [
159- # (output_shape, axis, input_shapes, inplace)
160-
161- ((6 , 3 ), 0 , [(2 , 3 ), (4 , 3 )], Inplace .OUT_OF_PLACE ),
162- # ((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE),
163- # ((3, 7), 1, [(3, 2), (3, 4), (3,1)], Inplace.OUT_OF_PLACE),
164- # ((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE),
165- # ((1, 1), 0, [(1, 1)], Inplace.OUT_OF_PLACE),
166- # ((4, 5, 6), 0, [(1, 5, 6), (3, 5, 6)], Inplace.OUT_OF_PLACE),
167- # ((2, 3, 6), 2, [(2, 3, 2), (2, 3, 4)], Inplace.OUT_OF_PLACE),
168-
169- # 添加更多测试用例以覆盖不同的维度和拼接轴
170- # ((2, 10, 3), 1, [(2, 5, 3), (2, 2, 3),(2,3,3)], Inplace.OUT_OF_PLACE), # 拼接沿第二维
132+
133+ ((6 ,), 0 , [(2 ,), (4 ,)], Inplace .OUT_OF_PLACE ),
134+
135+ ((6 , 3 ), 0 , [(2 , 3 ), (4 , 3 )], Inplace .OUT_OF_PLACE ),
136+ ((3 , 6 ), 1 , [(3 , 2 ), (3 , 4 )], Inplace .OUT_OF_PLACE ),
137+ ((3 , 7 ), 1 , [(3 , 2 ), (3 , 4 ), (3 , 1 )], Inplace .OUT_OF_PLACE ),
138+ ((3 , 3 , 10 ), 2 , [(3 , 3 , 4 ), (3 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
139+
140+ ((4 , 3 , 6 ), 0 , [(3 , 3 , 6 ), (1 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
141+ ((2 , 6 , 3 ), 1 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ),
142+ ((2 , 3 , 6 ), 2 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ),
143+
144+ ((4 , 3 , 5 , 6 ), 0 , [(1 , 3 , 5 , 6 ), (3 , 3 , 5 , 6 )], Inplace .OUT_OF_PLACE ),
145+ ((2 , 5 , 5 , 6 ), 1 , [(2 , 3 , 5 , 6 ), (2 , 2 , 5 , 6 )], Inplace .OUT_OF_PLACE ),
146+ ((2 , 3 , 5 , 6 ), 2 , [(2 , 3 , 2 , 6 ), (2 , 3 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
147+ ((2 , 3 , 5 , 6 ), 3 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 )], Inplace .OUT_OF_PLACE ),
148+ ((2 , 3 , 5 , 15 ), 3 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 ), (2 , 3 , 5 , 9 )], Inplace .OUT_OF_PLACE ),
149+
150+ ((4 , 2 , 3 , 4 , 5 ), 0 , [(1 , 2 , 3 , 4 , 5 ), (3 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
151+ ((2 , 4 , 3 , 2 , 5 ), 1 , [(2 , 2 , 3 , 2 , 5 ), (2 , 2 , 3 , 2 , 5 )], Inplace .OUT_OF_PLACE ),
152+ ((1 , 2 , 4 , 4 , 5 ), 2 , [(1 , 2 , 2 , 4 , 5 ), (1 , 2 , 2 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
153+ ((1 , 2 , 3 , 8 , 5 ), 3 , [(1 , 2 , 3 , 4 , 5 ), (1 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
154+ ((1 , 2 , 3 , 4 , 5 ), 4 , [(1 , 2 , 3 , 4 , 3 ), (1 , 2 , 3 , 4 , 2 )], Inplace .OUT_OF_PLACE ),
155+ ((4 , 14 , 3 , 4 , 5 ), 1 , [(4 , 3 , 3 , 4 , 5 ), (4 , 5 , 3 , 4 , 5 ), (4 , 6 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
156+
157+
158+ ((6 ,), - 1 , [(2 ,), (4 ,)], Inplace .OUT_OF_PLACE ),
159+
160+ ((6 , 3 ), - 2 , [(2 , 3 ), (4 , 3 )], Inplace .OUT_OF_PLACE ),
161+ ((3 , 6 ), - 1 , [(3 , 2 ), (3 , 4 )], Inplace .OUT_OF_PLACE ),
162+ ((3 , 7 ), - 1 , [(3 , 2 ), (3 , 4 ), (3 , 1 )], Inplace .OUT_OF_PLACE ),
163+ ((3 , 3 , 10 ), - 1 , [(3 , 3 , 4 ), (3 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
164+
165+ ((4 , 3 , 6 ), - 3 , [(3 , 3 , 6 ), (1 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
166+ ((2 , 6 , 3 ), - 2 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ),
167+ ((2 , 3 , 6 ), - 1 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ),
168+
169+ ((4 , 3 , 5 , 6 ), - 4 , [(1 , 3 , 5 , 6 ), (3 , 3 , 5 , 6 )], Inplace .OUT_OF_PLACE ),
170+ ((2 , 5 , 5 , 6 ), - 3 , [(2 , 3 , 5 , 6 ), (2 , 2 , 5 , 6 )], Inplace .OUT_OF_PLACE ),
171+ ((2 , 3 , 5 , 6 ), - 2 , [(2 , 3 , 2 , 6 ), (2 , 3 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
172+ ((2 , 3 , 5 , 6 ), - 1 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 )], Inplace .OUT_OF_PLACE ),
173+ ((2 , 3 , 5 , 15 ), - 1 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 ), (2 , 3 , 5 , 9 )], Inplace .OUT_OF_PLACE ),
174+
175+ ((4 , 2 , 3 , 4 , 5 ), - 5 , [(1 , 2 , 3 , 4 , 5 ), (3 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
176+ ((2 , 4 , 3 , 2 , 5 ), - 4 , [(2 , 2 , 3 , 2 , 5 ), (2 , 2 , 3 , 2 , 5 )], Inplace .OUT_OF_PLACE ),
177+ ((1 , 2 , 4 , 4 , 5 ), - 3 , [(1 , 2 , 2 , 4 , 5 ), (1 , 2 , 2 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
178+ ((1 , 2 , 3 , 8 , 5 ), - 2 , [(1 , 2 , 3 , 4 , 5 ), (1 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
179+ ((1 , 2 , 3 , 4 , 5 ), - 1 , [(1 , 2 , 3 , 4 , 3 ), (1 , 2 , 3 , 4 , 2 )], Inplace .OUT_OF_PLACE ),
180+ ((4 , 14 , 3 , 4 , 5 ), - 4 , [(4 , 3 , 3 , 4 , 5 ), (4 , 5 , 3 , 4 , 5 ), (4 , 6 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
171181 ]
172182
173183 args = get_args ()
174184 lib = open_lib ()
175185
176- # 绑定 C++ 函数
177- # 创建 Concat 描述符
178186 lib .infiniopCreateConcatDescriptor .restype = c_int32
179187 lib .infiniopCreateConcatDescriptor .argtypes = [
180188 infiniopHandle_t ,
181189 POINTER (infiniopConcatDescriptor_t ),
182- infiniopTensorDescriptor_t , # 输出张量描述符
183- POINTER (infiniopTensorDescriptor_t ), # 输入张量描述符数组
184- c_uint64 , # 输入张量数量
185- c_uint64 , # 拼接轴
190+ infiniopTensorDescriptor_t ,
191+ POINTER (infiniopTensorDescriptor_t ),
192+ c_uint64 , # nums_input
193+ c_int64 , # axis
186194 ]
187195
188- # 执行 Concat
189196 lib .infiniopConcat .restype = c_int32
190197 lib .infiniopConcat .argtypes = [
191198 infiniopConcatDescriptor_t ,
192- c_void_p , # 输出数据指针
193- POINTER (c_void_p ), # 输入数据指针数组
194- c_void_p , # 流(假设为 NULL)
199+ c_void_p ,
200+ POINTER (c_void_p ),
201+ c_void_p ,
195202 ]
196203
197- # 销毁 Concat 描述符
198204 lib .infiniopDestroyConcatDescriptor .restype = c_int32
199205 lib .infiniopDestroyConcatDescriptor .argtypes = [
200206 infiniopConcatDescriptor_t ,
201207 ]
202208
203- # 根据命令行参数执行测试
204209 if args .cpu :
205210 test_cpu (lib , test_cases )
206211 if args .cuda :
0 commit comments