-
Notifications
You must be signed in to change notification settings - Fork 952
Return valid for all-nulls in reduce() with nunique include-nulls aggregation #19196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
rapids-bot
merged 4 commits into
rapidsai:branch-25.08
from
davidwendt:reduce-nunique-all-nulls
Jun 26, 2025
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
89e548e
Return valid for all-nulls in reduce() with nunique include-nulls agg
davidwendt 2915fcb
update doxygen for empty/all-null case
davidwendt 1b820bf
Merge branch 'branch-25.08' into reduce-nunique-all-nulls
davidwendt 5097b1d
added gtest for this change
davidwendt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,128 +38,170 @@ | |
namespace cudf { | ||
namespace reduction { | ||
namespace detail { | ||
struct reduce_dispatch_functor { | ||
column_view const col; | ||
data_type output_dtype; | ||
std::optional<std::reference_wrapper<scalar const>> init; | ||
rmm::device_async_resource_ref mr; | ||
rmm::cuda_stream_view stream; | ||
|
||
reduce_dispatch_functor(column_view col, | ||
data_type output_dtype, | ||
std::optional<std::reference_wrapper<scalar const>> init, | ||
rmm::cuda_stream_view stream, | ||
rmm::device_async_resource_ref mr) | ||
: col(std::move(col)), output_dtype(output_dtype), init(init), mr(mr), stream(stream) | ||
{ | ||
namespace { | ||
|
||
std::unique_ptr<scalar> reduce_aggregate_impl( | ||
reduce_aggregation const& agg, | ||
column_view col, | ||
data_type output_dtype, | ||
std::optional<std::reference_wrapper<scalar const>> init, | ||
rmm::cuda_stream_view stream, | ||
rmm::device_async_resource_ref mr) | ||
{ | ||
switch (agg.kind) { | ||
case aggregation::SUM: return sum(col, output_dtype, init, stream, mr); | ||
case aggregation::PRODUCT: return product(col, output_dtype, init, stream, mr); | ||
case aggregation::MIN: return min(col, output_dtype, init, stream, mr); | ||
case aggregation::MAX: return max(col, output_dtype, init, stream, mr); | ||
case aggregation::ANY: return any(col, output_dtype, init, stream, mr); | ||
case aggregation::ALL: return all(col, output_dtype, init, stream, mr); | ||
case aggregation::HISTOGRAM: return histogram(col, stream, mr); | ||
case aggregation::MERGE_HISTOGRAM: return merge_histogram(col, stream, mr); | ||
case aggregation::SUM_OF_SQUARES: return sum_of_squares(col, output_dtype, stream, mr); | ||
case aggregation::MEAN: return mean(col, output_dtype, stream, mr); | ||
case aggregation::VARIANCE: { | ||
auto var_agg = static_cast<cudf::detail::var_aggregation const&>(agg); | ||
return variance(col, output_dtype, var_agg._ddof, stream, mr); | ||
} | ||
case aggregation::STD: { | ||
auto var_agg = static_cast<cudf::detail::std_aggregation const&>(agg); | ||
return standard_deviation(col, output_dtype, var_agg._ddof, stream, mr); | ||
} | ||
case aggregation::MEDIAN: { | ||
auto current_mr = cudf::get_current_device_resource_ref(); | ||
auto sorted_indices = | ||
cudf::detail::sorted_order(table_view{{col}}, {}, {null_order::AFTER}, stream, current_mr); | ||
auto valid_sorted_indices = | ||
cudf::detail::split(*sorted_indices, {col.size() - col.null_count()}, stream)[0]; | ||
auto col_ptr = cudf::detail::quantile( | ||
col, {0.5}, interpolation::LINEAR, valid_sorted_indices, true, stream, current_mr); | ||
return cudf::detail::get_element(*col_ptr, 0, stream, mr); | ||
} | ||
case aggregation::QUANTILE: { | ||
auto quantile_agg = static_cast<cudf::detail::quantile_aggregation const&>(agg); | ||
CUDF_EXPECTS(quantile_agg._quantiles.size() == 1, | ||
"Reduction quantile accepts only one quantile value"); | ||
auto current_mr = cudf::get_current_device_resource_ref(); | ||
auto sorted_indices = | ||
cudf::detail::sorted_order(table_view{{col}}, {}, {null_order::AFTER}, stream, current_mr); | ||
auto valid_sorted_indices = | ||
cudf::detail::split(*sorted_indices, {col.size() - col.null_count()}, stream)[0]; | ||
|
||
auto col_ptr = cudf::detail::quantile(col, | ||
quantile_agg._quantiles, | ||
quantile_agg._interpolation, | ||
valid_sorted_indices, | ||
true, | ||
stream, | ||
current_mr); | ||
return cudf::detail::get_element(*col_ptr, 0, stream, mr); | ||
} | ||
case aggregation::NUNIQUE: { | ||
auto nunique_agg = static_cast<cudf::detail::nunique_aggregation const&>(agg); | ||
return cudf::make_fixed_width_scalar( | ||
cudf::detail::distinct_count( | ||
col, nunique_agg._null_handling, nan_policy::NAN_IS_VALID, stream), | ||
stream, | ||
mr); | ||
} | ||
case aggregation::NTH_ELEMENT: { | ||
auto nth_agg = static_cast<cudf::detail::nth_element_aggregation const&>(agg); | ||
return nth_element(col, nth_agg._n, nth_agg._null_handling, stream, mr); | ||
} | ||
case aggregation::COLLECT_LIST: { | ||
auto col_agg = static_cast<cudf::detail::collect_list_aggregation const&>(agg); | ||
return collect_list(col, col_agg._null_handling, stream, mr); | ||
} | ||
case aggregation::COLLECT_SET: { | ||
auto col_agg = static_cast<cudf::detail::collect_set_aggregation const&>(agg); | ||
return collect_set( | ||
col, col_agg._null_handling, col_agg._nulls_equal, col_agg._nans_equal, stream, mr); | ||
} | ||
case aggregation::MERGE_LISTS: { | ||
return merge_lists(col, stream, mr); | ||
} | ||
case aggregation::MERGE_SETS: { | ||
auto col_agg = static_cast<cudf::detail::merge_sets_aggregation const&>(agg); | ||
return merge_sets(col, col_agg._nulls_equal, col_agg._nans_equal, stream, mr); | ||
} | ||
case aggregation::TDIGEST: { | ||
CUDF_EXPECTS(output_dtype.id() == type_id::STRUCT, | ||
"Tdigest aggregations expect output type to be STRUCT"); | ||
auto td_agg = static_cast<cudf::detail::tdigest_aggregation const&>(agg); | ||
return tdigest::detail::reduce_tdigest(col, td_agg.max_centroids, stream, mr); | ||
} | ||
case aggregation::MERGE_TDIGEST: { | ||
CUDF_EXPECTS(output_dtype.id() == type_id::STRUCT, | ||
"Tdigest aggregations expect output type to be STRUCT"); | ||
auto td_agg = static_cast<cudf::detail::merge_tdigest_aggregation const&>(agg); | ||
return tdigest::detail::reduce_merge_tdigest(col, td_agg.max_centroids, stream, mr); | ||
} | ||
case aggregation::HOST_UDF: { | ||
auto const& udf_base_ptr = | ||
dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr; | ||
auto const udf_ptr = dynamic_cast<reduce_host_udf const*>(udf_base_ptr.get()); | ||
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid HOST_UDF instance for reduction."); | ||
return (*udf_ptr)(col, output_dtype, init, stream, mr); | ||
} // case aggregation::HOST_UDF | ||
case aggregation::BITWISE_AGG: { | ||
auto const bitwise_agg = static_cast<cudf::detail::bitwise_aggregation const&>(agg); | ||
return bitwise_reduction(bitwise_agg.bit_op, col, stream, mr); | ||
} | ||
default: CUDF_FAIL("Unsupported reduction operator"); | ||
} | ||
} | ||
|
||
template <aggregation::Kind k> | ||
std::unique_ptr<scalar> operator()(reduce_aggregation const& agg) | ||
{ | ||
switch (k) { | ||
case aggregation::SUM: return sum(col, output_dtype, init, stream, mr); | ||
case aggregation::PRODUCT: return product(col, output_dtype, init, stream, mr); | ||
case aggregation::MIN: return min(col, output_dtype, init, stream, mr); | ||
case aggregation::MAX: return max(col, output_dtype, init, stream, mr); | ||
case aggregation::ANY: return any(col, output_dtype, init, stream, mr); | ||
case aggregation::ALL: return all(col, output_dtype, init, stream, mr); | ||
case aggregation::HISTOGRAM: return histogram(col, stream, mr); | ||
case aggregation::MERGE_HISTOGRAM: return merge_histogram(col, stream, mr); | ||
case aggregation::SUM_OF_SQUARES: return sum_of_squares(col, output_dtype, stream, mr); | ||
case aggregation::MEAN: return mean(col, output_dtype, stream, mr); | ||
case aggregation::VARIANCE: { | ||
auto var_agg = static_cast<cudf::detail::var_aggregation const&>(agg); | ||
return variance(col, output_dtype, var_agg._ddof, stream, mr); | ||
} | ||
case aggregation::STD: { | ||
auto var_agg = static_cast<cudf::detail::std_aggregation const&>(agg); | ||
return standard_deviation(col, output_dtype, var_agg._ddof, stream, mr); | ||
} | ||
case aggregation::MEDIAN: { | ||
auto current_mr = cudf::get_current_device_resource_ref(); | ||
auto sorted_indices = cudf::detail::sorted_order( | ||
table_view{{col}}, {}, {null_order::AFTER}, stream, current_mr); | ||
auto valid_sorted_indices = | ||
cudf::detail::split(*sorted_indices, {col.size() - col.null_count()}, stream)[0]; | ||
auto col_ptr = cudf::detail::quantile( | ||
col, {0.5}, interpolation::LINEAR, valid_sorted_indices, true, stream, current_mr); | ||
return cudf::detail::get_element(*col_ptr, 0, stream, mr); | ||
} | ||
case aggregation::QUANTILE: { | ||
auto quantile_agg = static_cast<cudf::detail::quantile_aggregation const&>(agg); | ||
CUDF_EXPECTS(quantile_agg._quantiles.size() == 1, | ||
"Reduction quantile accepts only one quantile value"); | ||
auto current_mr = cudf::get_current_device_resource_ref(); | ||
auto sorted_indices = cudf::detail::sorted_order( | ||
table_view{{col}}, {}, {null_order::AFTER}, stream, current_mr); | ||
auto valid_sorted_indices = | ||
cudf::detail::split(*sorted_indices, {col.size() - col.null_count()}, stream)[0]; | ||
|
||
auto col_ptr = cudf::detail::quantile(col, | ||
quantile_agg._quantiles, | ||
quantile_agg._interpolation, | ||
valid_sorted_indices, | ||
true, | ||
stream, | ||
current_mr); | ||
return cudf::detail::get_element(*col_ptr, 0, stream, mr); | ||
} | ||
case aggregation::NUNIQUE: { | ||
auto nunique_agg = static_cast<cudf::detail::nunique_aggregation const&>(agg); | ||
return cudf::make_fixed_width_scalar( | ||
cudf::detail::distinct_count( | ||
col, nunique_agg._null_handling, nan_policy::NAN_IS_VALID, stream), | ||
stream, | ||
mr); | ||
} | ||
case aggregation::NTH_ELEMENT: { | ||
auto nth_agg = static_cast<cudf::detail::nth_element_aggregation const&>(agg); | ||
return nth_element(col, nth_agg._n, nth_agg._null_handling, stream, mr); | ||
} | ||
case aggregation::COLLECT_LIST: { | ||
auto col_agg = static_cast<cudf::detail::collect_list_aggregation const&>(agg); | ||
return collect_list(col, col_agg._null_handling, stream, mr); | ||
} | ||
case aggregation::COLLECT_SET: { | ||
auto col_agg = static_cast<cudf::detail::collect_set_aggregation const&>(agg); | ||
return collect_set( | ||
col, col_agg._null_handling, col_agg._nulls_equal, col_agg._nans_equal, stream, mr); | ||
} | ||
case aggregation::MERGE_LISTS: { | ||
return merge_lists(col, stream, mr); | ||
} | ||
case aggregation::MERGE_SETS: { | ||
auto col_agg = static_cast<cudf::detail::merge_sets_aggregation const&>(agg); | ||
return merge_sets(col, col_agg._nulls_equal, col_agg._nans_equal, stream, mr); | ||
} | ||
case aggregation::TDIGEST: { | ||
CUDF_EXPECTS(output_dtype.id() == type_id::STRUCT, | ||
"Tdigest aggregations expect output type to be STRUCT"); | ||
auto td_agg = static_cast<cudf::detail::tdigest_aggregation const&>(agg); | ||
return tdigest::detail::reduce_tdigest(col, td_agg.max_centroids, stream, mr); | ||
} | ||
case aggregation::MERGE_TDIGEST: { | ||
CUDF_EXPECTS(output_dtype.id() == type_id::STRUCT, | ||
"Tdigest aggregations expect output type to be STRUCT"); | ||
auto td_agg = static_cast<cudf::detail::merge_tdigest_aggregation const&>(agg); | ||
return tdigest::detail::reduce_merge_tdigest(col, td_agg.max_centroids, stream, mr); | ||
} | ||
case aggregation::HOST_UDF: { | ||
auto const& udf_base_ptr = | ||
dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr; | ||
auto const udf_ptr = dynamic_cast<reduce_host_udf const*>(udf_base_ptr.get()); | ||
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid HOST_UDF instance for reduction."); | ||
return (*udf_ptr)(col, output_dtype, init, stream, mr); | ||
} // case aggregation::HOST_UDF | ||
case aggregation::BITWISE_AGG: { | ||
auto const bitwise_agg = static_cast<cudf::detail::bitwise_aggregation const&>(agg); | ||
return bitwise_reduction(bitwise_agg.bit_op, col, stream, mr); | ||
} | ||
default: CUDF_FAIL("Unsupported reduction operator"); | ||
/** | ||
* @brief Specialized implementation for empty or all-null input | ||
* | ||
* This implementation is used to handle the case where the input column is empty or all null. | ||
* It returns a scalar with the appropriate value for the reduction operation. | ||
* | ||
* @param agg The reduction operation to perform | ||
* @param col The input column | ||
* @param output_dtype The output data type | ||
* @param stream The CUDA stream to use | ||
* @param mr The memory resource to use | ||
* @return A scalar with the appropriate value for the reduction operation | ||
*/ | ||
std::unique_ptr<scalar> reduce_no_data_impl(reduce_aggregation const& agg, | ||
column_view col, | ||
data_type output_dtype, | ||
rmm::cuda_stream_view stream, | ||
rmm::device_async_resource_ref mr) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function just consolidates the many cases that were in the if-empty-all-nulls statement in the |
||
{ | ||
switch (agg.kind) { | ||
case aggregation::TDIGEST: [[fallthrough]]; | ||
case aggregation::MERGE_TDIGEST: return tdigest::detail::make_empty_tdigest_scalar(stream, mr); | ||
case aggregation::HISTOGRAM: | ||
return std::make_unique<list_scalar>( | ||
std::move(*reduction::detail::make_empty_histogram_like(col)), true, stream, mr); | ||
case aggregation::MERGE_HISTOGRAM: | ||
return std::make_unique<list_scalar>( | ||
std::move(*reduction::detail::make_empty_histogram_like(col.child(0))), true, stream, mr); | ||
case aggregation::COLLECT_LIST: [[fallthrough]]; | ||
case aggregation::COLLECT_SET: { | ||
auto scalar = make_list_scalar(empty_like(col)->view(), stream, mr); | ||
scalar->set_valid_async(false, stream); | ||
return scalar; | ||
} | ||
case aggregation::ANY: [[fallthrough]]; | ||
case aggregation::ALL: { | ||
return std::make_unique<numeric_scalar<bool>>(agg.kind == aggregation::ALL, true, stream, mr); | ||
} | ||
case aggregation::NUNIQUE: { | ||
auto nunique_agg = static_cast<cudf::detail::nunique_aggregation const&>(agg); | ||
auto valid = !col.is_empty() && (nunique_agg._null_handling == cudf::null_policy::INCLUDE); | ||
return std::make_unique<numeric_scalar<size_type>>(!col.is_empty(), valid, stream, mr); | ||
} | ||
default: { | ||
return cudf::is_nested(output_dtype) | ||
? make_empty_scalar_like(col, stream, mr) | ||
: make_default_constructed_scalar(output_dtype, stream, mr); | ||
} | ||
} | ||
}; | ||
} | ||
} // namespace | ||
|
||
std::unique_ptr<scalar> reduce(column_view const& col, | ||
reduce_aggregation const& agg, | ||
|
@@ -181,40 +223,9 @@ std::unique_ptr<scalar> reduce(column_view const& col, | |
} | ||
|
||
// Returns default scalar if input column is empty or all null | ||
if (col.size() <= col.null_count()) { | ||
if (agg.kind == aggregation::TDIGEST || agg.kind == aggregation::MERGE_TDIGEST) { | ||
return tdigest::detail::make_empty_tdigest_scalar(stream, mr); | ||
} | ||
|
||
if (agg.kind == aggregation::HISTOGRAM) { | ||
return std::make_unique<list_scalar>( | ||
std::move(*reduction::detail::make_empty_histogram_like(col)), true, stream, mr); | ||
} | ||
if (agg.kind == aggregation::MERGE_HISTOGRAM) { | ||
return std::make_unique<list_scalar>( | ||
std::move(*reduction::detail::make_empty_histogram_like(col.child(0))), true, stream, mr); | ||
} | ||
|
||
if (agg.kind == aggregation::COLLECT_LIST || agg.kind == aggregation::COLLECT_SET) { | ||
auto scalar = make_list_scalar(empty_like(col)->view(), stream, mr); | ||
scalar->set_valid_async(false, stream); | ||
return scalar; | ||
} | ||
|
||
// `make_default_constructed_scalar` does not support nested type. | ||
if (cudf::is_nested(output_dtype)) { return make_empty_scalar_like(col, stream, mr); } | ||
|
||
auto result = make_default_constructed_scalar(output_dtype, stream, mr); | ||
if (agg.kind == aggregation::ANY || agg.kind == aggregation::ALL) { | ||
// empty input should return false for ANY and return true for ALL | ||
dynamic_cast<numeric_scalar<bool>*>(result.get()) | ||
->set_value(agg.kind == aggregation::ALL, stream); | ||
} | ||
return result; | ||
} | ||
|
||
return cudf::detail::aggregation_dispatcher( | ||
agg.kind, reduce_dispatch_functor{col, output_dtype, init, stream, mr}, agg); | ||
return (col.size() == col.null_count()) | ||
? reduce_no_data_impl(agg, col, output_dtype, stream, mr) | ||
: reduce_aggregate_impl(agg, col, output_dtype, init, stream, mr); | ||
} | ||
} // namespace detail | ||
} // namespace reduction | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code has not changed. The functor
operator()
was simply changed to a regular function call.So the change just moved the code logic to the left.