diff --git a/snowflake-api/Cargo.toml b/snowflake-api/Cargo.toml index 206b2b7..91edb4d 100644 --- a/snowflake-api/Cargo.toml +++ b/snowflake-api/Cargo.toml @@ -19,19 +19,21 @@ default = ["cert-auth"] polars = ["dep:polars-core", "dep:polars-io"] [dependencies] -arrow = "53" +arrow = "54.2.1" async-trait = "0.1" base64 = "0.22" bytes = "1" futures = "0.3" +futures-util = "0.3" log = "0.4" regex = "1" -reqwest = { version = "0.12", default-features = false, features = [ +reqwest = { version = "=0.12.12", default-features = false, features = [ "gzip", "json", "rustls-tls", + "stream", ] } -reqwest-middleware = { version = "0.3", features = ["json"] } +reqwest-middleware = { version = "0.3.3", features = ["json"] } reqwest-retry = "0.6" serde = { version = "1", features = ["derive"] } serde_json = "1" @@ -54,7 +56,7 @@ tokio = { version = "1", features = ["macros", "rt-multi-thread"] } [dev-dependencies] anyhow = "1" -arrow = { version = "53", features = ["prettyprint"] } +arrow = { version = "54.2.1", features = ["prettyprint"] } clap = { version = "4", features = ["derive"] } pretty_env_logger = "0.5" tokio = { version = "1.35", features = ["macros", "rt-multi-thread"] } diff --git a/snowflake-api/examples/polars/src/main.rs b/snowflake-api/examples/polars/src/main.rs index 635212b..fd4b57c 100644 --- a/snowflake-api/examples/polars/src/main.rs +++ b/snowflake-api/examples/polars/src/main.rs @@ -32,7 +32,7 @@ async fn main() -> Result<()> { } async fn run_and_print(api: &SnowflakeApi, sql: &str) -> Result<()> { - let res = api.exec_raw(sql).await?; + let res = api.exec_raw(sql, false).await?; let df = DataFrame::try_from(res)?; // alternatively, you can use the `try_into` method on the response diff --git a/snowflake-api/examples/run_sql.rs b/snowflake-api/examples/run_sql.rs index 18ec8a9..829b347 100644 --- a/snowflake-api/examples/run_sql.rs +++ b/snowflake-api/examples/run_sql.rs @@ -2,10 +2,11 @@ extern crate snowflake_api; use anyhow::Result; use arrow::util::pretty::pretty_format_batches; -use clap::Parser; +use clap::{ArgAction, Parser}; +use futures_util::StreamExt; use std::fs; -use snowflake_api::{QueryResult, SnowflakeApi}; +use snowflake_api::{responses::ExecResponse, QueryResult, RawQueryResult, SnowflakeApi}; #[derive(clap::ValueEnum, Clone, Debug)] enum Output { @@ -56,6 +57,12 @@ struct Args { #[arg(long)] #[arg(value_enum, default_value_t = Output::Arrow)] output: Output, + + #[arg(long)] + host: Option, + + #[clap(long, action = ArgAction::Set)] + stream: bool, } #[tokio::main] @@ -89,30 +96,52 @@ async fn main() -> Result<()> { _ => { panic!("Either private key path or password must be set") } - }; - - match args.output { - Output::Arrow => { - let res = api.exec(&args.sql).await?; - match res { - QueryResult::Arrow(a) => { - println!("{}", pretty_format_batches(&a).unwrap()); - } - QueryResult::Json(j) => { - println!("{j}"); - } - QueryResult::Empty => { - println!("Query finished successfully") - } + } + // add optional host + .with_host(args.host); + + if args.stream { + let resp = api.exec_raw(&args.sql, true).await?; + + if let RawQueryResult::Stream(mut bytes_stream) = resp { + let mut chunks = vec![]; + while let Some(bytes) = bytes_stream.next().await { + chunks.push(bytes?); + } + + let bytes = chunks.into_iter().flatten().collect::>(); + let resp = serde_json::from_slice::(&bytes).unwrap(); + let raw_query_result = api.parse_arrow_raw_response(resp).await.unwrap(); + let batches = raw_query_result.deserialize_arrow().unwrap(); + + if let QueryResult::Arrow(a) = batches { + println!("{}", pretty_format_batches(&a).unwrap()); } } - Output::Json => { - let res = api.exec_json(&args.sql).await?; - println!("{res}"); - } - Output::Query => { - let res = api.exec_response(&args.sql).await?; - println!("{:?}", res); + } else { + match args.output { + Output::Arrow => { + let res = api.exec(&args.sql).await?; + match res { + QueryResult::Arrow(a) => { + println!("{}", pretty_format_batches(&a).unwrap()); + } + QueryResult::Json(j) => { + println!("{j}"); + } + QueryResult::Empty => { + println!("Query finished successfully") + } + } + } + Output::Json => { + let res = api.exec_json(&args.sql).await?; + println!("{res}"); + } + Output::Query => { + let res = api.exec_response(&args.sql).await?; + println!("{:?}", res); + } } } diff --git a/snowflake-api/examples/tracing/Cargo.toml b/snowflake-api/examples/tracing/Cargo.toml index 738daa2..206e3d0 100644 --- a/snowflake-api/examples/tracing/Cargo.toml +++ b/snowflake-api/examples/tracing/Cargo.toml @@ -5,7 +5,7 @@ version = "0.1.0" [dependencies] anyhow = "1" -arrow = { version = "53", features = ["prettyprint"] } +arrow = { version = "54.2.1", features = ["prettyprint"] } dotenv = "0.15" snowflake-api = { path = "../../../snowflake-api" } diff --git a/snowflake-api/src/connection.rs b/snowflake-api/src/connection.rs index e7087e1..9dd6805 100644 --- a/snowflake-api/src/connection.rs +++ b/snowflake-api/src/connection.rs @@ -1,4 +1,5 @@ use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue}; +use reqwest::Response; use reqwest_middleware::ClientWithMiddleware; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::RetryTransientMiddleware; @@ -113,17 +114,15 @@ impl Connection { .with(RetryTransientMiddleware::new_with_policy(retry_policy))) } - /// Perform request of given query type with extra body or parameters - // todo: implement soft error handling - // todo: is there better way to not repeat myself? - pub async fn request( + pub async fn send_request( &self, query_type: QueryType, account_identifier: &str, extra_get_params: &[(&str, &str)], auth: Option<&str>, body: impl serde::Serialize, - ) -> Result { + host: Option<&str>, + ) -> Result { let context = query_type.query_context(); let request_id = Uuid::new_v4(); @@ -144,10 +143,10 @@ impl Connection { ]; get_params.extend_from_slice(extra_get_params); - let url = format!( - "https://{}.snowflakecomputing.com/{}", - &account_identifier, context.path - ); + let base_url = host + .map(str::to_string) + .unwrap_or_else(|| format!("https://{}.snowflakecomputing.com", &account_identifier)); + let url = format!("{base_url}/{}", context.path); let url = Url::parse_with_params(&url, get_params)?; let mut headers = HeaderMap::new(); @@ -162,7 +161,6 @@ impl Connection { headers.append(header::AUTHORIZATION, auth_val); } - // todo: persist client to use connection polling let resp = self .client .post(url) @@ -171,6 +169,32 @@ impl Connection { .send() .await?; + Ok(resp) + } + + /// Perform request of given query type with extra body or parameters + // todo: implement soft error handling + // todo: is there better way to not repeat myself? + pub async fn request( + &self, + query_type: QueryType, + account_identifier: &str, + extra_get_params: &[(&str, &str)], + auth: Option<&str>, + body: impl serde::Serialize, + host: Option<&str>, + ) -> Result { + let resp = self + .send_request( + query_type, + account_identifier, + extra_get_params, + auth, + body, + host, + ) + .await?; + Ok(resp.json::().await?) } diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 2f4789e..4177103 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -15,6 +15,7 @@ clippy::missing_panics_doc use std::fmt::{Display, Formatter}; use std::io; +use std::pin::Pin; use std::sync::Arc; use arrow::error::ArrowError; @@ -23,6 +24,7 @@ use arrow::record_batch::RecordBatch; use base64::Engine; use bytes::{Buf, Bytes}; use futures::future::try_join_all; +use futures::Stream; use regex::Regex; use reqwest_middleware::ClientWithMiddleware; use thiserror::Error; @@ -41,7 +43,7 @@ pub mod connection; mod polars; mod put; mod requests; -mod responses; +pub mod responses; mod session; #[derive(Error, Debug)] @@ -150,6 +152,9 @@ pub enum RawQueryResult { /// Arrow IPC chunks /// see: Bytes(Vec), + Stream( + Pin> + std::marker::Send>>, + ), /// Json payload is deserialized, /// as it's already a part of REST response Json(JsonResult), @@ -162,12 +167,13 @@ impl RawQueryResult { RawQueryResult::Bytes(bytes) => { Self::flat_bytes_to_batches(bytes).map(QueryResult::Arrow) } + RawQueryResult::Stream(_) => unimplemented!(), RawQueryResult::Json(j) => Ok(QueryResult::Json(j)), RawQueryResult::Empty => Ok(QueryResult::Empty), } } - fn flat_bytes_to_batches(bytes: Vec) -> Result, ArrowError> { + pub fn flat_bytes_to_batches(bytes: Vec) -> Result, ArrowError> { let mut res = vec![]; for b in bytes { let mut batches = Self::bytes_to_batches(b)?; @@ -235,11 +241,16 @@ pub struct CertificateArgs { pub struct SnowflakeApiBuilder { pub auth: AuthArgs, client: Option, + host: Option, } impl SnowflakeApiBuilder { pub fn new(auth: AuthArgs) -> Self { - Self { auth, client: None } + Self { + auth, + client: None, + host: None, + } } pub fn with_client(mut self, client: ClientWithMiddleware) -> Self { @@ -247,6 +258,11 @@ impl SnowflakeApiBuilder { self } + pub fn with_host(mut self, uri: &str) -> Self { + self.host = Some(uri.to_string()); + self + } + pub fn build(self) -> Result { let connection = match self.client { Some(client) => Arc::new(Connection::new_with_middware(client)), @@ -263,6 +279,7 @@ impl SnowflakeApiBuilder { &self.auth.username, self.auth.role.as_deref(), &args.password, + self.host.as_deref(), ), AuthType::Certificate(args) => Session::cert_auth( Arc::clone(&connection), @@ -273,16 +290,16 @@ impl SnowflakeApiBuilder { &self.auth.username, self.auth.role.as_deref(), &args.private_key_pem, + self.host.as_deref(), ), }; let account_identifier = self.auth.account_identifier.to_uppercase(); - Ok(SnowflakeApi::new( - Arc::clone(&connection), - session, - account_identifier, - )) + Ok( + SnowflakeApi::new(Arc::clone(&connection), session, account_identifier) + .with_host(self.host), + ) } } @@ -291,6 +308,7 @@ pub struct SnowflakeApi { connection: Arc, session: Session, account_identifier: String, + host: Option, } impl SnowflakeApi { @@ -300,8 +318,16 @@ impl SnowflakeApi { connection, session, account_identifier, + host: None, } } + + pub fn with_host(mut self, host: Option) -> Self { + self.host = host.to_owned(); + self.session = self.session.with_host(host); + self + } + /// Initialize object with password auth. Authentication happens on the first request. pub fn with_password_auth( account_identifier: &str, @@ -323,6 +349,7 @@ impl SnowflakeApi { username, role, password, + None, ); let account_identifier = account_identifier.to_uppercase(); @@ -354,6 +381,7 @@ impl SnowflakeApi { username, role, private_key_pem, + None, ); let account_identifier = account_identifier.to_uppercase(); @@ -379,7 +407,7 @@ impl SnowflakeApi { /// Execute a single query against API. /// If statement is PUT, then file will be uploaded to the Snowflake-managed storage pub async fn exec(&self, sql: &str) -> Result { - let raw = self.exec_raw(sql).await?; + let raw = self.exec_raw(sql, false).await?; let res = raw.deserialize_arrow()?; Ok(res) } @@ -387,7 +415,11 @@ impl SnowflakeApi { /// Executes a single query against API. /// If statement is PUT, then file will be uploaded to the Snowflake-managed storage /// Returns raw bytes in the Arrow response - pub async fn exec_raw(&self, sql: &str) -> Result { + pub async fn exec_raw( + &self, + sql: &str, + stream: bool, + ) -> Result { let put_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*put\s+").unwrap(); // put commands go through a different flow and result is side-effect @@ -395,7 +427,7 @@ impl SnowflakeApi { log::info!("Detected PUT query"); self.exec_put(sql).await.map(|()| RawQueryResult::Empty) } else { - self.exec_arrow_raw(sql).await + self.exec_arrow_raw(sql, stream).await } } @@ -429,12 +461,27 @@ impl SnowflakeApi { .await } - async fn exec_arrow_raw(&self, sql: &str) -> Result { + async fn exec_arrow_raw( + &self, + sql: &str, + stream: bool, + ) -> Result { + if stream { + let bytes_stream = self.run_sql_stream(sql, QueryType::ArrowQuery).await?; + return Ok(RawQueryResult::Stream(Box::pin(bytes_stream))); + } + let resp = self .run_sql::(sql, QueryType::ArrowQuery) .await?; - log::debug!("Got query response: {:?}", resp); + self.parse_arrow_raw_response(resp).await + } + + pub async fn parse_arrow_raw_response( + &self, + resp: ExecResponse, + ) -> Result { let resp = match resp { // processable response ExecResponse::Query(qr) => Ok(qr), @@ -505,9 +552,42 @@ impl SnowflakeApi { &[], Some(&parts.session_token_auth_header), body, + self.host.as_deref(), ) .await?; Ok(resp) } + + async fn run_sql_stream( + &self, + sql_text: &str, + query_type: QueryType, + ) -> Result>, SnowflakeApiError> + { + log::debug!("Executing: {}", sql_text); + + let parts = self.session.get_token().await?; + + let body = ExecRequest { + sql_text: sql_text.to_string(), + async_exec: false, + sequence_id: parts.sequence_id, + is_internal: false, + }; + + let resp = self + .connection + .send_request( + query_type, + &self.account_identifier, + &[], + Some(&parts.session_token_auth_header), + body, + self.host.as_deref(), + ) + .await?; + + Ok(resp.bytes_stream()) + } } diff --git a/snowflake-api/src/polars.rs b/snowflake-api/src/polars.rs index 9640504..4c818f6 100644 --- a/snowflake-api/src/polars.rs +++ b/snowflake-api/src/polars.rs @@ -25,6 +25,7 @@ impl RawQueryResult { pub fn to_polars(self) -> Result { match self { RawQueryResult::Bytes(bytes) => dataframe_from_bytes(bytes), + RawQueryResult::Stream(_bytes_stream) => todo!(), RawQueryResult::Json(json) => dataframe_from_json(&json), RawQueryResult::Empty => Ok(DataFrame::empty()), } diff --git a/snowflake-api/src/session.rs b/snowflake-api/src/session.rs index 90acaaf..77376e8 100644 --- a/snowflake-api/src/session.rs +++ b/snowflake-api/src/session.rs @@ -118,6 +118,7 @@ pub struct Session { auth_tokens: Mutex>, auth_type: AuthType, account_identifier: String, + host: Option, warehouse: Option, database: Option, @@ -145,6 +146,7 @@ impl Session { username: &str, role: Option<&str>, private_key_pem: &str, + host: Option<&str>, ) -> Self { // uppercase everything as this is the convention let account_identifier = account_identifier.to_uppercase(); @@ -168,6 +170,7 @@ impl Session { role, schema, password: None, + host: host.map(str::to_string), } } @@ -183,6 +186,7 @@ impl Session { username: &str, role: Option<&str>, password: &str, + host: Option<&str>, ) -> Self { let account_identifier = account_identifier.to_uppercase(); @@ -205,9 +209,15 @@ impl Session { password, schema, private_key_pem: None, + host: host.map(str::to_string), } } + pub fn with_host(mut self, host: Option) -> Self { + self.host = host; + self + } + /// Get cached token or request a new one if old one has expired. pub async fn get_token(&self) -> Result { let mut auth_tokens = self.auth_tokens.lock().await; @@ -260,6 +270,7 @@ impl Session { &[("delete", "true")], Some(&tokens.session_token.auth_header()), serde_json::Value::default(), + self.host.as_deref(), ) .await?; @@ -336,6 +347,7 @@ impl Session { &get_params, None, body, + self.host.as_deref(), ) .await?; log::debug!("Auth response: {:?}", resp); @@ -396,6 +408,7 @@ impl Session { &[], Some(&auth), body, + self.host.as_deref(), ) .await?;