Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade deps march 2025 #63

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
10 changes: 6 additions & 4 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@ default = ["cert-auth"]
polars = ["dep:polars-core", "dep:polars-io"]

[dependencies]
arrow = "53"
arrow = "54.2.1"
async-trait = "0.1"
base64 = "0.22"
bytes = "1"
futures = "0.3"
futures-util = "0.3"
log = "0.4"
regex = "1"
reqwest = { version = "0.12", default-features = false, features = [
reqwest = { version = "=0.12.12", default-features = false, features = [
"gzip",
"json",
"rustls-tls",
"stream",
] }
reqwest-middleware = { version = "0.3", features = ["json"] }
reqwest-middleware = { version = "0.3.3", features = ["json"] }
reqwest-retry = "0.6"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
Expand All @@ -54,7 +56,7 @@ tokio = { version = "1", features = ["macros", "rt-multi-thread"] }

[dev-dependencies]
anyhow = "1"
arrow = { version = "53", features = ["prettyprint"] }
arrow = { version = "54.2.1", features = ["prettyprint"] }
clap = { version = "4", features = ["derive"] }
pretty_env_logger = "0.5"
tokio = { version = "1.35", features = ["macros", "rt-multi-thread"] }
2 changes: 1 addition & 1 deletion snowflake-api/examples/polars/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async fn main() -> Result<()> {
}

async fn run_and_print(api: &SnowflakeApi, sql: &str) -> Result<()> {
let res = api.exec_raw(sql).await?;
let res = api.exec_raw(sql, false).await?;

let df = DataFrame::try_from(res)?;
// alternatively, you can use the `try_into` method on the response
Expand Down
77 changes: 53 additions & 24 deletions snowflake-api/examples/run_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ extern crate snowflake_api;

use anyhow::Result;
use arrow::util::pretty::pretty_format_batches;
use clap::Parser;
use clap::{ArgAction, Parser};
use futures_util::StreamExt;
use std::fs;

use snowflake_api::{QueryResult, SnowflakeApi};
use snowflake_api::{responses::ExecResponse, QueryResult, RawQueryResult, SnowflakeApi};

#[derive(clap::ValueEnum, Clone, Debug)]
enum Output {
Expand Down Expand Up @@ -56,6 +57,12 @@ struct Args {
#[arg(long)]
#[arg(value_enum, default_value_t = Output::Arrow)]
output: Output,

#[arg(long)]
host: Option<String>,

#[clap(long, action = ArgAction::Set)]
stream: bool,
}

#[tokio::main]
Expand Down Expand Up @@ -89,30 +96,52 @@ async fn main() -> Result<()> {
_ => {
panic!("Either private key path or password must be set")
}
};

match args.output {
Output::Arrow => {
let res = api.exec(&args.sql).await?;
match res {
QueryResult::Arrow(a) => {
println!("{}", pretty_format_batches(&a).unwrap());
}
QueryResult::Json(j) => {
println!("{j}");
}
QueryResult::Empty => {
println!("Query finished successfully")
}
}
// add optional host
.with_host(args.host);

if args.stream {
let resp = api.exec_raw(&args.sql, true).await?;

if let RawQueryResult::Stream(mut bytes_stream) = resp {
let mut chunks = vec![];
while let Some(bytes) = bytes_stream.next().await {
chunks.push(bytes?);
}

let bytes = chunks.into_iter().flatten().collect::<Vec<u8>>();
let resp = serde_json::from_slice::<ExecResponse>(&bytes).unwrap();
let raw_query_result = api.parse_arrow_raw_response(resp).await.unwrap();
let batches = raw_query_result.deserialize_arrow().unwrap();

if let QueryResult::Arrow(a) = batches {
println!("{}", pretty_format_batches(&a).unwrap());
}
}
Output::Json => {
let res = api.exec_json(&args.sql).await?;
println!("{res}");
}
Output::Query => {
let res = api.exec_response(&args.sql).await?;
println!("{:?}", res);
} else {
match args.output {
Output::Arrow => {
let res = api.exec(&args.sql).await?;
match res {
QueryResult::Arrow(a) => {
println!("{}", pretty_format_batches(&a).unwrap());
}
QueryResult::Json(j) => {
println!("{j}");
}
QueryResult::Empty => {
println!("Query finished successfully")
}
}
}
Output::Json => {
let res = api.exec_json(&args.sql).await?;
println!("{res}");
}
Output::Query => {
let res = api.exec_response(&args.sql).await?;
println!("{:?}", res);
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion snowflake-api/examples/tracing/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ version = "0.1.0"

[dependencies]
anyhow = "1"
arrow = { version = "53", features = ["prettyprint"] }
arrow = { version = "54.2.1", features = ["prettyprint"] }
dotenv = "0.15"
snowflake-api = { path = "../../../snowflake-api" }

Expand Down
44 changes: 34 additions & 10 deletions snowflake-api/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
use reqwest::Response;
use reqwest_middleware::ClientWithMiddleware;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
Expand Down Expand Up @@ -113,17 +114,15 @@ impl Connection {
.with(RetryTransientMiddleware::new_with_policy(retry_policy)))
}

/// Perform request of given query type with extra body or parameters
// todo: implement soft error handling
// todo: is there better way to not repeat myself?
pub async fn request<R: serde::de::DeserializeOwned>(
pub async fn send_request(
&self,
query_type: QueryType,
account_identifier: &str,
extra_get_params: &[(&str, &str)],
auth: Option<&str>,
body: impl serde::Serialize,
) -> Result<R, ConnectionError> {
host: Option<&str>,
) -> Result<Response, ConnectionError> {
let context = query_type.query_context();

let request_id = Uuid::new_v4();
Expand All @@ -144,10 +143,10 @@ impl Connection {
];
get_params.extend_from_slice(extra_get_params);

let url = format!(
"https://{}.snowflakecomputing.com/{}",
&account_identifier, context.path
);
let base_url = host
.map(str::to_string)
.unwrap_or_else(|| format!("https://{}.snowflakecomputing.com", &account_identifier));
let url = format!("{base_url}/{}", context.path);
let url = Url::parse_with_params(&url, get_params)?;

let mut headers = HeaderMap::new();
Expand All @@ -162,7 +161,6 @@ impl Connection {
headers.append(header::AUTHORIZATION, auth_val);
}

// todo: persist client to use connection polling
let resp = self
.client
.post(url)
Expand All @@ -171,6 +169,32 @@ impl Connection {
.send()
.await?;

Ok(resp)
}

/// Perform request of given query type with extra body or parameters
// todo: implement soft error handling
// todo: is there better way to not repeat myself?
pub async fn request<R: serde::de::DeserializeOwned>(
&self,
query_type: QueryType,
account_identifier: &str,
extra_get_params: &[(&str, &str)],
auth: Option<&str>,
body: impl serde::Serialize,
host: Option<&str>,
) -> Result<R, ConnectionError> {
let resp = self
.send_request(
query_type,
account_identifier,
extra_get_params,
auth,
body,
host,
)
.await?;

Ok(resp.json::<R>().await?)
}

Expand Down
Loading