Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/common/callback_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use futures::Stream;
use pin_project::{pin_project, pinned_drop};
use std::fmt::Display;
use std::pin::Pin;
use std::task::{Context, Poll};

/// The reason why the stream ended:
/// - [CallbackStreamEndReason::Finished] if it finished gracefully
/// - [CallbackStreamEndReason::Aborted] if it was abandoned.
#[derive(Debug)]
pub enum CallbackStreamEndReason {
/// The stream finished gracefully.
Finished,
/// The stream was abandoned.
Aborted,
}

impl Display for CallbackStreamEndReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}

/// Stream that executes a callback when it is fully consumed or gets cancelled.
#[pin_project(PinnedDrop)]
pub struct CallbackStream<S, F>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read through the implementation and usage. I makes sense and a very nice fix.
My only question is about the name. Is it CallBack a common name for this kind of thing? The name does not spark immediate meaning of the work but if it is only me, it is fine to keep the name as-is

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'd say "callback" is pretty common for this in the software industry. A callback is a function whose execution gets deferred until a certain event happens. In this case, when the stream finishes.

where
S: Stream,
F: FnOnce(CallbackStreamEndReason),
{
#[pin]
stream: S,
callback: Option<F>,
}

impl<S, F> Stream for CallbackStream<S, F>
where
S: Stream,
F: FnOnce(CallbackStreamEndReason),
{
type Item = S::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();

match this.stream.poll_next(cx) {
Poll::Ready(None) => {
// Stream is fully consumed, execute the callback
if let Some(callback) = this.callback.take() {
callback(CallbackStreamEndReason::Finished);
}
Poll::Ready(None)
}
other => other,
}
}
}

#[pinned_drop]
impl<S, F> PinnedDrop for CallbackStream<S, F>
where
S: Stream,
F: FnOnce(CallbackStreamEndReason),
{
fn drop(self: Pin<&mut Self>) {
let this = self.project();
if let Some(callback) = this.callback.take() {
callback(CallbackStreamEndReason::Aborted);
}
}
}

/// Wrap a stream with a callback that will be executed when the stream is fully
/// consumed or gets canceled.
pub fn with_callback<S, F>(stream: S, callback: F) -> CallbackStream<S, F>
where
S: Stream,
F: FnOnce(CallbackStreamEndReason) + Send + 'static,
{
CallbackStream {
stream,
callback: Some(callback),
}
}
2 changes: 2 additions & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod callback_stream;
mod composed_extension_codec;
mod partitioning;
#[allow(unused)]
pub mod ttl_map;

pub(crate) use callback_stream::with_callback;
pub(crate) use composed_extension_codec::ComposedPhysicalExtensionCodec;
pub(crate) use partitioning::{scale_partitioning, scale_partitioning_props};
14 changes: 13 additions & 1 deletion src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::common::with_callback;
use crate::config_extension_ext::ContextGrpcMetadata;
use crate::execution_plans::{DistributedTaskContext, StageExec};
use crate::flight_service::service::ArrowFlightEndpoint;
Expand All @@ -11,6 +12,7 @@ use arrow_flight::error::FlightError;
use arrow_flight::flight_service_server::FlightService;
use datafusion::common::exec_datafusion_err;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use futures::TryStreamExt;
use prost::Message;
use std::sync::Arc;
Expand Down Expand Up @@ -126,7 +128,17 @@ impl ArrowFlightEndpoint {
.execute(doget.target_partition as usize, session_state.task_ctx())
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;

Ok(record_batch_stream_to_response(stream))
let schema = stream.schema();
let stream = with_callback(stream, move |_| {
// We need to hold a reference to the plan for at least as long as the stream is
// execution. Some plans might store state necessary for the stream to work, and
// dropping the plan early could drop this state too soon.
let _ = stage.plan;
});

Ok(record_batch_stream_to_response(Box::pin(
RecordBatchStreamAdapter::new(schema, stream),
)))
}
}

Expand Down
Loading