Skip to content

Commit

Permalink
adhere to protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco authored and rtyler committed Mar 25, 2024
1 parent f56d8c9 commit abafd2d
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 18 deletions.
50 changes: 40 additions & 10 deletions crates/core/src/kernel/arrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ use arrow_schema::{
};
use lazy_static::lazy_static;

use super::{ActionType, ArrayType, DataType, MapType, PrimitiveType, StructField, StructType};
use super::{
ActionType, ArrayType, DataType, MapType, PrimitiveType, StructField, StructType,
DECIMAL_MAX_PRECISION, DECIMAL_MAX_SCALE,
};

pub(crate) mod extract;
pub(crate) mod json;
Expand Down Expand Up @@ -118,14 +121,12 @@ impl TryFrom<&DataType> for ArrowDataType {
PrimitiveType::Boolean => Ok(ArrowDataType::Boolean),
PrimitiveType::Binary => Ok(ArrowDataType::Binary),
PrimitiveType::Decimal(precision, scale) => {
if precision <= &38 {
if precision <= &DECIMAL_MAX_PRECISION && scale <= &DECIMAL_MAX_SCALE {
Ok(ArrowDataType::Decimal128(*precision, *scale))
} else if precision <= &76 {
Ok(ArrowDataType::Decimal256(*precision, *scale))
} else {
Err(ArrowError::SchemaError(format!(
"Precision too large to be represented in Arrow: {}",
precision
Err(ArrowError::CastError(format!(
"Precision/scale can not be larger than 38 ({},{})",
precision, scale
)))
}
}
Expand Down Expand Up @@ -214,9 +215,12 @@ impl TryFrom<&ArrowDataType> for DataType {
ArrowDataType::Decimal128(p, s) => {
Ok(DataType::Primitive(PrimitiveType::Decimal(*p, *s)))
}
ArrowDataType::Decimal256(p, s) => {
Ok(DataType::Primitive(PrimitiveType::Decimal(*p, *s)))
}
ArrowDataType::Decimal256(p, s) => DataType::decimal(*p, *s).map_err(|_| {
ArrowError::SchemaError(format!(
"Invalid data type for Delta Lake: decimal({},{})",
p, s
))
}),
ArrowDataType::Date32 => Ok(DataType::Primitive(PrimitiveType::Date)),
ArrowDataType::Date64 => Ok(DataType::Primitive(PrimitiveType::Date)),
ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => {
Expand Down Expand Up @@ -773,6 +777,32 @@ mod tests {
);
}

#[test]
fn test_arrow_from_delta_decimal_type_invalid_precision() {
let precision = 39;
let scale = 2;
assert!(matches!(
<DataType as TryFrom<&ArrowDataType>>::try_from(&ArrowDataType::Decimal256(
precision, scale
))
.unwrap_err(),
_
));
}

#[test]
fn test_arrow_from_delta_decimal_type_invalid_scale() {
let precision = 2;
let scale = 39;
assert!(matches!(
<DataType as TryFrom<&ArrowDataType>>::try_from(&ArrowDataType::Decimal256(
precision, scale
))
.unwrap_err(),
_
));
}

#[test]
fn test_arrow_from_delta_timestamp_type() {
let timestamp_field = DataType::Primitive(PrimitiveType::Timestamp);
Expand Down
3 changes: 2 additions & 1 deletion crates/core/src/kernel/expressions/scalars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ impl Scalar {
Self::TimestampNtz(_) => DataType::Primitive(PrimitiveType::TimestampNtz),
Self::Date(_) => DataType::Primitive(PrimitiveType::Date),
Self::Binary(_) => DataType::Primitive(PrimitiveType::Binary),
Self::Decimal(_, precision, scale) => DataType::decimal(*precision, *scale),
// Unwrapping should be safe, since the scalar should never have values that are unsupported
Self::Decimal(_, precision, scale) => DataType::decimal(*precision, *scale).unwrap(),
Self::Null(data_type) => data_type.clone(),
Self::Struct(_, fields) => DataType::struct_type(fields.clone()),
}
Expand Down
55 changes: 52 additions & 3 deletions crates/core/src/kernel/models/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use serde_json::Value;

use crate::kernel::error::Error;
use crate::kernel::DataCheck;
use crate::protocol::ProtocolError;

/// Type alias for a top level schema
pub type Schema = StructType;
Expand Down Expand Up @@ -467,6 +468,12 @@ fn default_true() -> bool {
true
}

/// The maximum precision for [PrimitiveType::Decimal] values
pub const DECIMAL_MAX_PRECISION: u8 = 38;

/// The maximum scale for [PrimitiveType::Decimal] values
pub const DECIMAL_MAX_SCALE: i8 = 38;

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq, Hash)]
#[serde(rename_all = "snake_case")]
/// Primitive types supported by Delta
Expand Down Expand Up @@ -538,7 +545,12 @@ where
.ok_or_else(|| {
serde::de::Error::custom(format!("Invalid scale in decimal: {}", str_value))
})?;

if precision > DECIMAL_MAX_PRECISION || scale > DECIMAL_MAX_SCALE {
return Err(serde::de::Error::custom(format!(
"Precision or scale is larger than 38: {}, {}",
precision, scale
)));
}
Ok((precision, scale))
}

Expand Down Expand Up @@ -613,8 +625,16 @@ impl DataType {
pub const TIMESTAMP: Self = DataType::Primitive(PrimitiveType::Timestamp);
pub const TIMESTAMPNTZ: Self = DataType::Primitive(PrimitiveType::TimestampNtz);

pub fn decimal(precision: u8, scale: i8) -> Self {
DataType::Primitive(PrimitiveType::Decimal(precision, scale))
pub fn decimal(precision: u8, scale: i8) -> Result<Self, ProtocolError> {
if precision > DECIMAL_MAX_PRECISION || scale > DECIMAL_MAX_SCALE {
return Err(ProtocolError::InvalidField(format!(
"decimal({},{})",
precision, scale
)));
}
Ok(DataType::Primitive(PrimitiveType::Decimal(
precision, scale,
)))
}

pub fn struct_type(fields: Vec<StructField>) -> Self {
Expand Down Expand Up @@ -749,6 +769,35 @@ mod tests {
);
}

#[test]
fn test_invalid_decimal() {
let data = r#"
{
"name": "a",
"type": "decimal(39, 10)",
"nullable": false,
"metadata": {}
}
"#;
assert!(matches!(
serde_json::from_str::<StructField>(data).unwrap_err(),
_
));

let data = r#"
{
"name": "a",
"type": "decimal(10, 39)",
"nullable": false,
"metadata": {}
}
"#;
assert!(matches!(
serde_json::from_str::<StructField>(data).unwrap_err(),
_
));
}

#[test]
fn test_field_metadata() {
let data = r#"
Expand Down
2 changes: 1 addition & 1 deletion python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class PrimitiveType:
* "date",
* "timestamp",
* "timestampNtz",
* "decimal(<precision>, <scale>)"
* "decimal(<precision>, <scale>)" Max: decimal(38,38)
Args:
data_type: string representation of the data type
Expand Down
12 changes: 10 additions & 2 deletions python/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,16 @@ impl PrimitiveType {
#[new]
#[pyo3(signature = (data_type))]
fn new(data_type: String) -> PyResult<Self> {
let data_type: DeltaPrimitve = serde_json::from_str(&format!("\"{data_type}\""))
.map_err(|_| PyValueError::new_err(format!("invalid type string: {data_type}")))?;
let data_type: DeltaPrimitve =
serde_json::from_str(&format!("\"{data_type}\"")).map_err(|_| {
if data_type.starts_with("decimal") {
PyValueError::new_err(format!(
"invalid type string: {data_type}, precision/scale can't be larger than 38"
))
} else {
PyValueError::new_err(format!("invalid type string: {data_type}"))
}
})?;

Ok(Self {
inner_type: data_type,
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_primitive_delta_types():
"decimal(10,2)",
]

invalid_types = ["int", "decimal", "decimal()"]
invalid_types = ["int", "decimal", "decimal()", "decimal(39,1)", "decimal(1,39)"]

for data_type in valid_types:
delta_type = PrimitiveType(data_type)
Expand Down
16 changes: 16 additions & 0 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,22 @@ def test_issue_1651_roundtrip_timestamp(tmp_path: pathlib.Path):
assert dataset.count_rows() == 1


@pytest.mark.parametrize("engine", ["rust", "pyarrow"])
def test_invalid_decimals(tmp_path: pathlib.Path, engine):
import re
from decimal import Decimal

data = pa.table(
{"x": pa.array([Decimal("10000000000000000000000000000000000000.0")])}
)

with pytest.raises(
SchemaMismatchError,
match=re.escape("Invalid data type for Delta Lake: decimal(39,1)"),
):
write_deltalake(table_or_uri=tmp_path, mode="append", data=data, engine=engine)


def test_float_values(tmp_path: pathlib.Path):
data = pa.table(
{
Expand Down

0 comments on commit abafd2d

Please sign in to comment.