Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming support for arrow batches #44

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add streaming support for arrow batches
sgrebnov committed May 8, 2024
commit d77e97b179eabc3807b2ef83528a570737a01d06
1 change: 1 addition & 0 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ polars = ["dep:polars-core", "dep:polars-io"]
[dependencies]
arrow = "51"
async-trait = "0.1"
async-stream = "0.3.5"
base64 = "0.22"
bytes = "1"
futures = "0.3"
122 changes: 114 additions & 8 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -15,8 +15,11 @@ clippy::missing_panics_doc

use std::fmt::{Display, Formatter};
use std::io;
use std::pin::Pin;
use std::sync::Arc;

use async_stream::stream;

use arrow::error::ArrowError;
use arrow::ipc::reader::StreamReader;
use arrow::record_batch::RecordBatch;
@@ -27,7 +30,7 @@ use regex::Regex;
use reqwest_middleware::ClientWithMiddleware;
use thiserror::Error;

use responses::ExecResponse;
use responses::{ExecResponse, QueryExecResponseData};
use session::{AuthError, Session};

use crate::connection::QueryType;
@@ -36,6 +39,8 @@ use crate::requests::ExecRequest;
use crate::responses::{ExecResponseRowType, SnowflakeType};
use crate::session::AuthError::MissingEnvArgument;

use futures::{future, Stream, StreamExt};

pub mod connection;
#[cfg(feature = "polars")]
mod polars;
@@ -98,6 +103,8 @@ pub enum SnowflakeApiError {
GlobError(#[from] glob::GlobError),
}

const MAX_CHUNK_DOWNLOAD_WORKERS: usize = 10;

/// Even if Arrow is specified as a return type non-select queries
/// will return Json array of arrays: `[[42, "answer"], [43, "non-answer"]]`.
pub struct JsonResult {
@@ -144,24 +151,38 @@ pub enum QueryResult {
Empty,
}

pub type BytesStream = Pin<Box<dyn Stream<Item = Result<bytes::Bytes, SnowflakeApiError>> + Send>>;
pub type RecordBatchStream = Pin<Box<dyn Stream<Item = Result<RecordBatch, ArrowError>> + Send>>;

/// Raw query result
/// Can be transformed into [`QueryResult`]
pub enum RawQueryResult {
/// Arrow IPC chunks
/// see: <https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc>
Bytes(Vec<Bytes>),
Stream(BytesStream),
/// Json payload is deserialized,
/// as it's already a part of REST response
Json(JsonResult),
Empty,
}

impl RawQueryResult {
pub fn deserialize_arrow(self) -> Result<QueryResult, ArrowError> {
pub async fn deserialize_arrow(self) -> Result<QueryResult, ArrowError> {
match self {
RawQueryResult::Bytes(bytes) => {
Self::flat_bytes_to_batches(bytes).map(QueryResult::Arrow)
}
RawQueryResult::Stream(bytes_stream) => {
let arrow_records_stream = Self::to_record_batches_stream(bytes_stream);
let arrow_records = arrow_records_stream
.collect::<Vec<Result<RecordBatch, ArrowError>>>()
.await;

return Ok(QueryResult::Arrow(
arrow_records.into_iter().map(Result::unwrap).collect(),
));
}
RawQueryResult::Json(j) => Ok(QueryResult::Json(j)),
RawQueryResult::Empty => Ok(QueryResult::Empty),
}
@@ -176,6 +197,24 @@ impl RawQueryResult {
Ok(res)
}

fn to_record_batches_stream(bytes_stream: BytesStream) -> RecordBatchStream {
let batch_stream = bytes_stream.flat_map(|bytes_result| match bytes_result {
Ok(bytes) => match Self::bytes_to_batches(bytes) {
Ok(batches) => futures::stream::iter(batches.into_iter().map(Ok)).boxed(),
Err(e) => futures::stream::once(async move { Err(ArrowError::from(e)) }).boxed(),
},
Err(e) => futures::stream::once(async move {
Err(ArrowError::ParseError(format!(
"Unable to parse RecordBatch due to error in bytes stream: {}",
e.to_string()
)))
})
.boxed(),
});

Box::pin(batch_stream)
}

fn bytes_to_batches(bytes: Bytes) -> Result<Vec<RecordBatch>, ArrowError> {
let record_batches = StreamReader::try_new_unbuffered(bytes.reader(), None)?;
record_batches.into_iter().collect()
@@ -380,10 +419,23 @@ impl SnowflakeApi {
/// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
pub async fn exec(&self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
let raw = self.exec_raw(sql).await?;
let res = raw.deserialize_arrow()?;
let res = raw.deserialize_arrow().await?;
Ok(res)
}

// Executes a single query against API and returns a stream of RecordBatches
pub async fn exec_streamed(&self, sql: &str) -> Result<RecordBatchStream, SnowflakeApiError> {
let raw = self.exec_arrow_raw(sql, true).await?;
match raw {
RawQueryResult::Empty => Ok(Box::pin(futures::stream::empty())),
RawQueryResult::Stream(bytes_stream) => {
let arrow_stream = RawQueryResult::to_record_batches_stream(bytes_stream);
Ok(arrow_stream)
}
_ => Err(SnowflakeApiError::UnexpectedResponse),
}
}

/// Executes a single query against API.
/// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
/// Returns raw bytes in the Arrow response
@@ -395,7 +447,7 @@ impl SnowflakeApi {
log::info!("Detected PUT query");
self.exec_put(sql).await.map(|()| RawQueryResult::Empty)
} else {
self.exec_arrow_raw(sql).await
self.exec_arrow_raw(sql, false).await
}
}

@@ -429,8 +481,12 @@ impl SnowflakeApi {
.await
}

async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
let resp = self
async fn exec_arrow_raw(
&self,
sql: &str,
enable_streaming: bool,
) -> Result<RawQueryResult, SnowflakeApiError> {
let mut resp = self
.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
.await?;
log::debug!("Got query response: {:?}", resp);
@@ -459,14 +515,19 @@ impl SnowflakeApi {
value,
schema: resp.data.rowtype.into_iter().map(Into::into).collect(),
}))
} else if let Some(base64) = resp.data.rowset_base64 {
// fixme: is it possible to give streaming interface?
} else if resp.data.rowset_base64.is_some() {
if enable_streaming {
return Ok(self.chunks_to_bytes_stream(&resp.data));
}

let mut chunks = try_join_all(resp.data.chunks.iter().map(|chunk| {
self.connection
.get_chunk(&chunk.url, &resp.data.chunk_headers)
}))
.await?;

let base64 = resp.data.rowset_base64.unwrap_or_default();

// fixme: should base64 chunk go first?
// fixme: if response is chunked is it both base64 + chunks or just chunks?
if !base64.is_empty() {
@@ -510,4 +571,49 @@ impl SnowflakeApi {

Ok(resp)
}

fn chunks_to_bytes_stream(&self, data: &QueryExecResponseData) -> RawQueryResult {
let chunk_urls = data
.chunks
.iter()
.map(|chunk| chunk.url.clone())
.collect::<Vec<String>>();
let chunk_headers = data.chunk_headers.clone();
let connection = self.connection.clone();
let base64 = data.rowset_base64.clone().unwrap_or_default();

let stream = stream! {

let chunks_iter = chunk_urls.chunks(MAX_CHUNK_DOWNLOAD_WORKERS);

for chunk in chunks_iter {
let futures_batch = chunk.iter().map(|chunk_url| {
let headers = chunk_headers.clone();
let connection_clone = connection.clone();
async move {
connection_clone.get_chunk(chunk_url, &headers).await.map_err(SnowflakeApiError::from)
}
}).collect::<Vec<_>>();

let results = future::join_all(futures_batch).await;
for result in results {
yield result;
}
}

if !base64.is_empty() {
log::debug!("Got base64 encoded response");
match base64::engine::general_purpose::STANDARD.decode(&base64) {
Ok(bytes) => {
yield Ok(Bytes::from(bytes));
}
Err(e) => {
yield Err(SnowflakeApiError::from(e));
}
}
}
};

RawQueryResult::Stream(Box::pin(stream))
}
}
1 change: 1 addition & 0 deletions snowflake-api/src/polars.rs
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@ impl RawQueryResult {
RawQueryResult::Bytes(bytes) => dataframe_from_bytes(bytes),
RawQueryResult::Json(json) => dataframe_from_json(&json),
RawQueryResult::Empty => Ok(DataFrame::empty()),
RawQueryResult::Stream(_) => todo!(),
}
}
}