Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support query forward. #487

Merged
merged 4 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cli/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ CARGO_TARGET_DIR=${CARGO_TARGET_DIR:-./target}
DATABEND_USER=${DATABEND_USER:-root}
DATABEND_PASSWORD=${DATABEND_PASSWORD:-}
DATABEND_HOST=${DATABEND_HOST:-localhost}
DATABEND_PORT=${DATABEND_PORT:-8000}

TEST_HANDLER=$1

Expand All @@ -32,7 +33,7 @@ case $TEST_HANDLER in
;;
"http")
echo "==> Testing REST API handler"
export BENDSQL_DSN="databend+http://${DATABEND_USER}:${DATABEND_PASSWORD}@${DATABEND_HOST}:8000/?sslmode=disable&presign=on"
export BENDSQL_DSN="databend+http://${DATABEND_USER}:${DATABEND_PASSWORD}@${DATABEND_HOST}:${DATABEND_PORT}/?sslmode=disable&presign=on"
;;
*)
echo "Usage: $0 [flight|http]"
Expand Down
1 change: 1 addition & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ serde_json = { version = "1.0", default-features = false, features = ["std"] }
tokio = { version = "1.34", features = ["macros"] }
tokio-retry = "0.3"
tokio-util = { version = "0.7", features = ["io-util"] }
parking_lot = "0.12.3"
url = { version = "2.5", default-features = false }
uuid = { version = "1.6", features = ["v4"] }

Expand Down
45 changes: 38 additions & 7 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ use url::Url;

use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
use crate::session::SessionState;
use crate::stage::StageLocation;
use crate::{
error::{Error, Result},
request::{PaginationConfig, QueryRequest, SessionState, StageAttachmentConfig},
request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
response::{QueryError, QueryResponse},
};

const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
const HEADER_TENANT: &str = "X-DATABEND-TENANT";
const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
Expand Down Expand Up @@ -76,6 +78,7 @@ pub struct APIClient {
tls_ca_file: Option<String>,

presign: PresignMode,
last_node_id: Arc<parking_lot::Mutex<Option<String>>>,
}

impl APIClient {
Expand Down Expand Up @@ -283,6 +286,13 @@ impl APIClient {
}
}

pub fn set_last_node_id(&self, node_id: String) {
*self.last_node_id.lock() = Some(node_id)
}
pub fn last_node_id(&self) -> Option<String> {
self.last_node_id.lock().clone()
}

pub fn handle_warnings(&self, resp: &QueryResponse) {
if let Some(warnings) = &resp.warnings {
for w in warnings {
Expand All @@ -297,12 +307,18 @@ impl APIClient {
self.route_hint.next();
}
let session_state = self.session_state().await;
let need_sticky = session_state.need_sticky.unwrap_or(false);
let req = QueryRequest::new(sql)
.with_pagination(self.make_pagination())
.with_session(Some(session_state));
let endpoint = self.endpoint.join("v1/query")?;
let query_id = self.gen_query_id();
let headers = self.make_headers(&query_id).await?;
let mut headers = self.make_headers(&query_id).await?;
if need_sticky {
if let Some(node_id) = self.last_node_id() {
headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
}
}
let mut builder = self.cli.post(endpoint.clone()).json(&req);
builder = self.auth.wrap(builder).await?;
let mut resp = builder.headers(headers.clone()).send().await?;
Expand Down Expand Up @@ -344,7 +360,12 @@ impl APIClient {
Ok(result)
}

pub async fn query_page(&self, query_id: &str, next_uri: &str) -> Result<QueryResponse> {
pub async fn query_page(
&self,
query_id: &str,
next_uri: &str,
node_id: &str,
) -> Result<QueryResponse> {
info!("query page: {}", next_uri);
let endpoint = self.endpoint.join(next_uri)?;
let headers = self.make_headers(query_id).await?;
Expand All @@ -354,6 +375,7 @@ impl APIClient {
builder = self.auth.wrap(builder).await?;
builder
.headers(headers.clone())
.header(HEADER_STICKY_NODE, node_id)
.timeout(self.page_request_timeout)
.send()
.await
Expand Down Expand Up @@ -410,12 +432,14 @@ impl APIClient {

pub async fn wait_for_query(&self, resp: QueryResponse) -> Result<QueryResponse> {
info!("wait for query: {}", resp.id);
let node_id = resp.node_id.clone();
self.set_last_node_id(node_id.clone());
if let Some(next_uri) = &resp.next_uri {
let schema = resp.schema;
let mut data = resp.data;
let mut resp = self.query_page(&resp.id, next_uri).await?;
let mut resp = self.query_page(&resp.id, next_uri, &node_id).await?;
while let Some(next_uri) = &resp.next_uri {
resp = self.query_page(&resp.id, next_uri).await?;
resp = self.query_page(&resp.id, next_uri, &node_id).await?;
data.append(&mut resp.data);
}
resp.schema = schema;
Expand Down Expand Up @@ -487,6 +511,8 @@ impl APIClient {
sql, file_format_options, copy_options
);
let session_state = self.session_state().await;
let need_sticky = session_state.need_sticky.unwrap_or(false);

let stage_attachment = Some(StageAttachmentConfig {
location: stage,
file_format_options: Some(file_format_options),
Expand All @@ -498,8 +524,12 @@ impl APIClient {
.with_stage_attachment(stage_attachment);
let endpoint = self.endpoint.join("v1/query")?;
let query_id = self.gen_query_id();
let headers = self.make_headers(&query_id).await?;

let mut headers = self.make_headers(&query_id).await?;
if need_sticky {
if let Some(node_id) = self.last_node_id() {
headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
}
}
let mut builder = self.cli.post(endpoint.clone()).json(&req);
builder = self.auth.wrap(builder).await?;
let mut resp = builder.headers(headers.clone()).send().await?;
Expand Down Expand Up @@ -626,6 +656,7 @@ impl Default for APIClient {
tls_ca_file: None,
presign: PresignMode::Auto,
route_hint: Arc::new(RouteHintGenerator::new()),
last_node_id: Arc::new(Default::default()),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod error;
pub mod presign;
pub mod request;
pub mod response;
pub mod session;
pub mod stage;

pub use client::APIClient;
51 changes: 7 additions & 44 deletions core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::{BTreeMap, HashMap};
use std::collections::BTreeMap;

use crate::session::SessionState;
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
pub struct ServerInfo {
pub id: String,
pub start_time: String,
}
#[derive(Deserialize, Serialize, Debug, Default, Clone)]
pub struct SessionState {
#[serde(skip_serializing_if = "Option::is_none")]
pub database: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub settings: Option<BTreeMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub secondary_roles: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub txn_state: Option<String>,

// hide fields of no interest (but need to send back to server in next query)
#[serde(flatten)]
additional_fields: HashMap<String, serde_json::Value>,
}

impl SessionState {
pub fn with_settings(mut self, settings: Option<BTreeMap<String, String>>) -> Self {
self.settings = settings;
self
}

pub fn with_database(mut self, database: Option<String>) -> Self {
self.database = database;
self
}

pub fn with_role(mut self, role: Option<String>) -> Self {
self.role = role;
self
}
}

#[derive(Serialize, Debug)]
pub struct QueryRequest<'a> {
Expand Down Expand Up @@ -122,14 +90,9 @@ mod test {
#[test]
fn build_request() -> Result<()> {
let req = QueryRequest::new("select 1")
.with_session(Some(SessionState {
database: Some("default".to_string()),
settings: Some(BTreeMap::new()),
role: None,
secondary_roles: None,
txn_state: None,
additional_fields: Default::default(),
}))
.with_session(Some(
SessionState::default().with_database(Some("default".to_string())),
))
.with_pagination(Some(PaginationConfig {
wait_time_secs: Some(1),
max_rows_in_buffer: Some(1),
Expand All @@ -142,7 +105,7 @@ mod test {
}));
assert_eq!(
serde_json::to_string(&req)?,
r#"{"session":{"database":"default","settings":{}},"sql":"select 1","pagination":{"wait_time_secs":1,"max_rows_in_buffer":1,"max_rows_per_page":1},"stage_attachment":{"location":"@~/my_location"}}"#
r#"{"session":{"database":"default"},"sql":"select 1","pagination":{"wait_time_secs":1,"max_rows_in_buffer":1,"max_rows_per_page":1},"stage_attachment":{"location":"@~/my_location"}}"#
);
Ok(())
}
Expand Down
3 changes: 2 additions & 1 deletion core/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use serde::Deserialize;

use crate::request::SessionState;
use crate::session::SessionState;

#[derive(Deserialize, Debug)]
pub struct QueryError {
Expand Down Expand Up @@ -55,6 +55,7 @@ pub struct SchemaField {
#[derive(Deserialize, Debug)]
pub struct QueryResponse {
pub id: String,
pub node_id: String,
pub session_id: Option<String>,
pub session: Option<SessionState>,
pub schema: Vec<SchemaField>,
Expand Down
53 changes: 53 additions & 0 deletions core/src/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2021 Datafuse Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};

#[derive(Deserialize, Serialize, Debug, Default, Clone)]
pub struct SessionState {
#[serde(skip_serializing_if = "Option::is_none")]
pub database: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub settings: Option<BTreeMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub secondary_roles: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub txn_state: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub need_sticky: Option<bool>,

// hide fields of no interest (but need to send back to server in next query)
#[serde(flatten)]
additional_fields: HashMap<String, serde_json::Value>,
}

impl SessionState {
pub fn with_settings(mut self, settings: Option<BTreeMap<String, String>>) -> Self {
self.settings = settings;
self
}

pub fn with_database(mut self, database: Option<String>) -> Self {
self.database = database;
self
}

pub fn with_role(mut self, role: Option<String>) -> Self {
self.role = role;
self
}
}
29 changes: 20 additions & 9 deletions driver/src/rest_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,12 @@ impl Connection for RestAPIConnection {
async fn exec(&self, sql: &str) -> Result<i64> {
info!("exec: {}", sql);
let mut resp = self.client.start_query(sql).await?;
let node_id = resp.node_id.clone();
while let Some(next_uri) = resp.next_uri {
resp = self.client.query_page(&resp.id, &next_uri).await?;
resp = self
.client
.query_page(&resp.id, &next_uri, &node_id)
.await?;
}
Ok(resp.stats.progresses.write_progress.rows as i64)
}
Expand Down Expand Up @@ -201,14 +205,19 @@ impl<'o> RestAPIConnection {
Ok(Self { client })
}

async fn wait_for_schema(&self, pre: QueryResponse) -> Result<QueryResponse> {
if !pre.data.is_empty() || !pre.schema.is_empty() {
return Ok(pre);
async fn wait_for_schema(&self, resp: QueryResponse) -> Result<QueryResponse> {
if !resp.data.is_empty() || !resp.schema.is_empty() {
return Ok(resp);
}
let mut result = pre;
// preserve schema since it is no included in the final response
let node_id = resp.node_id.clone();
self.client.set_last_node_id(node_id.clone());
let mut result = resp;
// preserve schema since it is not included in the final response
while let Some(next_uri) = result.next_uri {
result = self.client.query_page(&result.id, &next_uri).await?;
result = self
.client
.query_page(&result.id, &next_uri, &node_id)
.await?;
if !result.data.is_empty() || !result.schema.is_empty() {
break;
}
Expand Down Expand Up @@ -240,6 +249,7 @@ pub struct RestAPIRows {
data: VecDeque<Vec<Option<String>>>,
stats: Option<ServerStats>,
query_id: String,
node_id: String,
next_uri: Option<String>,
next_page: Option<PageFut>,
}
Expand All @@ -250,6 +260,7 @@ impl RestAPIRows {
let rows = Self {
client,
query_id: resp.id,
node_id: resp.node_id,
next_uri: resp.next_uri,
schema: Arc::new(schema.clone()),
data: resp.data.into(),
Expand Down Expand Up @@ -278,7 +289,6 @@ impl Stream for RestAPIRows {
if self.schema.fields().is_empty() {
self.schema = Arc::new(resp.schema.try_into()?);
}
self.query_id = resp.id;
self.next_uri = resp.next_uri;
self.next_page = None;
self.stats = Some(ServerStats::from(resp.stats));
Expand All @@ -295,9 +305,10 @@ impl Stream for RestAPIRows {
let client = self.client.clone();
let next_uri = next_uri.clone();
let query_id = self.query_id.clone();
let node_id = self.node_id.clone();
self.next_page = Some(Box::pin(async move {
client
.query_page(&query_id, &next_uri)
.query_page(&query_id, &next_uri, &node_id)
.await
.map_err(|e| e.into())
}));
Expand Down