diff --git a/common/src/bounds.rs b/common/src/bounds.rs index af477ff493..e568cfb159 100644 --- a/common/src/bounds.rs +++ b/common/src/bounds.rs @@ -44,6 +44,18 @@ impl BoundsRange { } } + pub fn transform_inner_res( + &self, + transform_lower: impl Fn(&T) -> io::Result>, + transform_upper: impl Fn(&T) -> io::Result>, + ) -> io::Result> { + let range = BoundsRange { + lower_bound: transform_bound_inner_res(&self.lower_bound, &transform_lower)?, + upper_bound: transform_bound_inner_res(&self.upper_bound, &transform_upper)?, + }; + Ok(range) + } + /// Returns the first set inner value pub fn get_inner(&self) -> Option<&T> { inner_bound(&self.lower_bound).or(inner_bound(&self.upper_bound)) diff --git a/src/query/range_query/range_query_fastfield.rs b/src/query/range_query/range_query_fastfield.rs index b17694cfaf..0ade02c182 100644 --- a/src/query/range_query/range_query_fastfield.rs +++ b/src/query/range_query/range_query_fastfield.rs @@ -2,6 +2,7 @@ //! We use this variant only if the fastfield exists, otherwise the default in `range_query` is //! used, which uses the term dictionary + postings. +use std::io; use std::net::Ipv6Addr; use std::ops::{Bound, RangeInclusive}; @@ -256,62 +257,17 @@ fn search_on_json_numerical_field( .numerical_type() .unwrap_or_else(|| panic!("internal error: couldn't cast to numerical_type: {col_type:?}")); - let bounds = match typ.numerical_type().unwrap() { - NumericalType::I64 => { - let bounds = bounds.map_bound(|term| term.as_i64().unwrap()); - match actual_column_type { - NumericalType::I64 => bounds.map_bound(|&term| term.to_u64()), - NumericalType::U64 => { - bounds.transform_inner( - |&val| { - if val < 0 { - return TransformBound::NewBound(Bound::Unbounded); - } - TransformBound::Existing(val as u64) - }, - |&val| { - if val < 0 { - // no hits case - return TransformBound::NewBound(Bound::Excluded(0)); - } - TransformBound::Existing(val as u64) - }, - ) - } - NumericalType::F64 => bounds.map_bound(|&term| (term as f64).to_u64()), - } - } - NumericalType::U64 => { - let bounds = bounds.map_bound(|term| term.as_u64().unwrap()); - match actual_column_type { - NumericalType::U64 => bounds.map_bound(|&term| term.to_u64()), - NumericalType::I64 => { - bounds.transform_inner( - |&val| { - if val > i64::MAX as u64 { - // Actual no hits case - return TransformBound::NewBound(Bound::Excluded(i64::MAX as u64)); - } - TransformBound::Existing((val as i64).to_u64()) - }, - |&val| { - if val > i64::MAX as u64 { - return TransformBound::NewBound(Bound::Unbounded); - } - TransformBound::Existing((val as i64).to_u64()) - }, - ) - } - NumericalType::F64 => bounds.map_bound(|&term| (term as f64).to_u64()), - } - } + let bounds = match actual_column_type { + NumericalType::I64 => bounds.transform_inner_res::<_, io::Error>( + transform_to_i64::, + transform_to_i64::, + )?, + NumericalType::U64 => bounds.transform_inner_res::<_, std::io::Error>( + transform_to_u64::, + transform_to_u64::, + )?, NumericalType::F64 => { - let bounds = bounds.map_bound(|term| term.as_f64().unwrap()); - match actual_column_type { - NumericalType::U64 => transform_from_f64_bounds::(&bounds), - NumericalType::I64 => transform_from_f64_bounds::(&bounds), - NumericalType::F64 => bounds.map_bound(|&term| term.to_u64()), - } + bounds.transform_inner_res::<_, std::io::Error>(transform_to_f64, transform_to_f64)? } }; search_on_u64_ff( @@ -356,40 +312,126 @@ impl IntType for u64 { } } -fn transform_from_f64_bounds( - bounds: &BoundsRange, -) -> BoundsRange { - bounds.transform_inner( - |&lower_bound| { - if lower_bound < T::min().to_f64() { - return TransformBound::NewBound(Bound::Unbounded); - } - if lower_bound > T::max().to_f64() { - // no hits case - return TransformBound::NewBound(Bound::Excluded(u64::MAX)); - } +fn transform_from_f64( + val: f64, +) -> TransformBound { + if val < T::min().to_f64() { + return TransformBound::NewBound(Bound::Unbounded); + } + if val > T::max().to_f64() { + // no hits case + if IS_LOWER_BOUND { + return TransformBound::NewBound(Bound::Excluded(u64::MAX)); + } else { + return TransformBound::NewBound(Bound::Included(u64::MAX)); + } + } + if val.fract() == 0.0 { + TransformBound::Existing(T::from_f64(val).to_u64()) + } else { + TransformBound::NewBound(Bound::Included(T::from_f64(val.trunc()).to_u64())) + } +} - if lower_bound.fract() == 0.0 { - TransformBound::Existing(T::from_f64(lower_bound).to_u64()) - } else { - TransformBound::NewBound(Bound::Included(T::from_f64(lower_bound.trunc()).to_u64())) - } - }, - |&upper_bound| { - if upper_bound < T::min().to_f64() { - return TransformBound::NewBound(Bound::Unbounded); - } - if upper_bound > T::max().to_f64() { - // no hits case - return TransformBound::NewBound(Bound::Included(u64::MAX)); +fn transform_to_i64( + term: &ValueBytes>, +) -> io::Result> { + let val = match term.typ() { + Type::I64 => { + let val = term.as_i64().unwrap(); + val.to_u64() + } + Type::U64 => { + let val = term.as_u64().unwrap(); + if val > i64::MAX as u64 { + if IS_LOWER_BOUND { + // Actual no hits case + return Ok(TransformBound::NewBound(Bound::Excluded(i64::MAX as u64))); + } else { + return Ok(TransformBound::NewBound(Bound::Unbounded)); + } } - if upper_bound.fract() == 0.0 { - TransformBound::Existing(T::from_f64(upper_bound).to_u64()) - } else { - TransformBound::NewBound(Bound::Included(T::from_f64(upper_bound.trunc()).to_u64())) + (val as i64).to_u64() + } + Type::F64 => { + let val = term.as_f64().unwrap(); + return Ok(transform_from_f64::(val)); + } + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!( + "Expected term with u64, i64 or f64, but got {:?}", + term.typ() + ), + )); + } + }; + Ok(TransformBound::Existing(val)) +} + +fn transform_to_u64( + term: &ValueBytes>, +) -> io::Result> { + let val = match term.typ() { + Type::I64 => { + let val = term.as_i64().unwrap(); + if val < 0 { + if IS_LOWER_BOUND { + return Ok(TransformBound::NewBound(Bound::Unbounded)); + } else { + // Actual no hits case + return Ok(TransformBound::NewBound(Bound::Excluded(0))); + } } - }, - ) + (val as u64).to_u64() + } + Type::U64 => { + let val = term.as_u64().unwrap(); + val.to_u64() + } + Type::F64 => { + let val = term.as_f64().unwrap(); + return Ok(transform_from_f64::(val)); + } + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!( + "Expected term with u64, i64 or f64, but got {:?}", + term.typ() + ), + )); + } + }; + Ok(TransformBound::Existing(val)) +} + +fn transform_to_f64(term: &ValueBytes>) -> io::Result> { + let val = match term.typ() { + Type::I64 => { + let val = term.as_i64().unwrap(); + (val as f64).to_u64() + } + Type::U64 => { + let val = term.as_u64().unwrap(); + (val as f64).to_u64() + } + Type::F64 => { + let val = term.as_f64().unwrap(); + val.to_u64() + } + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!( + "Expected term with u64, i64 or f64, but got {:?}", + term.typ() + ), + )); + } + }; + Ok(TransformBound::Existing(val)) } fn search_on_u64_ff( @@ -970,6 +1012,22 @@ mod tests { 1 ); + // i64 and f64 on f64 field + assert_eq!( + count(RangeQuery::new( + Bound::Included(get_json_term(json_field, "id_f64", 10i64)), + Bound::Included(get_json_term(json_field, "id_f64", 12.1f64)), + )), + 1 + ); + assert_eq!( + count(RangeQuery::new( + Bound::Included(get_json_term(json_field, "id_f64", 10.1f64)), + Bound::Included(get_json_term(json_field, "id_f64", 1100i64)), + )), + 2 + ); + let reader = index.reader().unwrap(); let searcher = reader.searcher(); let query_parser = QueryParser::for_index(&index, vec![json_field]);