diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 51425bf..bd5675a 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -1,30 +1,30 @@ use proc_macro2::TokenStream; use quote::quote; use serde_derive_internals::{ - attr::{Default as SerdeDefault, Field}, + attr::{Container, Default as SerdeDefault, Field}, Ctxt, }; use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Fields}; -fn column_names(data: &DataStruct) -> TokenStream { +fn column_names(data: &DataStruct, cx: &Ctxt, container: &Container) -> TokenStream { match &data.fields { Fields::Named(fields) => { - let cx = Ctxt::new(); + let rename_rule = container.rename_all_rules().deserialize; let column_names_iter = fields .named .iter() .enumerate() - .map(|(index, field)| Field::from_ast(&cx, index, field, None, &SerdeDefault::None)) + .map(|(index, field)| Field::from_ast(cx, index, field, None, &SerdeDefault::None)) .filter(|field| !field.skip_serializing() && !field.skip_deserializing()) - .map(|field| field.name().serialize_name().to_string()); + .map(|field| { + rename_rule + .apply_to_field(field.name().serialize_name()) + .to_string() + }); - let tokens = quote! { + quote! { &[#( #column_names_iter,)*] - }; - - // TODO: do something more clever? - let _ = cx.check(); - tokens + } } Fields::Unnamed(_) => { quote! { &[] } @@ -39,13 +39,19 @@ fn column_names(data: &DataStruct) -> TokenStream { #[proc_macro_derive(Row)] pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); + + let cx = Ctxt::new(); + let container = Container::from_ast(&cx, &input); let name = input.ident; let column_names = match &input.data { - Data::Struct(data) => column_names(data), + Data::Struct(data) => column_names(data, &cx, &container), Data::Enum(_) | Data::Union(_) => panic!("`Row` can be derived only for structs"), }; + // TODO: do something more clever? + let _ = cx.check().expect("derive context error"); + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); // TODO: replace `clickhouse` with `::clickhouse` here. diff --git a/tests/it/insert.rs b/tests/it/insert.rs index 742aa7d..09b0ea8 100644 --- a/tests/it/insert.rs +++ b/tests/it/insert.rs @@ -1,4 +1,25 @@ -use crate::{create_simple_table, fetch_simple_rows, flush_query_log, SimpleRow}; +use crate::{create_simple_table, fetch_rows, flush_query_log, SimpleRow}; +use clickhouse::{sql::Identifier, Client, Row}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Row, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +struct RenameRow { + #[serde(rename = "fix_id")] + pub fix_id: i64, + #[serde(rename = "extComplexId")] + pub complex_id: String, + pub ext_float: f64, +} + +async fn create_rename_table(client: &Client, table_name: &str) { + client + .query("CREATE TABLE ?(fixId UInt64, extComplexId String, extFloat Float64) ENGINE = MergeTree ORDER BY fixId") + .bind(Identifier(table_name)) + .execute() + .await + .unwrap(); +} #[tokio::test] async fn keeps_client_options() { @@ -49,7 +70,7 @@ async fn keeps_client_options() { format!("should contain {client_setting_name} = {client_setting_value} (from the client options)") ); - let rows = fetch_simple_rows(&client, table_name).await; + let rows = fetch_rows::(&client, table_name).await; assert_eq!(rows, vec!(row)) } @@ -96,7 +117,7 @@ async fn overrides_client_options() { format!("should contain {setting_name} = {override_value} (from the insert options)") ); - let rows = fetch_simple_rows(&client, table_name).await; + let rows = fetch_rows::(&client, table_name).await; assert_eq!(rows, vec!(row)) } @@ -117,6 +138,34 @@ async fn empty_insert() { insert.end().await.unwrap(); - let rows = fetch_simple_rows(&client, table_name).await; + let rows = fetch_rows::(&client, table_name).await; assert!(rows.is_empty()) } + +#[tokio::test] +async fn rename_insert() { + let table_name = "insert_rename"; + let query_id = uuid::Uuid::new_v4().to_string(); + + let client = prepare_database!(); + create_rename_table(&client, table_name).await; + + let row = RenameRow { + fix_id: 42, + complex_id: String::from("foo"), + ext_float: 0.5, + }; + + let mut insert = client + .insert(table_name) + .unwrap() + .with_option("query_id", &query_id); + + insert.write(&row).await.unwrap(); + insert.end().await.unwrap(); + + flush_query_log(&client).await; + + let rows = fetch_rows::(&client, table_name).await; + assert_eq!(rows, vec!(row)) +} diff --git a/tests/it/inserter.rs b/tests/it/inserter.rs index a7d5ba3..766b30f 100644 --- a/tests/it/inserter.rs +++ b/tests/it/inserter.rs @@ -6,7 +6,7 @@ use serde::Serialize; use clickhouse::{inserter::Quantities, Client, Row}; -use crate::{create_simple_table, fetch_simple_rows, flush_query_log, SimpleRow}; +use crate::{create_simple_table, fetch_rows, flush_query_log, SimpleRow}; #[derive(Debug, Row, Serialize)] struct MyRow { @@ -236,7 +236,7 @@ async fn keeps_client_options() { format!("should contain {client_setting_name} = {client_setting_value} (from the client options)") ); - let rows = fetch_simple_rows(&client, table_name).await; + let rows = fetch_rows::(&client, table_name).await; assert_eq!(rows, vec!(row)) } @@ -284,6 +284,6 @@ async fn overrides_client_options() { format!("should contain {setting_name} = {override_value} (from the inserter options)") ); - let rows = fetch_simple_rows(&client, table_name).await; + let rows = fetch_rows::(&client, table_name).await; assert_eq!(rows, vec!(row)) } diff --git a/tests/it/main.rs b/tests/it/main.rs index 9a8c11d..001cbd6 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -1,6 +1,5 @@ use clickhouse::sql::Identifier; -use clickhouse::{sql, Client}; -use clickhouse_derive::Row; +use clickhouse::{sql, Client, Row}; use serde::{Deserialize, Serialize}; macro_rules! prepare_database { @@ -40,11 +39,14 @@ async fn create_simple_table(client: &Client, table_name: &str) { .unwrap(); } -async fn fetch_simple_rows(client: &Client, table_name: &str) -> Vec { +async fn fetch_rows(client: &Client, table_name: &str) -> Vec +where + T: Row + for<'b> Deserialize<'b>, +{ client .query("SELECT ?fields FROM ?") .bind(Identifier(table_name)) - .fetch_all::() + .fetch_all::() .await .unwrap() }