Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions datafusion/ffi/src/udaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use datafusion::{
};
use datafusion_proto_common::from_proto::parse_proto_fields_to_fields;
use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator};
use std::hash::{Hash, Hasher};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::{ffi::c_void, sync::Arc};

use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
Expand Down Expand Up @@ -141,6 +141,9 @@ pub struct FFI_AggregateUDF {
/// Release the memory of the private data when it is no longer being used.
pub release: unsafe extern "C" fn(udaf: &mut Self),

/// Hash value for the UDAF used for equality comparison.
pub hash_value: u64,

/// Internal data. This is only to be accessed by the provider of the udaf.
/// A [`ForeignAggregateUDF`] should never attempt to access this data.
pub private_data: *mut c_void,
Expand Down Expand Up @@ -339,6 +342,10 @@ impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
let is_nullable = udaf.is_nullable();
let volatility = udaf.signature().volatility.into();

let mut hasher = DefaultHasher::new();
udaf.hash(&mut hasher);
let hash_value = hasher.finish();

let private_data = Box::new(AggregateUDFPrivateData { udaf });

Self {
Expand All @@ -357,6 +364,7 @@ impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
coerce_types: coerce_types_fn_wrapper,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
hash_value,
private_data: Box::into_raw(private_data) as *mut c_void,
}
}
Expand Down Expand Up @@ -386,14 +394,20 @@ unsafe impl Sync for ForeignAggregateUDF {}

impl PartialEq for ForeignAggregateUDF {
fn eq(&self, other: &Self) -> bool {
// FFI_AggregateUDF cannot be compared, so identity equality is the best we can do.
std::ptr::eq(self, other)
let Self {
signature,
aliases,
udaf,
} = self;
signature == &other.signature
&& aliases == &other.aliases
&& udaf.hash_value == other.udaf.hash_value
}
}
impl Eq for ForeignAggregateUDF {}
impl Hash for ForeignAggregateUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
std::ptr::hash(self, state)
self.udaf.hash_value.hash(state);
}
}

Expand Down Expand Up @@ -740,4 +754,20 @@ mod tests {
test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement);
test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial);
}

#[test]
fn test_eq() -> Result<()> {
// Test that identical UDAFs are equal
let sum_udaf1 = create_test_foreign_udaf(Sum::new())?;
let sum_udaf2 = create_test_foreign_udaf(Sum::new())?;
assert_eq!(sum_udaf1, sum_udaf2);

// Test that different UDAFs are not equal
let count_udaf = create_test_foreign_udaf(
datafusion::functions_aggregate::count::Count::new(),
)?;
assert_ne!(sum_udaf1, count_udaf);

Ok(())
}
}
48 changes: 36 additions & 12 deletions datafusion/ffi/src/udf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use datafusion::{
use return_type_args::{
FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned,
};
use std::hash::{Hash, Hasher};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::{ffi::c_void, sync::Arc};

pub mod return_type_args;
Expand Down Expand Up @@ -111,6 +111,9 @@ pub struct FFI_ScalarUDF {
/// Release the memory of the private data when it is no longer being used.
pub release: unsafe extern "C" fn(udf: &mut Self),

/// Hash value for the UDF used for equality comparison.
pub hash_value: u64,

/// Internal data. This is only to be accessed by the provider of the udf.
/// A [`ForeignScalarUDF`] should never attempt to access this data.
pub private_data: *mut c_void,
Expand Down Expand Up @@ -248,6 +251,9 @@ impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
let volatility = udf.signature().volatility.into();
let short_circuits = udf.short_circuits();
let mut hasher = DefaultHasher::new();
udf.hash(&mut hasher);
let hash_value = hasher.finish();

let private_data = Box::new(ScalarUDFPrivateData { udf });

Expand All @@ -262,6 +268,7 @@ impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
coerce_types: coerce_types_fn_wrapper,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
hash_value,
private_data: Box::into_raw(private_data) as *mut c_void,
}
}
Expand Down Expand Up @@ -300,24 +307,15 @@ impl PartialEq for ForeignScalarUDF {
} = self;
name == &other.name
&& aliases == &other.aliases
&& std::ptr::eq(udf, &other.udf)
&& signature == &other.signature
&& udf.hash_value == other.udf.hash_value
}
}
impl Eq for ForeignScalarUDF {}

impl Hash for ForeignScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
let Self {
name,
aliases,
udf,
signature,
} = self;
name.hash(state);
aliases.hash(state);
std::ptr::hash(udf, state);
signature.hash(state);
self.udf.hash_value.hash(state);
}
}

Expand Down Expand Up @@ -463,4 +461,30 @@ mod tests {

Ok(())
}

fn create_test_foreign_udf(
original_udf: impl ScalarUDFImpl + 'static,
) -> Result<ScalarUDF> {
let original_udf = Arc::new(ScalarUDF::from(original_udf));
let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into();
let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?;
Ok(foreign_udf.into())
}

#[test]
fn test_eq() -> Result<()> {
// Test that identical UDFs are equal
let abs_udf1 =
create_test_foreign_udf(datafusion::functions::math::abs::AbsFunc::new())?;
let abs_udf2 =
create_test_foreign_udf(datafusion::functions::math::abs::AbsFunc::new())?;
assert_eq!(abs_udf1, abs_udf2);

// Test that different UDFs are not equal
let sqrt_udf =
create_test_foreign_udf(datafusion::functions::math::gcd::GcdFunc::new())?;
assert_ne!(abs_udf1, sqrt_udf);

Ok(())
}
}
45 changes: 35 additions & 10 deletions datafusion/ffi/src/udwf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use partition_evaluator::{FFI_PartitionEvaluator, ForeignPartitionEvaluator};
use partition_evaluator_args::{
FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs,
};
use std::hash::{Hash, Hasher};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::{ffi::c_void, sync::Arc};

mod partition_evaluator;
Expand Down Expand Up @@ -99,6 +99,9 @@ pub struct FFI_WindowUDF {
/// Release the memory of the private data when it is no longer being used.
pub release: unsafe extern "C" fn(udf: &mut Self),

/// Hash value for the UDWF used for equality comparison.
pub hash_value: u64,

/// Internal data. This is only to be accessed by the provider of the udf.
/// A [`ForeignWindowUDF`] should never attempt to access this data.
pub private_data: *mut c_void,
Expand Down Expand Up @@ -177,12 +180,6 @@ unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) {
}

unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF {
// let private_data = udf.private_data as *const WindowUDFPrivateData;
// let udf_data = &(*private_data);

// let private_data = Box::new(WindowUDFPrivateData {
// udf: Arc::clone(&udf_data.udf),
// });
let private_data = Box::new(WindowUDFPrivateData {
udf: Arc::clone(udwf.inner()),
});
Expand All @@ -197,6 +194,7 @@ unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF {
field: field_fn_wrapper,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
hash_value: udwf.hash_value,
private_data: Box::into_raw(private_data) as *mut c_void,
}
}
Expand All @@ -214,6 +212,10 @@ impl From<Arc<WindowUDF>> for FFI_WindowUDF {
let volatility = udf.signature().volatility.into();
let sort_options = udf.sort_options().map(|v| (&v).into()).into();

let mut hasher = DefaultHasher::new();
udf.hash(&mut hasher);
let hash_value = hasher.finish();

let private_data = Box::new(WindowUDFPrivateData { udf });

Self {
Expand All @@ -226,6 +228,7 @@ impl From<Arc<WindowUDF>> for FFI_WindowUDF {
field: field_fn_wrapper,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
hash_value,
private_data: Box::into_raw(private_data) as *mut c_void,
}
}
Expand Down Expand Up @@ -256,14 +259,22 @@ unsafe impl Sync for ForeignWindowUDF {}

impl PartialEq for ForeignWindowUDF {
fn eq(&self, other: &Self) -> bool {
// FFI_WindowUDF cannot be compared, so identity equality is the best we can do.
std::ptr::eq(self, other)
let Self {
name,
aliases,
udf,
signature,
} = self;
name == &other.name
&& aliases == &other.aliases
&& signature == &other.signature
&& udf.hash_value == other.udf.hash_value
}
}
impl Eq for ForeignWindowUDF {}
impl Hash for ForeignWindowUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
std::ptr::hash(self, state)
self.udf.hash_value.hash(state);
}
}

Expand Down Expand Up @@ -443,4 +454,18 @@ mod tests {

Ok(())
}

#[test]
fn test_eq() -> datafusion::common::Result<()> {
// Test that identical UDWFs are equal (using hash-based comparison)
let lag_udwf1 = create_test_foreign_udwf(WindowShift::lag())?;
let lag_udwf2 = create_test_foreign_udwf(WindowShift::lag())?;
assert_eq!(lag_udwf1, lag_udwf2);

// Test that different UDWFs are not equal
let lead_udwf = create_test_foreign_udwf(WindowShift::lead())?;
assert_ne!(lag_udwf1, lead_udwf);

Ok(())
}
}