Skip to content

Commit

Permalink
feat: support query forward.
Browse files Browse the repository at this point in the history
  • Loading branch information
youngsofun committed Oct 22, 2024
1 parent 0581d57 commit a69419e
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 16 deletions.
1 change: 1 addition & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ serde_json = { version = "1.0", default-features = false, features = ["std"] }
tokio = { version = "1.34", features = ["macros"] }
tokio-retry = "0.3"
tokio-util = { version = "0.7", features = ["io-util"] }
parking_lot = "0.12.3"
url = { version = "2.5", default-features = false }
uuid = { version = "1.6", features = ["v4"] }

Expand Down
45 changes: 38 additions & 7 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ use url::Url;

use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
use crate::session::SessionState;
use crate::stage::StageLocation;
use crate::{
error::{Error, Result},
request::{PaginationConfig, QueryRequest, SessionState, StageAttachmentConfig},
request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
response::{QueryError, QueryResponse},
};

const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
const HEADER_TENANT: &str = "X-DATABEND-TENANT";
const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
Expand Down Expand Up @@ -76,6 +78,7 @@ pub struct APIClient {
tls_ca_file: Option<String>,

presign: PresignMode,
last_node_id: Arc<parking_lot::Mutex<Option<String>>>,
}

impl APIClient {
Expand Down Expand Up @@ -283,6 +286,13 @@ impl APIClient {
}
}

pub fn set_last_node_id(&self, node_id: String) {
*self.last_node_id.lock() = Some(node_id)
}
pub fn last_node_id(&self) -> Option<String> {
self.last_node_id.lock().clone()
}

pub fn handle_warnings(&self, resp: &QueryResponse) {
if let Some(warnings) = &resp.warnings {
for w in warnings {
Expand All @@ -297,12 +307,18 @@ impl APIClient {
self.route_hint.next();
}
let session_state = self.session_state().await;
let need_sticky = session_state.need_sticky.unwrap_or(false);
let req = QueryRequest::new(sql)
.with_pagination(self.make_pagination())
.with_session(Some(session_state));
let endpoint = self.endpoint.join("v1/query")?;
let query_id = self.gen_query_id();
let headers = self.make_headers(&query_id).await?;
let mut headers = self.make_headers(&query_id).await?;
if need_sticky {
if let Some(node_id) = self.last_node_id() {
headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
}
}
let mut builder = self.cli.post(endpoint.clone()).json(&req);
builder = self.auth.wrap(builder).await?;
let mut resp = builder.headers(headers.clone()).send().await?;
Expand Down Expand Up @@ -344,7 +360,12 @@ impl APIClient {
Ok(result)
}

pub async fn query_page(&self, query_id: &str, next_uri: &str) -> Result<QueryResponse> {
pub async fn query_page(
&self,
query_id: &str,
next_uri: &str,
node_id: &str,
) -> Result<QueryResponse> {
info!("query page: {}", next_uri);
let endpoint = self.endpoint.join(next_uri)?;
let headers = self.make_headers(query_id).await?;
Expand All @@ -354,6 +375,7 @@ impl APIClient {
builder = self.auth.wrap(builder).await?;
builder
.headers(headers.clone())
.header(HEADER_STICKY_NODE, node_id)
.timeout(self.page_request_timeout)
.send()
.await
Expand Down Expand Up @@ -410,12 +432,14 @@ impl APIClient {

pub async fn wait_for_query(&self, resp: QueryResponse) -> Result<QueryResponse> {
info!("wait for query: {}", resp.id);
let node_id = resp.node_id.clone();
self.set_last_node_id(node_id.clone());
if let Some(next_uri) = &resp.next_uri {
let schema = resp.schema;
let mut data = resp.data;
let mut resp = self.query_page(&resp.id, next_uri).await?;
let mut resp = self.query_page(&resp.id, next_uri, &node_id).await?;
while let Some(next_uri) = &resp.next_uri {
resp = self.query_page(&resp.id, next_uri).await?;
resp = self.query_page(&resp.id, next_uri, &node_id).await?;
data.append(&mut resp.data);
}
resp.schema = schema;
Expand Down Expand Up @@ -487,6 +511,8 @@ impl APIClient {
sql, file_format_options, copy_options
);
let session_state = self.session_state().await;
let need_sticky = session_state.need_sticky.unwrap_or(false);

let stage_attachment = Some(StageAttachmentConfig {
location: stage,
file_format_options: Some(file_format_options),
Expand All @@ -498,8 +524,12 @@ impl APIClient {
.with_stage_attachment(stage_attachment);
let endpoint = self.endpoint.join("v1/query")?;
let query_id = self.gen_query_id();
let headers = self.make_headers(&query_id).await?;

let mut headers = self.make_headers(&query_id).await?;
if need_sticky {
if let Some(node_id) = self.last_node_id() {
headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
}
}
let mut builder = self.cli.post(endpoint.clone()).json(&req);
builder = self.auth.wrap(builder).await?;
let mut resp = builder.headers(headers.clone()).send().await?;
Expand Down Expand Up @@ -626,6 +656,7 @@ impl Default for APIClient {
tls_ca_file: None,
presign: PresignMode::Auto,
route_hint: Arc::new(RouteHintGenerator::new()),
last_node_id: Arc::new(Default::default()),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions core/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pub struct SchemaField {
#[derive(Deserialize, Debug)]
pub struct QueryResponse {
pub id: String,
pub node_id: String,
pub session_id: Option<String>,
pub session: Option<SessionState>,
pub schema: Vec<SchemaField>,
Expand Down
2 changes: 2 additions & 0 deletions core/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub struct SessionState {
pub secondary_roles: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub txn_state: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub need_sticky: Option<bool>,

// hide fields of no interest (but need to send back to server in next query)
#[serde(flatten)]
Expand Down
26 changes: 17 additions & 9 deletions driver/src/rest_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,12 @@ impl Connection for RestAPIConnection {
async fn exec(&self, sql: &str) -> Result<i64> {
info!("exec: {}", sql);
let mut resp = self.client.start_query(sql).await?;
let node_id = resp.node_id.clone();
while let Some(next_uri) = resp.next_uri {
resp = self.client.query_page(&resp.id, &next_uri).await?;
resp = self
.client
.query_page(&resp.id, &next_uri, &node_id)
.await?;
}
Ok(resp.stats.progresses.write_progress.rows as i64)
}
Expand Down Expand Up @@ -201,14 +205,16 @@ impl<'o> RestAPIConnection {
Ok(Self { client })
}

async fn wait_for_schema(&self, pre: QueryResponse) -> Result<QueryResponse> {
if !pre.data.is_empty() || !pre.schema.is_empty() {
return Ok(pre);
async fn wait_for_schema(&self, resp: QueryResponse) -> Result<QueryResponse> {
if !resp.data.is_empty() || !resp.schema.is_empty() {
return Ok(resp);
}
let mut result = pre;
// preserve schema since it is no included in the final response
let node_id = resp.node_id.clone();
self.client.set_last_node_id(node_id.clone());
let mut result = resp;
// preserve schema since it is not included in the final response
while let Some(next_uri) = result.next_uri {
result = self.client.query_page(&result.id, &next_uri).await?;
result = self.client.query_page(&result.id, &next_uri, &node_id).await?;
if !result.data.is_empty() || !result.schema.is_empty() {
break;
}
Expand Down Expand Up @@ -240,6 +246,7 @@ pub struct RestAPIRows {
data: VecDeque<Vec<Option<String>>>,
stats: Option<ServerStats>,
query_id: String,
node_id: String,
next_uri: Option<String>,
next_page: Option<PageFut>,
}
Expand All @@ -250,6 +257,7 @@ impl RestAPIRows {
let rows = Self {
client,
query_id: resp.id,
node_id: resp.node_id,
next_uri: resp.next_uri,
schema: Arc::new(schema.clone()),
data: resp.data.into(),
Expand Down Expand Up @@ -278,7 +286,6 @@ impl Stream for RestAPIRows {
if self.schema.fields().is_empty() {
self.schema = Arc::new(resp.schema.try_into()?);
}
self.query_id = resp.id;
self.next_uri = resp.next_uri;
self.next_page = None;
self.stats = Some(ServerStats::from(resp.stats));
Expand All @@ -295,9 +302,10 @@ impl Stream for RestAPIRows {
let client = self.client.clone();
let next_uri = next_uri.clone();
let query_id = self.query_id.clone();
let node_id = self.node_id.clone();
self.next_page = Some(Box::pin(async move {
client
.query_page(&query_id, &next_uri)
.query_page(&query_id, &next_uri, &node_id)
.await
.map_err(|e| e.into())
}));
Expand Down

0 comments on commit a69419e

Please sign in to comment.