diff --git a/cli/test.sh b/cli/test.sh index f4e4e3f1..0e0f3085 100755 --- a/cli/test.sh +++ b/cli/test.sh @@ -20,6 +20,7 @@ CARGO_TARGET_DIR=${CARGO_TARGET_DIR:-./target} DATABEND_USER=${DATABEND_USER:-root} DATABEND_PASSWORD=${DATABEND_PASSWORD:-} DATABEND_HOST=${DATABEND_HOST:-localhost} +DATABEND_PORT=${DATABEND_PORT:-8000} TEST_HANDLER=$1 @@ -32,7 +33,7 @@ case $TEST_HANDLER in ;; "http") echo "==> Testing REST API handler" - export BENDSQL_DSN="databend+http://${DATABEND_USER}:${DATABEND_PASSWORD}@${DATABEND_HOST}:8000/?sslmode=disable&presign=on" + export BENDSQL_DSN="databend+http://${DATABEND_USER}:${DATABEND_PASSWORD}@${DATABEND_HOST}:${DATABEND_PORT}/?sslmode=disable&presign=on" ;; *) echo "Usage: $0 [flight|http]" diff --git a/core/Cargo.toml b/core/Cargo.toml index 9ce5561a..b8a3f263 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -31,6 +31,7 @@ serde_json = { version = "1.0", default-features = false, features = ["std"] } tokio = { version = "1.34", features = ["macros"] } tokio-retry = "0.3" tokio-util = { version = "0.7", features = ["io-util"] } +parking_lot = "0.12.3" url = { version = "2.5", default-features = false } uuid = { version = "1.6", features = ["v4"] } diff --git a/core/src/client.rs b/core/src/client.rs index 15d247c1..3a5776ff 100644 --- a/core/src/client.rs +++ b/core/src/client.rs @@ -31,15 +31,17 @@ use url::Url; use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth}; use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader}; +use crate::session::SessionState; use crate::stage::StageLocation; use crate::{ error::{Error, Result}, - request::{PaginationConfig, QueryRequest, SessionState, StageAttachmentConfig}, + request::{PaginationConfig, QueryRequest, StageAttachmentConfig}, response::{QueryError, QueryResponse}, }; const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID"; const HEADER_TENANT: &str = "X-DATABEND-TENANT"; +const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE"; const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE"; const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME"; const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT"; @@ -76,6 +78,7 @@ pub struct APIClient { tls_ca_file: Option, presign: PresignMode, + last_node_id: Arc>>, } impl APIClient { @@ -283,6 +286,13 @@ impl APIClient { } } + pub fn set_last_node_id(&self, node_id: String) { + *self.last_node_id.lock() = Some(node_id) + } + pub fn last_node_id(&self) -> Option { + self.last_node_id.lock().clone() + } + pub fn handle_warnings(&self, resp: &QueryResponse) { if let Some(warnings) = &resp.warnings { for w in warnings { @@ -297,12 +307,18 @@ impl APIClient { self.route_hint.next(); } let session_state = self.session_state().await; + let need_sticky = session_state.need_sticky.unwrap_or(false); let req = QueryRequest::new(sql) .with_pagination(self.make_pagination()) .with_session(Some(session_state)); let endpoint = self.endpoint.join("v1/query")?; let query_id = self.gen_query_id(); - let headers = self.make_headers(&query_id).await?; + let mut headers = self.make_headers(&query_id).await?; + if need_sticky { + if let Some(node_id) = self.last_node_id() { + headers.insert(HEADER_STICKY_NODE, node_id.parse()?); + } + } let mut builder = self.cli.post(endpoint.clone()).json(&req); builder = self.auth.wrap(builder).await?; let mut resp = builder.headers(headers.clone()).send().await?; @@ -344,7 +360,12 @@ impl APIClient { Ok(result) } - pub async fn query_page(&self, query_id: &str, next_uri: &str) -> Result { + pub async fn query_page( + &self, + query_id: &str, + next_uri: &str, + node_id: &str, + ) -> Result { info!("query page: {}", next_uri); let endpoint = self.endpoint.join(next_uri)?; let headers = self.make_headers(query_id).await?; @@ -354,6 +375,7 @@ impl APIClient { builder = self.auth.wrap(builder).await?; builder .headers(headers.clone()) + .header(HEADER_STICKY_NODE, node_id) .timeout(self.page_request_timeout) .send() .await @@ -410,12 +432,14 @@ impl APIClient { pub async fn wait_for_query(&self, resp: QueryResponse) -> Result { info!("wait for query: {}", resp.id); + let node_id = resp.node_id.clone(); + self.set_last_node_id(node_id.clone()); if let Some(next_uri) = &resp.next_uri { let schema = resp.schema; let mut data = resp.data; - let mut resp = self.query_page(&resp.id, next_uri).await?; + let mut resp = self.query_page(&resp.id, next_uri, &node_id).await?; while let Some(next_uri) = &resp.next_uri { - resp = self.query_page(&resp.id, next_uri).await?; + resp = self.query_page(&resp.id, next_uri, &node_id).await?; data.append(&mut resp.data); } resp.schema = schema; @@ -487,6 +511,8 @@ impl APIClient { sql, file_format_options, copy_options ); let session_state = self.session_state().await; + let need_sticky = session_state.need_sticky.unwrap_or(false); + let stage_attachment = Some(StageAttachmentConfig { location: stage, file_format_options: Some(file_format_options), @@ -498,8 +524,12 @@ impl APIClient { .with_stage_attachment(stage_attachment); let endpoint = self.endpoint.join("v1/query")?; let query_id = self.gen_query_id(); - let headers = self.make_headers(&query_id).await?; - + let mut headers = self.make_headers(&query_id).await?; + if need_sticky { + if let Some(node_id) = self.last_node_id() { + headers.insert(HEADER_STICKY_NODE, node_id.parse()?); + } + } let mut builder = self.cli.post(endpoint.clone()).json(&req); builder = self.auth.wrap(builder).await?; let mut resp = builder.headers(headers.clone()).send().await?; @@ -626,6 +656,7 @@ impl Default for APIClient { tls_ca_file: None, presign: PresignMode::Auto, route_hint: Arc::new(RouteHintGenerator::new()), + last_node_id: Arc::new(Default::default()), } } } diff --git a/core/src/lib.rs b/core/src/lib.rs index e2652c74..cff92894 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -19,6 +19,7 @@ pub mod error; pub mod presign; pub mod request; pub mod response; +pub mod session; pub mod stage; pub use client::APIClient; diff --git a/core/src/request.rs b/core/src/request.rs index 90f12ef2..65d0585e 100644 --- a/core/src/request.rs +++ b/core/src/request.rs @@ -12,48 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{BTreeMap, HashMap}; +use std::collections::BTreeMap; +use crate::session::SessionState; use serde::{Deserialize, Serialize}; + #[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] pub struct ServerInfo { pub id: String, pub start_time: String, } -#[derive(Deserialize, Serialize, Debug, Default, Clone)] -pub struct SessionState { - #[serde(skip_serializing_if = "Option::is_none")] - pub database: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub settings: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub role: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub secondary_roles: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub txn_state: Option, - - // hide fields of no interest (but need to send back to server in next query) - #[serde(flatten)] - additional_fields: HashMap, -} - -impl SessionState { - pub fn with_settings(mut self, settings: Option>) -> Self { - self.settings = settings; - self - } - - pub fn with_database(mut self, database: Option) -> Self { - self.database = database; - self - } - - pub fn with_role(mut self, role: Option) -> Self { - self.role = role; - self - } -} #[derive(Serialize, Debug)] pub struct QueryRequest<'a> { @@ -122,14 +90,9 @@ mod test { #[test] fn build_request() -> Result<()> { let req = QueryRequest::new("select 1") - .with_session(Some(SessionState { - database: Some("default".to_string()), - settings: Some(BTreeMap::new()), - role: None, - secondary_roles: None, - txn_state: None, - additional_fields: Default::default(), - })) + .with_session(Some( + SessionState::default().with_database(Some("default".to_string())), + )) .with_pagination(Some(PaginationConfig { wait_time_secs: Some(1), max_rows_in_buffer: Some(1), @@ -142,7 +105,7 @@ mod test { })); assert_eq!( serde_json::to_string(&req)?, - r#"{"session":{"database":"default","settings":{}},"sql":"select 1","pagination":{"wait_time_secs":1,"max_rows_in_buffer":1,"max_rows_per_page":1},"stage_attachment":{"location":"@~/my_location"}}"# + r#"{"session":{"database":"default"},"sql":"select 1","pagination":{"wait_time_secs":1,"max_rows_in_buffer":1,"max_rows_per_page":1},"stage_attachment":{"location":"@~/my_location"}}"# ); Ok(()) } diff --git a/core/src/response.rs b/core/src/response.rs index 4b5fdecb..d89c82c4 100644 --- a/core/src/response.rs +++ b/core/src/response.rs @@ -14,7 +14,7 @@ use serde::Deserialize; -use crate::request::SessionState; +use crate::session::SessionState; #[derive(Deserialize, Debug)] pub struct QueryError { @@ -55,6 +55,7 @@ pub struct SchemaField { #[derive(Deserialize, Debug)] pub struct QueryResponse { pub id: String, + pub node_id: String, pub session_id: Option, pub session: Option, pub schema: Vec, diff --git a/core/src/session.rs b/core/src/session.rs new file mode 100644 index 00000000..139ea0e2 --- /dev/null +++ b/core/src/session.rs @@ -0,0 +1,53 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, HashMap}; + +#[derive(Deserialize, Serialize, Debug, Default, Clone)] +pub struct SessionState { + #[serde(skip_serializing_if = "Option::is_none")] + pub database: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub settings: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub secondary_roles: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub txn_state: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub need_sticky: Option, + + // hide fields of no interest (but need to send back to server in next query) + #[serde(flatten)] + additional_fields: HashMap, +} + +impl SessionState { + pub fn with_settings(mut self, settings: Option>) -> Self { + self.settings = settings; + self + } + + pub fn with_database(mut self, database: Option) -> Self { + self.database = database; + self + } + + pub fn with_role(mut self, role: Option) -> Self { + self.role = role; + self + } +} diff --git a/driver/src/rest_api.rs b/driver/src/rest_api.rs index 0a2a8764..e9f5b3b6 100644 --- a/driver/src/rest_api.rs +++ b/driver/src/rest_api.rs @@ -58,8 +58,12 @@ impl Connection for RestAPIConnection { async fn exec(&self, sql: &str) -> Result { info!("exec: {}", sql); let mut resp = self.client.start_query(sql).await?; + let node_id = resp.node_id.clone(); while let Some(next_uri) = resp.next_uri { - resp = self.client.query_page(&resp.id, &next_uri).await?; + resp = self + .client + .query_page(&resp.id, &next_uri, &node_id) + .await?; } Ok(resp.stats.progresses.write_progress.rows as i64) } @@ -201,14 +205,19 @@ impl<'o> RestAPIConnection { Ok(Self { client }) } - async fn wait_for_schema(&self, pre: QueryResponse) -> Result { - if !pre.data.is_empty() || !pre.schema.is_empty() { - return Ok(pre); + async fn wait_for_schema(&self, resp: QueryResponse) -> Result { + if !resp.data.is_empty() || !resp.schema.is_empty() { + return Ok(resp); } - let mut result = pre; - // preserve schema since it is no included in the final response + let node_id = resp.node_id.clone(); + self.client.set_last_node_id(node_id.clone()); + let mut result = resp; + // preserve schema since it is not included in the final response while let Some(next_uri) = result.next_uri { - result = self.client.query_page(&result.id, &next_uri).await?; + result = self + .client + .query_page(&result.id, &next_uri, &node_id) + .await?; if !result.data.is_empty() || !result.schema.is_empty() { break; } @@ -240,6 +249,7 @@ pub struct RestAPIRows { data: VecDeque>>, stats: Option, query_id: String, + node_id: String, next_uri: Option, next_page: Option, } @@ -250,6 +260,7 @@ impl RestAPIRows { let rows = Self { client, query_id: resp.id, + node_id: resp.node_id, next_uri: resp.next_uri, schema: Arc::new(schema.clone()), data: resp.data.into(), @@ -278,7 +289,6 @@ impl Stream for RestAPIRows { if self.schema.fields().is_empty() { self.schema = Arc::new(resp.schema.try_into()?); } - self.query_id = resp.id; self.next_uri = resp.next_uri; self.next_page = None; self.stats = Some(ServerStats::from(resp.stats)); @@ -295,9 +305,10 @@ impl Stream for RestAPIRows { let client = self.client.clone(); let next_uri = next_uri.clone(); let query_id = self.query_id.clone(); + let node_id = self.node_id.clone(); self.next_page = Some(Box::pin(async move { client - .query_page(&query_id, &next_uri) + .query_page(&query_id, &next_uri, &node_id) .await .map_err(|e| e.into()) }));