Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions src/execution_plans/common.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use arrow::array::RecordBatch;
use datafusion::common::runtime::SpawnedTask;
use datafusion::common::{DataFusionError, plan_err};
use datafusion::execution::memory_pool::{MemoryConsumer, MemoryPool};
use datafusion::physical_expr::Partitioning;
use datafusion::physical_plan::{ExecutionPlan, PlanProperties};
use futures::{Stream, StreamExt};
use std::borrow::Borrow;
use std::sync::Arc;
use tokio_stream::wrappers::UnboundedReceiverStream;

pub(super) fn require_one_child<L, T>(
children: L,
Expand Down Expand Up @@ -40,3 +45,104 @@ pub(super) fn scale_partitioning(
Partitioning::UnknownPartitioning(p) => Partitioning::UnknownPartitioning(f(*p)),
}
}

/// Consumes all the provided streams in parallel sending their produced messages to a single
/// queue in random order. The resulting queue is returned as a stream.
// FIXME: It should not be necessary to do this, it should be fine to just consume
// all the messages with a normal tokio::stream::select_all, however, that has the chance
// of deadlocking the stream on the server side (https://github.com/datafusion-contrib/datafusion-distributed/issues/228).
// Even having these channels bounded would result in deadlocks (learned it the hard way).
// Until we figure out what's wrong there, this is a good enough solution.
pub(super) fn spawn_select_all<T, El, Err>(
inner: Vec<T>,
pool: Arc<dyn MemoryPool>,
) -> impl Stream<Item = Result<El, Err>>
where
T: Stream<Item = Result<El, Err>> + Send + Unpin + 'static,
El: MemoryFootPrint + Send + 'static,
Err: Send + 'static,
{
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

let mut tasks = vec![];
for mut t in inner {
let tx = tx.clone();
let pool = Arc::clone(&pool);
let consumer = MemoryConsumer::new("NetworkBoundary");

tasks.push(SpawnedTask::spawn(async move {
while let Some(msg) = t.next().await {
let mut reservation = consumer.clone_with_new_id().register(&pool);
if let Ok(msg) = &msg {
reservation.grow(msg.get_memory_size());
}

if tx.send((msg, reservation)).is_err() {
return;
};
}
}))
}

UnboundedReceiverStream::new(rx).map(move |(msg, _reservation)| {
// keep the tasks alive as long as the stream lives
let _ = &tasks;
msg
})
}

pub(super) trait MemoryFootPrint {
fn get_memory_size(&self) -> usize;
}

impl MemoryFootPrint for RecordBatch {
fn get_memory_size(&self) -> usize {
self.get_array_memory_size()
}
}

#[cfg(test)]
mod tests {
use crate::execution_plans::common::{MemoryFootPrint, spawn_select_all};
use datafusion::execution::memory_pool::{MemoryPool, UnboundedMemoryPool};
use std::error::Error;
use std::sync::Arc;
use tokio_stream::StreamExt;

#[tokio::test]
async fn memory_reservation() -> Result<(), Box<dyn Error>> {
let pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default());

let mut stream = spawn_select_all(
vec![
futures::stream::iter(vec![Ok::<_, String>(1), Ok(2), Ok(3)]),
futures::stream::iter(vec![Ok(4), Ok(5)]),
],
Arc::clone(&pool),
);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let reserved = pool.reserved();
assert_eq!(reserved, 15);

for i in [1, 2, 3] {
let n = stream.next().await.unwrap()?;
assert_eq!(i, n)
}

let reserved = pool.reserved();
assert_eq!(reserved, 9);

drop(stream);

let reserved = pool.reserved();
assert_eq!(reserved, 0);

Ok(())
}

impl MemoryFootPrint for usize {
fn get_memory_size(&self) -> usize {
*self
}
}
}
11 changes: 7 additions & 4 deletions src/execution_plans/network_coalesce.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::channel_resolver_ext::get_distributed_channel_resolver;
use crate::config_extension_ext::ContextGrpcMetadata;
use crate::distributed_planner::{InputStageInfo, NetworkBoundary, limit_tasks_err};
use crate::execution_plans::common::{require_one_child, scale_partitioning_props};
use crate::execution_plans::common::{
require_one_child, scale_partitioning_props, spawn_select_all,
};
use crate::flight_service::DoGet;
use crate::metrics::MetricsCollectingStream;
use crate::metrics::proto::MetricsSetProto;
Expand All @@ -18,7 +20,7 @@ use datafusion::error::DataFusionError;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::{TryFutureExt, TryStreamExt};
use futures::{StreamExt, TryFutureExt, TryStreamExt};
use http::Extensions;
use prost::Message;
use std::any::Any;
Expand Down Expand Up @@ -319,11 +321,12 @@ impl ExecutionPlan for NetworkCoalesceExec {
.map_err(map_flight_to_datafusion_error),
)
}
.try_flatten_stream();
.try_flatten_stream()
.boxed();

Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
stream,
spawn_select_all(vec![stream], Arc::clone(context.memory_pool())),
)))
}
}
4 changes: 2 additions & 2 deletions src/execution_plans/network_shuffle.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::channel_resolver_ext::get_distributed_channel_resolver;
use crate::config_extension_ext::ContextGrpcMetadata;
use crate::execution_plans::common::{require_one_child, scale_partitioning};
use crate::execution_plans::common::{require_one_child, scale_partitioning, spawn_select_all};
use crate::flight_service::DoGet;
use crate::metrics::MetricsCollectingStream;
use crate::metrics::proto::MetricsSetProto;
Expand Down Expand Up @@ -372,7 +372,7 @@ impl ExecutionPlan for NetworkShuffleExec {

Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
futures::stream::select_all(stream),
spawn_select_all(stream.collect(), Arc::clone(context.memory_pool())),
)))
}
}