Skip to content

Commit

Permalink
spanner: Use bigdecimal instead of SpannerNumeric (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshidan authored Jul 5, 2023
1 parent 8dddd76 commit 09444a7
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 45 deletions.
13 changes: 7 additions & 6 deletions spanner-derive/tests/test.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use google_cloud_spanner::bigdecimal::{BigDecimal, Zero};
use serde::{Deserialize, Serialize};
use serial_test::serial;
use std::str::FromStr;
use time::{Date, OffsetDateTime};

use google_cloud_spanner::client::{Client, ClientConfig, Error};
use google_cloud_spanner::mutation::insert_struct;
use google_cloud_spanner::reader::AsyncIterator;
use google_cloud_spanner::statement::Statement;
use google_cloud_spanner::value::SpannerNumeric;
use google_cloud_spanner_derive::{Query, Table};

#[derive(Table, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
Expand Down Expand Up @@ -61,8 +62,8 @@ pub struct User {
pub nullable_bool: Option<bool>,
pub not_null_byte_array: Vec<u8>,
pub nullable_byte_array: Option<Vec<u8>>,
pub not_null_numeric: SpannerNumeric,
pub nullable_numeric: Option<SpannerNumeric>,
pub not_null_numeric: BigDecimal,
pub nullable_numeric: Option<BigDecimal>,
pub not_null_timestamp: OffsetDateTime,
pub nullable_timestamp: Option<OffsetDateTime>,
pub not_null_date: Date,
Expand All @@ -89,7 +90,7 @@ impl Default for User {
nullable_bool: Default::default(),
not_null_byte_array: Default::default(),
nullable_byte_array: Default::default(),
not_null_numeric: Default::default(),
not_null_numeric: BigDecimal::zero(),
nullable_numeric: Default::default(),
nullable_timestamp: Default::default(),
nullable_date: Default::default(),
Expand Down Expand Up @@ -122,7 +123,7 @@ async fn test_table_derive() -> Result<(), Error> {
let user_id = format!("user{now}");
let user = User {
user_id: user_id.clone(),
not_null_numeric: SpannerNumeric::new("-99999999999999999999999999999.999999999"),
not_null_numeric: BigDecimal::from_str("-99999999999999999999999999999.999999999").unwrap(),
..Default::default()
};
client.apply(vec![insert_struct("User", user)]).await?;
Expand All @@ -134,7 +135,7 @@ async fn test_table_derive() -> Result<(), Error> {
if let Some(row) = reader.next().await? {
let v: User = row.try_into()?;
assert_eq!(v.user_id, user_id);
assert_eq!(v.not_null_numeric.as_str(), "-99999999999999999999999999999.999999999");
assert_eq!(&v.not_null_numeric.to_string(), "-99999999999999999999999999999.999999999");
assert!(v.updated_at.unix_timestamp() >= now);
let json_string = serde_json::to_string(&v).unwrap();
let des = serde_json::from_str::<User>(json_string.as_str()).unwrap();
Expand Down
3 changes: 2 additions & 1 deletion spanner/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "google-cloud-spanner"
version = "0.21.0"
version = "0.22.0"
authors = ["yoshidan <[email protected]>"]
edition = "2021"
repository = "https://github.com/yoshidan/google-cloud-rust/tree/main/spanner"
Expand All @@ -21,6 +21,7 @@ parking_lot = "0.12"
base64 = "0.21"
serde = { version = "1.0", optional = true, features = ["derive"] }
tokio-util = "0.7"
bigdecimal = { version="0.3", features=["serde"] }

google-cloud-token = { version = "0.1.1", path = "../foundation/token" }
google-cloud-longrunning= { version = "0.15.0", path = "../foundation/longrunning" }
Expand Down
1 change: 1 addition & 0 deletions spanner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,4 @@ pub mod transaction;
pub mod transaction_ro;
pub mod transaction_rw;
pub mod value;
pub use bigdecimal;
43 changes: 40 additions & 3 deletions spanner/src/row.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::{BTreeMap, HashMap};
use std::num::ParseIntError;
use std::str::FromStr;
use std::sync::Arc;

use base64::prelude::*;
Expand All @@ -13,7 +14,8 @@ use time::{Date, OffsetDateTime};
use google_cloud_googleapis::spanner::v1::struct_type::Field;
use google_cloud_googleapis::spanner::v1::StructType;

use crate::value::{CommitTimestamp, SpannerNumeric};
use crate::bigdecimal::{BigDecimal, ParseBigDecimalError};
use crate::value::CommitTimestamp;

#[derive(Clone)]
pub struct Row {
Expand Down Expand Up @@ -46,6 +48,8 @@ pub enum Error {
InvalidStructColumnIndex(usize),
#[error("No column found in struct: name={0}")]
NoColumnFoundInStruct(String),
#[error("Failed to parse as BigDecimal field={0}")]
BigDecimalParseError(String, #[source] ParseBigDecimalError),
}

impl Row {
Expand Down Expand Up @@ -204,10 +208,12 @@ impl TryFromValue for Vec<u8> {
}
}

impl TryFromValue for SpannerNumeric {
impl TryFromValue for BigDecimal {
fn try_from(item: &Value, field: &Field) -> Result<Self, Error> {
match as_ref(item, field)? {
Kind::StringValue(s) => Ok(SpannerNumeric::new(s.to_string())),
Kind::StringValue(s) => {
Ok(BigDecimal::from_str(s).map_err(|e| Error::BigDecimalParseError(field.name.to_string(), e))?)
}
v => kind_to_error(v, field),
}
}
Expand Down Expand Up @@ -313,13 +319,16 @@ pub fn kind_to_error<'a, T>(v: &'a value::Kind, field: &'a Field) -> Result<T, E
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::ops::Add;
use std::str::FromStr;
use std::sync::Arc;

use prost_types::Value;
use time::OffsetDateTime;

use google_cloud_googleapis::spanner::v1::struct_type::Field;

use crate::bigdecimal::{BigDecimal, FromPrimitive, ToPrimitive, Zero};
use crate::row::{Error, Row, Struct as RowStruct, TryFromStruct};
use crate::statement::{Kinds, ToKind, ToStruct, Types};
use crate::value::CommitTimestamp;
Expand All @@ -328,6 +337,7 @@ mod tests {
pub struct_field: String,
pub struct_field_time: OffsetDateTime,
pub commit_timestamp: CommitTimestamp,
pub big_decimal: BigDecimal,
}

impl TryFromStruct for TestStruct {
Expand All @@ -336,6 +346,7 @@ mod tests {
struct_field: s.column_by_name("struct_field")?,
struct_field_time: s.column_by_name("struct_field_time")?,
commit_timestamp: s.column_by_name("commit_timestamp")?,
big_decimal: s.column_by_name("big_decimal")?,
})
}
}
Expand All @@ -347,6 +358,7 @@ mod tests {
("struct_field_time", self.struct_field_time.to_kind()),
// value from DB is timestamp. it's not string 'spanner.commit_timestamp()'.
("commit_timestamp", OffsetDateTime::from(self.commit_timestamp).to_kind()),
("big_decimal", self.big_decimal.to_kind()),
]
}

Expand All @@ -355,6 +367,7 @@ mod tests {
("struct_field", String::get_type()),
("struct_field_time", OffsetDateTime::get_type()),
("commit_timestamp", CommitTimestamp::get_type()),
("big_decimal", BigDecimal::get_type()),
]
}
}
Expand All @@ -365,6 +378,7 @@ mod tests {
index.insert("value".to_string(), 0);
index.insert("array".to_string(), 1);
index.insert("struct".to_string(), 2);
index.insert("decimal".to_string(), 3);

let now = OffsetDateTime::now_utc();
let row = Row {
Expand All @@ -382,6 +396,10 @@ mod tests {
name: "struct".to_string(),
r#type: Some(Vec::<TestStruct>::get_type()),
},
Field {
name: "decimal".to_string(),
r#type: Some(BigDecimal::get_type()),
},
]),
values: vec![
Value {
Expand All @@ -399,29 +417,48 @@ mod tests {
struct_field: "aaa".to_string(),
struct_field_time: now,
commit_timestamp: CommitTimestamp { timestamp: now },
big_decimal: BigDecimal::from_str("-99999999999999999999999999999.999999999").unwrap(),
},
TestStruct {
struct_field: "bbb".to_string(),
struct_field_time: now,
commit_timestamp: CommitTimestamp { timestamp: now },
big_decimal: BigDecimal::from_str("99999999999999999999999999999.999999999").unwrap(),
},
]
.to_kind(),
),
},
Value {
kind: Some(BigDecimal::from_f64(100.999999999999).unwrap().to_kind()),
},
],
};

let value = row.column_by_name::<String>("value").unwrap();
let array = row.column_by_name::<Vec<i64>>("array").unwrap();
let struct_data = row.column_by_name::<Vec<TestStruct>>("struct").unwrap();
let decimal = row.column_by_name::<BigDecimal>("decimal").unwrap();
assert_eq!(value, "aaa");
assert_eq!(array[0], 10);
assert_eq!(array[1], 100);
assert_eq!(decimal.to_f64().unwrap(), 100.999999999999);
assert_eq!(struct_data[0].struct_field, "aaa");
assert_eq!(struct_data[0].struct_field_time, now);
assert_eq!(
struct_data[0].big_decimal,
BigDecimal::from_str("-99999999999999999999999999999.999999999").unwrap()
);
assert_eq!(struct_data[1].struct_field, "bbb");
assert_eq!(struct_data[1].struct_field_time, now);
assert_eq!(struct_data[1].commit_timestamp.timestamp, now);
assert_eq!(
struct_data[1].big_decimal,
BigDecimal::from_str("99999999999999999999999999999.999999999").unwrap()
);
assert_eq!(
struct_data[1].big_decimal.clone().add(&struct_data[0].big_decimal),
BigDecimal::zero()
);
}
}
7 changes: 4 additions & 3 deletions spanner/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use time::{Date, OffsetDateTime};
use google_cloud_googleapis::spanner::v1::struct_type::Field;
use google_cloud_googleapis::spanner::v1::{StructType, Type, TypeAnnotationCode, TypeCode};

use crate::value::{CommitTimestamp, SpannerNumeric};
use crate::bigdecimal::BigDecimal;
use crate::value::CommitTimestamp;

/// A Statement is a SQL query with named parameters.
///
Expand Down Expand Up @@ -196,9 +197,9 @@ impl ToKind for Vec<u8> {
}
}

impl ToKind for SpannerNumeric {
impl ToKind for BigDecimal {
fn to_kind(&self) -> Kind {
self.as_str().to_string().to_kind()
self.to_string().to_kind()
}
fn get_type() -> Type {
single_type(TypeCode::Numeric)
Expand Down
23 changes: 0 additions & 23 deletions spanner/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,6 @@ use std::time::Duration;
use google_cloud_googleapis::spanner::v1::transaction_options::read_only::TimestampBound as InternalTimestampBound;
use google_cloud_googleapis::spanner::v1::transaction_options::ReadOnly;

/// https://cloud.google.com/spanner/docs/storing-numeric-data#precision_of_numeric_types
/// -99999999999999999999999999999.999999999~99999999999999999999999999999.999999999
/// TODO https://github.com/paupino/rust-decimal/issues/135
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct SpannerNumeric(String);

impl Default for SpannerNumeric {
fn default() -> Self {
Self::new("0")
}
}

impl SpannerNumeric {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}

pub fn as_str(&self) -> &str {
self.0.as_str()
}
}

#[derive(Clone, PartialEq, Eq)]
pub struct Timestamp {
/// Represents seconds of UTC time since Unix epoch
Expand Down
17 changes: 10 additions & 7 deletions spanner/tests/common.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::str::FromStr;
use time::{Date, OffsetDateTime};

use google_cloud_gax::conn::Environment;
Expand All @@ -11,7 +12,9 @@ use google_cloud_spanner::row::{Error as RowError, Row, Struct, TryFromStruct};
use google_cloud_spanner::session::SessionConfig;
use google_cloud_spanner::statement::Statement;
use google_cloud_spanner::transaction_ro::BatchReadOnlyTransaction;
use google_cloud_spanner::value::{CommitTimestamp, SpannerNumeric};
use google_cloud_spanner::value::CommitTimestamp;

use google_cloud_spanner::bigdecimal::BigDecimal;

pub const DATABASE: &str = "projects/local-project/instances/test-instance/databases/local-database";

Expand Down Expand Up @@ -113,8 +116,8 @@ pub fn create_user_mutation(user_id: &str, now: &OffsetDateTime) -> Mutation {
&None::<bool>,
&vec![1_u8],
&None::<Vec<u8>>,
&SpannerNumeric::new("100.24"),
&Some(SpannerNumeric::new("1000.42342")),
&BigDecimal::from_str("-99999999999999999999999999999.999999999").unwrap(),
&Some(BigDecimal::from_str("99999999999999999999999999999.999999999").unwrap()),
now,
&Some(*now),
&now.date(),
Expand Down Expand Up @@ -165,10 +168,10 @@ pub fn assert_user_row(row: &Row, source_user_id: &str, now: &OffsetDateTime, co
assert_eq!(not_null_byte_array.pop().unwrap(), 1_u8);
let nullable_byte_array = row.column_by_name::<Option<Vec<u8>>>("NullableByteArray").unwrap();
assert_eq!(nullable_byte_array, None);
let not_null_decimal = row.column_by_name::<SpannerNumeric>("NotNullNumeric").unwrap();
assert_eq!(not_null_decimal.as_str(), "100.24");
let nullable_decimal = row.column_by_name::<Option<SpannerNumeric>>("NullableNumeric").unwrap();
assert_eq!(nullable_decimal.unwrap().as_str(), "1000.42342");
let not_null_decimal = row.column_by_name::<BigDecimal>("NotNullNumeric").unwrap();
assert_eq!(not_null_decimal.to_string(), "-99999999999999999999999999999.999999999");
let nullable_decimal = row.column_by_name::<Option<BigDecimal>>("NullableNumeric").unwrap();
assert_eq!(nullable_decimal.unwrap().to_string(), "99999999999999999999999999999.999999999");
let not_null_ts = row.column_by_name::<OffsetDateTime>("NotNullTimestamp").unwrap();
assert_eq!(not_null_ts.to_string(), now.to_string());
let nullable_ts = row
Expand Down
44 changes: 42 additions & 2 deletions spanner/tests/transaction_ro_test.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use bigdecimal::BigDecimal;
use std::collections::HashMap;

use serial_test::serial;
use time::OffsetDateTime;

use common::*;
use google_cloud_spanner::key::Key;
use google_cloud_spanner::reader::AsyncIterator;
use google_cloud_spanner::row::Row;
use google_cloud_spanner::statement::Statement;
use google_cloud_spanner::transaction_ro::ReadOnlyTransaction;
Expand Down Expand Up @@ -100,8 +102,8 @@ async fn test_complex_query() {
let mut tx = data_client.read_only_transaction().await.unwrap();
let mut stmt = Statement::new(
"SELECT *,
ARRAY(SELECT AS STRUCT * FROM UserItem WHERE UserId = p.UserId) as UserItem,
ARRAY(SELECT AS STRUCT * FROM UserCharacter WHERE UserId = p.UserId) as UserCharacter,
ARRAY(SELECT AS STRUCT * FROM UserItem WHERE UserId = p.UserId ORDER BY ItemID) as UserItem,
ARRAY(SELECT AS STRUCT * FROM UserCharacter WHERE UserId = p.UserId ORDER BY CharacterID) as UserCharacter,
FROM User p WHERE UserId = @UserId;
",
);
Expand Down Expand Up @@ -320,3 +322,41 @@ async fn test_read_multi_row() {
.unwrap();
assert_eq!(2, all_rows(row).await.unwrap().len());
}

#[tokio::test]
#[serial]
async fn test_big_decimal() {
let client = create_data_client().await;
let mut tx = client.read_only_transaction().await.unwrap();
let stmt = Statement::new(
"SELECT
cast(\"-99999999999999999999999999999.999999999\" as numeric),
cast(\"-99999999999999999999999999999\" as numeric),
cast(\"-0.999999999\" as numeric),
cast(\"0\" as numeric),
cast(\"0.999999999\" as numeric),
cast(\"99999999999999999999999999999\" as numeric),
cast(\"99999999999999999999999999999.999999999\" as numeric)",
);
let mut iter = tx.query(stmt).await.unwrap();
let row = iter.next().await.unwrap().unwrap();
assert_eq!(
"-99999999999999999999999999999.999999999",
row.column::<BigDecimal>(0).unwrap().to_string()
);
assert_eq!(
"-99999999999999999999999999999",
row.column::<BigDecimal>(1).unwrap().to_string()
);
assert_eq!("-0.999999999", row.column::<BigDecimal>(2).unwrap().to_string());
assert_eq!("0", row.column::<BigDecimal>(3).unwrap().to_string());
assert_eq!("0.999999999", row.column::<BigDecimal>(4).unwrap().to_string());
assert_eq!(
"99999999999999999999999999999",
row.column::<BigDecimal>(5).unwrap().to_string()
);
assert_eq!(
"99999999999999999999999999999.999999999",
row.column::<BigDecimal>(6).unwrap().to_string()
);
}

0 comments on commit 09444a7

Please sign in to comment.