diff --git a/snowflake-api/Cargo.toml b/snowflake-api/Cargo.toml index 206b2b7..9614553 100644 --- a/snowflake-api/Cargo.toml +++ b/snowflake-api/Cargo.toml @@ -19,7 +19,7 @@ default = ["cert-auth"] polars = ["dep:polars-core", "dep:polars-io"] [dependencies] -arrow = "53" +arrow = "54.2.1" async-trait = "0.1" base64 = "0.22" bytes = "1" @@ -51,6 +51,7 @@ polars-io = { version = ">=0.32", features = [ glob = { version = "0.3" } object_store = { version = "0.11", features = ["aws"] } tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +flate2 = "1.0.34" [dev-dependencies] anyhow = "1" diff --git a/snowflake-api/src/connection.rs b/snowflake-api/src/connection.rs index e7087e1..4fbd6b4 100644 --- a/snowflake-api/src/connection.rs +++ b/snowflake-api/src/connection.rs @@ -8,6 +8,10 @@ use thiserror::Error; use url::Url; use uuid::Uuid; +use std::io::Read; +use flate2::bufread::GzDecoder; +use bytes; + #[derive(Error, Debug)] pub enum ConnectionError { #[error(transparent)] @@ -183,7 +187,7 @@ impl Connection { for (k, v) in headers { header_map.insert( HeaderName::from_bytes(k.as_bytes()).unwrap(), - HeaderValue::from_bytes(v.as_bytes()).unwrap(), + HeaderValue::from_bytes(v.as_bytes())?, ); } let bytes = self @@ -193,7 +197,22 @@ impl Connection { .send() .await? .bytes() - .await?; - Ok(bytes) + .await; + + match bytes { + Ok(bytes) => { + // convert from gzip to Bytes + let mut gz = GzDecoder::new(&bytes[..]); + let mut decoded_bytes = Vec::new(); + gz.read_to_end(&mut decoded_bytes).expect("Failed to decode bytes"); + let decoded = bytes::Bytes::copy_from_slice(&decoded_bytes); + + Ok(decoded) + } + Err(e) => { + Err(ConnectionError::RequestError(e)) + } + } + } } diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 2f4789e..77732d9 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -403,7 +403,7 @@ impl SnowflakeApi { let resp = self .run_sql::(sql, QueryType::JsonQuery) .await?; - log::debug!("Got PUT response: {:?}", resp); + log::trace!("Got PUT response: {:?}", resp); match resp { ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse), @@ -433,16 +433,27 @@ impl SnowflakeApi { let resp = self .run_sql::(sql, QueryType::ArrowQuery) .await?; - log::debug!("Got query response: {:?}", resp); + + log::trace!("Got query response: {:?}", resp); let resp = match resp { // processable response - ExecResponse::Query(qr) => Ok(qr), - ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse), - ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError( - e.data.error_code, - e.message.unwrap_or_default(), - )), + ExecResponse::Query(qr) => { + log::info!("Got a response: OK"); + Ok(qr) + }, + ExecResponse::PutGet(_) => { + log::info!("Got a response: Unexpected PUT response"); + Err(SnowflakeApiError::UnexpectedResponse) + }, + ExecResponse::Error(e) => + { + log::error!("Got a response: Error - {:?}", e); + Err(SnowflakeApiError::ApiError( + e.data.error_code, + e.message.unwrap_or_default(), + )) + }, }?; // if response was empty, base64 data is empty string diff --git a/snowflake-api/src/responses.rs b/snowflake-api/src/responses.rs index 3c2f497..e4c6272 100644 --- a/snowflake-api/src/responses.rs +++ b/snowflake-api/src/responses.rs @@ -1,9 +1,9 @@ use std::collections::HashMap; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; #[allow(clippy::large_enum_variant)] -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum ExecResponse { Query(QueryExecResponse), @@ -14,7 +14,7 @@ pub enum ExecResponse { // todo: add close session response, which should be just empty? // FIXME: dead_code #[allow(clippy::large_enum_variant, dead_code)] -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum AuthResponse { Login(LoginResponse), @@ -24,7 +24,7 @@ pub enum AuthResponse { Error(AuthErrorResponse), } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct BaseRestResponse { // null for auth pub code: Option, @@ -43,7 +43,7 @@ pub type RenewSessionResponse = BaseRestResponse; // Data should be always `null` on successful close session response pub type CloseSessionResponse = BaseRestResponse>; -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ExecErrorResponseData { pub age: i64, @@ -59,7 +59,7 @@ pub struct ExecErrorResponseData { pub sql_state: String, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] // FIXME: dead_code #[allow(dead_code)] @@ -68,13 +68,13 @@ pub struct AuthErrorResponseData { pub error_code: Option, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct NameValueParameter { pub name: String, pub value: serde_json::Value, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] // FIXME #[allow(dead_code)] @@ -90,7 +90,7 @@ pub struct LoginResponseData { pub validity_in_seconds: i64, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] // FIXME: dead_code #[allow(dead_code)] @@ -101,7 +101,7 @@ pub struct SessionInfo { pub role_name: String, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] // FIXME: dead_code #[allow(dead_code)] @@ -111,7 +111,7 @@ pub struct AuthenticatorResponseData { pub proof_key: String, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] // FIXME: dead_code #[allow(dead_code)] @@ -123,7 +123,7 @@ pub struct RenewSessionResponseData { pub session_id: i64, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct QueryExecResponseData { pub parameters: Vec, @@ -163,7 +163,7 @@ pub struct QueryExecResponseData { // `sendResultTime`, `queryResultFormat`, `queryContext` also exist } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct ExecResponseRowType { pub name: String, #[serde(rename = "byteLength")] @@ -178,7 +178,7 @@ pub struct ExecResponseRowType { } // fixme: is it good idea to keep this as an enum if more types could be added in future? -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum SnowflakeType { Fixed, @@ -196,7 +196,7 @@ pub enum SnowflakeType { Array, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ExecResponseChunk { pub url: String, @@ -204,7 +204,7 @@ pub struct ExecResponseChunk { pub uncompressed_size: i64, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PutGetResponseData { // `kind`, `operation` are present in Go implementation, but not in .NET @@ -233,14 +233,14 @@ pub struct PutGetResponseData { pub statement_type_id: Option, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "UPPERCASE")] pub enum CommandType { Upload, Download, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum PutGetStageInfo { Aws(AwsPutGetStageInfo), @@ -248,7 +248,7 @@ pub enum PutGetStageInfo { Gcs(GcsPutGetStageInfo), } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct AwsPutGetStageInfo { pub location_type: String, @@ -259,7 +259,7 @@ pub struct AwsPutGetStageInfo { pub end_point: Option, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub struct AwsCredentials { pub aws_key_id: String, @@ -269,7 +269,7 @@ pub struct AwsCredentials { pub aws_key: String, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GcsPutGetStageInfo { pub location_type: String, @@ -279,13 +279,13 @@ pub struct GcsPutGetStageInfo { pub presigned_url: String, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub struct GcsCredentials { pub gcs_access_token: String, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct AzurePutGetStageInfo { pub location_type: String, @@ -294,20 +294,20 @@ pub struct AzurePutGetStageInfo { pub creds: AzureCredentials, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub struct AzureCredentials { pub azure_sas_token: String, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum EncryptionMaterialVariant { Single(PutGetEncryptionMaterial), Multiple(Vec), } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PutGetEncryptionMaterial { // base64 encoded diff --git a/snowflake-api/src/session.rs b/snowflake-api/src/session.rs index 90acaaf..6ac6c1e 100644 --- a/snowflake-api/src/session.rs +++ b/snowflake-api/src/session.rs @@ -338,10 +338,11 @@ impl Session { body, ) .await?; - log::debug!("Auth response: {:?}", resp); + log::trace!("Auth response: {:?}", resp); match resp { AuthResponse::Login(lr) => { + log::debug!("Authenticated successfully"); let session_token = AuthToken::new(&lr.data.token, lr.data.validity_in_seconds); let master_token = AuthToken::new(&lr.data.master_token, lr.data.master_validity_in_seconds);