diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index e58fd22664..6afa90fdd9 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -30,9 +30,11 @@ use std::collections::HashMap; use std::fmt::Debug; +use std::ops::Deref; use std::sync::Arc; use std::time::Instant; +use arrow_schema::SchemaBuilder; use async_trait::async_trait; use datafusion::datasource::provider_as_source; use datafusion::error::Result as DataFusionResult; @@ -75,12 +77,16 @@ use crate::delta_datafusion::{ register_store, DataFusionMixins, DeltaColumn, DeltaScan, DeltaScanConfigBuilder, DeltaSessionConfig, DeltaTableProvider, }; -use crate::kernel::{Action, DataCheck, StructTypeExt}; + +use crate::kernel::{Action, DataCheck, Metadata, StructTypeExt}; use crate::logstore::LogStoreRef; +use crate::operations::cast::merge_schema::merge_arrow_schema; use crate::operations::cdc::*; use crate::operations::merge::barrier::find_node; use crate::operations::transaction::CommitBuilder; -use crate::operations::write::{write_execution_plan, write_execution_plan_cdc, WriterStatsConfig}; +use crate::operations::write::{ + write_execution_plan, write_execution_plan_cdc, SchemaMode, WriterStatsConfig, +}; use crate::protocol::{DeltaOperation, MergePredicate}; use crate::table::state::DeltaTableState; use crate::table::GeneratedColumn; @@ -128,6 +134,8 @@ pub struct MergeBuilder { snapshot: DeltaTableState, /// The source data source: DataFrame, + /// schema evolution mode only MERGE is available + schema_mode: Option, /// Delta object store for handling data files log_store: LogStoreRef, /// Datafusion session state relevant for executing the input plan @@ -170,6 +178,7 @@ impl MergeBuilder { state: None, commit_properties: CommitProperties::default(), writer_properties: None, + schema_mode: None, match_operations: Vec::new(), not_match_operations: Vec::new(), not_match_source_operations: Vec::new(), @@ -362,6 +371,11 @@ impl MergeBuilder { self.target_alias = Some(alias.to_string()); self } + /// Add Schema Write Mode + pub fn with_schema_mode(mut self, schema_mode: SchemaMode) -> Self { + self.schema_mode = Some(schema_mode); + self + } /// The Datafusion session state to use pub fn with_session_state(mut self, state: SessionState) -> Self { @@ -705,6 +719,7 @@ async fn execute( _safe_cast: bool, source_alias: Option, target_alias: Option, + schema_mode: Option, match_operations: Vec, not_match_target_operations: Vec, not_match_source_operations: Vec, @@ -844,7 +859,7 @@ async fn execute( let scan_config = DeltaScanConfigBuilder::default() .with_file_column(true) .with_parquet_pushdown(false) - .with_schema(snapshot.input_schema()?) + .with_schema(snapshot.input_schema()?.clone()) .build(&snapshot)?; let target_provider = Arc::new(DeltaTableProvider::try_new( @@ -853,12 +868,15 @@ async fn execute( scan_config.clone(), )?); + // One possible way to make this progress work is to pretend the target dataframe have the merge_schema + // when merge schema mode is selected, then all the process should be the same let target_provider = provider_as_source(target_provider); let target = LogicalPlanBuilder::scan(target_name.clone(), target_provider.clone(), None)?.build()?; let source_schema = source.schema(); let target_schema = target.schema(); + let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?; let predicate = match predicate { @@ -938,6 +956,50 @@ async fn execute( .map(|op| MergeOperation::try_from(op, &join_schema_df, &state, &target_alias)) .collect::, DeltaTableError>>()?; + // merge_arrow_schema is used to tell whether the two schema can be merge but we use the operation statement to pick new columns + // this avoid the side effect of adding unnessary columns (eg. target.id = source.ID) "ID" will not be added since "id" exist in target and end user intended it to be "id" + let mut new_schema = None; + let mut actions: Vec = vec![]; + if matches!(schema_mode, Some(SchemaMode::Merge)) { + let merge_schema = merge_arrow_schema( + snapshot.input_schema()?, + source_schema.inner().clone(), + false, + )?; + + let mut schema_bulider = SchemaBuilder::from(merge_schema.deref()); + + modify_schema( + &mut schema_bulider, + target_schema, + source_schema, + &match_operations, + )?; + + modify_schema( + &mut schema_bulider, + target_schema, + source_schema, + ¬_match_source_operations, + )?; + + modify_schema( + &mut schema_bulider, + target_schema, + source_schema, + ¬_match_target_operations, + )?; + let schema = Arc::new(schema_bulider.finish()); + new_schema = Some(schema.clone()); + let schema_action = Action::Metadata(Metadata::try_new( + schema.try_into()?, + current_metadata.partition_columns.clone(), + snapshot.metadata().configuration.clone(), + )?); + + actions.push(schema_action); + } + let matched = col(SOURCE_COLUMN) .is_true() .and(col(TARGET_COLUMN).is_true()); @@ -1046,9 +1108,17 @@ async fn execute( let projection = join.with_column(OPERATION_COLUMN, case)?; let mut new_columns = vec![]; + + let mut null_target_columns = vec![]; let mut write_projection = Vec::new(); - for delta_field in snapshot.schema().fields() { + let schema = if let Some(schema) = new_schema { + &schema.try_into()? + } else { + snapshot.schema() + }; + + for delta_field in schema.fields() { let mut when_expr = Vec::with_capacity(operations_size); let mut then_expr = Vec::with_capacity(operations_size); @@ -1058,8 +1128,26 @@ async fn execute( }), None => TableReference::none(), }; + + let source_qualifier = match &source_alias { + Some(alias) => Some(TableReference::Bare { + table: alias.to_owned().into(), + }), + None => TableReference::none(), + }; let name = delta_field.name(); - let column = Column::new(qualifier.clone(), name); + + // check if the name of column is in the target table + let column = if snapshot.schema().index_of(name).is_none() { + let null_column = cast( + lit(ScalarValue::Null).alias(name), + delta_field.data_type().try_into()?, + ); + null_target_columns.push(null_column); + Column::new(source_qualifier.clone(), name) + } else { + Column::new(qualifier.clone(), name) + }; for (idx, (operations, _)) in ops.iter().enumerate() { let op = operations @@ -1268,17 +1356,20 @@ async fn execute( // Extra select_columns is required so that before and after have same schema order // DataFusion doesn't have UnionByName yet, see https://github.com/apache/datafusion/issues/12650 + + let mut select_columns = target_schema + .columns() + .iter() + .filter(|c| c.name != crate::delta_datafusion::PATH_COLUMN) + .map(|c| Expr::Column(c.clone())) + .collect_vec(); + // in case of added columns from a schema evelutions added them as null columns in the target qualifer + select_columns.extend(null_target_columns); + let mut before = cdc_projection .clone() .filter(col(crate::delta_datafusion::PATH_COLUMN).is_not_null())? - .select( - target_schema - .columns() - .iter() - .filter(|c| c.name != crate::delta_datafusion::PATH_COLUMN) - .map(|c| Expr::Column(c.clone())) - .collect_vec(), - )? + .select(select_columns)? .select_columns( &after .schema() @@ -1308,6 +1399,7 @@ async fn execute( )?; let merge_final = &project.into_unoptimized_plan(); + let write = state.create_physical_plan(merge_final).await?; let err = || DeltaTableError::Generic("Unable to locate expected metric node".into()); @@ -1340,6 +1432,7 @@ async fn execute( None, ) .await?; + add_actions.extend(actions); if should_cdc && !change_data.is_empty() { let mut df = change_data @@ -1444,6 +1537,25 @@ async fn execute( Ok((commit.snapshot(), metrics)) } +fn modify_schema( + ending_schema: &mut SchemaBuilder, + target_schema: &DFSchema, + source_schema: &DFSchema, + operations: &[MergeOperation], +) -> DeltaResult<()> { + for columns in operations + .iter() + .filter(|ops| matches!(ops.r#type, OperationType::Update | OperationType::Insert)) + .flat_map(|ops| ops.operations.keys()) + { + if target_schema.field_from_column(columns).is_err() { + let new_fields = source_schema.field_with_unqualified_name(columns.name())?; + ending_schema.push(new_fields.to_owned().with_nullable(true)); + } + } + Ok(()) +} + fn remove_table_alias(expr: Expr, table_alias: &str) -> Expr { expr.transform(&|expr| match expr { Expr::Column(c) => match c.relation { @@ -1498,6 +1610,7 @@ impl std::future::IntoFuture for MergeBuilder { this.safe_cast, this.source_alias, this.target_alias, + this.schema_mode, this.match_operations, this.not_match_operations, this.not_match_source_operations, @@ -1525,6 +1638,7 @@ mod tests { use crate::kernel::StructField; use crate::operations::load_cdf::collect_batches; use crate::operations::merge::filter::generalize_filter; + use crate::operations::write::SchemaMode; use crate::operations::DeltaOps; use crate::protocol::*; use crate::writer::test_utils::datafusion::get_data; @@ -1715,6 +1829,132 @@ mod tests { assert_merge(table, metrics).await; } + #[tokio::test] + async fn test_merge_schema_evolution_simple_update() { + let (table, _) = setup().await; + + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", ArrowDataType::Utf8, true), + Field::new("value", ArrowDataType::Int32, true), + Field::new("modified", ArrowDataType::Utf8, true), + Field::new("inserted_by", ArrowDataType::Utf8, true), + ])); + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![50, 200, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + Arc::new(arrow::array::StringArray::from(vec!["B1", "C1", "X1"])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let (table, _) = DeltaOps(table) + .merge(source, col("target.id").eq(col("source.id"))) + .with_source_alias("source") + .with_target_alias("target") + .with_schema_mode(SchemaMode::Merge) + .when_matched_update(|update| { + update + .update("value", col("source.value").add(lit(1))) + .update("modified", col("source.modified")) + .update("inserted_by", col("source.inserted_by")) + }) + .unwrap() + .await + .unwrap(); + + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[0]; + let parameters = last_commit.operation_parameters.clone().unwrap(); + assert_eq!(parameters["mergePredicate"], json!("target.id = source.id")); + let expected = vec![ + "+----+-------+------------+-------------+", + "| id | value | modified | inserted_by |", + "+----+-------+------------+-------------+", + "| A | 1 | 2021-02-01 | |", + "| B | 51 | 2021-02-02 | B1 |", + "| C | 201 | 2023-07-04 | C1 |", + "| D | 100 | 2021-02-02 | |", + "+----+-------+------------+-------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + + #[tokio::test] + async fn test_merge_schema_evolution_simple_insert() { + let (table, _) = setup().await; + + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", ArrowDataType::Utf8, true), + Field::new("value", ArrowDataType::Int32, true), + Field::new("modified", ArrowDataType::Utf8, true), + Field::new("inserted_by", ArrowDataType::Utf8, true), + ])); + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + Arc::new(arrow::array::StringArray::from(vec!["B1", "C1", "X1"])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let (table, _) = DeltaOps(table) + .merge(source, col("target.id").eq(col("source.id"))) + .with_source_alias("source") + .with_target_alias("target") + .with_schema_mode(SchemaMode::Merge) + .when_not_matched_insert(|insert| { + insert + .set("id", col("source.id")) + .set("value", col("source.value")) + .set("modified", col("source.modified")) + .set("inserted_by", "source.inserted_by") + }) + .unwrap() + .await + .unwrap(); + + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[0]; + let parameters = last_commit.operation_parameters.clone().unwrap(); + assert_eq!(parameters["mergePredicate"], json!("target.id = source.id")); + assert_eq!( + parameters["notMatchedPredicates"], + json!(r#"[{"actionType":"insert"}]"#) + ); + let expected = vec![ + "+----+-------+------------+-------------+", + "| id | value | modified | inserted_by |", + "+----+-------+------------+-------------+", + "| A | 1 | 2021-02-01 | |", + "| B | 10 | 2021-02-01 | |", + "| C | 10 | 2021-02-02 | |", + "| D | 100 | 2021-02-02 | |", + "| X | 30 | 2023-07-04 | X1 |", + "+----+-------+------------+-------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + #[tokio::test] async fn test_merge_str() { // Validate that users can use string predicates @@ -2581,6 +2821,90 @@ mod tests { assert_batches_sorted_eq!(&expected, &actual); } + #[tokio::test] + async fn test_empty_table_schema_evo_merge() { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", ArrowDataType::Utf8, true), + Field::new("value", ArrowDataType::Int32, true), + Field::new("modified", ArrowDataType::Utf8, true), + Field::new("inserted_by", ArrowDataType::Utf8, true), + ])); + let table = setup_table(Some(vec!["modified"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + Arc::new(arrow::array::StringArray::from(vec!["B1", "C1", "X1"])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let (table, metrics) = DeltaOps(table) + .merge( + source, + col("target.id") + .eq(col("source.id")) + .and(col("target.modified").eq(lit("2021-02-02"))), + ) + .with_schema_mode(crate::operations::write::SchemaMode::Merge) + .with_source_alias("source") + .with_target_alias("target") + .when_not_matched_insert(|insert| { + insert + .set("id", col("source.id")) + .set("value", col("source.value")) + .set("modified", col("source.modified")) + .set("inserted_by", col("source.inserted_by")) + }) + .unwrap() + .await + .unwrap(); + + assert_eq!(table.version(), 1); + assert!(table.get_files_count() >= 2); + assert!(metrics.num_target_files_added >= 2); + assert_eq!(metrics.num_target_files_removed, 0); + assert_eq!(metrics.num_target_rows_copied, 0); + assert_eq!(metrics.num_target_rows_updated, 0); + assert_eq!(metrics.num_target_rows_inserted, 3); + assert_eq!(metrics.num_target_rows_deleted, 0); + assert_eq!(metrics.num_output_rows, 3); + assert_eq!(metrics.num_source_rows, 3); + + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[0]; + let parameters = last_commit.operation_parameters.clone().unwrap(); + + assert_eq!( + parameters["predicate"], + json!("id BETWEEN 'B' AND 'X' AND modified = '2021-02-02'") + ); + + let expected = vec![ + "+----+-------+-------------+------------+", + "| id | value | inserted_by | modified |", + "+----+-------+-------------+------------+", + "| B | 10 | B1 | 2021-02-02 |", + "| C | 20 | C1 | 2023-07-04 |", + "| X | 30 | X1 | 2023-07-04 |", + "+----+-------+-------------+------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + #[tokio::test] async fn test_merge_case_sensitive() { let schema = vec![ @@ -3332,6 +3656,115 @@ mod tests { ], &batches } } + #[tokio::test] + async fn test_merge_cdc_enabled_simple_with_schema_merge() { + // Manually creating the desired table with the right minimum CDC features + use crate::kernel::Protocol; + use crate::operations::merge::Action; + + let schema = get_delta_schema(); + + let actions = vec![Action::Protocol(Protocol::new(1, 4))]; + let table: DeltaTable = DeltaOps::new_in_memory() + .create() + .with_columns(schema.fields().cloned()) + .with_actions(actions) + .with_configuration_property(TableProperty::EnableChangeDataFeed, Some("true")) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let schema = get_arrow_schema(&None); + let table = write_data(table, &schema).await; + + assert_eq!(table.version(), 1); + assert_eq!(table.get_files_count(), 1); + let source = merge_source(schema); + let source = source.with_column("inserted_by", lit("new_value")).unwrap(); + + let (table, metrics) = DeltaOps(table) + .merge(source, col("target.id").eq(col("source.id"))) + .with_source_alias("source") + .with_target_alias("target") + .with_schema_mode(SchemaMode::Merge) + .when_matched_update(|update| { + update + .update("value", col("source.value")) + .update("modified", col("source.modified")) + }) + .unwrap() + .when_not_matched_by_source_update(|update| { + update + .predicate(col("target.value").eq(lit(1))) + .update("value", col("target.value") + lit(1)) + }) + .unwrap() + .when_not_matched_insert(|insert| { + insert + .set("id", col("source.id")) + .set("value", col("source.value")) + .set("modified", col("source.modified")) + .set("inserted_by", col("source.inserted_by")) + }) + .unwrap() + .await + .unwrap(); + + let expected = vec![ + "+----+-------+------------+-------------+", + "| id | value | modified | inserted_by |", + "+----+-------+------------+-------------+", + "| A | 2 | 2021-02-01 | |", + "| B | 10 | 2021-02-02 | new_value |", + "| C | 20 | 2023-07-04 | new_value |", + "| D | 100 | 2021-02-02 | |", + "| X | 30 | 2023-07-04 | new_value |", + "+----+-------+------------+-------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + + let ctx = SessionContext::new(); + let table = DeltaOps(table) + .load_cdf() + .with_session_ctx(ctx.clone()) + .with_starting_version(0) + .build() + .await + .expect("Failed to load CDF"); + + let mut batches = collect_batches( + table.properties().output_partitioning().partition_count(), + table, + ctx, + ) + .await + .expect("Failed to collect batches"); + + let _ = arrow::util::pretty::print_batches(&batches); + + // The batches will contain a current _commit_timestamp which shouldn't be check_append_only + let _: Vec<_> = batches.iter_mut().map(|b| b.remove_column(6)).collect(); + + assert_batches_sorted_eq! {[ + "+----+-------+------------+-------------+------------------+-----------------+", + "| id | value | modified | inserted_by | _change_type | _commit_version |", + "+----+-------+------------+-------------+------------------+-----------------+", + "| A | 1 | 2021-02-01 | | insert | 1 |", + "| A | 1 | 2021-02-01 | | update_preimage | 2 |", + "| A | 2 | 2021-02-01 | | update_postimage | 2 |", + "| B | 10 | 2021-02-01 | | insert | 1 |", + "| B | 10 | 2021-02-01 | | update_preimage | 2 |", + "| B | 10 | 2021-02-02 | new_value | update_postimage | 2 |", + "| C | 10 | 2021-02-02 | | insert | 1 |", + "| C | 10 | 2021-02-02 | | update_preimage | 2 |", + "| C | 20 | 2023-07-04 | new_value | update_postimage | 2 |", + "| D | 100 | 2021-02-02 | | insert | 1 |", + "| X | 30 | 2023-07-04 | new_value | insert | 2 |", + "+----+-------+------------+-------------+------------------+-----------------+", + ], &batches } + } + #[tokio::test] async fn test_merge_cdc_enabled_delete() { // Manually creating the desired table with the right minimum CDC features diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index f19c685118..1186ebb2b5 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -192,6 +192,7 @@ class RawDeltaTable: predicate: str, source_alias: Optional[str], target_alias: Optional[str], + schema_mode: Optional[str], writer_properties: Optional[WriterProperties], commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], @@ -285,6 +286,7 @@ def get_num_idx_cols_and_stats_columns( class PyMergeBuilder: source_alias: str target_alias: str + schema_mode: Optional[str] arrow_schema: pyarrow.Schema def when_matched_update( diff --git a/python/deltalake/table.py b/python/deltalake/table.py index f8357c3700..bcef2b8ef2 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -972,6 +972,7 @@ def merge( predicate: str, source_alias: Optional[str] = None, target_alias: Optional[str] = None, + schema_mode: Optional[str] = None, error_on_type_mismatch: bool = True, writer_properties: Optional[WriterProperties] = None, large_dtypes: Optional[bool] = None, @@ -1055,6 +1056,7 @@ def merge( predicate=predicate, source_alias=source_alias, target_alias=target_alias, + schema_mode=schema_mode, safe_cast=not error_on_type_mismatch, writer_properties=writer_properties, commit_properties=commit_properties, diff --git a/python/src/lib.rs b/python/src/lib.rs index b91874616d..015b5bd176 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -894,6 +894,7 @@ impl RawDeltaTable { predicate, source_alias = None, target_alias = None, + schema_mode = None, safe_cast = false, writer_properties = None, post_commithook_properties = None, @@ -906,6 +907,7 @@ impl RawDeltaTable { predicate: String, source_alias: Option, target_alias: Option, + schema_mode: Option, safe_cast: bool, writer_properties: Option, post_commithook_properties: Option, @@ -926,6 +928,7 @@ impl RawDeltaTable { predicate, source_alias, target_alias, + schema_mode, safe_cast, writer_properties, post_commithook_properties, diff --git a/python/src/merge.rs b/python/src/merge.rs index a2ff75a6d1..914abbecd1 100644 --- a/python/src/merge.rs +++ b/python/src/merge.rs @@ -7,12 +7,14 @@ use deltalake::datafusion::datasource::MemTable; use deltalake::datafusion::prelude::SessionContext; use deltalake::logstore::LogStoreRef; use deltalake::operations::merge::MergeBuilder; +use deltalake::operations::write::SchemaMode; use deltalake::operations::CustomExecuteHandler; use deltalake::table::state::DeltaTableState; use deltalake::{DeltaResult, DeltaTable}; use pyo3::prelude::*; use std::collections::HashMap; use std::future::IntoFuture; +use std::str::FromStr; use std::sync::Arc; use crate::error::PythonError; @@ -29,6 +31,8 @@ pub(crate) struct PyMergeBuilder { source_alias: Option, #[pyo3(get)] target_alias: Option, + #[pyo3(get)] + schema_mode: Option, arrow_schema: Arc, } @@ -41,6 +45,7 @@ impl PyMergeBuilder { predicate: String, source_alias: Option, target_alias: Option, + schema_mode: Option, safe_cast: bool, writer_properties: Option, post_commithook_properties: Option, @@ -65,6 +70,10 @@ impl PyMergeBuilder { cmd = cmd.with_target_alias(trgt_alias); } + if let Some(sch_mode) = &schema_mode { + cmd = cmd.with_schema_mode(SchemaMode::from_str(sch_mode)?); + } + if let Some(writer_props) = writer_properties { cmd = cmd.with_writer_properties(set_writer_properties(writer_props)?); } @@ -83,6 +92,7 @@ impl PyMergeBuilder { _builder: Some(cmd), source_alias, target_alias, + schema_mode, arrow_schema: schema, }) } diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 69eb73ebc6..138d9ef8e8 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -129,6 +129,48 @@ def test_merge_when_matched_update_wo_predicate( assert result == expected +def test_merge_when_matched_update_wo_predicate_with_schema_evolution( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "customer": pa.array(["john", "doe"]), + } + ) + + dt.merge( + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", + schema_mode="merge", + ).when_matched_update( + {"price": "s.price", "sold": "s.sold+int'10'", "customer": "s.customer"} + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 20, 30], pa.int32()), + "deleted": pa.array([False] * 5), + "customer": pa.array([None, None, None, "john", "doe"]), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + def test_merge_when_matched_update_all_wo_predicate( tmp_path: pathlib.Path, sample_table: pa.Table ): @@ -339,6 +381,56 @@ def test_merge_when_not_matched_insert_with_predicate( assert result == expected +def test_merge_when_not_matched_insert_with_predicate_schema_evolution( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "10"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "customer": pa.array(["john", "doe"]), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + schema_mode="merge", + predicate="target.id = source.id", + ).when_not_matched_insert( + updates={ + "id": "source.id", + "price": "source.price", + "sold": "source.sold", + "customer": "source.customer", + "deleted": "False", + }, + predicate="source.price < 50", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5", "6"]), + "price": pa.array([0, 1, 2, 3, 4, 10], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 4, 10], pa.int32()), + "deleted": pa.array([False] * 6), + "customer": pa.array([None, None, None, None, None, "john"]), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + def test_merge_when_not_matched_insert_all_with_predicate( tmp_path: pathlib.Path, sample_table: pa.Table ): @@ -417,6 +509,47 @@ def test_merge_when_not_matched_insert_all_with_exclude( assert result == expected +def test_merge_when_not_matched_insert_all_with_exclude_and_with_schema_evo( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "9"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([None, None], pa.bool_()), + "customer": pa.array(["john", "doe"]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + schema_mode="merge", + predicate="target.id = source.id", + ).when_not_matched_insert_all(except_cols=["sold"]).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5", "6", "9"]), + "price": pa.array([0, 1, 2, 3, 4, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 4, None, None], pa.int32()), + "deleted": pa.array([False, False, False, False, False, None, None]), + "customer": pa.array([None, None, None, None, None, "john", "doe"]), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + def test_merge_when_not_matched_insert_all_with_predicate_special_column_names( tmp_path: pathlib.Path, sample_table_with_spaces_numbers: pa.Table ):