Skip to content

Commit 9d25304

Browse files
committed
add cpu concat
1 parent f74eca1 commit 9d25304

File tree

5 files changed

+645
-0
lines changed

5 files changed

+645
-0
lines changed

include/ops/concat/concat.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef CONCAT_H
2+
#define CONCAT_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
// Concat描述符结构
8+
typedef struct ConcatDescriptor {
9+
Device device; // 设备类型(例如 DevCpu、DevNvGpu)
10+
uint64_t axis; // 拼接轴(从0开始)
11+
} ConcatDescriptor;
12+
13+
typedef ConcatDescriptor *infiniopConcatDescriptor_t;
14+
15+
// 创建Concat描述符
16+
__C __export infiniopStatus_t infiniopCreateConcatDescriptor(infiniopHandle_t handle,
17+
infiniopConcatDescriptor_t *desc_ptr,
18+
infiniopTensorDescriptor_t y,
19+
infiniopTensorDescriptor_t *x,
20+
uint64_t num_inputs,
21+
uint64_t axis);
22+
23+
// 执行Concat操作
24+
__C __export infiniopStatus_t infiniopConcat(infiniopConcatDescriptor_t desc,
25+
void *y,
26+
void const **x,
27+
void *stream);
28+
29+
// 销毁Concat描述符
30+
__C __export infiniopStatus_t infiniopDestroyConcatDescriptor(infiniopConcatDescriptor_t desc);
31+
32+
#endif

operatorspy/tests/concat.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64
2+
import ctypes
3+
import sys
4+
import os
5+
6+
# 调整路径以导入 operatorspy 模块
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8+
from operatorspy import (
9+
open_lib,
10+
to_tensor,
11+
DeviceEnum,
12+
infiniopHandle_t,
13+
infiniopTensorDescriptor_t,
14+
create_handle,
15+
destroy_handle,
16+
check_error,
17+
)
18+
19+
from operatorspy.tests.test_utils import get_args
20+
from enum import Enum, auto
21+
import torch
22+
23+
24+
class Inplace(Enum):
25+
OUT_OF_PLACE = auto()
26+
# 对于 concat 算子,通常不支持 in-place 操作,因此这里只保留 OUT_OF_PLACE
27+
# 你可以根据实际需求扩展其他选项
28+
# INPLACE_A = auto()
29+
# INPLACE_B = auto()
30+
31+
32+
class ConcatDescriptor(Structure):
33+
_fields_ = [("device", c_int32),]
34+
35+
36+
infiniopConcatDescriptor_t = POINTER(ConcatDescriptor)
37+
38+
39+
def concat_py(*tensors, dim=0):
40+
"""使用 PyTorch 进行拼接的辅助函数"""
41+
return torch.cat(tensors, dim=dim)
42+
43+
44+
def test(
45+
lib,
46+
handle,
47+
torch_device,
48+
c_shape,
49+
axis,
50+
input_shapes,
51+
tensor_dtype=torch.float32,
52+
inplace=Inplace.OUT_OF_PLACE,
53+
):
54+
"""
55+
测试 concat 算子
56+
"""
57+
print(
58+
f"Testing Concat on {torch_device} with output_shape:{c_shape}, input_shapes:{input_shapes}, axis:{axis}, dtype:{tensor_dtype}, inplace: {inplace.name}"
59+
)
60+
61+
# 创建输入张量
62+
inputs = [torch.rand(shape, dtype=tensor_dtype).to(torch_device) for shape in input_shapes]
63+
64+
for idx, tensor in enumerate(inputs):
65+
print(f"Input {idx}:")
66+
print(tensor)
67+
print("-" * 50)
68+
69+
# 创建输出张量
70+
if inplace == Inplace.OUT_OF_PLACE:
71+
c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device)
72+
else:
73+
# 对于 concat,通常不支持 in-place 操作,因此这里简化为 OUT_OF_PLACE
74+
c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device)
75+
76+
# 使用 PyTorch 进行拼接,作为参考答案
77+
ans = concat_py(*inputs, dim=axis)
78+
79+
print("ans:",ans)
80+
print("-" * 50)
81+
82+
# 将张量转换为 infiniop 所需的格式
83+
input_tensors = [to_tensor(t, lib) for t in inputs]
84+
c_tensor = to_tensor(c, lib) if inplace == Inplace.OUT_OF_PLACE else to_tensor(c, lib)
85+
86+
# 创建 Concat 描述符
87+
descriptor = infiniopConcatDescriptor_t()
88+
89+
# 准备输入描述符数组
90+
num_inputs = len(input_tensors)
91+
input_desc_array_type = infiniopTensorDescriptor_t * num_inputs
92+
input_desc_array = input_desc_array_type(*[t.descriptor for t in input_tensors])
93+
94+
# 创建描述符
95+
check_error(
96+
lib.infiniopCreateConcatDescriptor(
97+
handle,
98+
ctypes.byref(descriptor),
99+
c_tensor.descriptor, # 使用 c_tensor 的描述符
100+
input_desc_array, # 输入张量描述符数组
101+
c_uint64(num_inputs),
102+
c_uint64(axis),
103+
)
104+
)
105+
106+
print("c1:",c)
107+
print("-" * 50)
108+
109+
# 执行拼接操作
110+
input_data_ptrs = (c_void_p * num_inputs)(*[t.data for t in input_tensors])
111+
check_error(
112+
lib.infiniopConcat(
113+
descriptor,
114+
c_tensor.data,
115+
ctypes.cast(input_data_ptrs, POINTER(c_void_p)),
116+
None # 假设不需要流
117+
)
118+
)
119+
120+
print("c2:",c)
121+
print("-" * 50)
122+
123+
# 验证结果
124+
assert torch.allclose(c, ans, atol=0, rtol=1e-5), "Concat result does not match PyTorch's result."
125+
126+
# 销毁描述符
127+
check_error(lib.infiniopDestroyConcatDescriptor(descriptor))
128+
129+
130+
def test_cpu(lib, test_cases):
131+
device = DeviceEnum.DEVICE_CPU
132+
handle = create_handle(lib, device)
133+
for c_shape, axis, input_shapes, inplace in test_cases:
134+
test(lib, handle, "cpu", c_shape, axis, input_shapes, inplace=inplace)
135+
destroy_handle(lib, handle)
136+
137+
138+
def test_cuda(lib, test_cases):
139+
device = DeviceEnum.DEVICE_CUDA
140+
handle = create_handle(lib, device)
141+
for c_shape, axis, input_shapes, inplace in test_cases:
142+
test(lib, handle, "cuda", c_shape, axis, input_shapes, inplace=inplace)
143+
destroy_handle(lib, handle)
144+
145+
146+
def test_bang(lib, test_cases):
147+
import torch_mlu
148+
149+
device = DeviceEnum.DEVICE_BANG
150+
handle = create_handle(lib, device)
151+
for c_shape, axis, input_shapes, inplace in test_cases:
152+
test(lib, handle, "mlu", c_shape, axis, input_shapes, inplace=inplace)
153+
destroy_handle(lib, handle)
154+
155+
156+
if __name__ == "__main__":
157+
# 定义测试用例
158+
test_cases = [
159+
# (output_shape, axis, input_shapes, inplace)
160+
161+
((6, 3), 0, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE),
162+
# ((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE),
163+
# ((3, 7), 1, [(3, 2), (3, 4), (3,1)], Inplace.OUT_OF_PLACE),
164+
# ((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE),
165+
# ((1, 1), 0, [(1, 1)], Inplace.OUT_OF_PLACE),
166+
# ((4, 5, 6), 0, [(1, 5, 6), (3, 5, 6)], Inplace.OUT_OF_PLACE),
167+
# ((2, 3, 6), 2, [(2, 3, 2), (2, 3, 4)], Inplace.OUT_OF_PLACE),
168+
169+
# 添加更多测试用例以覆盖不同的维度和拼接轴
170+
# ((2, 10, 3), 1, [(2, 5, 3), (2, 2, 3),(2,3,3)], Inplace.OUT_OF_PLACE), # 拼接沿第二维
171+
]
172+
173+
args = get_args()
174+
lib = open_lib()
175+
176+
# 绑定 C++ 函数
177+
# 创建 Concat 描述符
178+
lib.infiniopCreateConcatDescriptor.restype = c_int32
179+
lib.infiniopCreateConcatDescriptor.argtypes = [
180+
infiniopHandle_t,
181+
POINTER(infiniopConcatDescriptor_t),
182+
infiniopTensorDescriptor_t, # 输出张量描述符
183+
POINTER(infiniopTensorDescriptor_t), # 输入张量描述符数组
184+
c_uint64, # 输入张量数量
185+
c_uint64, # 拼接轴
186+
]
187+
188+
# 执行 Concat
189+
lib.infiniopConcat.restype = c_int32
190+
lib.infiniopConcat.argtypes = [
191+
infiniopConcatDescriptor_t,
192+
c_void_p, # 输出数据指针
193+
POINTER(c_void_p), # 输入数据指针数组
194+
c_void_p, # 流(假设为 NULL)
195+
]
196+
197+
# 销毁 Concat 描述符
198+
lib.infiniopDestroyConcatDescriptor.restype = c_int32
199+
lib.infiniopDestroyConcatDescriptor.argtypes = [
200+
infiniopConcatDescriptor_t,
201+
]
202+
203+
# 根据命令行参数执行测试
204+
if args.cpu:
205+
test_cpu(lib, test_cases)
206+
if args.cuda:
207+
test_cuda(lib, test_cases)
208+
if args.bang:
209+
test_bang(lib, test_cases)
210+
if not (args.cpu or args.cuda or args.bang):
211+
test_cpu(lib, test_cases)
212+
213+
print("\033[92mConcat Test passed!\033[0m")
214+
215+
216+
217+

0 commit comments

Comments
 (0)