diff --git a/src/lib.rs b/src/lib.rs index 6df4a50..fda89e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,8 @@ #[macro_use] extern crate static_assertions; -use self::{error::Result, http_client::HttpClient}; +use self::{error::Result, http_client::HttpClient, sql::ser}; +use ::serde::Serialize; use std::{collections::HashMap, fmt::Display, sync::Arc}; pub use self::{compression::Compression, row::Row}; @@ -160,6 +161,12 @@ impl Client { self } + pub fn with_param(self, name: &str, value: impl Serialize) -> Result { + let mut param = String::from(""); + ser::write_param(&mut param, &value)?; + Ok(self.with_option(format!("param_{name}"), param)) + } + /// Used to specify a header that will be passed to all queries. /// /// # Example diff --git a/src/query.rs b/src/query.rs index 926fdb5..fd4e2d9 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,5 +1,5 @@ use hyper::{header::CONTENT_LENGTH, Method, Request}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::fmt::Display; use url::Url; @@ -10,7 +10,7 @@ use crate::{ request_body::RequestBody, response::Response, row::Row, - sql::{Bind, SqlBuilder}, + sql::{ser, Bind, SqlBuilder}, Client, }; @@ -195,6 +195,16 @@ impl Query { self.client.add_option(name, value); self } + + pub fn with_param(mut self, name: &str, value: impl Serialize) -> Self { + let mut param = String::from(""); + if let Err(err) = ser::write_param(&mut param, &value) { + self.sql = SqlBuilder::Failed(format!("invalid param: {err}")); + self + } else { + self.with_option(format!("param_{name}"), param) + } + } } /// A cursor that emits rows. diff --git a/src/sql/bind.rs b/src/sql/bind.rs index 88a0d00..c4885b0 100644 --- a/src/sql/bind.rs +++ b/src/sql/bind.rs @@ -8,13 +8,13 @@ use super::{escape, ser}; #[sealed] pub trait Bind { #[doc(hidden)] - fn write(&self, dst: impl fmt::Write) -> Result<(), String>; + fn write(&self, dst: &mut impl fmt::Write) -> Result<(), String>; } #[sealed] impl Bind for S { #[inline] - fn write(&self, mut dst: impl fmt::Write) -> Result<(), String> { + fn write(&self, mut dst: &mut impl fmt::Write) -> Result<(), String> { ser::write_arg(&mut dst, self) } } @@ -26,7 +26,7 @@ pub struct Identifier<'a>(pub &'a str); #[sealed] impl<'a> Bind for Identifier<'a> { #[inline] - fn write(&self, dst: impl fmt::Write) -> Result<(), String> { + fn write(&self, dst: &mut impl fmt::Write) -> Result<(), String> { escape::identifier(self.0, dst).map_err(|err| err.to_string()) } } diff --git a/src/sql/escape.rs b/src/sql/escape.rs index e43cde6..fae1706 100644 --- a/src/sql/escape.rs +++ b/src/sql/escape.rs @@ -1,35 +1,31 @@ use std::fmt; +// Trust clickhouse-connect https://github.com/ClickHouse/clickhouse-connect/blob/5d85563410f3ec378cb199ec51d75e033211392c/clickhouse_connect/driver/binding.py#L15 + // See https://clickhouse.tech/docs/en/sql-reference/syntax/#syntax-string-literal -pub(crate) fn string(src: &str, dst: impl fmt::Write) -> fmt::Result { - escape(src, dst, '\'') +pub(crate) fn string(src: &str, dst: &mut impl fmt::Write) -> fmt::Result { + dst.write_char('\'')?; + escape(src, dst)?; + dst.write_char('\'') } // See https://clickhouse.tech/docs/en/sql-reference/syntax/#syntax-identifiers -pub(crate) fn identifier(src: &str, dst: impl fmt::Write) -> fmt::Result { - escape(src, dst, '`') +pub(crate) fn identifier(src: &str, dst: &mut impl fmt::Write) -> fmt::Result { + dst.write_char('\'')?; + escape(src, dst)?; + dst.write_char('\'') } -fn escape(src: &str, mut dst: impl fmt::Write, ch: char) -> fmt::Result { - dst.write_char(ch)?; - - // TODO: escape newlines? - for (idx, part) in src.split(ch).enumerate() { - if idx > 0 { - dst.write_char('\\')?; - dst.write_char(ch)?; - } - - for (idx, part) in part.split('\\').enumerate() { - if idx > 0 { - dst.write_str("\\\\")?; - } - - dst.write_str(part)?; - } +pub(crate) fn escape(src: &str, dst: &mut impl fmt::Write) -> fmt::Result { + const REPLACE: &[char] = &['\\', '\'', '`', '\t', '\n']; + let mut rest = src; + while let Some(nextidx) = rest.find(REPLACE) { + let (before, after) = rest.split_at(nextidx); + rest = after; + dst.write_str(before)?; + dst.write_char('\\')?; } - - dst.write_char(ch) + dst.write_str(rest) } #[test] diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 7417be7..66330f6 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -9,7 +9,7 @@ pub use bind::{Bind, Identifier}; mod bind; pub(crate) mod escape; -mod ser; +pub(crate) mod ser; #[derive(Debug, Clone)] pub(crate) enum SqlBuilder { diff --git a/src/sql/ser.rs b/src/sql/ser.rs index 00ea606..6cbdb22 100644 --- a/src/sql/ser.rs +++ b/src/sql/ser.rs @@ -8,23 +8,23 @@ use thiserror::Error; use super::escape; -// === SqlSerializerError === +// === SerializerError === #[derive(Debug, Error)] -enum SqlSerializerError { +enum SerializerError { #[error("{0} is unsupported")] Unsupported(&'static str), #[error("{0}")] Custom(String), } -impl ser::Error for SqlSerializerError { +impl ser::Error for SerializerError { fn custom(msg: T) -> Self { Self::Custom(msg.to_string()) } } -impl From for SqlSerializerError { +impl From for SerializerError { fn from(err: fmt::Error) -> Self { Self::Custom(err.to_string()) } @@ -32,8 +32,8 @@ impl From for SqlSerializerError { // === SqlSerializer === -type Result = std::result::Result; -type Impossible = ser::Impossible<(), SqlSerializerError>; +type Result = std::result::Result; +type Impossible = ser::Impossible<(), SerializerError>; struct SqlSerializer<'a, W> { writer: &'a mut W, @@ -43,7 +43,7 @@ macro_rules! unsupported { ($ser_method:ident($ty:ty) -> $ret:ty, $($other:tt)*) => { #[inline] fn $ser_method(self, _v: $ty) -> $ret { - Err(SqlSerializerError::Unsupported(stringify!($ser_method))) + Err(SerializerError::Unsupported(stringify!($ser_method))) } unsupported!($($other)*); }; @@ -53,7 +53,7 @@ macro_rules! unsupported { ($ser_method:ident, $($other:tt)*) => { #[inline] fn $ser_method(self) -> Result { - Err(SqlSerializerError::Unsupported(stringify!($ser_method))) + Err(SerializerError::Unsupported(stringify!($ser_method))) } unsupported!($($other)*); }; @@ -73,7 +73,7 @@ macro_rules! forward_to_display { } impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { - type Error = SqlSerializerError; + type Error = SerializerError; type Ok = (); type SerializeMap = Impossible; type SerializeSeq = SqlListSerializer<'a, W>; @@ -177,12 +177,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { _variant: &'static str, _value: &T, ) -> Result { - Err(SqlSerializerError::Unsupported("serialize_newtype_variant")) + Err(SerializerError::Unsupported("serialize_newtype_variant")) } #[inline] fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { - Err(SqlSerializerError::Unsupported("serialize_tuple_struct")) + Err(SerializerError::Unsupported("serialize_tuple_struct")) } #[inline] @@ -193,12 +193,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { _variant: &'static str, _len: usize, ) -> Result { - Err(SqlSerializerError::Unsupported("serialize_tuple_variant")) + Err(SerializerError::Unsupported("serialize_tuple_variant")) } #[inline] fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { - Err(SqlSerializerError::Unsupported("serialize_struct")) + Err(SerializerError::Unsupported("serialize_struct")) } #[inline] @@ -209,7 +209,7 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { _variant: &'static str, _len: usize, ) -> Result { - Err(SqlSerializerError::Unsupported("serialize_struct_variant")) + Err(SerializerError::Unsupported("serialize_struct_variant")) } #[inline] @@ -227,7 +227,7 @@ struct SqlListSerializer<'a, W> { } impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> { - type Error = SqlSerializerError; + type Error = SerializerError; type Ok = (); #[inline] @@ -254,7 +254,7 @@ impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> { } impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> { - type Error = SqlSerializerError; + type Error = SerializerError; type Ok = (); #[inline] @@ -271,6 +271,159 @@ impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> { } } +// === ParamSerializer === + +struct ParamSerializer<'a, W> { + writer: &'a mut W, +} + +impl<'a, W: Write> Serializer for ParamSerializer<'a, W> { + type Error = SerializerError; + type Ok = (); + type SerializeMap = Impossible; + type SerializeSeq = SqlListSerializer<'a, W>; + type SerializeStruct = Impossible; + type SerializeStructVariant = Impossible; + type SerializeTuple = SqlListSerializer<'a, W>; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + + unsupported!( + serialize_map(Option) -> Result, + serialize_bytes(&[u8]), + serialize_unit, + serialize_unit_struct(&'static str), + ); + + forward_to_display!( + serialize_i8(i8), + serialize_i16(i16), + serialize_i32(i32), + serialize_i64(i64), + serialize_i128(i128), + serialize_u8(u8), + serialize_u16(u16), + serialize_u32(u32), + serialize_u64(u64), + serialize_u128(u128), + serialize_f32(f32), + serialize_f64(f64), + serialize_bool(bool), + ); + + #[inline] + fn serialize_char(self, value: char) -> Result { + let mut tmp = [0u8; 4]; + self.serialize_str(value.encode_utf8(&mut tmp)) + } + + #[inline] + fn serialize_str(self, value: &str) -> Result { + // ClickHouse expects strings in params to be unquoted until inside a nested type + // nested types go through serialize_seq which'll quote strings + Ok(escape::escape(value, self.writer)?) + } + + #[inline] + fn serialize_seq(self, _len: Option) -> Result> { + self.writer.write_char('[')?; + Ok(SqlListSerializer { + writer: self.writer, + has_items: false, + closing_char: ']', + }) + } + + #[inline] + fn serialize_tuple(self, _len: usize) -> Result> { + self.writer.write_char('(')?; + Ok(SqlListSerializer { + writer: self.writer, + has_items: false, + closing_char: ')', + }) + } + + #[inline] + fn serialize_some(self, _value: &T) -> Result { + _value.serialize(self) + } + + #[inline] + fn serialize_none(self) -> std::result::Result { + self.writer.write_str("NULL")?; + Ok(()) + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + escape::string(variant, self.writer)?; + Ok(()) + } + + #[inline] + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result { + value.serialize(self) + } + + #[inline] + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result { + Err(SerializerError::Unsupported("serialize_newtype_variant")) + } + + #[inline] + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { + Err(SerializerError::Unsupported("serialize_tuple_struct")) + } + + #[inline] + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(SerializerError::Unsupported("serialize_tuple_variant")) + } + + #[inline] + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(SerializerError::Unsupported("serialize_struct")) + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(SerializerError::Unsupported("serialize_struct_variant")) + } + + #[inline] + fn is_human_readable(&self) -> bool { + true + } +} + // === Public API === pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> { @@ -279,6 +432,12 @@ pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Resu .map_err(|err| err.to_string()) } +pub(crate) fn write_param(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> { + value + .serialize(ParamSerializer { writer }) + .map_err(|err| err.to_string()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/it/query.rs b/tests/it/query.rs index 80e0158..a3bc4ef 100644 --- a/tests/it/query.rs +++ b/tests/it/query.rs @@ -85,6 +85,45 @@ async fn fetch_one_and_optional() { assert_eq!(got_string, "bar"); } +#[tokio::test] +async fn server_side_param() { + let client = prepare_database!() + .with_param("val1", 42) + .expect("failed to bind 42"); + + let result = client + .query("SELECT plus({val1: Int32}, {val2: Int32}) AS result") + .with_param("val2", 144) + .fetch_one::() + .await + .expect("failed to fetch u64"); + assert_eq!(result, 186); + + let result = client + .query("SELECT {val1: String} AS result") + .with_param("val1", "string") + .fetch_one::() + .await + .expect("failed to fetch string"); + assert_eq!(result, "string"); + + let result = client + .query("SELECT {val1: String} AS result") + .with_param("val1", "\x01\x02\x03\\ \"\'") + .fetch_one::() + .await + .expect("failed to fetch string"); + assert_eq!(result, "\x01\x02\x03\\ \"\'"); + + let result = client + .query("SELECT {val1: Array(String)} AS result") + .with_param("val1", vec!["a", "bc"]) + .fetch_one::>() + .await + .expect("failed to fetch string"); + assert_eq!(result, &["a", "bc"]); +} + // See #19. #[tokio::test] async fn long_query() {