Skip to content

Commit 2c90c6c

Browse files
fix equal and the relevant tests
1 parent c04b586 commit 2c90c6c

File tree

4 files changed

+65
-9
lines changed

4 files changed

+65
-9
lines changed

cpp/src/arrow/array/statistics.cc

+36
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "arrow/array/statistics.h"
2222

2323
#include <memory>
24+
#include <type_traits>
2425
#include <variant>
2526

2627
#include "arrow/scalar.h"
@@ -59,4 +60,39 @@ const std::shared_ptr<DataType>& ArrayStatistics::ValueToArrowType(
5960
} visitor{array_type};
6061
return std::visit(visitor, value.value());
6162
}
63+
namespace {
64+
bool ValueTypeEquality(const std::optional<ArrayStatistics::ValueType>& first,
65+
const std::optional<ArrayStatistics::ValueType>& second) {
66+
if (first == second) {
67+
return true;
68+
}
69+
if (!first || !second) {
70+
return false;
71+
}
72+
// check for std::shared_ptr<Scalar> separately, as it does not have == operator
73+
// overload
74+
return std::visit(
75+
[](auto& v1, auto& v2) {
76+
if constexpr (std::is_same_v<std::decay_t<decltype(v1)>,
77+
std::shared_ptr<Scalar>> &&
78+
std::is_same_v<std::decay_t<decltype(v2)>,
79+
std::shared_ptr<Scalar>>) {
80+
if (!v1 || !v2) {
81+
// both null case is handled in std::optional and return true
82+
return false;
83+
}
84+
return v1->Equals(*v2);
85+
}
86+
return false;
87+
},
88+
first.value(), second.value());
89+
}
90+
} // namespace
91+
bool ArrayStatistics::Equals(const ArrayStatistics& other) const {
92+
return null_count == other.null_count && distinct_count == other.distinct_count &&
93+
is_min_exact == other.is_min_exact && is_max_exact == other.is_max_exact &&
94+
ValueTypeEquality(this->max, other.max) &&
95+
ValueTypeEquality(this->min, other.min);
96+
}
97+
6298
} // namespace arrow

cpp/src/arrow/array/statistics.h

+1-5
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,7 @@ struct ARROW_EXPORT ArrayStatistics {
101101
bool is_max_exact = false;
102102

103103
/// \brief Check two statistics for equality
104-
bool Equals(const ArrayStatistics& other) const {
105-
return null_count == other.null_count && distinct_count == other.distinct_count &&
106-
min == other.min && is_min_exact == other.is_min_exact && max == other.max &&
107-
is_max_exact == other.is_max_exact;
108-
}
104+
bool Equals(const ArrayStatistics& other) const;
109105

110106
/// \brief Check two statistics for equality
111107
bool operator==(const ArrayStatistics& other) const { return Equals(other); }

cpp/src/arrow/array/statistics_test.cc

+24
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <gtest/gtest.h>
1919

2020
#include "arrow/array/statistics.h"
21+
#include "arrow/scalar.h"
22+
#include "arrow/testing/gtest_util.h"
2123

2224
namespace arrow {
2325

@@ -91,6 +93,28 @@ TEST(ArrayStatisticsTest, TestEquality) {
9193
ASSERT_NE(statistics1, statistics2);
9294
statistics2.max = static_cast<int64_t>(-29);
9395
ASSERT_EQ(statistics1, statistics2);
96+
statistics2.max = static_cast<int64_t>(2);
97+
ASSERT_NE(statistics1, statistics2);
98+
99+
statistics1.max = std::nullopt;
100+
statistics2.max = std::nullopt;
101+
// check the state of both of them are std::nullopt
102+
ASSERT_EQ(statistics1.max, statistics2.max);
103+
// the state of one of them is std::nullopt
104+
statistics1.max = std::shared_ptr<Scalar>();
105+
ASSERT_NE(statistics1, statistics2);
106+
// the state of both of them are nullptr
107+
statistics2.max = std::shared_ptr<Scalar>();
108+
ASSERT_EQ(statistics1, statistics2);
109+
ASSERT_OK_AND_ASSIGN(statistics1.max, MakeScalar(int64(), 5));
110+
// the state of one of them is nullptr
111+
ASSERT_NE(statistics1, statistics2);
112+
// the state of one of them has different type
113+
statistics2.max = static_cast<int64_t>(10);
114+
ASSERT_NE(statistics1.max, statistics2.max);
115+
ASSERT_OK_AND_ASSIGN(statistics2.max, MakeScalar(int64(), 5));
116+
// the state of both of them are equal
117+
ASSERT_EQ(statistics1, statistics2);
94118

95119
statistics1.is_max_exact = true;
96120
ASSERT_NE(statistics1, statistics2);

cpp/src/arrow/record_batch_test.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -1497,8 +1497,8 @@ TEST_F(TestRecordBatch, MakeStatisticsArrayNestedType) {
14971497
auto rb = RecordBatch::Make(rb_schema, 5,
14981498
{struct_array, array_c, array_d, struct_nested_stat});
14991499

1500-
auto expected_scalar = std::static_pointer_cast<Scalar>(std::shared_ptr<StructScalar>(
1501-
new StructScalar({MakeScalar(int64_t{5}), MakeScalar(int64_t{10})}, struct_type)));
1500+
auto expected_scalar = internal::checked_pointer_cast<StructScalar>(
1501+
ScalarFromJSON(struct_type, R"([5,10])"));
15021502
auto a = ArrayStatistics::ValueType{std::static_pointer_cast<Scalar>(expected_scalar)};
15031503

15041504
ASSERT_OK_AND_ASSIGN(
@@ -1541,8 +1541,8 @@ TEST_F(TestRecordBatch, MakeStatisticsArrayNestedNestedType) {
15411541
StructArray::Make({struct_nested_0, struct_nested_1},
15421542
{field("struct_nested_0", struct_nested_0->type()),
15431543
field("struct_nested_1", struct_nested_1->type())}));
1544-
auto expected_scalar = std::static_pointer_cast<Scalar>(std::shared_ptr<StructScalar>(
1545-
new StructScalar({MakeScalar(int32_t{5}), MakeScalar(int32_t{10})}, struct_type)));
1544+
auto expected_scalar = internal::checked_pointer_cast<StructScalar>(
1545+
ScalarFromJSON(struct_type, R"([5,10])"));
15461546
auto rb_schema = schema({field("struct", struct_parent->type())});
15471547
auto rb = RecordBatch::Make(rb_schema, 5, {struct_parent});
15481548
ASSERT_OK_AND_ASSIGN(auto rb_stat, rb->MakeStatisticsArray())

0 commit comments

Comments
 (0)