diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 2b8385ac2d89c..abf068db0cca8 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -269,7 +269,7 @@ impl StatementExecutor { let options = task_ctx.session_config().options(); // Track memory usage for the query result if it's bounded - let mut reservation = + let reservation = MemoryConsumer::new("DataFusion-Cli").register(task_ctx.memory_pool()); if physical_plan.boundedness().is_unbounded() { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 712f4389f5852..b6c606ff467f9 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2186,7 +2186,7 @@ mod tests { // configure with same memory / disk manager let memory_pool = ctx1.runtime_env().memory_pool.clone(); - let mut reservation = MemoryConsumer::new("test").register(&memory_pool); + let reservation = MemoryConsumer::new("test").register(&memory_pool); reservation.grow(100); let disk_manager = ctx1.runtime_env().disk_manager.clone(); diff --git a/datafusion/datasource-parquet/src/file_format.rs b/datafusion/datasource-parquet/src/file_format.rs index 6635c9072dd97..d59b42ed15d15 100644 --- a/datafusion/datasource-parquet/src/file_format.rs +++ b/datafusion/datasource-parquet/src/file_format.rs @@ -1360,7 +1360,7 @@ impl FileSink for ParquetSink { parquet_props.clone(), ) .await?; - let mut reservation = MemoryConsumer::new(format!("ParquetSink[{path}]")) + let reservation = MemoryConsumer::new(format!("ParquetSink[{path}]")) .register(context.memory_pool()); file_write_tasks.spawn(async move { while let Some(batch) = rx.recv().await { @@ -1465,7 +1465,7 @@ impl DataSink for ParquetSink { async fn column_serializer_task( mut rx: Receiver, mut writer: ArrowColumnWriter, - mut reservation: MemoryReservation, + reservation: MemoryReservation, ) -> Result<(ArrowColumnWriter, MemoryReservation)> { while let Some(col) = rx.recv().await { writer.write(&col)?; @@ -1550,7 +1550,7 @@ fn spawn_rg_join_and_finalize_task( rg_rows: usize, pool: &Arc, ) -> SpawnedTask { - let mut rg_reservation = + let rg_reservation = MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool); SpawnedTask::spawn(async move { @@ -1682,12 +1682,12 @@ async fn concatenate_parallel_row_groups( mut object_store_writer: Box, pool: Arc, ) -> Result { - let mut file_reservation = + let file_reservation = MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); while let Some(task) = serialize_rx.recv().await { let result = task.join_unwind().await; - let (serialized_columns, mut rg_reservation, _cnt) = + let (serialized_columns, rg_reservation, _cnt) = result.map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; let mut rg_out = parquet_writer.next_row_group()?; diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index fbf9ce41da8fe..b23eede2a054e 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -18,7 +18,7 @@ //! [`MemoryPool`] for memory management during query execution, [`proxy`] for //! help with allocation accounting. -use datafusion_common::{Result, internal_err}; +use datafusion_common::{Result, internal_datafusion_err}; use std::hash::{Hash, Hasher}; use std::{cmp::Ordering, sync::Arc, sync::atomic}; @@ -322,7 +322,7 @@ impl MemoryConsumer { pool: Arc::clone(pool), consumer: self, }), - size: 0, + size: atomic::AtomicUsize::new(0), } } } @@ -351,13 +351,13 @@ impl Drop for SharedRegistration { #[derive(Debug)] pub struct MemoryReservation { registration: Arc, - size: usize, + size: atomic::AtomicUsize, } impl MemoryReservation { /// Returns the size of this reservation in bytes pub fn size(&self) -> usize { - self.size + self.size.load(atomic::Ordering::Relaxed) } /// Returns [MemoryConsumer] for this [MemoryReservation] @@ -367,8 +367,8 @@ impl MemoryReservation { /// Frees all bytes from this reservation back to the underlying /// pool, returning the number of bytes freed. - pub fn free(&mut self) -> usize { - let size = self.size; + pub fn free(&self) -> usize { + let size = self.size.load(atomic::Ordering::Relaxed); if size != 0 { self.shrink(size) } @@ -380,60 +380,62 @@ impl MemoryReservation { /// # Panics /// /// Panics if `capacity` exceeds [`Self::size`] - pub fn shrink(&mut self, capacity: usize) { - let new_size = self.size.checked_sub(capacity).unwrap(); + pub fn shrink(&self, capacity: usize) { + self.size.fetch_sub(capacity, atomic::Ordering::Relaxed); self.registration.pool.shrink(self, capacity); - self.size = new_size } /// Tries to free `capacity` bytes from this reservation /// if `capacity` does not exceed [`Self::size`] /// Returns new reservation size /// or error if shrinking capacity is more than allocated size - pub fn try_shrink(&mut self, capacity: usize) -> Result { - if let Some(new_size) = self.size.checked_sub(capacity) { - self.registration.pool.shrink(self, capacity); - self.size = new_size; - Ok(new_size) - } else { - internal_err!( - "Cannot free the capacity {capacity} out of allocated size {}", - self.size + pub fn try_shrink(&self, capacity: usize) -> Result { + let updated = self.size.fetch_update( + atomic::Ordering::Relaxed, + atomic::Ordering::Relaxed, + |prev| prev.checked_sub(capacity), + ); + updated.map_err(|_| { + let prev = self.size.load(atomic::Ordering::Relaxed); + internal_datafusion_err!( + "Cannot free the capacity {capacity} out of allocated size {prev}" ) - } + }) } /// Sets the size of this reservation to `capacity` - pub fn resize(&mut self, capacity: usize) { - match capacity.cmp(&self.size) { - Ordering::Greater => self.grow(capacity - self.size), - Ordering::Less => self.shrink(self.size - capacity), + pub fn resize(&self, capacity: usize) { + let size = self.size.load(atomic::Ordering::Relaxed); + match capacity.cmp(&size) { + Ordering::Greater => self.grow(capacity - size), + Ordering::Less => self.shrink(size - capacity), _ => {} } } /// Try to set the size of this reservation to `capacity` - pub fn try_resize(&mut self, capacity: usize) -> Result<()> { - match capacity.cmp(&self.size) { - Ordering::Greater => self.try_grow(capacity - self.size)?, - Ordering::Less => self.shrink(self.size - capacity), + pub fn try_resize(&self, capacity: usize) -> Result<()> { + let size = self.size.load(atomic::Ordering::Relaxed); + match capacity.cmp(&size) { + Ordering::Greater => self.try_grow(capacity - size)?, + Ordering::Less => self.shrink(size - capacity), _ => {} }; Ok(()) } /// Increase the size of this reservation by `capacity` bytes - pub fn grow(&mut self, capacity: usize) { + pub fn grow(&self, capacity: usize) { self.registration.pool.grow(self, capacity); - self.size += capacity; + self.size.fetch_add(capacity, atomic::Ordering::Relaxed); } /// Try to increase the size of this reservation by `capacity` /// bytes, returning error if there is insufficient capacity left /// in the pool. - pub fn try_grow(&mut self, capacity: usize) -> Result<()> { + pub fn try_grow(&self, capacity: usize) -> Result<()> { self.registration.pool.try_grow(self, capacity)?; - self.size += capacity; + self.size.fetch_add(capacity, atomic::Ordering::Relaxed); Ok(()) } @@ -447,10 +449,16 @@ impl MemoryReservation { /// # Panics /// /// Panics if `capacity` exceeds [`Self::size`] - pub fn split(&mut self, capacity: usize) -> MemoryReservation { - self.size = self.size.checked_sub(capacity).unwrap(); + pub fn split(&self, capacity: usize) -> MemoryReservation { + self.size + .fetch_update( + atomic::Ordering::Relaxed, + atomic::Ordering::Relaxed, + |prev| prev.checked_sub(capacity), + ) + .unwrap(); Self { - size: capacity, + size: atomic::AtomicUsize::new(capacity), registration: Arc::clone(&self.registration), } } @@ -458,7 +466,7 @@ impl MemoryReservation { /// Returns a new empty [`MemoryReservation`] with the same [`MemoryConsumer`] pub fn new_empty(&self) -> Self { Self { - size: 0, + size: atomic::AtomicUsize::new(0), registration: Arc::clone(&self.registration), } } @@ -466,7 +474,7 @@ impl MemoryReservation { /// Splits off all the bytes from this [`MemoryReservation`] into /// a new [`MemoryReservation`] with the same [`MemoryConsumer`] pub fn take(&mut self) -> MemoryReservation { - self.split(self.size) + self.split(self.size.load(atomic::Ordering::Relaxed)) } } @@ -492,7 +500,7 @@ mod tests { #[test] fn test_memory_pool_underflow() { let pool = Arc::new(GreedyMemoryPool::new(50)) as _; - let mut a1 = MemoryConsumer::new("a1").register(&pool); + let a1 = MemoryConsumer::new("a1").register(&pool); assert_eq!(pool.reserved(), 0); a1.grow(100); @@ -507,7 +515,7 @@ mod tests { a1.try_grow(30).unwrap(); assert_eq!(pool.reserved(), 30); - let mut a2 = MemoryConsumer::new("a2").register(&pool); + let a2 = MemoryConsumer::new("a2").register(&pool); a2.try_grow(25).unwrap_err(); assert_eq!(pool.reserved(), 30); @@ -521,7 +529,7 @@ mod tests { #[test] fn test_split() { let pool = Arc::new(GreedyMemoryPool::new(50)) as _; - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.try_grow(20).unwrap(); assert_eq!(r1.size(), 20); @@ -542,10 +550,10 @@ mod tests { #[test] fn test_new_empty() { let pool = Arc::new(GreedyMemoryPool::new(50)) as _; - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.try_grow(20).unwrap(); - let mut r2 = r1.new_empty(); + let r2 = r1.new_empty(); r2.try_grow(5).unwrap(); assert_eq!(r1.size(), 20); @@ -559,7 +567,7 @@ mod tests { let mut r1 = MemoryConsumer::new("r1").register(&pool); r1.try_grow(20).unwrap(); - let mut r2 = r1.take(); + let r2 = r1.take(); r2.try_grow(5).unwrap(); assert_eq!(r1.size(), 0); diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index bf74b5f6f4c6b..b10270851cc06 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -212,7 +212,7 @@ impl MemoryPool for FairSpillPool { .checked_div(state.num_spill) .unwrap_or(spill_available); - if reservation.size + additional > available { + if reservation.size() + additional > available { return Err(insufficient_capacity_err( reservation, additional, @@ -264,7 +264,7 @@ fn insufficient_capacity_err( "Failed to allocate additional {} for {} with {} already allocated for this reservation - {} remain available for the total pool", human_readable_size(additional), reservation.registration.consumer.name, - human_readable_size(reservation.size), + human_readable_size(reservation.size()), human_readable_size(available) ) } @@ -526,12 +526,12 @@ mod tests { fn test_fair() { let pool = Arc::new(FairSpillPool::new(100)) as _; - let mut r1 = MemoryConsumer::new("unspillable").register(&pool); + let r1 = MemoryConsumer::new("unspillable").register(&pool); // Can grow beyond capacity of pool r1.grow(2000); assert_eq!(pool.reserved(), 2000); - let mut r2 = MemoryConsumer::new("r2") + let r2 = MemoryConsumer::new("r2") .with_can_spill(true) .register(&pool); // Can grow beyond capacity of pool @@ -563,7 +563,7 @@ mod tests { assert_eq!(r2.size(), 10); assert_eq!(pool.reserved(), 30); - let mut r3 = MemoryConsumer::new("r3") + let r3 = MemoryConsumer::new("r3") .with_can_spill(true) .register(&pool); @@ -584,7 +584,7 @@ mod tests { r1.free(); assert_eq!(pool.reserved(), 80); - let mut r4 = MemoryConsumer::new("s4").register(&pool); + let r4 = MemoryConsumer::new("s4").register(&pool); let err = r4.try_grow(30).unwrap_err().strip_backtrace(); assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 30.0 B for s4 with 0.0 B already allocated for this reservation - 20.0 B remain available for the total pool"); } @@ -601,18 +601,18 @@ mod tests { // Test: use all the different interfaces to change reservation size // set r1=50, using grow and shrink - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.grow(50); r1.grow(20); r1.shrink(20); // set r2=15 using try_grow - let mut r2 = MemoryConsumer::new("r2").register(&pool); + let r2 = MemoryConsumer::new("r2").register(&pool); r2.try_grow(15) .expect("should succeed in memory allotment for r2"); // set r3=20 using try_resize - let mut r3 = MemoryConsumer::new("r3").register(&pool); + let r3 = MemoryConsumer::new("r3").register(&pool); r3.try_resize(25) .expect("should succeed in memory allotment for r3"); r3.try_resize(20) @@ -620,12 +620,12 @@ mod tests { // set r4=10 // this should not be reported in top 3 - let mut r4 = MemoryConsumer::new("r4").register(&pool); + let r4 = MemoryConsumer::new("r4").register(&pool); r4.grow(10); // Test: reports if new reservation causes error // using the previously set sizes for other consumers - let mut r5 = MemoryConsumer::new("r5").register(&pool); + let r5 = MemoryConsumer::new("r5").register(&pool); let res = r5.try_grow(150); assert!(res.is_err()); let error = res.unwrap_err().strip_backtrace(); @@ -650,7 +650,7 @@ mod tests { let same_name = "foo"; // Test: see error message when no consumers recorded yet - let mut r0 = MemoryConsumer::new(same_name).register(&pool); + let r0 = MemoryConsumer::new(same_name).register(&pool); let res = r0.try_grow(150); assert!(res.is_err()); let error = res.unwrap_err().strip_backtrace(); @@ -665,7 +665,7 @@ mod tests { r0.grow(10); // make r0=10, pool available=90 let new_consumer_same_name = MemoryConsumer::new(same_name); - let mut r1 = new_consumer_same_name.register(&pool); + let r1 = new_consumer_same_name.register(&pool); // TODO: the insufficient_capacity_err() message is per reservation, not per consumer. // a followup PR will clarify this message "0 bytes already allocated for this reservation" let res = r1.try_grow(150); @@ -695,7 +695,7 @@ mod tests { // will be recognized as different in the TrackConsumersPool let consumer_with_same_name_but_different_hash = MemoryConsumer::new(same_name).with_can_spill(true); - let mut r2 = consumer_with_same_name_but_different_hash.register(&pool); + let r2 = consumer_with_same_name_but_different_hash.register(&pool); let res = r2.try_grow(150); assert!(res.is_err()); let error = res.unwrap_err().strip_backtrace(); @@ -714,10 +714,10 @@ mod tests { // Baseline: see the 2 memory consumers let setting = make_settings(); let _bound = setting.bind_to_scope(); - let mut r0 = MemoryConsumer::new("r0").register(&pool); + let r0 = MemoryConsumer::new("r0").register(&pool); r0.grow(10); let r1_consumer = MemoryConsumer::new("r1"); - let mut r1 = r1_consumer.register(&pool); + let r1 = r1_consumer.register(&pool); r1.grow(20); let res = r0.try_grow(150); @@ -791,13 +791,13 @@ mod tests { .downcast::>() .unwrap(); // set r1=20 - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.grow(20); // set r2=15 - let mut r2 = MemoryConsumer::new("r2").register(&pool); + let r2 = MemoryConsumer::new("r2").register(&pool); r2.grow(15); // set r3=45 - let mut r3 = MemoryConsumer::new("r3").register(&pool); + let r3 = MemoryConsumer::new("r3").register(&pool); r3.grow(45); let downcasted = upcasted diff --git a/datafusion/physical-plan/src/buffer.rs b/datafusion/physical-plan/src/buffer.rs new file mode 100644 index 0000000000000..01b3b76354852 --- /dev/null +++ b/datafusion/physical-plan/src/buffer.rs @@ -0,0 +1,675 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`BufferExec`] decouples production and consumption on messages by buffering the input in the +//! background up to a certain capacity. + +use crate::execution_plan::{CardinalityEffect, SchedulingType}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; +use crate::projection::ProjectionExec; +use crate::stream::RecordBatchStreamAdapter; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SortOrderPushdownResult, +}; +use arrow::array::RecordBatch; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{Result, Statistics, internal_err, plan_err}; +use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr_common::metrics::{ + ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, +}; +use datafusion_physical_expr_common::physical_expr::{ + PhysicalExpr, is_dynamic_physical_expr, +}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use futures::{Stream, StreamExt, TryStreamExt}; +use pin_project_lite::pin_project; +use std::any::Any; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::UnboundedReceiver; +use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore}; + +/// Decouples production and consumption of record batches with an internal queue per partition, +/// eagerly filling up the capacity of the queues even before any message is requested. +/// +/// ```text +/// ┌───────────────────────────┐ +/// │ BufferExec │ +/// │ │ +/// │┌────── Partition 0 ──────┐│ +/// ││ ┌────┐ ┌────┐││ ┌────┐ +/// ──background poll────────▶│ │ │ ├┼┼───────▶ │ +/// ││ └────┘ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// │┌────── Partition 1 ──────┐│ +/// ││ ┌────┐ ┌────┐ ┌────┐││ ┌────┐ +/// ──background poll─▶│ │ │ │ │ ├┼┼───────▶ │ +/// ││ └────┘ └────┘ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// │ │ +/// │ ... │ +/// │ │ +/// │┌────── Partition N ──────┐│ +/// ││ ┌────┐││ ┌────┐ +/// ──background poll───────────────▶│ ├┼┼───────▶ │ +/// ││ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// └───────────────────────────┘ +/// ``` +/// +/// The capacity is provided in bytes, and for each buffered record batch it will take into account +/// the size reported by [RecordBatch::get_array_memory_size]. +/// +/// This is useful for operators that conditionally start polling one of their children only after +/// other child has finished, allowing to perform some early work and accumulating batches in +/// memory so that they can be served immediately when requested. +#[derive(Debug, Clone)] +pub struct BufferExec { + input: Arc, + properties: PlanProperties, + capacity: usize, + wait_first_poll: bool, + metrics: ExecutionPlanMetricsSet, +} + +impl BufferExec { + /// Builds a new [BufferExec] with the provided capacity in bytes. + pub fn new(input: Arc, capacity: usize) -> Self { + let properties = input + .properties() + .clone() + .with_scheduling_type(SchedulingType::Cooperative); + + Self { + input, + properties, + capacity, + wait_first_poll: false, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + /// Returns the input [ExecutionPlan] of this [BufferExec]. + pub fn input(&self) -> &Arc { + &self.input + } + + /// Returns the per-partition capacity in bytes for this [BufferExec]. + pub fn capacity(&self) -> usize { + self.capacity + } +} + +impl DisplayAs for BufferExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "BufferExec: capacity={}", self.capacity) + } + DisplayFormatType::TreeRender => { + writeln!(f, "target_batch_size={}", self.capacity) + } + } + } +} + +impl ExecutionPlan for BufferExec { + fn name(&self) -> &str { + "BufferExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + if children.len() != 1 { + return plan_err!("BufferExec can only have one child"); + } + Ok(Arc::new(Self::new(children.swap_remove(0), self.capacity))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let mem_reservation = MemoryConsumer::new(format!("BufferExec[{partition}]")) + .register(context.memory_pool()); + let in_stream = self.input.execute(partition, context)?; + + // Set up the metrics for the stream. + let curr_mem_in = Arc::new(AtomicUsize::new(0)); + let curr_mem_out = Arc::clone(&curr_mem_in); + let mut max_mem_in = 0; + let max_mem = MetricBuilder::new(&self.metrics).gauge("max_mem_used", partition); + + let curr_queued_in = Arc::new(AtomicUsize::new(0)); + let curr_queued_out = Arc::clone(&curr_queued_in); + let mut max_queued_in = 0; + let max_queued = MetricBuilder::new(&self.metrics).gauge("max_queued", partition); + + // Capture metrics when an element is queued on the stream. + let in_stream = in_stream.inspect_ok(move |v| { + let size = v.get_array_memory_size(); + let curr_size = curr_mem_in.fetch_add(size, Ordering::Relaxed) + size; + if curr_size > max_mem_in { + max_mem_in = curr_size; + max_mem.set(max_mem_in); + } + + let curr_queued = curr_queued_in.fetch_add(1, Ordering::Relaxed) + 1; + if curr_queued > max_queued_in { + max_queued_in = curr_queued; + max_queued.set(max_queued_in); + } + }); + // Buffer the input. + let out_stream = MemoryBufferedStream::new( + in_stream, + self.capacity, + mem_reservation, + self.wait_first_poll, + ); + // Update in the metrics that when an element gets out, some memory gets freed. + let out_stream = out_stream.inspect_ok(move |v| { + curr_mem_out.fetch_sub(v.get_array_memory_size(), Ordering::Relaxed); + curr_queued_out.fetch_sub(1, Ordering::Relaxed); + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + out_stream, + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) + } + + fn supports_limit_pushdown(&self) -> bool { + self.input.supports_limit_pushdown() + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } + + fn try_swapping_with_projection( + &self, + projection: &ProjectionExec, + ) -> Result>> { + match self.input.try_swapping_with_projection(projection)? { + Some(new_input) => Ok(Some( + Arc::new(self.clone()).with_new_children(vec![new_input])?, + )), + None => Ok(None), + } + } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + FilterDescription::from_children(parent_filters, &self.children()) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + // If there is a dynamic filter being pushed down through this node, we don't want to buffer, + // we prefer to give a chance to the dynamic filter to be populated with something rather + // than eagerly polling data with an empty dynamic filter. + let has_dynamic_filter = child_pushdown_result + .parent_filters + .iter() + .any(|v| is_dynamic_physical_expr(&v.filter)); + if has_dynamic_filter { + let mut new_self = self.clone(); + new_self.wait_first_poll = true; + let mut result = FilterPushdownPropagation::if_all(child_pushdown_result); + result.updated_node = Some(Arc::new(new_self) as _); + Ok(result) + } else { + Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) + } + } + + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + // CoalesceBatchesExec is transparent for sort ordering - it preserves order + // Delegate to the child and wrap with a new CoalesceBatchesExec + self.input.try_pushdown_sort(order)?.try_map(|new_input| { + Ok(Arc::new(Self::new(new_input, self.capacity)) as Arc) + }) + } +} + +/// Represents anything that occupies a capacity in a [MemoryBufferedStream]. +pub trait SizedMessage { + fn size(&self) -> usize; +} + +impl SizedMessage for RecordBatch { + fn size(&self) -> usize { + self.get_array_memory_size() + } +} + +pin_project! { +/// Decouples production and consumption of messages in a stream with an internal queue, eagerly +/// filling it up to the specified maximum capacity even before any message is requested. +/// +/// Allows each message to have a different size, which is taken into account for determining if +/// the queue is full or not. +pub struct MemoryBufferedStream { + task: SpawnedTask<()>, + batch_rx: UnboundedReceiver>, + memory_reservation: Arc, + first_poll_notify: Option>, +}} + +impl MemoryBufferedStream { + /// Builds a new [MemoryBufferedStream] with the provided capacity and event handler. + /// + /// This immediately spawns a Tokio task that will start consumption of the input stream. + pub fn new( + mut input: impl Stream> + Unpin + Send + 'static, + capacity: usize, + memory_reservation: MemoryReservation, + wait_first_pool: bool, + ) -> Self { + let semaphore = Arc::new(Semaphore::new(capacity)); + let (batch_tx, batch_rx) = tokio::sync::mpsc::unbounded_channel(); + + let mut first_poll_notify = None; + if wait_first_pool { + first_poll_notify = Some(Arc::new(Notify::new())); + } + let mut first_poll_notify_clone = first_poll_notify.clone(); + + let memory_reservation = Arc::new(memory_reservation); + let memory_reservation_clone = Arc::clone(&memory_reservation); + let task = SpawnedTask::spawn(async move { + if let Some(first_poll_notify) = first_poll_notify_clone.take() { + first_poll_notify.notified_owned().await; + } + while let Some(item_or_err) = input.next().await { + let item = match item_or_err { + Ok(batch) => batch, + Err(err) => { + let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine. + break; + } + }; + + let size = item.size(); + if let Err(err) = memory_reservation.try_grow(size) { + let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine. + break; + } + + // We need to cap the minimum between amount of permits and the actual size of the + // message. If at any point we try to acquire more permits than the capacity of the + // semaphore, the stream will deadlock. + let capped_size = size.min(capacity) as u32; + + let semaphore = Arc::clone(&semaphore); + let Ok(permit) = semaphore.acquire_many_owned(capped_size).await else { + let _ = batch_tx.send(internal_err!("Closed semaphore in MemoryBufferedStream. This is a bug in DataFusion, please report it!")); + break; + }; + + if batch_tx.send(Ok((item, permit))).is_err() { + break; // stream was closed + }; + } + }); + + Self { + task, + batch_rx, + memory_reservation: memory_reservation_clone, + first_poll_notify, + } + } + + /// Returns the number of queued messages. + pub fn messages_queued(&self) -> usize { + self.batch_rx.len() + } +} + +impl Stream for MemoryBufferedStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let self_project = self.project(); + if let Some(first_poll_notify) = self_project.first_poll_notify.take() { + first_poll_notify.notify_one(); + } + match self_project.batch_rx.poll_recv(cx) { + Poll::Ready(Some(Ok((item, _semaphore_permit)))) => { + self_project.memory_reservation.shrink(item.size()); + Poll::Ready(Some(Ok(item))) + } + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + if self.batch_rx.is_closed() { + let len = self.batch_rx.len(); + (len, Some(len)) + } else { + (self.batch_rx.len(), None) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::{DataFusionError, assert_contains}; + use datafusion_execution::memory_pool::{ + GreedyMemoryPool, MemoryPool, UnboundedMemoryPool, + }; + use std::error::Error; + use std::fmt::Debug; + use std::sync::Arc; + use std::time::Duration; + use tokio::time::timeout; + + #[tokio::test] + async fn buffers_only_some_messages() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let buffered = MemoryBufferedStream::new(input, 4, res, false); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 2); + Ok(()) + } + + #[tokio::test] + async fn respects_wait_for_first_poll() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let buffered = MemoryBufferedStream::new(input, 4, res, true); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 0); + Ok(()) + } + + #[tokio::test] + async fn yields_all_messages() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res, false); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 4); + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn yields_all_messages_after_first_poll() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res, true); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 0); + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn yields_first_msg_even_if_big() -> Result<(), Box> { + let input = futures::stream::iter([25, 1, 2, 3]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res, false); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn memory_pool_kills_stream() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = bounded_memory_pool_and_reservation(7); + + let mut buffered = MemoryBufferedStream::new(input, 10, res, false); + wait_for_buffering().await; + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + let msg = pull_err_msg(&mut buffered).await?; + + assert_contains!(msg.to_string(), "Failed to allocate additional 4.0 B"); + Ok(()) + } + + #[tokio::test] + async fn memory_pool_does_not_kill_stream() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = bounded_memory_pool_and_reservation(7); + + let mut buffered = MemoryBufferedStream::new(input, 3, res, false); + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn messages_pass_even_if_all_exceed_limit() -> Result<(), Box> { + let input = futures::stream::iter([3, 3, 3, 3]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 2, res, false); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn errors_get_propagated() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(|v| { + if v == 3 { + return internal_err!("Error on 3"); + } + Ok(v) + }); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res, false); + wait_for_buffering().await; + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_err_msg(&mut buffered).await?; + + Ok(()) + } + + #[tokio::test] + async fn memory_gets_released_if_stream_drops() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (pool, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res, false); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 4); + assert_eq!(pool.reserved(), 10); + + pull_ok_msg(&mut buffered).await?; + assert_eq!(buffered.messages_queued(), 3); + assert_eq!(pool.reserved(), 9); + + pull_ok_msg(&mut buffered).await?; + assert_eq!(buffered.messages_queued(), 2); + assert_eq!(pool.reserved(), 7); + + drop(buffered); + assert_eq!(pool.reserved(), 0); + Ok(()) + } + + fn memory_pool_and_reservation() -> (Arc, MemoryReservation) { + let pool = Arc::new(UnboundedMemoryPool::default()) as _; + let reservation = MemoryConsumer::new("test").register(&pool); + (pool, reservation) + } + + fn bounded_memory_pool_and_reservation( + size: usize, + ) -> (Arc, MemoryReservation) { + let pool = Arc::new(GreedyMemoryPool::new(size)) as _; + let reservation = MemoryConsumer::new("test").register(&pool); + (pool, reservation) + } + + async fn wait_for_buffering() { + // We do not have control over the spawned task, so the best we can do is to yield some + // cycles to the tokio runtime and let the task make progress on its own. + tokio::time::sleep(Duration::from_millis(1)).await; + } + + async fn pull_ok_msg( + buffered: &mut MemoryBufferedStream, + ) -> Result> { + Ok(timeout(Duration::from_millis(1), buffered.next()) + .await? + .unwrap_or_else(|| internal_err!("Stream should not have finished"))?) + } + + async fn pull_err_msg( + buffered: &mut MemoryBufferedStream, + ) -> Result> { + Ok(timeout(Duration::from_millis(1), buffered.next()) + .await? + .map(|v| match v { + Ok(v) => internal_err!( + "Stream should not have failed, but succeeded with {v:?}" + ), + Err(err) => Ok(err), + }) + .unwrap_or_else(|| internal_err!("Stream should not have finished"))?) + } + + async fn finished( + buffered: &mut MemoryBufferedStream, + ) -> Result<(), Box> { + match timeout(Duration::from_millis(1), buffered.next()) + .await? + .is_none() + { + true => Ok(()), + false => internal_err!("Stream should have finished")?, + } + } + + impl SizedMessage for usize { + fn size(&self) -> usize { + *self + } + } +} diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 4f32b6176ec39..7ada14be66543 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -206,7 +206,7 @@ async fn load_left_input( let (batches, _metrics, reservation) = stream .try_fold( (Vec::new(), metrics, reservation), - |(mut batches, metrics, mut reservation), batch| async { + |(mut batches, metrics, reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch reservation.try_grow(batch_size)?; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 44637321a7e35..b57f9132253bf 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -682,10 +682,10 @@ async fn collect_left_input( let schema = stream.schema(); // Load all batches and count the rows - let (batches, metrics, mut reservation) = stream + let (batches, metrics, reservation) = stream .try_fold( (Vec::new(), join_metrics, reservation), - |(mut batches, metrics, mut reservation), batch| async { + |(mut batches, metrics, reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch reservation.try_grow(batch_size)?; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index 508be2e3984f4..d7ece845e943c 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -620,7 +620,7 @@ async fn build_buffered_data( // Combine batches and record number of rows let initial = (Vec::new(), 0, metrics, reservation); - let (batches, num_rows, metrics, mut reservation) = buffered + let (batches, num_rows, metrics, reservation) = buffered .try_fold(initial, |mut acc, batch| async { let batch_size = get_record_batch_memory_size(&batch); acc.3.try_grow(batch_size)?; diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index ec8e154caec91..ba870ba65cb92 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -65,6 +65,7 @@ mod visitor; pub mod aggregates; pub mod analyze; pub mod async_func; +pub mod buffer; pub mod coalesce; pub mod coalesce_batches; pub mod coalesce_partitions; diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 3e8fdf1f3ed7e..a8361f7b2941e 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -709,7 +709,7 @@ impl ExternalSorter { &self, batch: RecordBatch, metrics: &BaselineMetrics, - mut reservation: MemoryReservation, + reservation: MemoryReservation, ) -> Result { assert_eq!( get_reserved_bytes_for_record_batch(&batch)?, @@ -736,7 +736,7 @@ impl ExternalSorter { .then({ move |batches| async move { match batches { - Ok((schema, sorted_batches, mut reservation)) => { + Ok((schema, sorted_batches, reservation)) => { // Calculate the total size of sorted batches let total_sorted_size: usize = sorted_batches .iter() diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index a510f44e4f4df..779511a865b6a 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -180,7 +180,7 @@ impl RowCursorStream { self.rows.save(stream_idx, &rows); // track the memory in the newly created Rows. - let mut rows_reservation = self.reservation.new_empty(); + let rows_reservation = self.reservation.new_empty(); rows_reservation.try_grow(rows.size())?; Ok(RowValues::new(rows, rows_reservation)) } @@ -246,7 +246,7 @@ impl FieldCursorStream { let array = value.into_array(batch.num_rows())?; let size_in_mem = array.get_buffer_memory_size(); let array = array.as_any().downcast_ref::().expect("field values"); - let mut array_reservation = self.reservation.new_empty(); + let array_reservation = self.reservation.new_empty(); array_reservation.try_grow(size_in_mem)?; Ok(ArrayValues::new( self.sort.options, diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 80c2233d05db6..4b7e707fccedd 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -1005,7 +1005,7 @@ mod test { .build_arc() .unwrap(); - let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); @@ -1071,7 +1071,7 @@ mod test { .build_arc() .unwrap(); - let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index f1b9e3e88d123..1313909adbba2 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -283,7 +283,7 @@ mod tests { assert!(work_table.take().is_err()); let pool = Arc::new(UnboundedMemoryPool::default()) as _; - let mut reservation = MemoryConsumer::new("test_work_table").register(&pool); + let reservation = MemoryConsumer::new("test_work_table").register(&pool); // Update batch to work_table let array: ArrayRef = Arc::new((0..5).collect::()); diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 5f590560c4675..50332e23d7ee4 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -749,6 +749,7 @@ message PhysicalPlanNode { SortMergeJoinExecNode sort_merge_join = 34; MemoryScanExecNode memory_scan = 35; AsyncFuncExecNode async_func = 36; + BufferExecNode buffer = 37; } } @@ -1413,3 +1414,8 @@ message AsyncFuncExecNode { repeated PhysicalExprNode async_exprs = 2; repeated string async_expr_names = 3; } + +message BufferExecNode { + PhysicalPlanNode input = 1; + uint64 capacity = 2; +} \ No newline at end of file diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f6d364f269b48..a6e411744ce0f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1838,6 +1838,118 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode { deserializer.deserialize_struct("datafusion.BinaryExprNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for BufferExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.capacity != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.BufferExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.capacity != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("capacity", ToString::to_string(&self.capacity).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for BufferExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "capacity", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Capacity, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "capacity" => Ok(GeneratedField::Capacity), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = BufferExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.BufferExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut capacity__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Capacity => { + if capacity__.is_some() { + return Err(serde::de::Error::duplicate_field("capacity")); + } + capacity__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(BufferExecNode { + input: input__, + capacity: capacity__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.BufferExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CaseNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17509,6 +17621,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::AsyncFunc(v) => { struct_ser.serialize_field("asyncFunc", v)?; } + physical_plan_node::PhysicalPlanType::Buffer(v) => { + struct_ser.serialize_field("buffer", v)?; + } } } struct_ser.end() @@ -17576,6 +17691,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "memoryScan", "async_func", "asyncFunc", + "buffer", ]; #[allow(clippy::enum_variant_names)] @@ -17615,6 +17731,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { SortMergeJoin, MemoryScan, AsyncFunc, + Buffer, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17671,6 +17788,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "sortMergeJoin" | "sort_merge_join" => Ok(GeneratedField::SortMergeJoin), "memoryScan" | "memory_scan" => Ok(GeneratedField::MemoryScan), "asyncFunc" | "async_func" => Ok(GeneratedField::AsyncFunc), + "buffer" => Ok(GeneratedField::Buffer), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17936,6 +18054,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("asyncFunc")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::AsyncFunc) +; + } + GeneratedField::Buffer => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("buffer")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Buffer) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c1afd73ec3c52..bdd1e5eeb989f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1076,7 +1076,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37" )] pub physical_plan_type: ::core::option::Option, } @@ -1156,6 +1156,8 @@ pub mod physical_plan_node { MemoryScan(super::MemoryScanExecNode), #[prost(message, tag = "36")] AsyncFunc(::prost::alloc::boxed::Box), + #[prost(message, tag = "37")] + Buffer(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -2136,6 +2138,13 @@ pub struct AsyncFuncExecNode { #[prost(string, repeated, tag = "3")] pub async_expr_names: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BufferExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(uint64, tag = "2")] + pub capacity: u64, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum WindowFrameUnits { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 45868df4ced6c..5c1a70b7e4dec 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -103,6 +103,7 @@ use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, PhysicalExpr, Wind use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_plan::async_func::AsyncFuncExec; +use datafusion_physical_plan::buffer::BufferExec; use prost::Message; use prost::bytes::BufMut; @@ -255,6 +256,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { PhysicalPlanType::AsyncFunc(async_func) => { self.try_into_async_func_physical_plan(async_func, ctx, extension_codec) } + PhysicalPlanType::Buffer(buffer) => { + self.try_into_buffer_physical_plan(buffer, ctx, extension_codec) + } } } @@ -473,6 +477,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ); } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_buffer_exec( + exec, + extension_codec, + ); + } + let mut buf: Vec = vec![]; match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { @@ -2028,6 +2039,18 @@ impl protobuf::PhysicalPlanNode { Ok(Arc::new(AsyncFuncExec::try_new(async_exprs, input)?)) } + fn try_into_buffer_physical_plan( + &self, + buffer: &protobuf::BufferExecNode, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&buffer.input, ctx, extension_codec)?; + + Ok(Arc::new(BufferExec::new(input, buffer.capacity as usize))) + } + fn try_from_explain_exec( exec: &ExplainExec, _extension_codec: &dyn PhysicalExtensionCodec, @@ -3309,6 +3332,25 @@ impl protobuf::PhysicalPlanNode { ))), }) } + + fn try_from_buffer_exec( + exec: &BufferExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + Arc::clone(exec.input()), + extension_codec, + )?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Buffer(Box::new( + protobuf::BufferExecNode { + input: Some(Box::new(input)), + capacity: exec.capacity() as u64, + }, + ))), + }) + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone {