From a69419e7c17b1f4883e666f01357ad2f191d5542 Mon Sep 17 00:00:00 2001 From: Yang Xiufeng Date: Tue, 22 Oct 2024 21:13:06 +0800 Subject: [PATCH] feat: support query forward. --- core/Cargo.toml | 1 + core/src/client.rs | 45 +++++++++++++++++++++++++++++++++++------- core/src/response.rs | 1 + core/src/session.rs | 2 ++ driver/src/rest_api.rs | 26 +++++++++++++++--------- 5 files changed, 59 insertions(+), 16 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index 9ce5561a..b8a3f263 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -31,6 +31,7 @@ serde_json = { version = "1.0", default-features = false, features = ["std"] } tokio = { version = "1.34", features = ["macros"] } tokio-retry = "0.3" tokio-util = { version = "0.7", features = ["io-util"] } +parking_lot = "0.12.3" url = { version = "2.5", default-features = false } uuid = { version = "1.6", features = ["v4"] } diff --git a/core/src/client.rs b/core/src/client.rs index 15d247c1..3a5776ff 100644 --- a/core/src/client.rs +++ b/core/src/client.rs @@ -31,15 +31,17 @@ use url::Url; use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth}; use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader}; +use crate::session::SessionState; use crate::stage::StageLocation; use crate::{ error::{Error, Result}, - request::{PaginationConfig, QueryRequest, SessionState, StageAttachmentConfig}, + request::{PaginationConfig, QueryRequest, StageAttachmentConfig}, response::{QueryError, QueryResponse}, }; const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID"; const HEADER_TENANT: &str = "X-DATABEND-TENANT"; +const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE"; const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE"; const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME"; const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT"; @@ -76,6 +78,7 @@ pub struct APIClient { tls_ca_file: Option, presign: PresignMode, + last_node_id: Arc>>, } impl APIClient { @@ -283,6 +286,13 @@ impl APIClient { } } + pub fn set_last_node_id(&self, node_id: String) { + *self.last_node_id.lock() = Some(node_id) + } + pub fn last_node_id(&self) -> Option { + self.last_node_id.lock().clone() + } + pub fn handle_warnings(&self, resp: &QueryResponse) { if let Some(warnings) = &resp.warnings { for w in warnings { @@ -297,12 +307,18 @@ impl APIClient { self.route_hint.next(); } let session_state = self.session_state().await; + let need_sticky = session_state.need_sticky.unwrap_or(false); let req = QueryRequest::new(sql) .with_pagination(self.make_pagination()) .with_session(Some(session_state)); let endpoint = self.endpoint.join("v1/query")?; let query_id = self.gen_query_id(); - let headers = self.make_headers(&query_id).await?; + let mut headers = self.make_headers(&query_id).await?; + if need_sticky { + if let Some(node_id) = self.last_node_id() { + headers.insert(HEADER_STICKY_NODE, node_id.parse()?); + } + } let mut builder = self.cli.post(endpoint.clone()).json(&req); builder = self.auth.wrap(builder).await?; let mut resp = builder.headers(headers.clone()).send().await?; @@ -344,7 +360,12 @@ impl APIClient { Ok(result) } - pub async fn query_page(&self, query_id: &str, next_uri: &str) -> Result { + pub async fn query_page( + &self, + query_id: &str, + next_uri: &str, + node_id: &str, + ) -> Result { info!("query page: {}", next_uri); let endpoint = self.endpoint.join(next_uri)?; let headers = self.make_headers(query_id).await?; @@ -354,6 +375,7 @@ impl APIClient { builder = self.auth.wrap(builder).await?; builder .headers(headers.clone()) + .header(HEADER_STICKY_NODE, node_id) .timeout(self.page_request_timeout) .send() .await @@ -410,12 +432,14 @@ impl APIClient { pub async fn wait_for_query(&self, resp: QueryResponse) -> Result { info!("wait for query: {}", resp.id); + let node_id = resp.node_id.clone(); + self.set_last_node_id(node_id.clone()); if let Some(next_uri) = &resp.next_uri { let schema = resp.schema; let mut data = resp.data; - let mut resp = self.query_page(&resp.id, next_uri).await?; + let mut resp = self.query_page(&resp.id, next_uri, &node_id).await?; while let Some(next_uri) = &resp.next_uri { - resp = self.query_page(&resp.id, next_uri).await?; + resp = self.query_page(&resp.id, next_uri, &node_id).await?; data.append(&mut resp.data); } resp.schema = schema; @@ -487,6 +511,8 @@ impl APIClient { sql, file_format_options, copy_options ); let session_state = self.session_state().await; + let need_sticky = session_state.need_sticky.unwrap_or(false); + let stage_attachment = Some(StageAttachmentConfig { location: stage, file_format_options: Some(file_format_options), @@ -498,8 +524,12 @@ impl APIClient { .with_stage_attachment(stage_attachment); let endpoint = self.endpoint.join("v1/query")?; let query_id = self.gen_query_id(); - let headers = self.make_headers(&query_id).await?; - + let mut headers = self.make_headers(&query_id).await?; + if need_sticky { + if let Some(node_id) = self.last_node_id() { + headers.insert(HEADER_STICKY_NODE, node_id.parse()?); + } + } let mut builder = self.cli.post(endpoint.clone()).json(&req); builder = self.auth.wrap(builder).await?; let mut resp = builder.headers(headers.clone()).send().await?; @@ -626,6 +656,7 @@ impl Default for APIClient { tls_ca_file: None, presign: PresignMode::Auto, route_hint: Arc::new(RouteHintGenerator::new()), + last_node_id: Arc::new(Default::default()), } } } diff --git a/core/src/response.rs b/core/src/response.rs index 7f5ce65d..d89c82c4 100644 --- a/core/src/response.rs +++ b/core/src/response.rs @@ -55,6 +55,7 @@ pub struct SchemaField { #[derive(Deserialize, Debug)] pub struct QueryResponse { pub id: String, + pub node_id: String, pub session_id: Option, pub session: Option, pub schema: Vec, diff --git a/core/src/session.rs b/core/src/session.rs index ab78f2ee..139ea0e2 100644 --- a/core/src/session.rs +++ b/core/src/session.rs @@ -27,6 +27,8 @@ pub struct SessionState { pub secondary_roles: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub txn_state: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub need_sticky: Option, // hide fields of no interest (but need to send back to server in next query) #[serde(flatten)] diff --git a/driver/src/rest_api.rs b/driver/src/rest_api.rs index 0a2a8764..d589ba60 100644 --- a/driver/src/rest_api.rs +++ b/driver/src/rest_api.rs @@ -58,8 +58,12 @@ impl Connection for RestAPIConnection { async fn exec(&self, sql: &str) -> Result { info!("exec: {}", sql); let mut resp = self.client.start_query(sql).await?; + let node_id = resp.node_id.clone(); while let Some(next_uri) = resp.next_uri { - resp = self.client.query_page(&resp.id, &next_uri).await?; + resp = self + .client + .query_page(&resp.id, &next_uri, &node_id) + .await?; } Ok(resp.stats.progresses.write_progress.rows as i64) } @@ -201,14 +205,16 @@ impl<'o> RestAPIConnection { Ok(Self { client }) } - async fn wait_for_schema(&self, pre: QueryResponse) -> Result { - if !pre.data.is_empty() || !pre.schema.is_empty() { - return Ok(pre); + async fn wait_for_schema(&self, resp: QueryResponse) -> Result { + if !resp.data.is_empty() || !resp.schema.is_empty() { + return Ok(resp); } - let mut result = pre; - // preserve schema since it is no included in the final response + let node_id = resp.node_id.clone(); + self.client.set_last_node_id(node_id.clone()); + let mut result = resp; + // preserve schema since it is not included in the final response while let Some(next_uri) = result.next_uri { - result = self.client.query_page(&result.id, &next_uri).await?; + result = self.client.query_page(&result.id, &next_uri, &node_id).await?; if !result.data.is_empty() || !result.schema.is_empty() { break; } @@ -240,6 +246,7 @@ pub struct RestAPIRows { data: VecDeque>>, stats: Option, query_id: String, + node_id: String, next_uri: Option, next_page: Option, } @@ -250,6 +257,7 @@ impl RestAPIRows { let rows = Self { client, query_id: resp.id, + node_id: resp.node_id, next_uri: resp.next_uri, schema: Arc::new(schema.clone()), data: resp.data.into(), @@ -278,7 +286,6 @@ impl Stream for RestAPIRows { if self.schema.fields().is_empty() { self.schema = Arc::new(resp.schema.try_into()?); } - self.query_id = resp.id; self.next_uri = resp.next_uri; self.next_page = None; self.stats = Some(ServerStats::from(resp.stats)); @@ -295,9 +302,10 @@ impl Stream for RestAPIRows { let client = self.client.clone(); let next_uri = next_uri.clone(); let query_id = self.query_id.clone(); + let node_id = self.node_id.clone(); self.next_page = Some(Box::pin(async move { client - .query_page(&query_id, &next_uri) + .query_page(&query_id, &next_uri, &node_id) .await .map_err(|e| e.into()) }));