Skip to content

Commit 8a05be3

Browse files
committed
Add carray_int64_t_to_tuple function
1 parent 59d1b6b commit 8a05be3

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

cuda_core/cuda/core/experimental/_memoryview.pyx

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ from typing import Any, Optional
1212
import numpy
1313

1414
from cuda.core.experimental._utils.cuda_utils import handle_return, driver
15+
from cuda.core.experimental._utils cimport cuda_utils
1516

1617

1718
# TODO(leofang): support NumPy structured dtypes
@@ -213,13 +214,9 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
213214
buf.ptr = <intptr_t>(dl_tensor.data)
214215

215216
# Construct shape and strides tuples using the Python/C API for speed
216-
buf.shape = cpython.PyTuple_New(dl_tensor.ndim)
217-
for i in range(dl_tensor.ndim):
218-
cpython.PyTuple_SET_ITEM(buf.shape, i, cpython.PyLong_FromLong(dl_tensor.shape[i]))
217+
buf.shape = cuda_utils.carray_int64_t_to_tuple(dl_tensor.shape, dl_tensor.ndim)
219218
if dl_tensor.strides:
220-
buf.strides = cpython.PyTuple_New(dl_tensor.ndim)
221-
for i in range(dl_tensor.ndim):
222-
cpython.PyTuple_SET_ITEM(buf.strides, i, cpython.PyLong_FromLong(dl_tensor.strides[i]))
219+
buf.strides = cuda_utils.carray_int64_t_to_tuple(dl_tensor.strides, dl_tensor.ndim)
223220
else:
224221
# C-order
225222
buf.strides = None

cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,19 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
6+
cimport cpython
7+
cimport libc.stdint
8+
9+
510
cpdef int _check_driver_error(error) except?-1
611
cpdef int _check_runtime_error(error) except?-1
712
cpdef int _check_nvrtc_error(error) except?-1
813
cpdef check_or_create_options(type cls, options, str options_description=*, bint keep_none=*)
14+
15+
16+
cdef inline tuple carray_int64_t_to_tuple(libc.stdint.int64_t *ptr, int length):
17+
result = cpython.PyTuple_New(length)
18+
for i in range(length):
19+
cpython.PyTuple_SET_ITEM(result, i, cpython.PyLong_FromLongLong(ptr[i]))
20+
return result

0 commit comments

Comments
 (0)