Skip to content

Commit

Permalink
updated unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: JustinRush80 <[email protected]>
  • Loading branch information
JustinRush80 committed Jan 23, 2025
1 parent 07e105c commit 7c88080
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 6 deletions.
115 changes: 111 additions & 4 deletions crates/core/src/operations/merge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1661,6 +1661,7 @@ mod tests {
use datafusion_expr::expr::Placeholder;
use datafusion_expr::lit;
use datafusion_expr::Expr;
use delta_kernel::schema::StructType;
use itertools::Itertools;
use regex::Regex;
use serde_json::json;
Expand Down Expand Up @@ -1831,7 +1832,7 @@ mod tests {
assert_merge(table, metrics).await;
}
#[tokio::test]
async fn test_merge_with_schema_mode_no_change_of_schema() {
async fn test_merge_with_schema_merge_no_change_of_schema() {
let (table, _) = setup().await;

let schema = Arc::new(ArrowSchema::new(vec![
Expand Down Expand Up @@ -1921,6 +1922,97 @@ mod tests {
assert_merge(after_table, metrics).await;
}

#[tokio::test]
async fn test_merge_with_schema_merge_and_struct() {
let (table, _) = setup().await;

let nested_schema = Arc::new(ArrowSchema::new(vec![Field::new(
"count",
ArrowDataType::Int64,
true,
)]));

let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new(
"nested",
ArrowDataType::Struct(nested_schema.fields().clone()),
true,
),
]));
let count_array = arrow::array::Int64Array::from(vec![Some(1)]);
let id_array = arrow::array::StringArray::from(vec![Some("X")]);
let value_array = arrow::array::Int32Array::from(vec![Some(1)]);
let modified_array = arrow::array::StringArray::from(vec![Some("2021-02-02")]);

let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(id_array),
Arc::new(value_array),
Arc::new(modified_array),
Arc::new(arrow::array::StructArray::from(
RecordBatch::try_new(nested_schema, vec![Arc::new(count_array)]).unwrap(),
)),
],
)
.unwrap();

let ctx = SessionContext::new();

let source = ctx.read_batch(batch).unwrap();

let (table, _) = DeltaOps(table.clone())
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
.set("nested", col("source.nested"))
})
.unwrap()
.await
.unwrap();

let snapshot_bytes = table
.log_store
.read_commit_entry(2)
.await
.unwrap()
.expect("failed to get snapshot bytes");
let actions = crate::logstore::get_actions(2, snapshot_bytes)
.await
.unwrap();

let schema_actions = actions
.iter()
.any(|action| matches!(action, Action::Metadata(_)));

dbg!(&schema_actions);

assert!(schema_actions);
let expected = vec![
"+----+-------+------------+------------+",
"| id | value | modified | nested |",
"+----+-------+------------+------------+",
"| A | 1 | 2021-02-01 | |",
"| B | 10 | 2021-02-01 | |",
"| C | 10 | 2021-02-02 | |",
"| D | 100 | 2021-02-02 | |",
"| X | 1 | 2021-02-02 | {count: 1} |",
"+----+-------+------------+------------+",
];
let actual = get_data(&table).await;

assert_batches_sorted_eq!(&expected, &actual);
}

#[tokio::test]
async fn test_merge_schema_evolution_simple_update() {
let (table, _) = setup().await;
Expand Down Expand Up @@ -1964,7 +2056,7 @@ mod tests {
.unwrap();

let commit_info = table.history(None).await.unwrap();
dbg!(&commit_info);

let last_commit = &commit_info[0];
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
Expand All @@ -1979,6 +2071,8 @@ mod tests {
"+----+-------+------------+-------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = Arc::clone(&schema).try_into().unwrap();
assert_eq!(&expected_schema_struct, table.schema().unwrap());
assert_batches_sorted_eq!(&expected, &actual);
}

Expand Down Expand Up @@ -2045,6 +2139,8 @@ mod tests {
"+----+-------+------------+-------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = Arc::clone(&schema).try_into().unwrap();
assert_eq!(&expected_schema_struct, table.schema().unwrap());
assert_batches_sorted_eq!(&expected, &actual);
}

Expand Down Expand Up @@ -3050,7 +3146,7 @@ mod tests {
}

#[tokio::test]
async fn test_empty_table_schema_evo_merge() {
async fn test_empty_table_with_schema_merge() {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Expand Down Expand Up @@ -3130,6 +3226,8 @@ mod tests {
"+----+-------+-------------+------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = Arc::clone(&schema).try_into().unwrap();
assert_eq!(&expected_schema_struct, table.schema().unwrap());
assert_batches_sorted_eq!(&expected, &actual);
}

Expand Down Expand Up @@ -3903,6 +4001,13 @@ mod tests {
assert_eq!(table.version(), 0);

let schema = get_arrow_schema(&None);

let source_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new("inserted_by", ArrowDataType::Utf8, true),
]));
let table = write_data(table, &schema).await;

assert_eq!(table.version(), 1);
Expand All @@ -3911,7 +4016,7 @@ mod tests {
let source = source.with_column("inserted_by", lit("new_value")).unwrap();

let (table, _) = DeltaOps(table)
.merge(source, col("target.id").eq(col("source.id")))
.merge(source.clone(), col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
Expand Down Expand Up @@ -3950,6 +4055,8 @@ mod tests {
"+----+-------+------------+-------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = source_schema.try_into().unwrap();
assert_eq!(&expected_schema_struct, table.schema().unwrap());
assert_batches_sorted_eq!(&expected, &actual);

let ctx = SessionContext::new();
Expand Down
2 changes: 0 additions & 2 deletions python/src/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ use deltalake::datafusion::datasource::MemTable;
use deltalake::datafusion::prelude::SessionContext;
use deltalake::logstore::LogStoreRef;
use deltalake::operations::merge::MergeBuilder;
use deltalake::operations::write::SchemaMode;
use deltalake::operations::CustomExecuteHandler;
use deltalake::table::state::DeltaTableState;
use deltalake::{DeltaResult, DeltaTable};
use pyo3::prelude::*;
use std::collections::HashMap;
use std::future::IntoFuture;
use std::str::FromStr;
use std::sync::Arc;

use crate::error::PythonError;
Expand Down
16 changes: 16 additions & 0 deletions python/tests/test_generated_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,22 @@ def test_merge_with_gc(table_with_gc: DeltaTable, data_without_gc):
assert table_with_gc.to_pyarrow_table() == expected_data


def test_merge_with_g_during_schema_evolution(table_with_gc: DeltaTable, data_without_gc):
(
table_with_gc.merge(
data_without_gc, predicate="s.id = t.id", source_alias="s", target_alias="t",merge_schema=True
)
.when_not_matched_insert_all()
.execute()
)
id_col = pa.field("id", pa.int32())
gc = pa.field("gc", pa.int32())
expected_data = pa.Table.from_pydict(
{"id": [1, 2], "gc": [5, 5]}, schema=pa.schema([id_col, gc])
)
assert table_with_gc.to_pyarrow_table() == expected_data


def test_merge_with_gc_invalid(table_with_gc: DeltaTable, invalid_gc_data):
import re

Expand Down
3 changes: 3 additions & 0 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def test_merge_when_matched_update_wo_predicate_with_schema_evolution(
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result.schema == expected.schema
assert result == expected


Expand Down Expand Up @@ -428,6 +429,7 @@ def test_merge_when_not_matched_insert_with_predicate_schema_evolution(
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result.schema == expected.schema
assert result == expected


Expand Down Expand Up @@ -547,6 +549,7 @@ def test_merge_when_not_matched_insert_all_with_exclude_and_with_schema_evo(
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result.schema == expected.schema
assert result == expected


Expand Down

0 comments on commit 7c88080

Please sign in to comment.