@@ -41,9 +41,70 @@ namespace py = pybind11;
41
41
typedef SSIZE_T ssize_t ;
42
42
#endif
43
43
44
+ #include < cstdint>
45
+ #include < cstring>
46
+ #include < vector>
47
+
44
48
namespace triton { namespace backend { namespace python {
45
49
46
50
#ifdef TRITON_PB_STUB
51
+ py::array
52
+ deserialize_bytes_tensor_cpp (const uint8_t * data, size_t data_size)
53
+ {
54
+ if (data_size == 0 ) {
55
+ py::module numpy = py::module::import (" numpy" );
56
+ return numpy.attr (" empty" )(0 , py::dtype (" object" ));
57
+ }
58
+
59
+ // First pass: count the number of strings and calculate total size
60
+ size_t offset = 0 ;
61
+ size_t num_strings = 0 ;
62
+ size_t total_string_size = 0 ;
63
+
64
+ while (offset < data_size) {
65
+ if (offset + 4 > data_size) {
66
+ throw PythonBackendException (
67
+ " Invalid bytes tensor data: incomplete length field" );
68
+ }
69
+
70
+ // Read 4-byte length (little-endian)
71
+ uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
72
+ offset += 4 ;
73
+
74
+ if (offset + length > data_size) {
75
+ throw PythonBackendException (
76
+ " Invalid bytes tensor data: string extends beyond buffer" );
77
+ }
78
+
79
+ num_strings++;
80
+ total_string_size += length;
81
+ offset += length;
82
+ }
83
+
84
+ // Create numpy array of objects using pybind11's numpy module
85
+ py::module numpy = py::module::import (" numpy" );
86
+ py::array result = numpy.attr (" empty" )(num_strings, py::dtype (" object" ));
87
+ auto result_ptr = static_cast <PyObject**>(result.request ().ptr );
88
+
89
+ // Second pass: extract strings
90
+ offset = 0 ;
91
+ size_t string_index = 0 ;
92
+
93
+ while (offset < data_size) {
94
+ uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
95
+ offset += 4 ;
96
+
97
+ // Create Python bytes object using pybind11
98
+ py::bytes bytes_obj (reinterpret_cast <const char *>(data + offset), length);
99
+ Py_INCREF (bytes_obj.ptr ()); // Increment reference count
100
+ result_ptr[string_index] = bytes_obj.ptr ();
101
+ string_index++;
102
+ offset += length;
103
+ }
104
+
105
+ return result;
106
+ }
107
+
47
108
PbTensor::PbTensor (const std::string& name, py::array& numpy_array)
48
109
: name_(name)
49
110
{
@@ -160,14 +221,9 @@ PbTensor::PbTensor(
160
221
py::array (triton_to_pybind_dtype (dtype_), dims_, (void *)memory_ptr_);
161
222
numpy_array_ = numpy_array.attr (" view" )(triton_to_numpy_type (dtype_));
162
223
} else {
163
- py::object numpy_array = py::array (
164
- triton_to_pybind_dtype (TRITONSERVER_TYPE_UINT8), {byte_size},
165
- (void *)memory_ptr_);
166
- py::module triton_pb_utils =
167
- py::module::import (" triton_python_backend_utils" );
168
- numpy_array_ =
169
- triton_pb_utils.attr (" deserialize_bytes_tensor" )(numpy_array)
170
- .attr (" reshape" )(dims);
224
+ py::object numpy_array = deserialize_bytes_tensor_cpp (
225
+ static_cast <const uint8_t *>(memory_ptr_), byte_size_);
226
+ numpy_array_ = numpy_array.attr (" reshape" )(dims_);
171
227
}
172
228
} else {
173
229
numpy_array_ = py::none ();
@@ -234,6 +290,7 @@ delete_unused_dltensor(PyObject* dlp)
234
290
}
235
291
}
236
292
293
+
237
294
std::shared_ptr<PbTensor>
238
295
PbTensor::FromNumpy (const std::string& name, py::array& numpy_array)
239
296
{
@@ -668,14 +725,9 @@ PbTensor::PbTensor(
668
725
py::array (triton_to_pybind_dtype (dtype_), dims_, (void *)memory_ptr_);
669
726
numpy_array_ = numpy_array.attr (" view" )(triton_to_numpy_type (dtype_));
670
727
} else {
671
- py::object numpy_array = py::array (
672
- triton_to_pybind_dtype (TRITONSERVER_TYPE_UINT8), {byte_size_},
673
- (void *)memory_ptr_);
674
- py::module triton_pb_utils =
675
- py::module::import (" triton_python_backend_utils" );
676
- numpy_array_ =
677
- triton_pb_utils.attr (" deserialize_bytes_tensor" )(numpy_array)
678
- .attr (" reshape" )(dims_);
728
+ py::object numpy_array = deserialize_bytes_tensor_cpp (
729
+ static_cast <const uint8_t *>(memory_ptr_), byte_size_);
730
+ numpy_array_ = numpy_array.attr (" reshape" )(dims_);
679
731
}
680
732
} else {
681
733
numpy_array_ = py::none ();
0 commit comments