Skip to content

Commit

Permalink
feat(derive): use rename_all rule for container when naming fields (#117
Browse files Browse the repository at this point in the history
)
  • Loading branch information
v3xro authored Sep 5, 2024
1 parent ae6f92c commit 8c55a9d
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 23 deletions.
30 changes: 18 additions & 12 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
use proc_macro2::TokenStream;
use quote::quote;
use serde_derive_internals::{
attr::{Default as SerdeDefault, Field},
attr::{Container, Default as SerdeDefault, Field},
Ctxt,
};
use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Fields};

fn column_names(data: &DataStruct) -> TokenStream {
fn column_names(data: &DataStruct, cx: &Ctxt, container: &Container) -> TokenStream {
match &data.fields {
Fields::Named(fields) => {
let cx = Ctxt::new();
let rename_rule = container.rename_all_rules().deserialize;
let column_names_iter = fields
.named
.iter()
.enumerate()
.map(|(index, field)| Field::from_ast(&cx, index, field, None, &SerdeDefault::None))
.map(|(index, field)| Field::from_ast(cx, index, field, None, &SerdeDefault::None))
.filter(|field| !field.skip_serializing() && !field.skip_deserializing())
.map(|field| field.name().serialize_name().to_string());
.map(|field| {
rename_rule
.apply_to_field(field.name().serialize_name())
.to_string()
});

let tokens = quote! {
quote! {
&[#( #column_names_iter,)*]
};

// TODO: do something more clever?
let _ = cx.check();
tokens
}
}
Fields::Unnamed(_) => {
quote! { &[] }
Expand All @@ -39,13 +39,19 @@ fn column_names(data: &DataStruct) -> TokenStream {
#[proc_macro_derive(Row)]
pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);

let cx = Ctxt::new();
let container = Container::from_ast(&cx, &input);
let name = input.ident;

let column_names = match &input.data {
Data::Struct(data) => column_names(data),
Data::Struct(data) => column_names(data, &cx, &container),
Data::Enum(_) | Data::Union(_) => panic!("`Row` can be derived only for structs"),
};

// TODO: do something more clever?
let _ = cx.check().expect("derive context error");

let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

// TODO: replace `clickhouse` with `::clickhouse` here.
Expand Down
57 changes: 53 additions & 4 deletions tests/it/insert.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,25 @@
use crate::{create_simple_table, fetch_simple_rows, flush_query_log, SimpleRow};
use crate::{create_simple_table, fetch_rows, flush_query_log, SimpleRow};
use clickhouse::{sql::Identifier, Client, Row};
use serde::{Deserialize, Serialize};

#[derive(Debug, Row, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
struct RenameRow {
#[serde(rename = "fix_id")]
pub fix_id: i64,
#[serde(rename = "extComplexId")]
pub complex_id: String,
pub ext_float: f64,
}

async fn create_rename_table(client: &Client, table_name: &str) {
client
.query("CREATE TABLE ?(fixId UInt64, extComplexId String, extFloat Float64) ENGINE = MergeTree ORDER BY fixId")
.bind(Identifier(table_name))
.execute()
.await
.unwrap();
}

#[tokio::test]
async fn keeps_client_options() {
Expand Down Expand Up @@ -49,7 +70,7 @@ async fn keeps_client_options() {
format!("should contain {client_setting_name} = {client_setting_value} (from the client options)")
);

let rows = fetch_simple_rows(&client, table_name).await;
let rows = fetch_rows::<SimpleRow>(&client, table_name).await;
assert_eq!(rows, vec!(row))
}

Expand Down Expand Up @@ -96,7 +117,7 @@ async fn overrides_client_options() {
format!("should contain {setting_name} = {override_value} (from the insert options)")
);

let rows = fetch_simple_rows(&client, table_name).await;
let rows = fetch_rows::<SimpleRow>(&client, table_name).await;
assert_eq!(rows, vec!(row))
}

Expand All @@ -117,6 +138,34 @@ async fn empty_insert() {

insert.end().await.unwrap();

let rows = fetch_simple_rows(&client, table_name).await;
let rows = fetch_rows::<SimpleRow>(&client, table_name).await;
assert!(rows.is_empty())
}

#[tokio::test]
async fn rename_insert() {
let table_name = "insert_rename";
let query_id = uuid::Uuid::new_v4().to_string();

let client = prepare_database!();
create_rename_table(&client, table_name).await;

let row = RenameRow {
fix_id: 42,
complex_id: String::from("foo"),
ext_float: 0.5,
};

let mut insert = client
.insert(table_name)
.unwrap()
.with_option("query_id", &query_id);

insert.write(&row).await.unwrap();
insert.end().await.unwrap();

flush_query_log(&client).await;

let rows = fetch_rows::<RenameRow>(&client, table_name).await;
assert_eq!(rows, vec!(row))
}
6 changes: 3 additions & 3 deletions tests/it/inserter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use serde::Serialize;

use clickhouse::{inserter::Quantities, Client, Row};

use crate::{create_simple_table, fetch_simple_rows, flush_query_log, SimpleRow};
use crate::{create_simple_table, fetch_rows, flush_query_log, SimpleRow};

#[derive(Debug, Row, Serialize)]
struct MyRow {
Expand Down Expand Up @@ -236,7 +236,7 @@ async fn keeps_client_options() {
format!("should contain {client_setting_name} = {client_setting_value} (from the client options)")
);

let rows = fetch_simple_rows(&client, table_name).await;
let rows = fetch_rows::<SimpleRow>(&client, table_name).await;
assert_eq!(rows, vec!(row))
}

Expand Down Expand Up @@ -284,6 +284,6 @@ async fn overrides_client_options() {
format!("should contain {setting_name} = {override_value} (from the inserter options)")
);

let rows = fetch_simple_rows(&client, table_name).await;
let rows = fetch_rows::<SimpleRow>(&client, table_name).await;
assert_eq!(rows, vec!(row))
}
10 changes: 6 additions & 4 deletions tests/it/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use clickhouse::sql::Identifier;
use clickhouse::{sql, Client};
use clickhouse_derive::Row;
use clickhouse::{sql, Client, Row};
use serde::{Deserialize, Serialize};

macro_rules! prepare_database {
Expand Down Expand Up @@ -40,11 +39,14 @@ async fn create_simple_table(client: &Client, table_name: &str) {
.unwrap();
}

async fn fetch_simple_rows(client: &Client, table_name: &str) -> Vec<SimpleRow> {
async fn fetch_rows<T>(client: &Client, table_name: &str) -> Vec<T>
where
T: Row + for<'b> Deserialize<'b>,
{
client
.query("SELECT ?fields FROM ?")
.bind(Identifier(table_name))
.fetch_all::<SimpleRow>()
.fetch_all::<T>()
.await
.unwrap()
}
Expand Down

0 comments on commit 8c55a9d

Please sign in to comment.