Skip to content

Commit 703a5ac

Browse files
committed
Nits + Python bindings
1 parent 6765968 commit 703a5ac

File tree

9 files changed

+76
-20
lines changed

9 files changed

+76
-20
lines changed

cpp/src/arrow/compute/api_scalar.cc

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "arrow/type.h"
3131
#include "arrow/util/checked_cast.h"
3232
#include "arrow/util/logging.h"
33+
3334
namespace arrow {
3435

3536
namespace internal {

cpp/src/arrow/compute/api_scalar.h

+2
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions {
264264
/// Regular expression with named capture fields
265265
std::string pattern;
266266
};
267+
267268
class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions {
268269
public:
269270
explicit ExtractRegexSpanOptions(std::string pattern);
@@ -273,6 +274,7 @@ class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions {
273274
/// Regular expression with named capture fields
274275
std::string pattern;
275276
};
277+
276278
/// Options for IsIn and IndexIn functions
277279
class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
278280
public:

cpp/src/arrow/compute/kernels/scalar_string_ascii.cc

+28-17
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ namespace compute {
4343
namespace internal {
4444

4545
namespace {
46+
4647
// ----------------------------------------------------------------------
4748
// re2 utilities
4849

@@ -2201,7 +2202,9 @@ struct BaseExtractRegexData {
22012202
}
22022203
return Status::OK();
22032204
}
2205+
22042206
int64_t num_groups() const { return static_cast<int64_t>(group_names.size()); }
2207+
22052208
std::unique_ptr<RE2> regex;
22062209
std::vector<std::string> group_names;
22072210

@@ -2297,14 +2300,15 @@ struct ExtractRegex : public ExtractRegexBase {
22972300
std::shared_ptr<DataType> type = out->array_data()->type;
22982301
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<ArrayBuilder> array_builder,
22992302
MakeBuilder(type, ctx->memory_pool()));
2300-
auto struct_builder = checked_pointer_cast<StructBuilder>(std::move(array_builder));
2301-
ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length));
2303+
StructBuilder* struct_builder = checked_cast<StructBuilder*>(array_builder.get());
2304+
ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].length()));
23022305

23032306
std::vector<BuilderType*> field_builders;
23042307
field_builders.reserve(group_count);
23052308
for (int i = 0; i < group_count; i++) {
23062309
field_builders.push_back(
23072310
checked_cast<BuilderType*>(struct_builder->field_builder(i)));
2311+
RETURN_NOT_OK(field_builders.back()->Reserve(batch[0].length()));
23082312
}
23092313

23102314
auto visit_null = [&]() { return struct_builder->AppendNull(); };
@@ -2353,6 +2357,7 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) {
23532357
}
23542358
DCHECK_OK(registry->AddFunction(std::move(func)));
23552359
}
2360+
23562361
struct ExtractRegexSpanData : public BaseExtractRegexData {
23572362
static Result<ExtractRegexSpanData> Make(const std::string& pattern,
23582363
bool is_utf8 = true) {
@@ -2367,12 +2372,11 @@ struct ExtractRegexSpanData : public BaseExtractRegexData {
23672372
return nullptr;
23682373
}
23692374
DCHECK(is_base_binary_like(input_type->id()));
2370-
const size_t field_count = num_groups();
23712375
FieldVector fields;
2372-
fields.reserve(field_count);
2376+
fields.reserve(num_groups());
23732377
auto index_type = is_binary_like(input_type->id()) ? int32() : int64();
23742378
for (const auto& group_name : group_names) {
2375-
// size list is 2 as every span contains position and length
2379+
// list size is 2 as every span contains position and length
23762380
fields.push_back(field(group_name, fixed_size_list(index_type, 2)));
23772381
}
23782382
return struct_(std::move(fields));
@@ -2401,12 +2405,14 @@ struct ExtractRegexSpan : ExtractRegexBase {
24012405
ExtractRegexSpanData::Make(options.pattern, Type::is_utf8));
24022406
return ExtractRegexSpan{data}.Extract(ctx, batch, out);
24032407
}
2408+
24042409
Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
24052410
DCHECK_NE(out->array_data(), nullptr);
24062411
std::shared_ptr<DataType> out_type = out->array_data()->type;
24072412
ARROW_ASSIGN_OR_RAISE(auto out_builder, MakeBuilder(out_type, ctx->memory_pool()));
2408-
auto struct_builder = checked_pointer_cast<StructBuilder>(std::move(out_builder));
2413+
StructBuilder* struct_builder = checked_cast<StructBuilder*>(out_builder.get());
24092414
ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length));
2415+
24102416
std::vector<FixedSizeListBuilder*> span_builders;
24112417
std::vector<OffsetBuilderType*> array_builders;
24122418
span_builders.reserve(group_count);
@@ -2416,8 +2422,8 @@ struct ExtractRegexSpan : ExtractRegexBase {
24162422
checked_cast<FixedSizeListBuilder*>(struct_builder->field_builder(i)));
24172423
array_builders.push_back(
24182424
checked_cast<OffsetBuilderType*>(span_builders.back()->value_builder()));
2419-
RETURN_NOT_OK(span_builders.back()->Reserve(batch[0].array.length));
2420-
RETURN_NOT_OK(array_builders.back()->Reserve(2 * batch[0].array.length));
2425+
RETURN_NOT_OK(span_builders.back()->Reserve(batch[0].length()));
2426+
RETURN_NOT_OK(array_builders.back()->Reserve(2 * batch[0].length()));
24212427
}
24222428

24232429
auto visit_null = [&]() { return struct_builder->AppendNull(); };
@@ -2451,15 +2457,20 @@ struct ExtractRegexSpan : ExtractRegexBase {
24512457
};
24522458

24532459
const FunctionDoc extract_regex_span_doc(
2454-
"Extract substrings captured by a regex pattern and Save the result in the form of "
2455-
"(offset,length)",
2456-
"For each string in strings, match the regular expression and, if\n"
2457-
"successful, emit a struct with field names and values coming from the\n"
2458-
"regular expression's named capture groups, which are stored in a form of a\n "
2459-
"fixed_size_list(offset, length). If the input is null or the regular \n"
2460-
"expression Fails matching, a null output value is emitted.\n"
2461-
"Regular expression matching is done using the Google RE2 library.",
2462-
{"strings"}, "ExtractRegexSpanOptions", true);
2460+
"Extract string spans captured by a regex pattern",
2461+
("For each string in strings, match the regular expression and, if\n"
2462+
"successful, emit a struct with field names and values coming from the\n"
2463+
"regular expression's named capture groups. Each struct field value\n"
2464+
"will be a fixed_size_list(offset_type, 2) where offset_type is int32\n"
2465+
"or int64, depending on the input string type. The two elements in\n"
2466+
"each fixed-size list are the index and the length of the substring\n"
2467+
"matched by the corresponding named capture group.\n"
2468+
"\n"
2469+
"If the input is null or the regular expression fails matching,\n"
2470+
"a null output value is emitted.\n"
2471+
"\n"
2472+
"Regular expression matching is done using the Google RE2 library."),
2473+
{"strings"}, "ExtractRegexSpanOptions", /*options_required=*/true);
24632474

24642475
Result<TypeHolder> ResolveExtractRegexSpanOutputType(
24652476
KernelContext* ctx, const std::vector<TypeHolder>& types) {

cpp/src/arrow/compute/kernels/scalar_string_test.cc

+7-2
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ TYPED_TEST(TestBinaryKernels, NonUtf8Regex) {
314314
this->MakeArray({"\xfc\x40", "this \xfc\x40 that \xfc\x40"}),
315315
this->MakeArray({"bazz", "this bazz that \xfc\x40"}), &options);
316316
}
317-
// TODO the following test is broken
317+
// TODO the following test is broken (GH-45735)
318318
{
319319
ExtractRegexOptions options("(?P<letter>[\\xfc])(?P<digit>\\d)");
320320
auto null_bitmap = std::make_shared<Buffer>("0");
@@ -371,7 +371,7 @@ TYPED_TEST(TestBinaryKernels, NonUtf8WithNullRegex) {
371371
this->template MakeArray<std::string>({{"\x00\x40", 2}}),
372372
this->type(), R"(["bazz"])", &options);
373373
}
374-
// TODO the following test is broken
374+
// TODO the following test is broken (GH-45735)
375375
{
376376
ExtractRegexOptions options("(?P<null>[\\x00])(?P<digit>\\d)");
377377
auto null_bitmap = std::make_shared<Buffer>("0");
@@ -1960,6 +1960,7 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegex) {
19601960
R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "3"}])",
19611961
&options);
19621962
}
1963+
19631964
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpan) {
19641965
ExtractRegexSpanOptions options{"(?P<letter>[ab]+)(?P<digit>\\d+)"};
19651966
auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64();
@@ -1990,6 +1991,7 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpan) {
19901991
{"letter":[2,2], "digit":[4,3]}])",
19911992
&options);
19921993
}
1994+
19931995
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanCaptureOption) {
19941996
ExtractRegexSpanOptions options{"(?P<foo>foo)?(?P<digit>\\d+)?"};
19951997
auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64();
@@ -2022,13 +2024,15 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoCapture) {
20222024
this->CheckUnary("extract_regex", R"(["oofoo", "bar", null])", type,
20232025
R"([{}, null, null])", &options);
20242026
}
2027+
20252028
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanNoCapture) {
20262029
// XXX Should we accept this or is it a user error?
20272030
ExtractRegexSpanOptions options{"foo"};
20282031
auto type = struct_({});
20292032
this->CheckUnary("extract_regex_span", R"(["oofoo", "bar", null])", type,
20302033
R"([{}, null, null])", &options);
20312034
}
2035+
20322036
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoOptions) {
20332037
Datum input = ArrayFromJSON(this->type(), "[]");
20342038
ASSERT_RAISES(Invalid, CallFunction("extract_regex", {input}));
@@ -2051,6 +2055,7 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegexInvalid) {
20512055
Invalid, ::testing::HasSubstr("Regular expression contains unnamed groups"),
20522056
CallFunction("extract_regex", {input}, &options));
20532057
}
2058+
20542059
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanInvalid) {
20552060
Datum input = ArrayFromJSON(this->type(), "[]");
20562061
ExtractRegexSpanOptions options{"invalid["};

docs/source/cpp/compute.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,8 @@ String component extraction
11401140
library. The output struct field names refer to the named capture groups,
11411141
e.g. 'letter' and 'digit' for the regular expression
11421142
``(?P<letter>[ab])(?P<digit>\\d)``.
1143-
* \(2) Extract the offset and length of substrings defined by a regular expression
1143+
1144+
* \(2) Extract the offset and length of substrings defined by a regular expression
11441145
using the Google RE2 library. The output struct field names refer to the named
11451146
capture groups, e.g. 'letter' and 'digit' for the regular expression
11461147
``(?P<letter>[ab])(?P<digit>\\d)``. Each output struct field is a fixed size list

python/pyarrow/_compute.pyx

+19
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,25 @@ class ExtractRegexOptions(_ExtractRegexOptions):
12221222
self._set_options(pattern)
12231223

12241224

1225+
cdef class _ExtractRegexSpanOptions(FunctionOptions):
1226+
def _set_options(self, pattern):
1227+
self.wrapped.reset(new CExtractRegexSpanOptions(tobytes(pattern)))
1228+
1229+
1230+
class ExtractRegexSpanOptions(_ExtractRegexSpanOptions):
1231+
"""
1232+
Options for the `extract_regex_span` function.
1233+
1234+
Parameters
1235+
----------
1236+
pattern : str
1237+
Regular expression with named capture fields.
1238+
"""
1239+
1240+
def __init__(self, pattern):
1241+
self._set_options(pattern)
1242+
1243+
12251244
cdef class _SliceOptions(FunctionOptions):
12261245
def _set_options(self, start, stop, step):
12271246
self.wrapped.reset(new CSliceOptions(start, stop, step))

python/pyarrow/compute.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
RunEndEncodeOptions,
4141
ElementWiseAggregateOptions,
4242
ExtractRegexOptions,
43+
ExtractRegexSpanOptions,
4344
FilterOptions,
4445
IndexOptions,
4546
JoinOptions,

python/pyarrow/includes/libarrow.pxd

+5
Original file line numberDiff line numberDiff line change
@@ -2500,6 +2500,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
25002500
CExtractRegexOptions(c_string pattern)
25012501
c_string pattern
25022502

2503+
cdef cppclass CExtractRegexSpanOptions \
2504+
"arrow::compute::ExtractRegexSpanOptions"(CFunctionOptions):
2505+
CExtractRegexSpanOptions(c_string pattern)
2506+
c_string pattern
2507+
25032508
cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions):
25042509
CCastOptions()
25052510
CCastOptions(c_bool safe)

python/pyarrow/tests/test_compute.py

+11
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def test_option_class_equality(request):
152152
pc.RunEndEncodeOptions(),
153153
pc.ElementWiseAggregateOptions(skip_nulls=True),
154154
pc.ExtractRegexOptions("pattern"),
155+
pc.ExtractRegexSpanOptions("pattern"),
155156
pc.FilterOptions(),
156157
pc.IndexOptions(pa.scalar(1)),
157158
pc.JoinOptions(),
@@ -1092,6 +1093,16 @@ def test_extract_regex():
10921093
assert struct.tolist() == expected
10931094

10941095

1096+
def test_extract_regex_span():
1097+
ar = pa.array(['a1', 'zb234z'])
1098+
expected = [{'letter': [0, 1], 'digit': [1, 1]},
1099+
{'letter': [1, 1], 'digit': [2, 3]}]
1100+
struct = pc.extract_regex_span(ar, pattern=r'(?P<letter>[ab])(?P<digit>\d+)')
1101+
assert struct.tolist() == expected
1102+
struct = pc.extract_regex_span(ar, r'(?P<letter>[ab])(?P<digit>\d+)')
1103+
assert struct.tolist() == expected
1104+
1105+
10951106
def test_binary_join():
10961107
ar_list = pa.array([['foo', 'bar'], None, []])
10971108
expected = pa.array(['foo-bar', None, ''])

0 commit comments

Comments
 (0)