Skip to content

Commit a0248a9

Browse files
findepialamb
andauthored
Fix PartialOrd for ScalarUDF (#17182)
Before the changes, `PartialOrd` could return `Some(Equal)` for two functions that are not equal in `PartialEq` sense. This is violation of `PartialOrd` contract. This was possible e.g. when - two functions have same name, but are of different types (e.g. when someone constructs DataFusion LogicalPlan and mixes DataFusion builtin functions with their own) - a function has a parameter (e.g. "safe" attribute for a CAST). The parameter is honored by `PartialEq` comparison but was transparent to `PartialOrd` ordering. The fix is to consult eq inside ord implementation. If ord thinks two instances are equal, but they are not equal in Eq sense, they are considered incomparable. Co-authored-by: Andrew Lamb <[email protected]>
1 parent 02a7472 commit a0248a9

File tree

1 file changed

+63
-17
lines changed

1 file changed

+63
-17
lines changed

datafusion/expr/src/udf.rs

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,31 @@ impl PartialEq for ScalarUDF {
6767
}
6868
}
6969

70-
// TODO (https://github.com/apache/datafusion/issues/17064) PartialOrd is not consistent with PartialEq for `ScalarUDF` and it should be
71-
// Manual implementation based on `ScalarUDFImpl::equals`
7270
impl PartialOrd for ScalarUDF {
7371
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
74-
match self.name().partial_cmp(other.name()) {
75-
Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
76-
cmp => cmp,
72+
let mut cmp = self.name().cmp(other.name());
73+
if cmp == Ordering::Equal {
74+
cmp = self.signature().partial_cmp(other.signature())?;
7775
}
76+
if cmp == Ordering::Equal {
77+
cmp = self.aliases().partial_cmp(other.aliases())?;
78+
}
79+
// Contract for PartialOrd and PartialEq consistency requires that
80+
// a == b if and only if partial_cmp(a, b) == Some(Equal).
81+
if cmp == Ordering::Equal && self != other {
82+
// Functions may have other properties besides name and signature
83+
// that differentiate two instances (e.g. type, or arbitrary parameters).
84+
// We cannot return Some(Equal) in such case.
85+
return None;
86+
}
87+
debug_assert!(
88+
cmp == Ordering::Equal || self != other,
89+
"Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
90+
The functions compare as equal, but they are not equal based on general properties that \
91+
the PartialOrd implementation observes,",
92+
self.name(), other.name()
93+
);
94+
Some(cmp)
7895
}
7996
}
8097

@@ -942,23 +959,26 @@ The following regular expression functions are supported:"#,
942959
#[cfg(test)]
943960
mod tests {
944961
use super::*;
962+
use datafusion_expr_common::signature::Volatility;
945963
use std::hash::DefaultHasher;
946964

947965
#[derive(Debug, PartialEq, Eq, Hash)]
948966
struct TestScalarUDFImpl {
967+
name: &'static str,
949968
field: &'static str,
969+
signature: Signature,
950970
}
951971
impl ScalarUDFImpl for TestScalarUDFImpl {
952972
fn as_any(&self) -> &dyn Any {
953973
self
954974
}
955975

956976
fn name(&self) -> &str {
957-
"TestScalarUDFImpl"
977+
self.name
958978
}
959979

960980
fn signature(&self) -> &Signature {
961-
unimplemented!()
981+
&self.signature
962982
}
963983

964984
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
@@ -970,17 +990,43 @@ mod tests {
970990
}
971991
}
972992

993+
// PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
994+
// must be consistent, so they are tested together.
973995
#[test]
974-
fn test_partial_eq() {
975-
let a1 = ScalarUDF::from(TestScalarUDFImpl { field: "a" });
976-
let a2 = ScalarUDF::from(TestScalarUDFImpl { field: "a" });
977-
let b = ScalarUDF::from(TestScalarUDFImpl { field: "b" });
978-
let eq = a1 == a2;
979-
assert!(eq);
980-
assert_eq!(a1, a2);
981-
assert_eq!(hash(&a1), hash(&a2));
982-
assert_ne!(a1, b);
983-
assert_ne!(a2, b);
996+
fn test_partial_eq_hash_and_partial_ord() {
997+
// A parameterized function
998+
let f = test_func("foo", "a");
999+
1000+
// Same like `f`, different instance
1001+
let f2 = test_func("foo", "a");
1002+
assert_eq!(f, f2);
1003+
assert_eq!(hash(&f), hash(&f2));
1004+
assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
1005+
1006+
// Different parameter
1007+
let b = test_func("foo", "b");
1008+
assert_ne!(f, b);
1009+
assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test
1010+
assert_eq!(f.partial_cmp(&b), None);
1011+
1012+
// Different name
1013+
let o = test_func("other", "a");
1014+
assert_ne!(f, o);
1015+
assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test
1016+
assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
1017+
1018+
// Different name and parameter
1019+
assert_ne!(b, o);
1020+
assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test
1021+
assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
1022+
}
1023+
1024+
fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF {
1025+
ScalarUDF::from(TestScalarUDFImpl {
1026+
name,
1027+
field: parameter,
1028+
signature: Signature::any(1, Volatility::Immutable),
1029+
})
9841030
}
9851031

9861032
fn hash<T: Hash>(value: &T) -> u64 {

0 commit comments

Comments
 (0)