From abafd2d0cb8dde32ffa990dc30fb97a5581688ec Mon Sep 17 00:00:00 2001 From: ion-elgreco <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 24 Mar 2024 17:03:27 +0100 Subject: [PATCH] adhere to protocol --- crates/core/src/kernel/arrow/mod.rs | 50 +++++++++++++---- crates/core/src/kernel/expressions/scalars.rs | 3 +- crates/core/src/kernel/models/schema.rs | 55 ++++++++++++++++++- python/deltalake/_internal.pyi | 2 +- python/src/schema.rs | 12 +++- python/tests/test_schema.py | 2 +- python/tests/test_writer.py | 16 ++++++ 7 files changed, 122 insertions(+), 18 deletions(-) diff --git a/crates/core/src/kernel/arrow/mod.rs b/crates/core/src/kernel/arrow/mod.rs index 45d6432e1d..648ad16bbc 100644 --- a/crates/core/src/kernel/arrow/mod.rs +++ b/crates/core/src/kernel/arrow/mod.rs @@ -8,7 +8,10 @@ use arrow_schema::{ }; use lazy_static::lazy_static; -use super::{ActionType, ArrayType, DataType, MapType, PrimitiveType, StructField, StructType}; +use super::{ + ActionType, ArrayType, DataType, MapType, PrimitiveType, StructField, StructType, + DECIMAL_MAX_PRECISION, DECIMAL_MAX_SCALE, +}; pub(crate) mod extract; pub(crate) mod json; @@ -118,14 +121,12 @@ impl TryFrom<&DataType> for ArrowDataType { PrimitiveType::Boolean => Ok(ArrowDataType::Boolean), PrimitiveType::Binary => Ok(ArrowDataType::Binary), PrimitiveType::Decimal(precision, scale) => { - if precision <= &38 { + if precision <= &DECIMAL_MAX_PRECISION && scale <= &DECIMAL_MAX_SCALE { Ok(ArrowDataType::Decimal128(*precision, *scale)) - } else if precision <= &76 { - Ok(ArrowDataType::Decimal256(*precision, *scale)) } else { - Err(ArrowError::SchemaError(format!( - "Precision too large to be represented in Arrow: {}", - precision + Err(ArrowError::CastError(format!( + "Precision/scale can not be larger than 38 ({},{})", + precision, scale ))) } } @@ -214,9 +215,12 @@ impl TryFrom<&ArrowDataType> for DataType { ArrowDataType::Decimal128(p, s) => { Ok(DataType::Primitive(PrimitiveType::Decimal(*p, *s))) } - ArrowDataType::Decimal256(p, s) => { - Ok(DataType::Primitive(PrimitiveType::Decimal(*p, *s))) - } + ArrowDataType::Decimal256(p, s) => DataType::decimal(*p, *s).map_err(|_| { + ArrowError::SchemaError(format!( + "Invalid data type for Delta Lake: decimal({},{})", + p, s + )) + }), ArrowDataType::Date32 => Ok(DataType::Primitive(PrimitiveType::Date)), ArrowDataType::Date64 => Ok(DataType::Primitive(PrimitiveType::Date)), ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => { @@ -773,6 +777,32 @@ mod tests { ); } + #[test] + fn test_arrow_from_delta_decimal_type_invalid_precision() { + let precision = 39; + let scale = 2; + assert!(matches!( + >::try_from(&ArrowDataType::Decimal256( + precision, scale + )) + .unwrap_err(), + _ + )); + } + + #[test] + fn test_arrow_from_delta_decimal_type_invalid_scale() { + let precision = 2; + let scale = 39; + assert!(matches!( + >::try_from(&ArrowDataType::Decimal256( + precision, scale + )) + .unwrap_err(), + _ + )); + } + #[test] fn test_arrow_from_delta_timestamp_type() { let timestamp_field = DataType::Primitive(PrimitiveType::Timestamp); diff --git a/crates/core/src/kernel/expressions/scalars.rs b/crates/core/src/kernel/expressions/scalars.rs index d29cccb022..eccc16e1a4 100644 --- a/crates/core/src/kernel/expressions/scalars.rs +++ b/crates/core/src/kernel/expressions/scalars.rs @@ -63,7 +63,8 @@ impl Scalar { Self::TimestampNtz(_) => DataType::Primitive(PrimitiveType::TimestampNtz), Self::Date(_) => DataType::Primitive(PrimitiveType::Date), Self::Binary(_) => DataType::Primitive(PrimitiveType::Binary), - Self::Decimal(_, precision, scale) => DataType::decimal(*precision, *scale), + // Unwrapping should be safe, since the scalar should never have values that are unsupported + Self::Decimal(_, precision, scale) => DataType::decimal(*precision, *scale).unwrap(), Self::Null(data_type) => data_type.clone(), Self::Struct(_, fields) => DataType::struct_type(fields.clone()), } diff --git a/crates/core/src/kernel/models/schema.rs b/crates/core/src/kernel/models/schema.rs index 84e5967f12..7415bea970 100644 --- a/crates/core/src/kernel/models/schema.rs +++ b/crates/core/src/kernel/models/schema.rs @@ -11,6 +11,7 @@ use serde_json::Value; use crate::kernel::error::Error; use crate::kernel::DataCheck; +use crate::protocol::ProtocolError; /// Type alias for a top level schema pub type Schema = StructType; @@ -467,6 +468,12 @@ fn default_true() -> bool { true } +/// The maximum precision for [PrimitiveType::Decimal] values +pub const DECIMAL_MAX_PRECISION: u8 = 38; + +/// The maximum scale for [PrimitiveType::Decimal] values +pub const DECIMAL_MAX_SCALE: i8 = 38; + #[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq, Hash)] #[serde(rename_all = "snake_case")] /// Primitive types supported by Delta @@ -538,7 +545,12 @@ where .ok_or_else(|| { serde::de::Error::custom(format!("Invalid scale in decimal: {}", str_value)) })?; - + if precision > DECIMAL_MAX_PRECISION || scale > DECIMAL_MAX_SCALE { + return Err(serde::de::Error::custom(format!( + "Precision or scale is larger than 38: {}, {}", + precision, scale + ))); + } Ok((precision, scale)) } @@ -613,8 +625,16 @@ impl DataType { pub const TIMESTAMP: Self = DataType::Primitive(PrimitiveType::Timestamp); pub const TIMESTAMPNTZ: Self = DataType::Primitive(PrimitiveType::TimestampNtz); - pub fn decimal(precision: u8, scale: i8) -> Self { - DataType::Primitive(PrimitiveType::Decimal(precision, scale)) + pub fn decimal(precision: u8, scale: i8) -> Result { + if precision > DECIMAL_MAX_PRECISION || scale > DECIMAL_MAX_SCALE { + return Err(ProtocolError::InvalidField(format!( + "decimal({},{})", + precision, scale + ))); + } + Ok(DataType::Primitive(PrimitiveType::Decimal( + precision, scale, + ))) } pub fn struct_type(fields: Vec) -> Self { @@ -749,6 +769,35 @@ mod tests { ); } + #[test] + fn test_invalid_decimal() { + let data = r#" + { + "name": "a", + "type": "decimal(39, 10)", + "nullable": false, + "metadata": {} + } + "#; + assert!(matches!( + serde_json::from_str::(data).unwrap_err(), + _ + )); + + let data = r#" + { + "name": "a", + "type": "decimal(10, 39)", + "nullable": false, + "metadata": {} + } + "#; + assert!(matches!( + serde_json::from_str::(data).unwrap_err(), + _ + )); + } + #[test] fn test_field_metadata() { let data = r#" diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index b16d468571..5f2f4634b2 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -233,7 +233,7 @@ class PrimitiveType: * "date", * "timestamp", * "timestampNtz", - * "decimal(, )" + * "decimal(, )" Max: decimal(38,38) Args: data_type: string representation of the data type diff --git a/python/src/schema.rs b/python/src/schema.rs index 0d0823cbff..edd0c9fecb 100644 --- a/python/src/schema.rs +++ b/python/src/schema.rs @@ -84,8 +84,16 @@ impl PrimitiveType { #[new] #[pyo3(signature = (data_type))] fn new(data_type: String) -> PyResult { - let data_type: DeltaPrimitve = serde_json::from_str(&format!("\"{data_type}\"")) - .map_err(|_| PyValueError::new_err(format!("invalid type string: {data_type}")))?; + let data_type: DeltaPrimitve = + serde_json::from_str(&format!("\"{data_type}\"")).map_err(|_| { + if data_type.starts_with("decimal") { + PyValueError::new_err(format!( + "invalid type string: {data_type}, precision/scale can't be larger than 38" + )) + } else { + PyValueError::new_err(format!("invalid type string: {data_type}")) + } + })?; Ok(Self { inner_type: data_type, diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index 4d70c720dd..23198d9ef3 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -76,7 +76,7 @@ def test_primitive_delta_types(): "decimal(10,2)", ] - invalid_types = ["int", "decimal", "decimal()"] + invalid_types = ["int", "decimal", "decimal()", "decimal(39,1)", "decimal(1,39)"] for data_type in valid_types: delta_type = PrimitiveType(data_type) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 96903f0824..9baab32d9a 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -1391,6 +1391,22 @@ def test_issue_1651_roundtrip_timestamp(tmp_path: pathlib.Path): assert dataset.count_rows() == 1 +@pytest.mark.parametrize("engine", ["rust", "pyarrow"]) +def test_invalid_decimals(tmp_path: pathlib.Path, engine): + import re + from decimal import Decimal + + data = pa.table( + {"x": pa.array([Decimal("10000000000000000000000000000000000000.0")])} + ) + + with pytest.raises( + SchemaMismatchError, + match=re.escape("Invalid data type for Delta Lake: decimal(39,1)"), + ): + write_deltalake(table_or_uri=tmp_path, mode="append", data=data, engine=engine) + + def test_float_values(tmp_path: pathlib.Path): data = pa.table( {