Skip to content

Commit ad47864

Browse files
committed
Use nanoarrow C++ helpers and iterate stream in accumulations
1 parent 2325b24 commit ad47864

File tree

1 file changed

+70
-63
lines changed

1 file changed

+70
-63
lines changed

pandas/_libs/arrow_string_accumulations.cc

+70-63
Original file line numberDiff line numberDiff line change
@@ -30,68 +30,77 @@ static auto ReleaseArrowSchema(void *ptr) noexcept -> void {
3030
delete schema;
3131
}
3232

33-
static auto CumSum(const struct ArrowArrayView *array_view,
33+
template <size_t OffsetSize>
34+
static auto CumSum(struct ArrowArrayStream *array_stream,
3435
struct ArrowArray *out, bool skipna) {
3536
bool seen_na = false;
3637
std::stringstream ss{};
3738

38-
for (int64_t i = 0; i < array_view->length; i++) {
39-
const bool isna = ArrowArrayViewIsNull(array_view, i);
40-
if (!skipna && (seen_na || isna)) {
41-
seen_na = true;
42-
ArrowArrayAppendNull(out, 1);
43-
} else {
44-
if (!isna) {
45-
const auto std_sv = ArrowArrayViewGetStringUnsafe(array_view, i);
46-
ss << std::string_view{std_sv.data,
47-
static_cast<size_t>(std_sv.size_bytes)};
39+
int i = 0;
40+
nanoarrow::ViewArrayStream array_stream_view(array_stream);
41+
for (const auto &array : array_stream_view) {
42+
for (const auto &sv : nanoarrow::ViewArrayAsBytes<OffsetSize>(&array)) {
43+
++i;
44+
if (i == 50) {
45+
}
46+
if ((!sv || seen_na) && !skipna) {
47+
seen_na = true;
48+
ArrowArrayAppendNull(out, 1);
49+
} else {
50+
if (sv) {
51+
ss << std::string_view{(*sv).data,
52+
static_cast<size_t>((*sv).size_bytes)};
53+
}
54+
const auto str = ss.str();
55+
const ArrowStringView asv{str.c_str(),
56+
static_cast<int64_t>(str.size())};
57+
NANOARROW_THROW_NOT_OK(ArrowArrayAppendString(out, asv));
4858
}
49-
const auto str = ss.str();
50-
const ArrowStringView asv{str.c_str(), static_cast<int64_t>(str.size())};
51-
NANOARROW_THROW_NOT_OK(ArrowArrayAppendString(out, asv));
5259
}
5360
}
5461
}
5562

63+
// TODO: doesn't seem like all compilers in CI support this?
5664
// template <typename T>
5765
// concept MinOrMaxOp =
5866
// std::same_as<T, std::less<>> || std::same_as<T, std::greater<>>;
5967

60-
template <auto Op>
68+
template <size_t OffsetSize, auto Op>
6169
// requires MinOrMaxOp<decltype(Op)>
62-
static auto CumMinOrMax(const struct ArrowArrayView *array_view,
70+
static auto CumMinOrMax(struct ArrowArrayStream *array_stream,
6371
struct ArrowArray *out, bool skipna) {
6472
bool seen_na = false;
6573
std::optional<std::string> current_str{};
6674

67-
for (int64_t i = 0; i < array_view->length; i++) {
68-
const bool isna = ArrowArrayViewIsNull(array_view, i);
69-
if (!skipna && (seen_na || isna)) {
70-
seen_na = true;
71-
ArrowArrayAppendNull(out, 1);
72-
} else {
73-
if (!isna || current_str) {
74-
if (!isna) {
75-
const auto asv = ArrowArrayViewGetStringUnsafe(array_view, i);
76-
const nb::str pyval{asv.data, static_cast<size_t>(asv.size_bytes)};
77-
78-
if (current_str) {
79-
const nb::str pycurrent{current_str->data(), current_str->size()};
80-
if (Op(pyval, pycurrent)) {
81-
current_str =
82-
std::string{asv.data, static_cast<size_t>(asv.size_bytes)};
75+
nanoarrow::ViewArrayStream array_stream_view(array_stream);
76+
for (const auto &array : array_stream_view) {
77+
for (const auto &sv : nanoarrow::ViewArrayAsBytes<OffsetSize>(&array)) {
78+
if ((!sv || seen_na) && !skipna) {
79+
seen_na = true;
80+
ArrowArrayAppendNull(out, 1);
81+
} else {
82+
if (sv || current_str) {
83+
if (sv) {
84+
const nb::str pyval{(*sv).data,
85+
static_cast<size_t>((*sv).size_bytes)};
86+
if (current_str) {
87+
const nb::str pycurrent{current_str->data(), current_str->size()};
88+
if (Op(pyval, pycurrent)) {
89+
current_str = std::string{
90+
(*sv).data, static_cast<size_t>((*sv).size_bytes)};
91+
}
92+
} else {
93+
current_str = std::string{(*sv).data,
94+
static_cast<size_t>((*sv).size_bytes)};
8395
}
84-
} else {
85-
current_str =
86-
std::string{asv.data, static_cast<size_t>(asv.size_bytes)};
8796
}
88-
}
8997

90-
struct ArrowStringView out_sv{
91-
current_str->data(), static_cast<int64_t>(current_str->size())};
92-
NANOARROW_THROW_NOT_OK(ArrowArrayAppendString(out, out_sv));
93-
} else {
94-
ArrowArrayAppendEmpty(out, 1);
98+
struct ArrowStringView out_sv{
99+
current_str->data(), static_cast<int64_t>(current_str->size())};
100+
NANOARROW_THROW_NOT_OK(ArrowArrayAppendString(out, out_sv));
101+
} else {
102+
ArrowArrayAppendEmpty(out, 1);
103+
}
95104
}
96105
}
97106
}
@@ -131,7 +140,6 @@ class ArrowStringAccumulation {
131140
switch (schema_view.type) {
132141
case NANOARROW_TYPE_STRING:
133142
case NANOARROW_TYPE_LARGE_STRING:
134-
case NANOARROW_TYPE_STRING_VIEW:
135143
break;
136144
default:
137145
const auto error_message =
@@ -159,30 +167,29 @@ class ArrowStringAccumulation {
159167

160168
NANOARROW_THROW_NOT_OK(ArrowArrayStartAppending(uarray_out.get()));
161169

162-
nanoarrow::UniqueArray chunk{};
163-
int errcode{};
164-
165-
while ((errcode = ArrowArrayStreamGetNext(stream_.get(), chunk.get(),
166-
nullptr) == 0) &&
167-
chunk->release != nullptr) {
168-
struct ArrowArrayView array_view{};
169-
NANOARROW_THROW_NOT_OK(
170-
ArrowArrayViewInitFromSchema(&array_view, schema_.get(), nullptr));
171-
172-
NANOARROW_THROW_NOT_OK(
173-
ArrowArrayViewSetArray(&array_view, chunk.get(), nullptr));
174-
175-
if (accumulation_ == "cumsum") {
176-
CumSum(&array_view, uarray_out.get(), skipna_);
177-
} else if (accumulation_ == "cummin") {
178-
CumMinOrMax<std::less{}>(&array_view, uarray_out.get(), skipna_);
179-
} else if (accumulation_ == "cummax") {
180-
CumMinOrMax<std::greater{}>(&array_view, uarray_out.get(), skipna_);
170+
if (accumulation_ == "cumsum") {
171+
if (schema_view.type == NANOARROW_TYPE_STRING) {
172+
CumSum<32>(stream_.get(), uarray_out.get(), skipna_);
181173
} else {
182-
throw std::runtime_error("Unexpected branch");
174+
CumSum<64>(stream_.get(), uarray_out.get(), skipna_);
183175
}
184176

185-
chunk.reset();
177+
} else if (accumulation_ == "cummin") {
178+
if (schema_view.type == NANOARROW_TYPE_STRING) {
179+
CumMinOrMax<32, std::less{}>(stream_.get(), uarray_out.get(), skipna_);
180+
} else {
181+
CumMinOrMax<64, std::less{}>(stream_.get(), uarray_out.get(), skipna_);
182+
}
183+
} else if (accumulation_ == "cummax") {
184+
if (schema_view.type == NANOARROW_TYPE_STRING) {
185+
CumMinOrMax<32, std::greater{}>(stream_.get(), uarray_out.get(),
186+
skipna_);
187+
} else {
188+
CumMinOrMax<64, std::greater{}>(stream_.get(), uarray_out.get(),
189+
skipna_);
190+
}
191+
} else {
192+
throw std::runtime_error("Unexpected branch");
186193
}
187194

188195
NANOARROW_THROW_NOT_OK(

0 commit comments

Comments
 (0)