Skip to content

Commit 389c770

Browse files
wweicWei Chen
andauthored
perf: optimize string tensor deserialization with high performance c++ implementation (#416)
* perf: optimize string tensor deserialization with high performance c++ implementation * Address PR comments --------- Co-authored-by: Wei Chen <[email protected]>
1 parent 8b5a055 commit 389c770

File tree

1 file changed

+68
-16
lines changed

1 file changed

+68
-16
lines changed

src/pb_tensor.cc

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,70 @@ namespace py = pybind11;
4141
typedef SSIZE_T ssize_t;
4242
#endif
4343

44+
#include <cstdint>
45+
#include <cstring>
46+
#include <vector>
47+
4448
namespace triton { namespace backend { namespace python {
4549

4650
#ifdef TRITON_PB_STUB
51+
py::array
52+
deserialize_bytes_tensor_cpp(const uint8_t* data, size_t data_size)
53+
{
54+
if (data_size == 0) {
55+
py::module numpy = py::module::import("numpy");
56+
return numpy.attr("empty")(0, py::dtype("object"));
57+
}
58+
59+
// First pass: count the number of strings and calculate total size
60+
size_t offset = 0;
61+
size_t num_strings = 0;
62+
size_t total_string_size = 0;
63+
64+
while (offset < data_size) {
65+
if (offset + 4 > data_size) {
66+
throw PythonBackendException(
67+
"Invalid bytes tensor data: incomplete length field");
68+
}
69+
70+
// Read 4-byte length (little-endian)
71+
uint32_t length = *reinterpret_cast<const uint32_t*>(data + offset);
72+
offset += 4;
73+
74+
if (offset + length > data_size) {
75+
throw PythonBackendException(
76+
"Invalid bytes tensor data: string extends beyond buffer");
77+
}
78+
79+
num_strings++;
80+
total_string_size += length;
81+
offset += length;
82+
}
83+
84+
// Create numpy array of objects using pybind11's numpy module
85+
py::module numpy = py::module::import("numpy");
86+
py::array result = numpy.attr("empty")(num_strings, py::dtype("object"));
87+
auto result_ptr = static_cast<PyObject**>(result.request().ptr);
88+
89+
// Second pass: extract strings
90+
offset = 0;
91+
size_t string_index = 0;
92+
93+
while (offset < data_size) {
94+
uint32_t length = *reinterpret_cast<const uint32_t*>(data + offset);
95+
offset += 4;
96+
97+
// Create Python bytes object using pybind11
98+
py::bytes bytes_obj(reinterpret_cast<const char*>(data + offset), length);
99+
Py_INCREF(bytes_obj.ptr()); // Increment reference count
100+
result_ptr[string_index] = bytes_obj.ptr();
101+
string_index++;
102+
offset += length;
103+
}
104+
105+
return result;
106+
}
107+
47108
PbTensor::PbTensor(const std::string& name, py::array& numpy_array)
48109
: name_(name)
49110
{
@@ -160,14 +221,9 @@ PbTensor::PbTensor(
160221
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
161222
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
162223
} else {
163-
py::object numpy_array = py::array(
164-
triton_to_pybind_dtype(TRITONSERVER_TYPE_UINT8), {byte_size},
165-
(void*)memory_ptr_);
166-
py::module triton_pb_utils =
167-
py::module::import("triton_python_backend_utils");
168-
numpy_array_ =
169-
triton_pb_utils.attr("deserialize_bytes_tensor")(numpy_array)
170-
.attr("reshape")(dims);
224+
py::object numpy_array = deserialize_bytes_tensor_cpp(
225+
static_cast<const uint8_t*>(memory_ptr_), byte_size_);
226+
numpy_array_ = numpy_array.attr("reshape")(dims_);
171227
}
172228
} else {
173229
numpy_array_ = py::none();
@@ -234,6 +290,7 @@ delete_unused_dltensor(PyObject* dlp)
234290
}
235291
}
236292

293+
237294
std::shared_ptr<PbTensor>
238295
PbTensor::FromNumpy(const std::string& name, py::array& numpy_array)
239296
{
@@ -668,14 +725,9 @@ PbTensor::PbTensor(
668725
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
669726
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
670727
} else {
671-
py::object numpy_array = py::array(
672-
triton_to_pybind_dtype(TRITONSERVER_TYPE_UINT8), {byte_size_},
673-
(void*)memory_ptr_);
674-
py::module triton_pb_utils =
675-
py::module::import("triton_python_backend_utils");
676-
numpy_array_ =
677-
triton_pb_utils.attr("deserialize_bytes_tensor")(numpy_array)
678-
.attr("reshape")(dims_);
728+
py::object numpy_array = deserialize_bytes_tensor_cpp(
729+
static_cast<const uint8_t*>(memory_ptr_), byte_size_);
730+
numpy_array_ = numpy_array.attr("reshape")(dims_);
679731
}
680732
} else {
681733
numpy_array_ = py::none();

0 commit comments

Comments
 (0)