Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "pumps"
description = "Eager streams for Rust"
keywords = ["async", "pipeline"]
version = "0.0.4"
version = "0.1.0"
edition = "2021"
license-file = "LICENSE-MIT"
homepage = "https://github.com/alexpusch/pumps_rs"
Expand Down
86 changes: 86 additions & 0 deletions src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ use tokio::{
use crate::{
concurrency::Concurrency,
pumps::{
and_then::AndThenPump,
catch::CatchPump,
filter_map::FilterMapPump,
flatten::{FlattenConcurrency, FlattenPump},
flatten_iter::FlattenIterPump,
map::MapPump,
map_err::MapErrPump,
map_ok::MapOkPump,
try_filter_map::TryFilterMapPump,
},
Pump,
};
Expand Down Expand Up @@ -659,6 +661,90 @@ impl<OutOk, OutErr> Pipeline<Result<OutOk, OutErr>> {
concurrency,
})
}

/// Applies the provided async function to the success value for each item in a [Pipeline] of `Results<T, E>`.
/// The function returns a `Result` which will be the new output.
///
/// # Example
/// ```rust
/// use pumps::{Pipeline, Concurrency};
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let (mut output, h) = Pipeline::from_iter(vec![Ok(1), Err("error"), Ok(2)])
/// .and_then(|x| async move { Ok(x * 2) }, Concurrency::serial())
/// .build();
///
/// assert_eq!(output.recv().await, Some(Ok(2)));
/// assert_eq!(output.recv().await, Some(Err("error")));
/// assert_eq!(output.recv().await, Some(Ok(4)));
/// assert_eq!(output.recv().await, None);
/// # });
/// ```
pub fn and_then<F, Fut, T>(
self,
map_fn: F,
concurrency: Concurrency,
) -> Pipeline<Result<T, OutErr>>
where
F: Fn(OutOk) -> Fut + Send + 'static,
Fut: Future<Output = Result<T, OutErr>> + Send,
T: Send + 'static,
OutErr: Send + 'static,
OutOk: Send + 'static,
{
self.pump(AndThenPump {
map_fn,
concurrency,
})
}

/// Apply the provided async filter map function on the success value for each item in a [Pipeline] of `Results<T, E>`.
/// The Future returned from the filter map function is executed concurrently according to the [Concurrency] configuration.
/// If the invocation results in `Ok(None)`, the triggering item will be removed from the pipeline.
/// If the invocation results in `Ok(Some(val))`, the result will be sent as output.
/// If the invocation results in `Err(e)` or the input was `Err(e)`, the error will be sent as output.
///
/// # Example
/// ```rust
/// use pumps::{Pipeline, Concurrency};
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let (mut output, h) = Pipeline::from_iter(vec![Ok(1), Ok(2), Ok(3), Err("input error")])
/// .try_filter_map(|x| async move {
/// if x % 2 == 0 {
/// Ok(Some(x * 2))
/// } else if x % 3 == 0 {
/// Ok(None)
/// } else {
/// Err("odd and not divisible by 3")
/// }
/// }, Concurrency::serial())
/// .build();
///
/// assert_eq!(output.recv().await, Some(Err("odd and not divisible by 3"))); // 1 -> Err
/// assert_eq!(output.recv().await, Some(Ok(4))); // 2 -> Ok(4)
/// // 3 -> None (filtered)
/// assert_eq!(output.recv().await, Some(Err("input error"))); // input error -> Err
/// assert_eq!(output.recv().await, None);
/// # });
/// ```
pub fn try_filter_map<F, Fut, T>(
self,
map_fn: F,
concurrency: Concurrency,
) -> Pipeline<Result<T, OutErr>>
where
F: FnMut(OutOk) -> Fut + Send + 'static,
Fut: Future<Output = Result<Option<T>, OutErr>> + Send,
T: Send + 'static,
OutErr: Send + 'static,
OutOk: Send + 'static,
{
self.pump(TryFilterMapPump {
map_fn,
concurrency,
})
}
}

#[cfg(test)]
Expand Down
85 changes: 85 additions & 0 deletions src/pumps/and_then.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use std::future::Future;

use tokio::{sync::mpsc::Receiver, task::JoinHandle};

use crate::{concurrency::Concurrency, concurrency_base, Pump};

pub struct AndThenPump<F> {
pub(crate) map_fn: F,
pub(crate) concurrency: Concurrency,
}

impl<InOk, InErr, OutOk, F, Fut> Pump<Result<InOk, InErr>, Result<OutOk, InErr>> for AndThenPump<F>
where
F: FnMut(InOk) -> Fut + Send + 'static,
Fut: Future<Output = Result<OutOk, InErr>> + Send,
InOk: Send + 'static,
InErr: Send + 'static,
OutOk: Send + 'static,
{
fn spawn(
mut self,
mut input_receiver: Receiver<Result<InOk, InErr>>,
) -> (Receiver<Result<OutOk, InErr>>, JoinHandle<()>) {
concurrency_base! {
input_receiver = input_receiver;
concurrency = self.concurrency;

on_input(input, in_progress) => {
match input {
Ok(input) => {
let fut = (self.map_fn)(input);
in_progress.push_back(fut);
},
Err(e) => {
if let Err(_e) = output_sender.send(Err(e)).await {
break;
}
}
}
},
on_progress(output, output_sender) => {
if let Err(_e) = output_sender.send(output).await {
break;
}
}
}
}
}

#[cfg(test)]
mod tests {
use crate::Pipeline;
use tokio::sync::mpsc;

#[tokio::test]
async fn and_then_works() {
let (input_sender, input_receiver) = mpsc::channel(100);

let (mut output_receiver, join_handle) = Pipeline::from(input_receiver)
.and_then(
|x| async move {
if x % 2 == 0 {
Ok(x)
} else {
Err("odd")
}
},
Default::default(),
)
.build();

input_sender.send(Ok(1)).await.unwrap();
input_sender.send(Err("oh no")).await.unwrap();
input_sender.send(Ok(2)).await.unwrap();

assert_eq!(output_receiver.recv().await, Some(Err("odd")));
assert_eq!(output_receiver.recv().await, Some(Err("oh no")));
assert_eq!(output_receiver.recv().await, Some(Ok(2)));

drop(input_sender);
assert_eq!(output_receiver.recv().await, None);

assert!(matches!(join_handle.await, Ok(())));
}
}
2 changes: 2 additions & 0 deletions src/pumps/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub(crate) mod and_then;
pub(crate) mod backpressure;
pub(crate) mod backpressure_with_relief_valve;
pub(crate) mod batch;
Expand All @@ -13,3 +14,4 @@ pub(crate) mod map_ok;
pub(crate) mod pump;
pub(crate) mod skip;
pub(crate) mod take;
pub(crate) mod try_filter_map;
102 changes: 102 additions & 0 deletions src/pumps/try_filter_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use std::future::Future;

use tokio::{sync::mpsc::Receiver, task::JoinHandle};

use crate::{concurrency::Concurrency, concurrency_base, Pump};

pub struct TryFilterMapPump<F> {
pub(crate) map_fn: F,
pub(crate) concurrency: Concurrency,
}

impl<InOk, InErr, OutOk, F, Fut> Pump<Result<InOk, InErr>, Result<OutOk, InErr>>
for TryFilterMapPump<F>
where
F: FnMut(InOk) -> Fut + Send + 'static,
Fut: Future<Output = Result<Option<OutOk>, InErr>> + Send,
InOk: Send + 'static,
InErr: Send + 'static,
OutOk: Send + 'static,
{
fn spawn(
mut self,
mut input_receiver: Receiver<Result<InOk, InErr>>,
) -> (Receiver<Result<OutOk, InErr>>, JoinHandle<()>) {
concurrency_base! {
input_receiver = input_receiver;
concurrency = self.concurrency;

on_input(input, in_progress) => {
match input {
Ok(input) => {
let fut = (self.map_fn)(input);
in_progress.push_back(fut);
},
Err(e) => {
if let Err(_e) = output_sender.send(Err(e)).await {
break;
}
}
}
},
on_progress(output, output_sender) => {
match output {
Ok(Some(output)) => {
if let Err(_e) = output_sender.send(Ok(output)).await {
break;
}
},
Ok(None) => {},
Err(e) => {
if let Err(_e) = output_sender.send(Err(e)).await {
break;
}
}
}
}
}
}
}

#[cfg(test)]
mod tests {
use crate::Pipeline;
use tokio::sync::mpsc;

#[tokio::test]
async fn try_filter_map_works() {
let (input_sender, input_receiver) = mpsc::channel(100);

let (mut output_receiver, join_handle) = Pipeline::from(input_receiver)
.try_filter_map(
|x| async move {
if x % 2 == 0 {
Ok(Some(x))
} else if x % 3 == 0 {
Ok(None)
} else {
Err("odd and not divisible by 3")
}
},
Default::default(),
)
.build();

input_sender.send(Ok(2)).await.unwrap(); // driven by map_fn: Ok(Some(2)) -> Ok(2)
input_sender.send(Ok(3)).await.unwrap(); // driven by map_fn: Ok(None) -> filtered
input_sender.send(Ok(5)).await.unwrap(); // driven by map_fn: Err -> Err
input_sender.send(Err("input error")).await.unwrap(); // pass through: Err

assert_eq!(output_receiver.recv().await, Some(Ok(2)));
assert_eq!(
output_receiver.recv().await,
Some(Err("odd and not divisible by 3"))
);
assert_eq!(output_receiver.recv().await, Some(Err("input error")));

drop(input_sender);
assert_eq!(output_receiver.recv().await, None);

assert!(matches!(join_handle.await, Ok(())));
}
}