Skip to content

Commit

Permalink
fix(python, rust): check timestamp_ntz in nested fields, add check_ca…
Browse files Browse the repository at this point in the history
…n_writestamp_ntz in pyarrow writer (#2443)

# Description
The nested fields weren't checked, which meant you could get a
timestampNtz in your schema but not have the reader/writer features set.
This check is now done recursively.
  • Loading branch information
ion-elgreco authored Apr 23, 2024
1 parent da6ed7b commit 12979dd
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 14 deletions.
8 changes: 2 additions & 6 deletions crates/core/src/operations/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,12 @@ impl CreateBuilder {
)
};

let contains_timestampntz = &self
.columns
.iter()
.any(|f| f.data_type() == &DataType::TIMESTAMPNTZ);

let contains_timestampntz = PROTOCOL.contains_timestampntz(&self.columns);
// TODO configure more permissive versions based on configuration. Also how should this ideally be handled?
// We set the lowest protocol we can, and if subsequent writes use newer features we update metadata?

let (min_reader_version, min_writer_version, writer_features, reader_features) =
if *contains_timestampntz {
if contains_timestampntz {
let mut converted_writer_features = self
.configuration
.keys()
Expand Down
27 changes: 21 additions & 6 deletions crates/core/src/operations/transaction/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use lazy_static::lazy_static;
use once_cell::sync::Lazy;

use super::{TableReference, TransactionError};
use crate::kernel::{Action, DataType, EagerSnapshot, ReaderFeatures, Schema, WriterFeatures};
use crate::kernel::{
Action, DataType, EagerSnapshot, ReaderFeatures, Schema, StructField, WriterFeatures,
};
use crate::protocol::DeltaOperation;
use crate::table::state::DeltaTableState;

Expand Down Expand Up @@ -77,17 +79,30 @@ impl ProtocolChecker {
Ok(())
}

/// checks if table contains timestamp_ntz in any field including nested fields.
pub fn contains_timestampntz(&self, fields: &Vec<StructField>) -> bool {
fn check_vec_fields(fields: &Vec<StructField>) -> bool {
fields.iter().any(|f| _check_type(f.data_type()))
}

fn _check_type(dtype: &DataType) -> bool {
match dtype {
&DataType::TIMESTAMPNTZ => true,
DataType::Array(inner) => _check_type(inner.element_type()),
DataType::Struct(inner) => check_vec_fields(inner.fields()),
_ => false,
}
}
check_vec_fields(fields)
}

/// Check can write_timestamp_ntz
pub fn check_can_write_timestamp_ntz(
&self,
snapshot: &DeltaTableState,
schema: &Schema,
) -> Result<(), TransactionError> {
let contains_timestampntz = schema
.fields()
.iter()
.any(|f| f.data_type() == &DataType::TIMESTAMPNTZ);

let contains_timestampntz = self.contains_timestampntz(schema.fields());
let required_features: Option<&HashSet<WriterFeatures>> =
match snapshot.protocol().min_writer_version {
0..=6 => None,
Expand Down
2 changes: 1 addition & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "deltalake-python"
version = "0.17.0"
version = "0.17.1"
authors = ["Qingping Hou <[email protected]>", "Will Jones <[email protected]>"]
homepage = "https://github.com/delta-io/delta-rs"
license = "Apache-2.0"
Expand Down
1 change: 1 addition & 0 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class RawDeltaTable:
custom_metadata: Optional[Dict[str, str]],
) -> None: ...
def cleanup_metadata(self) -> None: ...
def check_can_write_timestamp_ntz(self, schema: pyarrow.Schema) -> None: ...

def rust_core_version() -> str: ...
def write_new_deltalake(
Expand Down
1 change: 1 addition & 0 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def visitor(written_file: Any) -> None:
# We don't currently provide a way to set invariants
# (and maybe never will), so only enforce if already exist.
table_protocol = table.protocol()
table._table.check_can_write_timestamp_ntz(schema)
if (
table_protocol.min_writer_version > MAX_SUPPORTED_PYARROW_WRITER_VERSION
or table_protocol.min_writer_version
Expand Down
13 changes: 12 additions & 1 deletion python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use deltalake::operations::filesystem_check::FileSystemCheckBuilder;
use deltalake::operations::merge::MergeBuilder;
use deltalake::operations::optimize::{OptimizeBuilder, OptimizeType};
use deltalake::operations::restore::RestoreBuilder;
use deltalake::operations::transaction::{CommitBuilder, CommitProperties};
use deltalake::operations::transaction::{CommitBuilder, CommitProperties, PROTOCOL};
use deltalake::operations::update::UpdateBuilder;
use deltalake::operations::vacuum::VacuumBuilder;
use deltalake::parquet::basic::Compression;
Expand Down Expand Up @@ -175,6 +175,17 @@ impl RawDeltaTable {
))
}

pub fn check_can_write_timestamp_ntz(&self, schema: PyArrowType<ArrowSchema>) -> PyResult<()> {
let schema: StructType = (&schema.0).try_into().map_err(PythonError::from)?;
Ok(PROTOCOL
.check_can_write_timestamp_ntz(
self._table.snapshot().map_err(PythonError::from)?,
&schema,
)
.map_err(|e| DeltaTableError::Generic(e.to_string()))
.map_err(PythonError::from)?)
}

pub fn load_version(&mut self, version: i64) -> PyResult<()> {
Ok(rt()
.block_on(self._table.load_version(version))
Expand Down
42 changes: 42 additions & 0 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,3 +1528,45 @@ def test_rust_decimal_cast(tmp_path: pathlib.Path):
write_deltalake(
tmp_path, data, mode="append", schema_mode="merge", engine="rust"
)


@pytest.mark.parametrize(
"array",
[
pa.array([[datetime(2010, 1, 1)]]),
pa.array([{"foo": datetime(2010, 1, 1)}]),
pa.array([{"foo": [[datetime(2010, 1, 1)]]}]),
pa.array([{"foo": [[{"foo": datetime(2010, 1, 1)}]]}]),
],
)
def test_write_timestamp_ntz_nested(tmp_path: pathlib.Path, array: pa.array):
data = pa.table({"x": array})
write_deltalake(tmp_path, data, mode="append", engine="rust")

dt = DeltaTable(tmp_path)

protocol = dt.protocol()
assert protocol.min_reader_version == 3
assert protocol.min_writer_version == 7
assert protocol.reader_features == ["timestampNtz"]
assert protocol.writer_features == ["timestampNtz"]


def test_write_timestamp_ntz_on_table_with_features_not_enabled(tmp_path: pathlib.Path):
data = pa.table({"x": pa.array(["foo"])})
write_deltalake(tmp_path, data, mode="append", engine="pyarrow")

dt = DeltaTable(tmp_path)

protocol = dt.protocol()
assert protocol.min_reader_version == 1
assert protocol.min_writer_version == 2

data = pa.table({"x": pa.array([datetime(2010, 1, 1)])})
with pytest.raises(
DeltaError,
match="Generic DeltaTable error: Writer features must be specified for writerversion >= 7, please specify: TimestampWithoutTimezone",
):
write_deltalake(
tmp_path, data, mode="overwrite", engine="pyarrow", schema_mode="overwrite"
)

0 comments on commit 12979dd

Please sign in to comment.