Skip to content

Commit 2764dc9

Browse files
apply suggested changes
1 parent 0677728 commit 2764dc9

File tree

4 files changed

+152
-110
lines changed

4 files changed

+152
-110
lines changed

cpp/src/arrow/compute/api_scalar.h

-2
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,6 @@ class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions {
272272

273273
/// Regular expression with named capture fields
274274
std::string pattern;
275-
276-
/// Shows the matched string
277275
};
278276
/// Options for IsIn and IndexIn functions
279277
class ARROW_EXPORT SetLookupOptions : public FunctionOptions {

cpp/src/arrow/compute/kernels/scalar_string_ascii.cc

+90-85
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ namespace compute {
4343
namespace internal {
4444

4545
namespace {
46-
4746
// ----------------------------------------------------------------------
4847
// re2 utilities
4948

@@ -2185,9 +2184,34 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) {
21852184

21862185
using ExtractRegexState = OptionsWrapper<ExtractRegexOptions>;
21872186

2187+
struct BaseExtractRegexData {
2188+
Status Init() {
2189+
RETURN_NOT_OK(RegexStatus(*regex));
2190+
const int group_count = regex->NumberOfCapturingGroups();
2191+
const auto& name_map = regex->CapturingGroupNames();
2192+
group_names.reserve(group_count);
2193+
2194+
for (int i = 0; i < group_count; i++) {
2195+
auto item = name_map.find(i + 1); // re2 starts counting from 1
2196+
if (item == name_map.end()) {
2197+
// XXX should we instead just create fields with an empty name?
2198+
return Status::Invalid("Regular expression contains unnamed groups");
2199+
}
2200+
group_names.emplace_back(item->second);
2201+
}
2202+
return Status::OK();
2203+
}
2204+
int64_t num_groups() const { return static_cast<int64_t>(group_names.size()); }
2205+
std::unique_ptr<RE2> regex;
2206+
std::vector<std::string> group_names;
2207+
2208+
protected:
2209+
explicit BaseExtractRegexData(const std::string& pattern, bool is_utf8 = true)
2210+
: regex(new RE2(pattern, MakeRE2Options(is_utf8))) {}
2211+
};
2212+
21882213
// TODO cache this once per ExtractRegexOptions
2189-
class ExtractRegexData {
2190-
public:
2214+
struct ExtractRegexData : public BaseExtractRegexData {
21912215
static Result<ExtractRegexData> Make(const ExtractRegexOptions& options,
21922216
bool is_utf8 = true) {
21932217
ExtractRegexData data(options.pattern, is_utf8);
@@ -2197,50 +2221,24 @@ class ExtractRegexData {
21972221

21982222
Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const {
21992223
const DataType* input_type = types[0].type;
2200-
// as mentioned here
2201-
// https://arrow.apache.org/docs/developers/cpp/development.html#code-style-linting-and-ci
2202-
// nullptr should not be used
2203-
if (input_type == NULLPTR) {
2224+
if (input_type == nullptr) {
22042225
// No input type specified
2205-
return NULLPTR;
2226+
return nullptr;
22062227
}
22072228
// Input type is either [Large]Binary or [Large]String and is also the type
22082229
// of each field in the output struct type.
22092230
DCHECK(is_base_binary_like(input_type->id()));
22102231
FieldVector fields;
2211-
fields.reserve(group_names_.size());
2232+
fields.reserve(num_groups());
22122233
std::shared_ptr<DataType> owned_type = input_type->GetSharedPtr();
2213-
std::transform(group_names_.begin(), group_names_.end(), std::back_inserter(fields),
2234+
std::transform(group_names.begin(), group_names.end(), std::back_inserter(fields),
22142235
[&](const std::string& name) { return field(name, owned_type); });
2215-
return struct_(fields);
2236+
return struct_(std::move(fields));
22162237
}
2217-
int64_t num_group() const { return group_names_.size(); }
2218-
std::shared_ptr<RE2> regex() const { return regex_; }
22192238

2220-
protected:
2239+
private:
22212240
explicit ExtractRegexData(const std::string& pattern, bool is_utf8 = true)
2222-
: regex_(new RE2(pattern, MakeRE2Options(is_utf8))) {}
2223-
2224-
Status Init() {
2225-
RETURN_NOT_OK(RegexStatus(*regex_));
2226-
2227-
const int group_count = regex_->NumberOfCapturingGroups();
2228-
const auto& name_map = regex_->CapturingGroupNames();
2229-
group_names_.reserve(group_count);
2230-
2231-
for (int i = 0; i < group_count; i++) {
2232-
auto item = name_map.find(i + 1); // re2 starts counting from 1
2233-
if (item == name_map.end()) {
2234-
// XXX should we instead just create fields with an empty name?
2235-
return Status::Invalid("Regular expression contains unnamed groups");
2236-
}
2237-
group_names_.emplace_back(item->second);
2238-
}
2239-
return Status::OK();
2240-
}
2241-
2242-
std::shared_ptr<RE2> regex_;
2243-
std::vector<std::string> group_names_;
2241+
: BaseExtractRegexData(pattern, is_utf8) {}
22442242
};
22452243

22462244
Result<TypeHolder> ResolveExtractRegexOutput(KernelContext* ctx,
@@ -2251,17 +2249,17 @@ Result<TypeHolder> ResolveExtractRegexOutput(KernelContext* ctx,
22512249
}
22522250

22532251
struct ExtractRegexBase {
2254-
const ExtractRegexData& data;
2252+
const BaseExtractRegexData& data;
22552253
const int group_count;
22562254
std::vector<re2::StringPiece> found_values;
22572255
std::vector<RE2::Arg> args;
22582256
std::vector<const RE2::Arg*> args_pointers;
22592257
const RE2::Arg** args_pointers_start;
22602258
const RE2::Arg* null_arg = nullptr;
22612259

2262-
explicit ExtractRegexBase(const ExtractRegexData& data)
2260+
explicit ExtractRegexBase(const BaseExtractRegexData& data)
22632261
: data(data),
2264-
group_count(static_cast<int>(data.num_group())),
2262+
group_count(static_cast<int>(data.num_groups())),
22652263
found_values(group_count) {
22662264
args.reserve(group_count);
22672265
args_pointers.reserve(group_count);
@@ -2276,7 +2274,7 @@ struct ExtractRegexBase {
22762274
}
22772275

22782276
bool Match(std::string_view s) {
2279-
return RE2::PartialMatchN(ToStringPiece(s), *data.regex(), args_pointers_start,
2277+
return RE2::PartialMatchN(ToStringPiece(s), *data.regex, args_pointers_start,
22802278
group_count);
22812279
}
22822280
};
@@ -2291,18 +2289,16 @@ struct ExtractRegex : public ExtractRegexBase {
22912289
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
22922290
ExtractRegexOptions options = ExtractRegexState::Get(ctx);
22932291
ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, Type::is_utf8));
2294-
return ExtractRegex{data}.Extract(ctx, batch, out);
2292+
return ExtractRegex(data).Extract(ctx, batch, out);
22952293
}
22962294

22972295
Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
2298-
ExtractRegexOptions options = ExtractRegexState::Get(ctx);
2299-
DCHECK_NE(out->array_data(), NULLPTR);
2296+
DCHECK_NE(out->array_data(), nullptr);
23002297
std::shared_ptr<DataType> type = out->array_data()->type;
2301-
DCHECK_NE(type, NULLPTR);
2302-
2303-
std::unique_ptr<ArrayBuilder> array_builder;
2304-
RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type, &array_builder));
2305-
StructBuilder* struct_builder = checked_cast<StructBuilder*>(array_builder.get());
2298+
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<ArrayBuilder> array_builder,
2299+
MakeBuilder(type, ctx->memory_pool()));
2300+
auto struct_builder = checked_pointer_cast<StructBuilder>(std::move(array_builder));
2301+
ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length));
23062302

23072303
std::vector<BuilderType*> field_builders;
23082304
field_builders.reserve(group_count);
@@ -2357,82 +2353,83 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) {
23572353
}
23582354
DCHECK_OK(registry->AddFunction(std::move(func)));
23592355
}
2360-
class ExtractRegexSpanData : public ExtractRegexData {
2361-
public:
2362-
static Result<ExtractRegexSpanData> Make(const std::string& pattern) {
2363-
auto data = ExtractRegexSpanData(pattern, true);
2356+
struct ExtractRegexSpanData : public BaseExtractRegexData {
2357+
static Result<ExtractRegexSpanData> Make(const std::string& pattern,
2358+
bool is_utf8 = true) {
2359+
auto data = ExtractRegexSpanData(pattern, is_utf8);
23642360
ARROW_RETURN_NOT_OK(data.Init());
23652361
return data;
23662362
}
23672363

23682364
Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const {
23692365
const DataType* input_type = types[0].type;
2370-
if (input_type == NULLPTR) {
2371-
return NULLPTR;
2366+
if (input_type == nullptr) {
2367+
return nullptr;
23722368
}
23732369
DCHECK(is_base_binary_like(input_type->id()));
2374-
const size_t field_count = group_names_.size();
2370+
const size_t field_count = num_groups();
23752371
FieldVector fields;
23762372
fields.reserve(field_count);
2377-
const auto owned_type = input_type->GetSharedPtr();
2378-
for (const auto& group_name : group_names_) {
2379-
auto type = is_binary_like(owned_type->id()) ? int32() : int64();
2373+
auto index_type = is_binary_like(input_type->id()) ? int32() : int64();
2374+
for (const auto& group_name : group_names) {
23802375
// size list is 2 as every span contains position and length
2381-
fields.push_back(field(group_name + "_span", fixed_size_list(type, 2)));
2376+
fields.push_back(field(group_name, fixed_size_list(index_type, 2)));
23822377
}
2383-
return struct_(fields);
2378+
return struct_(std::move(fields));
23842379
}
23852380

23862381
private:
23872382
ExtractRegexSpanData(const std::string& pattern, const bool is_utf8)
2388-
: ExtractRegexData(pattern, is_utf8) {}
2383+
: BaseExtractRegexData(pattern, is_utf8) {}
23892384
};
23902385

23912386
template <typename Type>
23922387
struct ExtractRegexSpan : ExtractRegexBase {
23932388
using ArrayType = typename TypeTraits<Type>::ArrayType;
23942389
using BuilderType = typename TypeTraits<Type>::BuilderType;
2390+
using offset_type = typename Type::offset_type;
2391+
using OffsetBuilderType =
2392+
typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::BuilderType;
2393+
using OffsetCType =
2394+
typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::CType;
2395+
23952396
using ExtractRegexBase::ExtractRegexBase;
23962397

23972398
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
23982399
auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(ctx);
2399-
ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexSpanData::Make(options.pattern));
2400+
ARROW_ASSIGN_OR_RAISE(auto data,
2401+
ExtractRegexSpanData::Make(options.pattern, Type::is_utf8));
24002402
return ExtractRegexSpan{data}.Extract(ctx, batch, out);
24012403
}
24022404
Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
2403-
DCHECK_NE(out->array_data(), NULLPTR);
2405+
DCHECK_NE(out->array_data(), nullptr);
24042406
std::shared_ptr<DataType> out_type = out->array_data()->type;
2405-
DCHECK_NE(out_type, NULLPTR);
2406-
std::unique_ptr<ArrayBuilder> out_builder;
2407-
ARROW_RETURN_NOT_OK(
2408-
MakeBuilder(ctx->memory_pool(), out->type()->GetSharedPtr(), &out_builder));
2407+
ARROW_ASSIGN_OR_RAISE(auto out_builder, MakeBuilder(out_type, ctx->memory_pool()));
24092408
auto struct_builder = checked_pointer_cast<StructBuilder>(std::move(out_builder));
2409+
ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length));
24102410
std::vector<FixedSizeListBuilder*> span_builders;
2411-
std::vector<ArrayBuilder*> array_builders;
2411+
std::vector<OffsetBuilderType*> array_builders;
24122412
span_builders.reserve(group_count);
24132413
array_builders.reserve(group_count);
24142414
for (int i = 0; i < group_count; i++) {
24152415
span_builders.push_back(
24162416
checked_cast<FixedSizeListBuilder*>(struct_builder->field_builder(i)));
2417-
array_builders.push_back(span_builders[i]->value_builder());
2417+
array_builders.push_back(
2418+
checked_cast<OffsetBuilderType*>(span_builders.back()->value_builder()));
2419+
RETURN_NOT_OK(span_builders.back()->Reserve(batch[0].array.length));
2420+
RETURN_NOT_OK(array_builders.back()->Reserve(2 * batch[0].array.length));
24182421
}
2422+
24192423
auto visit_null = [&]() { return struct_builder->AppendNull(); };
24202424
auto visit_value = [&](std::string_view element) -> Status {
24212425
if (Match(element)) {
24222426
for (int i = 0; i < group_count; i++) {
24232427
// https://github.com/google/re2/issues/24#issuecomment-97653183
2424-
if (found_values[i].data() != NULLPTR) {
2428+
if (found_values[i].data() != nullptr) {
24252429
int64_t begin = found_values[i].data() - element.data();
24262430
int64_t size = found_values[i].size();
2427-
if (is_binary_like(batch.GetTypes()[0].id())) {
2428-
ARROW_RETURN_NOT_OK(checked_cast<Int32Builder*>(array_builders[i])
2429-
->AppendValues({static_cast<int32_t>(begin),
2430-
static_cast<int32_t>(size)}));
2431-
} else {
2432-
ARROW_RETURN_NOT_OK(checked_cast<Int64Builder*>(array_builders[i])
2433-
->AppendValues({begin, size}));
2434-
}
2435-
2431+
array_builders[i]->UnsafeAppend(static_cast<OffsetCType>(begin));
2432+
array_builders[i]->UnsafeAppend(static_cast<OffsetCType>(size));
24362433
ARROW_RETURN_NOT_OK(span_builders[i]->Append());
24372434
} else {
24382435
ARROW_RETURN_NOT_OK(span_builders[i]->AppendNull());
@@ -2448,25 +2445,33 @@ struct ExtractRegexSpan : ExtractRegexBase {
24482445
VisitArraySpanInline<Type>(batch[0].array, visit_value, visit_null));
24492446

24502447
ARROW_ASSIGN_OR_RAISE(auto out_array, struct_builder->Finish());
2451-
out->value = out_array->data();
2448+
out->value = std::move(out_array->data());
24522449
return Status::OK();
24532450
}
24542451
};
24552452

2456-
const FunctionDoc extract_regex_doc_span(
2457-
"likes extract_regex; however, it contains the position and length of results", "",
2453+
const FunctionDoc extract_regex_span_doc(
2454+
"Extract substrings captured by a regex pattern and Save the result in the form of "
2455+
"(offset,length)",
2456+
"For each string in strings, match the regular expression and, if\n"
2457+
"successful, emit a struct with field names and values coming from the\n"
2458+
"regular expression's named capture groups, which are stored in a form of a\n "
2459+
"fixed_size_list(offset, length). If the input is null or the regular \n"
2460+
"expression Fails matching, a null output value is emitted.\n"
2461+
"Regular expression matching is done using the Google RE2 library.",
24582462
{"strings"}, "ExtractRegexSpanOptions", true);
24592463

2460-
Result<TypeHolder> resolver(KernelContext* ctx, const std::vector<TypeHolder>& types) {
2464+
Result<TypeHolder> ResolveExtractRegexSpanOutputType(
2465+
KernelContext* ctx, const std::vector<TypeHolder>& types) {
24612466
auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(*ctx->state());
24622467
ARROW_ASSIGN_OR_RAISE(auto span, ExtractRegexSpanData::Make(options.pattern));
24632468
return span.ResolveOutputType(types);
24642469
}
24652470

24662471
void AddAsciiStringExtractRegexSpan(FunctionRegistry* registry) {
24672472
auto func = std::make_shared<ScalarFunction>("extract_regex_span", Arity::Unary(),
2468-
extract_regex_doc_span);
2469-
OutputType output_type(resolver);
2473+
extract_regex_span_doc);
2474+
OutputType output_type(ResolveExtractRegexSpanOutputType);
24702475
for (const auto& type : BaseBinaryTypes()) {
24712476
ScalarKernel kernel({type}, output_type,
24722477
GenerateVarBinaryToVarBinary<ExtractRegexSpan>(type),

cpp/src/arrow/compute/kernels/scalar_string_test.cc

+49-18
Original file line numberDiff line numberDiff line change
@@ -1960,28 +1960,59 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegex) {
19601960
R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "3"}])",
19611961
&options);
19621962
}
1963-
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSapn) {
1964-
ExtractRegexSpanOptions options{"(?P<letter>[ab])(?P<digit>\\d)"};
1963+
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpan) {
1964+
ExtractRegexSpanOptions options{"(?P<letter>[ab]+)(?P<digit>\\d+)"};
19651965
auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64();
1966-
auto out_type = struct_({field("letter_span", fixed_size_list(type_fixe_size_list, 2)),
1967-
field("digit_span", fixed_size_list(type_fixe_size_list, 2))});
1966+
auto out_type = struct_({field("letter", fixed_size_list(type_fixe_size_list, 2)),
1967+
field("digit", fixed_size_list(type_fixe_size_list, 2))});
19681968
this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options);
1969+
this->CheckUnary("extract_regex_span", R"([ null,"123ab","cd123ab","cd123abef"])",
1970+
out_type, R"([null,null,null,null])", &options);
19691971
this->CheckUnary(
1970-
"extract_regex_span", R"(["a1", "b2", "c3", null])", out_type,
1971-
R"([{"letter_span":[0,1], "digit_span":[1,1]}, {"letter_span":[0,1], "digit_span":[1,1]}, null, null])",
1972-
&options);
1973-
this->CheckUnary(
1974-
"extract_regex_span", R"(["a1", "c3", null, "b2"])", out_type,
1975-
R"([{"letter_span":[0,1], "digit_span": [1,1]}, null, null, {"letter_span":[0,1], "digit_span":[1,1]}])",
1976-
&options);
1977-
this->CheckUnary(
1978-
"extract_regex_span", R"(["a1", "b2"])", out_type,
1979-
R"([{"letter_span": [0,1], "digit_span": [1,1]}, {"letter_span": [0,1], "digit_span": [1,1]}])",
1980-
&options);
1981-
this->CheckUnary(
1982-
"extract_regex_span", R"(["a1", "zb3z"])", out_type,
1983-
R"([{"letter_span": [0,1], "digit_span": [1,1]}, {"letter_span": [1,1], "digit_span": [2,1]}])",
1972+
"extract_regex_span",
1973+
R"(["a1", "b2", "c3", null,"123ab","abb12","abc13","cedbb15","cedaabb125efg"])",
1974+
out_type,
1975+
R"([{"letter":[0,1], "digit":[1,1]},
1976+
{"letter":[0,1], "digit":[1,1]},
1977+
null,
1978+
null,
1979+
null,
1980+
{"letter":[0,3], "digit":[3,2]},
1981+
null,
1982+
{"letter":[3,2], "digit":[5,2]},
1983+
{"letter":[3,4], "digit":[7,3]}])",
19841984
&options);
1985+
this->CheckUnary("extract_regex_span", R"([ "a3","b2","cdaa123","cdab123ef"])",
1986+
out_type,
1987+
R"([{"letter":[0,1], "digit":[1,1]},
1988+
{"letter":[0,1], "digit":[1,1]},
1989+
{"letter":[2,2], "digit":[4,3]},
1990+
{"letter":[2,2], "digit":[4,3]}])",
1991+
&options);
1992+
}
1993+
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanCaptureOption) {
1994+
ExtractRegexSpanOptions options{"(?P<foo>foo)?(?P<digit>\\d+)?"};
1995+
auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64();
1996+
auto out_type = struct_({field("foo", fixed_size_list(type_fixe_size_list, 2)),
1997+
field("digit", fixed_size_list(type_fixe_size_list, 2))});
1998+
this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options);
1999+
this->CheckUnary("extract_regex_span", R"(["foo","foo123","abcfoo123","abc",null])",
2000+
out_type,
2001+
R"([{"foo":[0,3],"digit":null},
2002+
{"foo":[0,3],"digit":[3,3]},
2003+
{"foo":null,"digit":null},
2004+
{"foo":null,"digit":null},
2005+
null])",
2006+
&options);
2007+
options = ExtractRegexSpanOptions{"(?P<foo>foo)(?P<digit>\\d+)?"};
2008+
this->CheckUnary("extract_regex_span", R"(["foo123","foo","123","abc","abcfoo"])",
2009+
out_type,
2010+
R"([{"foo":[0,3],"digit":[3,3]},
2011+
{"foo":[0,3],"digit":null},
2012+
null,
2013+
null,
2014+
{"foo":[3,3],"digit":null}])",
2015+
&options);
19852016
}
19862017

19872018
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoCapture) {

0 commit comments

Comments
 (0)