Skip to content

Commit 31dd0bb

Browse files
committed
GH-45739: [C++][Python] Fix crash when calling hash_pivot_wider without options
1 parent b3d218c commit 31dd0bb

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

cpp/src/arrow/compute/kernels/hash_aggregate_pivot.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -452,9 +452,11 @@ const FunctionDoc hash_pivot_doc{
452452
} // namespace
453453

454454
void RegisterHashAggregatePivot(FunctionRegistry* registry) {
455+
static const auto default_pivot_options = PivotWiderOptions::Defaults();
456+
455457
{
456-
auto func = std::make_shared<HashAggregateFunction>("hash_pivot_wider",
457-
Arity::Ternary(), hash_pivot_doc);
458+
auto func = std::make_shared<HashAggregateFunction>(
459+
"hash_pivot_wider", Arity::Ternary(), hash_pivot_doc, &default_pivot_options);
458460
for (auto key_type : BaseBinaryTypes()) {
459461
// Anything that scatter() (i.e. take()) accepts can be passed as values
460462
auto sig = KernelSignature::Make(

python/pyarrow/tests/test_table.py

+26
Original file line numberDiff line numberDiff line change
@@ -2975,6 +2975,32 @@ def test_table_group_by_first():
29752975
assert result.equals(expected)
29762976

29772977

2978+
@pytest.mark.acero
2979+
def test_table_group_by_pivot_wider():
2980+
table = pa.table({'group': [1, 2, 3, 1, 2, 3],
2981+
'key': ['h', 'h', 'h', 'w', 'w', 'w'],
2982+
'value': [10, 20, 30, 40, 50, 60]})
2983+
2984+
with pytest.raises(ValueError, match='accepts 3 arguments but 2 passed'):
2985+
table.group_by("group").aggregate([("key", "pivot_wider")])
2986+
2987+
# GH-45739: calling hash_pivot_wider without options shouldn't crash
2988+
# (even though it's not very useful as key_names=[])
2989+
result = table.group_by("group").aggregate([(("key", "value"), "pivot_wider")])
2990+
expected = pa.table({'group': [1, 2, 3],
2991+
'key_value_pivot_wider': [{}, {}, {}]})
2992+
assert result.equals(expected)
2993+
2994+
options = pc.PivotWiderOptions(key_names=('h', 'w'))
2995+
result = table.group_by("group").aggregate(
2996+
[(("key", "value"), "pivot_wider", options)])
2997+
expected = pa.table(
2998+
{'group': [1, 2, 3],
2999+
'key_value_pivot_wider': [
3000+
{'h': 10, 'w': 40}, {'h': 20, 'w': 50}, {'h': 30, 'w': 60}]})
3001+
assert result.equals(expected)
3002+
3003+
29783004
def test_table_to_recordbatchreader():
29793005
table = pa.Table.from_pydict({'x': [1, 2, 3]})
29803006
reader = table.to_reader()

0 commit comments

Comments
 (0)