Skip to content

Commit c8c6e73

Browse files
committed
Add support for async query response
1 parent 38dcb9e commit c8c6e73

File tree

3 files changed

+83
-19
lines changed

3 files changed

+83
-19
lines changed

snowflake-api/src/connection.rs

+35-16
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ pub enum ConnectionError {
2929
/// Container for query parameters
3030
/// This API has different endpoints and MIME types for different requests
3131
struct QueryContext {
32-
path: &'static str,
32+
path: String,
3333
accept_mime: &'static str,
34+
method: reqwest::Method
3435
}
3536

3637
pub enum QueryType {
@@ -39,30 +40,40 @@ pub enum QueryType {
3940
CloseSession,
4041
JsonQuery,
4142
ArrowQuery,
43+
ArrowQueryResult(String),
4244
}
43-
4445
impl QueryType {
45-
const fn query_context(&self) -> QueryContext {
46+
fn query_context(&self) -> QueryContext {
4647
match self {
4748
Self::LoginRequest => QueryContext {
48-
path: "session/v1/login-request",
49+
path: "session/v1/login-request".to_string(),
4950
accept_mime: "application/json",
51+
method: reqwest::Method::POST,
5052
},
5153
Self::TokenRequest => QueryContext {
52-
path: "/session/token-request",
54+
path: "/session/token-request".to_string(),
5355
accept_mime: "application/snowflake",
56+
method: reqwest::Method::POST,
5457
},
5558
Self::CloseSession => QueryContext {
56-
path: "session",
59+
path: "session".to_string(),
5760
accept_mime: "application/snowflake",
61+
method: reqwest::Method::POST,
5862
},
5963
Self::JsonQuery => QueryContext {
60-
path: "queries/v1/query-request",
64+
path: "queries/v1/query-request".to_string(),
6165
accept_mime: "application/json",
66+
method: reqwest::Method::POST,
6267
},
6368
Self::ArrowQuery => QueryContext {
64-
path: "queries/v1/query-request",
69+
path: "queries/v1/query-request".to_string(),
70+
accept_mime: "application/snowflake",
71+
method: reqwest::Method::POST,
72+
},
73+
Self::ArrowQueryResult(query_result_url) => QueryContext {
74+
path: query_result_url.to_string(),
6575
accept_mime: "application/snowflake",
76+
method: reqwest::Method::GET,
6677
},
6778
}
6879
}
@@ -163,14 +174,22 @@ impl Connection {
163174
}
164175

165176
// todo: persist client to use connection polling
166-
let resp = self
167-
.client
168-
.post(url)
169-
.headers(headers)
170-
.json(&body)
171-
.send()
172-
.await?;
173-
177+
let resp = match context.method {
178+
reqwest::Method::POST => self
179+
.client
180+
.post(url)
181+
.headers(headers)
182+
.json(&body)
183+
.send()
184+
.await?,
185+
reqwest::Method::GET => self
186+
.client
187+
.get(url)
188+
.headers(headers)
189+
.send()
190+
.await?,
191+
_ => panic!("Unsupported method"),
192+
};
174193
Ok(resp.json::<R>().await?)
175194
}
176195

snowflake-api/src/lib.rs

+38-2
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ impl SnowflakeApi {
407407

408408
match resp {
409409
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
410+
ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse),
410411
ExecResponse::PutGet(pg) => put::put(pg).await,
411412
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
412413
e.data.error_code,
@@ -430,14 +431,21 @@ impl SnowflakeApi {
430431
}
431432

432433
async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
433-
let resp = self
434+
let mut resp = self
434435
.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
435436
.await?;
436437
log::debug!("Got query response: {:?}", resp);
437438

439+
if let ExecResponse::QueryAsync(data) = &resp {
440+
log::debug!("Got async exec response");
441+
resp = self.get_async_exec_result(&data.data.get_result_url).await?;
442+
log::debug!("Got result for async exec: {:?}", resp);
443+
}
444+
438445
let resp = match resp {
439446
// processable response
440447
ExecResponse::Query(qr) => Ok(qr),
448+
ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse),
441449
ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
442450
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
443451
e.data.error_code,
@@ -504,10 +512,38 @@ impl SnowflakeApi {
504512
&self.account_identifier,
505513
&[],
506514
Some(&parts.session_token_auth_header),
507-
body,
515+
Some(body),
508516
)
509517
.await?;
510518

511519
Ok(resp)
512520
}
521+
522+
pub async fn get_async_exec_result(&self, query_result_url: &String) -> Result<ExecResponse, SnowflakeApiError>{
523+
log::debug!("Getting async exec result: {}", query_result_url);
524+
525+
let mut delay = 1; // Initial delay of 1 second
526+
527+
loop {
528+
let parts = self.session.get_token().await?;
529+
let resp = self
530+
.connection
531+
.request::<ExecResponse>(
532+
QueryType::ArrowQueryResult(query_result_url.to_string()),
533+
&self.account_identifier,
534+
&[],
535+
Some(&parts.session_token_auth_header),
536+
serde_json::Value::default()
537+
)
538+
.await?;
539+
540+
if let ExecResponse::QueryAsync(_) = &resp {
541+
// simple exponential retry with a maximum wait time of 5 seconds
542+
tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await;
543+
delay = (delay * 2).min(5); // cap delay to 5 seconds
544+
} else {
545+
return Ok(resp);
546+
}
547+
};
548+
}
513549
}

snowflake-api/src/responses.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use serde::Deserialize;
77
#[serde(untagged)]
88
pub enum ExecResponse {
99
Query(QueryExecResponse),
10+
QueryAsync(QueryAsyncExecResponse),
1011
PutGet(PutGetExecResponse),
1112
Error(ExecErrorResponse),
1213
}
@@ -34,6 +35,7 @@ pub struct BaseRestResponse<D> {
3435

3536
pub type PutGetExecResponse = BaseRestResponse<PutGetResponseData>;
3637
pub type QueryExecResponse = BaseRestResponse<QueryExecResponseData>;
38+
pub type QueryAsyncExecResponse = BaseRestResponse<QueryAsyncExecResponseData>;
3739
pub type ExecErrorResponse = BaseRestResponse<ExecErrorResponseData>;
3840
pub type AuthErrorResponse = BaseRestResponse<AuthErrorResponseData>;
3941
pub type AuthenticatorResponse = BaseRestResponse<AuthenticatorResponseData>;
@@ -54,7 +56,7 @@ pub struct ExecErrorResponseData {
5456
pub pos: Option<i64>,
5557

5658
// fixme: only valid for exec query response error? present in any exec query response?
57-
pub query_id: String,
59+
pub query_id: Option<String>,
5860
pub sql_state: String,
5961
}
6062

@@ -151,6 +153,13 @@ pub struct QueryExecResponseData {
151153
// `sendResultTime`, `queryResultFormat`, `queryContext` also exist
152154
}
153155

156+
#[derive(Deserialize, Debug)]
157+
#[serde(rename_all = "camelCase")]
158+
pub struct QueryAsyncExecResponseData {
159+
pub query_id: String,
160+
pub get_result_url: String,
161+
}
162+
154163
#[derive(Deserialize, Debug)]
155164
pub struct ExecResponseRowType {
156165
pub name: String,

0 commit comments

Comments
 (0)