Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 132 additions & 2 deletions src/aggregation/agg_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::collections::HashSet;

use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;

use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
Expand Down Expand Up @@ -78,9 +79,26 @@ impl TryFrom<AggregationForDeserialization> for Aggregation {

fn try_from(value: AggregationForDeserialization) -> serde_json::Result<Self> {
let AggregationForDeserialization {
aggs_remaining_json,
sub_aggregation,
mut aggs_remaining_json,
mut sub_aggregation,
} = value;

// Extract nested "aggs" from inside the aggregation variant
// E.g., {"terms": {"field": "x", "aggs": {...}}} -> extract "aggs" from inside "terms"
if let Some(obj) = aggs_remaining_json.as_object_mut() {
if let Some(nested) =
obj.values_mut()
.filter_map(Value::as_object_mut)
.find_map(|config| {
config
.remove("aggs")
.and_then(|nested| serde_json::from_value::<Aggregations>(nested).ok())
})
{
sub_aggregation = nested;
}
}

let agg: AggregationVariants = serde_json::from_value(aggs_remaining_json)?;
Ok(Aggregation {
agg,
Expand Down Expand Up @@ -386,4 +404,116 @@ mod tests {
.collect()
)
}

#[test]
fn test_nested_aggs_deserialization() {
// Test that nested "aggs" fields are properly deserialized
let json = serde_json::json!({
"terms": {
"field": "category",
"aggs": {
"brand_breakdown": {
"terms": {
"field": "brand"
}
}
}
}
});

let agg: Aggregation = serde_json::from_value(json).unwrap();

// Verify the main aggregation was deserialized
assert!(matches!(agg.agg, AggregationVariants::Terms(_)));
if let AggregationVariants::Terms(terms) = &agg.agg {
assert_eq!(terms.field, "category");
}

// Verify nested aggregation was extracted
assert_eq!(agg.sub_aggregation.len(), 1);
assert!(agg.sub_aggregation.contains_key("brand_breakdown"));

// Verify the nested aggregation structure
let brand_agg = agg.sub_aggregation.get("brand_breakdown").unwrap();
assert!(matches!(brand_agg.agg, AggregationVariants::Terms(_)));
if let AggregationVariants::Terms(terms) = &brand_agg.agg {
assert_eq!(terms.field, "brand");
}
}

#[test]
fn test_triple_nested_aggs_deserialization() {
// Test triple-nested aggregations
let json = serde_json::json!({
"terms": {
"field": "category",
"aggs": {
"brand_breakdown": {
"terms": {
"field": "brand",
"aggs": {
"rating_breakdown": {
"terms": {
"field": "rating"
}
}
}
}
}
}
}
});

let agg: Aggregation = serde_json::from_value(json).unwrap();

// Verify first level
assert_eq!(agg.sub_aggregation.len(), 1);
let brand_agg = agg.sub_aggregation.get("brand_breakdown").unwrap();

// Verify second level
assert_eq!(brand_agg.sub_aggregation.len(), 1);
let rating_agg = brand_agg.sub_aggregation.get("rating_breakdown").unwrap();

// Verify third level
assert!(matches!(rating_agg.agg, AggregationVariants::Terms(_)));
if let AggregationVariants::Terms(terms) = &rating_agg.agg {
assert_eq!(terms.field, "rating");
}
}

#[test]
fn test_nested_aggs_with_metrics() {
// Test nested aggregations with metric sub-aggregations
let json = serde_json::json!({
"terms": {
"field": "category",
"aggs": {
"avg_price": {
"avg": {
"field": "price"
}
},
"max_rating": {
"max": {
"field": "rating"
}
}
}
}
});

let agg: Aggregation = serde_json::from_value(json).unwrap();

// Verify nested aggregations were extracted
assert_eq!(agg.sub_aggregation.len(), 2);
assert!(agg.sub_aggregation.contains_key("avg_price"));
assert!(agg.sub_aggregation.contains_key("max_rating"));

// Verify the metric types
let avg_agg = agg.sub_aggregation.get("avg_price").unwrap();
assert!(matches!(avg_agg.agg, AggregationVariants::Average(_)));

let max_agg = agg.sub_aggregation.get("max_rating").unwrap();
assert!(matches!(max_agg.agg, AggregationVariants::Max(_)));
}
}
Loading