Skip to content

Commit 675e3c4

Browse files
fix casting,includes and equal methods with relevant tests
1 parent 470961b commit 675e3c4

File tree

5 files changed

+79
-22
lines changed

5 files changed

+79
-22
lines changed

cpp/src/arrow/array/statistics.cc

+39-3
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
// This empty .cc file is for embedding not inlined symbols in
19-
// arrow::ArrayStatistics into libarrow.
20-
2118
#include "arrow/array/statistics.h"
2219

20+
#include <memory>
21+
#include <type_traits>
22+
#include <variant>
23+
2324
#include "arrow/scalar.h"
25+
#include "arrow/type.h"
2426

2527
namespace arrow {
2628

@@ -56,4 +58,38 @@ const std::shared_ptr<DataType>& ArrayStatistics::ValueToArrowType(
5658
} visitor{array_type};
5759
return std::visit(visitor, value.value());
5860
}
61+
namespace {
62+
bool ValueTypeEquality(const std::optional<ArrayStatistics::ValueType>& first,
63+
const std::optional<ArrayStatistics::ValueType>& second) {
64+
if (first == second) {
65+
return true;
66+
}
67+
if (!first || !second) {
68+
return false;
69+
}
70+
71+
return std::visit(
72+
[](auto& v1, auto& v2) {
73+
if constexpr (std::is_same_v<std::decay_t<decltype(v1)>,
74+
std::shared_ptr<Scalar>> &&
75+
std::is_same_v<std::decay_t<decltype(v2)>,
76+
std::shared_ptr<Scalar>>) {
77+
if (!v1 || !v2) {
78+
// both null case is handled in std::optional and return true
79+
return false;
80+
}
81+
return v1->Equals(*v2);
82+
}
83+
return false;
84+
},
85+
first.value(), second.value());
86+
}
87+
} // namespace
88+
bool ArrayStatistics::Equals(const ArrayStatistics& other) const {
89+
return null_count == other.null_count && distinct_count == other.distinct_count &&
90+
is_min_exact == other.is_min_exact && is_max_exact == other.is_max_exact &&
91+
ValueTypeEquality(this->max, other.max) &&
92+
ValueTypeEquality(this->min, other.min);
93+
}
94+
5995
} // namespace arrow

cpp/src/arrow/array/statistics.h

+4-7
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
#pragma once
1919

2020
#include <cstdint>
21+
#include <memory>
2122
#include <optional>
2223
#include <string>
2324
#include <variant>
2425

25-
#include "arrow/type.h"
2626
#include "arrow/util/visibility.h"
2727

2828
namespace arrow {
29-
29+
class DataType;
30+
struct Scalar;
3031
/// \class ArrayStatistics
3132
/// \brief Statistics for an Array
3233
///
@@ -100,11 +101,7 @@ struct ARROW_EXPORT ArrayStatistics {
100101
bool is_max_exact = false;
101102

102103
/// \brief Check two statistics for equality
103-
bool Equals(const ArrayStatistics& other) const {
104-
return null_count == other.null_count && distinct_count == other.distinct_count &&
105-
min == other.min && is_min_exact == other.is_min_exact && max == other.max &&
106-
is_max_exact == other.is_max_exact;
107-
}
104+
bool Equals(const ArrayStatistics& other) const;
108105

109106
/// \brief Check two statistics for equality
110107
bool operator==(const ArrayStatistics& other) const { return Equals(other); }

cpp/src/arrow/array/statistics_test.cc

+24
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <gtest/gtest.h>
1919

2020
#include "arrow/array/statistics.h"
21+
#include "arrow/scalar.h"
22+
#include "arrow/testing/gtest_util.h"
2123

2224
namespace arrow {
2325

@@ -91,6 +93,28 @@ TEST(ArrayStatisticsTest, TestEquality) {
9193
ASSERT_NE(statistics1, statistics2);
9294
statistics2.max = static_cast<int64_t>(-29);
9395
ASSERT_EQ(statistics1, statistics2);
96+
statistics2.max = static_cast<int64_t>(2);
97+
ASSERT_NE(statistics1, statistics2);
98+
99+
statistics1.max = std::nullopt;
100+
statistics2.max = std::nullopt;
101+
// check the state of both of them are std::nullopt
102+
ASSERT_EQ(statistics1.max, statistics2.max);
103+
// the state of one of them is std::nullopt
104+
statistics1.max = std::shared_ptr<Scalar>();
105+
ASSERT_NE(statistics1, statistics2);
106+
// the state of both of them are nullptr
107+
statistics2.max = std::shared_ptr<Scalar>();
108+
ASSERT_EQ(statistics1, statistics2);
109+
ASSERT_OK_AND_ASSIGN(statistics1.max, MakeScalar(int64(), 5));
110+
// the state of one of them is nullptr
111+
ASSERT_NE(statistics1, statistics2);
112+
// the state of one of them has different type
113+
statistics2.max = static_cast<int64_t>(10);
114+
ASSERT_NE(statistics1.max, statistics2.max);
115+
ASSERT_OK_AND_ASSIGN(statistics2.max, MakeScalar(int64(), 5));
116+
// the state of both of them are equal
117+
ASSERT_EQ(statistics1, statistics2);
94118

95119
statistics1.is_max_exact = true;
96120
ASSERT_NE(statistics1, statistics2);

cpp/src/arrow/record_batch.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -520,17 +520,17 @@ Status EnumerateStatistics(const RecordBatch& record_batch, OnStatistics on_stat
520520
RETURN_NOT_OK(on_statistics(statistics));
521521
statistics.start_new_column = false;
522522

523-
const auto& schema = ExtractArrayDataAndType(record_batch);
524-
auto num_fields = static_cast<int64_t>(schema.size());
523+
const auto& array_statistics_and_type_vector = ExtractArrayDataAndType(record_batch);
524+
auto num_fields = static_cast<int64_t>(array_statistics_and_type_vector.size());
525525
for (int64_t nth_column = 0; nth_column < num_fields; ++nth_column) {
526-
const auto& type = schema[nth_column].second;
527-
auto column_statistics = schema[nth_column].first;
526+
const auto& type = array_statistics_and_type_vector[nth_column].second;
527+
auto column_statistics = array_statistics_and_type_vector[nth_column].first;
528528
if (!column_statistics) {
529529
continue;
530530
}
531531

532532
statistics.start_new_column = true;
533-
statistics.nth_column = nth_column;
533+
statistics.nth_column = static_cast<int32_t>(nth_column);
534534
if (column_statistics->null_count.has_value()) {
535535
statistics.nth_statistics++;
536536
statistics.key = ARROW_STATISTICS_KEY_NULL_COUNT_EXACT;

cpp/src/arrow/record_batch_test.cc

+7-7
Original file line numberDiff line numberDiff line change
@@ -1474,16 +1474,16 @@ TEST_F(TestRecordBatch, MakeStatisticsArrayNestedType) {
14741474
statistics_struct->null_count = 0;
14751475
auto struct_array_data = struct_array->data();
14761476
auto statistics_struct_child_a = std::make_shared<ArrayStatistics>();
1477-
statistics_struct_child_a->min = 1;
1477+
statistics_struct_child_a->min = int64_t{1};
14781478
struct_array_data->statistics = statistics_struct;
14791479
struct_array_data->child_data[0]->statistics = statistics_struct_child_a;
14801480
auto array_c = ArrayFromJSON(int64(), R"([11,12,13,14,15])");
14811481
array_c->data()->statistics = std::make_shared<ArrayStatistics>();
1482-
array_c->data()->statistics->max = 15;
1482+
array_c->data()->statistics->max = int64_t{15};
14831483
auto array_d = ArrayFromJSON(int64(), R"([16,17,18,19,20])");
14841484
auto nested_child = struct_nested_stat->data()->child_data[0];
14851485
nested_child->statistics = std::make_shared<ArrayStatistics>();
1486-
nested_child->statistics->max = 5;
1486+
nested_child->statistics->max = int64_t{5};
14871487
nested_child->statistics->is_max_exact = true;
14881488

14891489
auto rb_schema =
@@ -1492,8 +1492,8 @@ TEST_F(TestRecordBatch, MakeStatisticsArrayNestedType) {
14921492
auto rb = RecordBatch::Make(rb_schema, 5,
14931493
{struct_array, array_c, array_d, struct_nested_stat});
14941494

1495-
auto expected_scalar = std::static_pointer_cast<Scalar>(std::shared_ptr<StructScalar>(
1496-
new StructScalar({MakeScalar(int64_t{5}), MakeScalar(int64_t{10})}, struct_type)));
1495+
auto expected_scalar = internal::checked_pointer_cast<StructScalar>(
1496+
ScalarFromJSON(struct_type, R"([5,10])"));
14971497
auto a = ArrayStatistics::ValueType{std::static_pointer_cast<Scalar>(expected_scalar)};
14981498

14991499
ASSERT_OK_AND_ASSIGN(
@@ -1536,8 +1536,8 @@ TEST_F(TestRecordBatch, MakeStatisticsArrayNestedNestedType) {
15361536
StructArray::Make({struct_nested_0, struct_nested_1},
15371537
{field("struct_nested_0", struct_nested_0->type()),
15381538
field("struct_nested_1", struct_nested_1->type())}));
1539-
auto expected_scalar = std::static_pointer_cast<Scalar>(std::shared_ptr<StructScalar>(
1540-
new StructScalar({MakeScalar(int32_t{5}), MakeScalar(int32_t{10})}, struct_type)));
1539+
auto expected_scalar = internal::checked_pointer_cast<StructScalar>(
1540+
ScalarFromJSON(struct_type, R"([5,10])"));
15411541
auto rb_schema = schema({field("struct", struct_parent->type())});
15421542
auto rb = RecordBatch::Make(rb_schema, 5, {struct_parent});
15431543

0 commit comments

Comments
 (0)