diff --git a/Cargo.lock b/Cargo.lock index d594d73..7150e6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2231,6 +2231,7 @@ name = "pg_parquet" version = "0.1.0" dependencies = [ "arrow", + "arrow-cast", "arrow-schema", "aws-config", "aws-credential-types", diff --git a/Cargo.toml b/Cargo.toml index b5a372b..76a9e1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ pg_test = [] [dependencies] arrow = {version = "53", default-features = false} +arrow-cast = {version = "53", default-features = false} arrow-schema = {version = "53", default-features = false} aws-config = { version = "1.5", default-features = false, features = ["rustls"]} aws-credential-types = {version = "1.2", default-features = false} diff --git a/src/arrow_parquet/arrow_to_pg.rs b/src/arrow_parquet/arrow_to_pg.rs index ec7c9ce..4596a90 100644 --- a/src/arrow_parquet/arrow_to_pg.rs +++ b/src/arrow_parquet/arrow_to_pg.rs @@ -1,31 +1,28 @@ +use std::ops::Deref; + use arrow::array::{ Array, ArrayData, BinaryArray, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, ListArray, MapArray, StringArray, StructArray, Time64MicrosecondArray, TimestampMicrosecondArray, UInt32Array, }; -use arrow_schema::Fields; +use arrow_schema::{DataType, FieldRef, Fields, TimeUnit}; use pgrx::{ datum::{Date, Time, TimeWithTimeZone, Timestamp, TimestampWithTimeZone}, - pg_sys::{ - Datum, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, - INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, - }, + pg_sys::{Datum, Oid, CHAROID, NUMERICOID, TEXTOID, TIMEOID}, prelude::PgHeapTuple, AllocatedByRust, AnyNumeric, IntoDatum, PgTupleDesc, }; use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CopyOperation, }, type_compat::{ fallback_to_text::{reset_fallback_to_text_context, FallbackToText}, geometry::{is_postgis_geometry_type, Geometry}, map::{is_map_type, Map}, - pg_arrow_type_conversions::{ - extract_precision_and_scale_from_numeric_typmod, should_write_numeric_as_text, - }, + pg_arrow_type_conversions::extract_precision_and_scale_from_numeric_typmod, }, }; @@ -57,12 +54,11 @@ pub(crate) trait ArrowArrayToPgType: From { #[derive(Clone)] pub(crate) struct ArrowToPgAttributeContext { name: String, + data_type: DataType, + needs_cast: bool, typoid: Oid, typmod: i32, - is_array: bool, - is_composite: bool, is_geometry: bool, - is_map: bool, attribute_contexts: Option>, attribute_tupledesc: Option>, precision: Option, @@ -70,12 +66,20 @@ pub(crate) struct ArrowToPgAttributeContext { } impl ArrowToPgAttributeContext { - pub(crate) fn new(name: &str, typoid: Oid, typmod: i32, fields: Fields) -> Self { - let field = fields - .iter() - .find(|field| field.name() == name) - .unwrap_or_else(|| panic!("failed to find field {}", name)) - .clone(); + pub(crate) fn new( + name: &str, + typoid: Oid, + typmod: i32, + field: FieldRef, + cast_to_type: Option, + ) -> Self { + let needs_cast = cast_to_type.is_some(); + + let data_type = if let Some(cast_to_type) = &cast_to_type { + cast_to_type.clone() + } else { + field.data_type().clone() + }; let is_array = is_array_type(typoid); let is_composite; @@ -147,9 +151,11 @@ impl ArrowToPgAttributeContext { _ => unreachable!(), }; + // we only cast the top-level attributes, which already covers the nested attributes Some(collect_arrow_to_pg_attribute_contexts( attribute_tupledesc, &fields, + None, )) } else { None @@ -157,12 +163,11 @@ impl ArrowToPgAttributeContext { Self { name: name.to_string(), + data_type, + needs_cast, typoid: attribute_typoid, typmod, - is_array, - is_composite, is_geometry, - is_map, attribute_contexts, attribute_tupledesc, scale, @@ -173,27 +178,49 @@ impl ArrowToPgAttributeContext { pub(crate) fn name(&self) -> &str { &self.name } + + pub(crate) fn needs_cast(&self) -> bool { + self.needs_cast + } + + pub(crate) fn data_type(&self) -> &DataType { + &self.data_type + } } pub(crate) fn collect_arrow_to_pg_attribute_contexts( tupledesc: &PgTupleDesc, fields: &Fields, + cast_to_types: Option>>, ) -> Vec { - // parquet file does not contain generated columns. PG will handle them. - let include_generated_columns = false; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); + let attributes = collect_attributes_for(CopyOperation::CopyFrom, tupledesc); + let mut attribute_contexts = vec![]; - for attribute in attributes { + for (idx, attribute) in attributes.iter().enumerate() { let attribute_name = attribute.name(); let attribute_typoid = attribute.type_oid().value(); let attribute_typmod = attribute.type_mod(); + let field = fields + .iter() + .find(|field| field.name() == attribute_name) + .unwrap_or_else(|| panic!("failed to find field {}", attribute_name)) + .clone(); + + let cast_to_type = if let Some(cast_to_types) = cast_to_types.as_ref() { + debug_assert!(cast_to_types.len() == attributes.len()); + cast_to_types.get(idx).cloned().expect("cast_to_type null") + } else { + None + }; + let attribute_context = ArrowToPgAttributeContext::new( attribute_name, attribute_typoid, attribute_typmod, - fields.clone(), + field, + cast_to_type, ); attribute_contexts.push(attribute_context); @@ -206,7 +233,7 @@ pub(crate) fn to_pg_datum( attribute_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - if attribute_context.is_array { + if matches!(attribute_array.data_type(), DataType::List(_)) { to_pg_array_datum(attribute_array, attribute_context) } else { to_pg_nonarray_datum(attribute_array, attribute_context) @@ -227,43 +254,34 @@ fn to_pg_nonarray_datum( primitive_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - match attribute_context.typoid { - FLOAT4OID => { + match attribute_context.data_type() { + DataType::Float32 => { to_pg_datum!(Float32Array, f32, primitive_array, attribute_context) } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!(Float64Array, f64, primitive_array, attribute_context) } - INT2OID => { + DataType::Int16 => { to_pg_datum!(Int16Array, i16, primitive_array, attribute_context) } - INT4OID => { + DataType::Int32 => { to_pg_datum!(Int32Array, i32, primitive_array, attribute_context) } - INT8OID => { + DataType::Int64 => { to_pg_datum!(Int64Array, i64, primitive_array, attribute_context) } - BOOLOID => { - to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) - } - CHAROID => { - to_pg_datum!(StringArray, i8, primitive_array, attribute_context) - } - TEXTOID => { - to_pg_datum!(StringArray, String, primitive_array, attribute_context) - } - BYTEAOID => { - to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) - } - OIDOID => { + DataType::UInt32 => { to_pg_datum!(UInt32Array, Oid, primitive_array, attribute_context) } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Boolean => { + to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) + } + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, i8, primitive_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!(StringArray, String, primitive_array, attribute_context) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -272,72 +290,74 @@ fn to_pg_nonarray_datum( primitive_array, attribute_context ) - } else { - to_pg_datum!( - Decimal128Array, - AnyNumeric, - primitive_array, - attribute_context - ) } } - DATEOID => { - to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) + } else { + to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) + } } - TIMEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Time64MicrosecondArray, - Time, + Decimal128Array, + AnyNumeric, primitive_array, attribute_context ) } - TIMETZOID => { + DataType::Date32 => { + to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + } + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Time, + primitive_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + TimeWithTimeZone, + primitive_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - TimeWithTimeZone, + TimestampMicrosecondArray, + Timestamp, primitive_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone_str)) + if timezone_str.deref() == "+00:00" => + { to_pg_datum!( TimestampMicrosecondArray, - Timestamp, + TimestampWithTimeZone, primitive_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - TimestampWithTimeZone, + StructArray, + PgHeapTuple, primitive_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Map, primitive_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - PgHeapTuple, - primitive_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Map, primitive_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - FallbackToText, - primitive_array, - attribute_context - ) - } + panic!("unsupported data type: {:?}", attribute_context.data_type()); } } } @@ -354,8 +374,13 @@ fn to_pg_array_datum( let list_array = list_array.value(0).to_data(); - match attribute_context.typoid { - FLOAT4OID => { + let element_field = match attribute_context.data_type() { + DataType::List(field) => field, + _ => unreachable!(), + }; + + match element_field.data_type() { + DataType::Float32 => { to_pg_datum!( Float32Array, Vec>, @@ -363,7 +388,7 @@ fn to_pg_array_datum( attribute_context ) } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!( Float64Array, Vec>, @@ -371,16 +396,19 @@ fn to_pg_array_datum( attribute_context ) } - INT2OID => { + DataType::Int16 => { to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) } - INT4OID => { + DataType::Int32 => { to_pg_datum!(Int32Array, Vec>, list_array, attribute_context) } - INT8OID => { + DataType::Int64 => { to_pg_datum!(Int64Array, Vec>, list_array, attribute_context) } - BOOLOID => { + DataType::UInt32 => { + to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) + } + DataType::Boolean => { to_pg_datum!( BooleanArray, Vec>, @@ -388,34 +416,17 @@ fn to_pg_array_datum( attribute_context ) } - CHAROID => { - to_pg_datum!(StringArray, Vec>, list_array, attribute_context) - } - TEXTOID => { - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } - BYTEAOID => { - to_pg_datum!( - BinaryArray, - Vec>>, - list_array, - attribute_context - ) - } - OIDOID => { - to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) - } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!( + StringArray, + Vec>, + list_array, + attribute_context + ) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -424,82 +435,89 @@ fn to_pg_array_datum( list_array, attribute_context ) + } + } + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!( + BinaryArray, + Vec>, + list_array, + attribute_context + ) } else { to_pg_datum!( - Decimal128Array, - Vec>, + BinaryArray, + Vec>>, list_array, attribute_context ) } } - DATEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Date32Array, - Vec>, + Decimal128Array, + Vec>, list_array, attribute_context ) } - TIMEOID => { + DataType::Date32 => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + Date32Array, + Vec>, list_array, attribute_context ) } - TIMETZOID => { + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + TimestampMicrosecondArray, + Vec>, list_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone_str)) + if timezone_str.deref() == "+00:00" => + { to_pg_datum!( TimestampMicrosecondArray, - Vec>, + Vec>, list_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - Vec>, + StructArray, + Vec>>, list_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Vec>, list_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - Vec>>, - list_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Vec>, list_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!( - BinaryArray, - Vec>, - list_array, - attribute_context - ) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } + panic!("unsupported data type: {:?}", attribute_context.data_type()); } } } diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index a3cd53b..808d506 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -1,4 +1,7 @@ +use std::sync::Arc; + use arrow::array::RecordBatch; +use arrow_cast::{cast_with_options, CastOptions}; use futures::StreamExt; use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStream}; use pgrx::{ @@ -12,13 +15,15 @@ use url::Url; use crate::{ arrow_parquet::arrow_to_pg::to_pg_datum, - pgrx_utils::collect_valid_attributes, + pgrx_utils::{collect_attributes_for, CopyOperation}, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, }; use super::{ arrow_to_pg::{collect_arrow_to_pg_attribute_contexts, ArrowToPgAttributeContext}, - schema_parser::ensure_arrow_schema_match_tupledesc, + schema_parser::{ + ensure_arrow_schema_match_tupledesc_schema, parse_arrow_schema_from_tupledesc, + }, uri_utils::{parquet_reader_from_uri, PG_BACKEND_TOKIO_RUNTIME}, }; @@ -41,12 +46,28 @@ impl ParquetReaderContext { let parquet_reader = parquet_reader_from_uri(&uri); - let schema = parquet_reader.schema(); - ensure_arrow_schema_match_tupledesc(schema.clone(), tupledesc); + let parquet_file_schema = parquet_reader.schema(); - let binary_out_funcs = Self::collect_binary_out_funcs(tupledesc); + let tupledesc_schema = + parse_arrow_schema_from_tupledesc(tupledesc, CopyOperation::CopyFrom); + + let tupledesc_schema = Arc::new(tupledesc_schema); - let attribute_contexts = collect_arrow_to_pg_attribute_contexts(tupledesc, &schema.fields); + // Ensure that the arrow schema matches the tupledesc. + // Gets cast_to_types for each attribute if a cast is needed for the attribute's columnar array + // to match the expected columnar array for its tupledesc type. + let cast_to_types = ensure_arrow_schema_match_tupledesc_schema( + parquet_file_schema.clone(), + tupledesc_schema.clone(), + ); + + let attribute_contexts = collect_arrow_to_pg_attribute_contexts( + tupledesc, + &tupledesc_schema.fields, + Some(cast_to_types), + ); + + let binary_out_funcs = Self::collect_binary_out_funcs(tupledesc); ParquetReaderContext { buffer: Vec::new(), @@ -65,8 +86,7 @@ impl ParquetReaderContext { unsafe { let mut binary_out_funcs = vec![]; - let include_generated_columns = false; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); + let attributes = collect_attributes_for(CopyOperation::CopyFrom, tupledesc); for att in attributes.iter() { let typoid = att.type_oid(); @@ -94,11 +114,25 @@ impl ParquetReaderContext { for attribute_context in attribute_contexts { let name = attribute_context.name(); - let column = record_batch + let column_array = record_batch .column_by_name(name) .unwrap_or_else(|| panic!("column {} not found", name)); - let datum = to_pg_datum(column.to_data(), attribute_context); + let datum = if attribute_context.needs_cast() { + // should fail instead of returning None if the cast fails at runtime + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + + let casted_column_array = + cast_with_options(&column_array, attribute_context.data_type(), &cast_options) + .unwrap_or_else(|e| panic!("failed to cast column {}: {}", name, e)); + + to_pg_datum(casted_column_array.to_data(), attribute_context) + } else { + to_pg_datum(column_array.to_data(), attribute_context) + }; datums.push(datum); } diff --git a/src/arrow_parquet/parquet_writer.rs b/src/arrow_parquet/parquet_writer.rs index 7c12009..af35b75 100644 --- a/src/arrow_parquet/parquet_writer.rs +++ b/src/arrow_parquet/parquet_writer.rs @@ -15,6 +15,7 @@ use crate::{ schema_parser::parse_arrow_schema_from_tupledesc, uri_utils::{parquet_writer_from_uri, PG_BACKEND_TOKIO_RUNTIME}, }, + pgrx_utils::CopyOperation, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, }; @@ -57,7 +58,7 @@ impl ParquetWriterContext { .set_created_by("pg_parquet".to_string()) .build(); - let schema = parse_arrow_schema_from_tupledesc(tupledesc); + let schema = parse_arrow_schema_from_tupledesc(tupledesc, CopyOperation::CopyTo); let schema = Arc::new(schema); let parquet_writer = parquet_writer_from_uri(&uri, schema.clone(), writer_props); diff --git a/src/arrow_parquet/pg_to_arrow.rs b/src/arrow_parquet/pg_to_arrow.rs index 40cc03c..152bf05 100644 --- a/src/arrow_parquet/pg_to_arrow.rs +++ b/src/arrow_parquet/pg_to_arrow.rs @@ -15,8 +15,8 @@ use pgrx::{ use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CopyOperation, }, type_compat::{ fallback_to_text::{reset_fallback_to_text_context, FallbackToText}, @@ -169,8 +169,8 @@ pub(crate) fn collect_pg_to_arrow_attribute_contexts( tupledesc: &PgTupleDesc, fields: &Fields, ) -> Vec { - let include_generated_columns = true; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); + let attributes = collect_attributes_for(CopyOperation::CopyTo, tupledesc); + let mut attribute_contexts = vec![]; for attribute in attributes { diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index 8dd79cf..3380060 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -1,7 +1,8 @@ use std::{collections::HashMap, ops::Deref, sync::Arc}; use arrow::datatypes::{Field, Fields, Schema}; -use arrow_schema::FieldRef; +use arrow_cast::can_cast_types; +use arrow_schema::{DataType, FieldRef}; use parquet::arrow::{arrow_to_parquet_schema, PARQUET_FIELD_ID_META_KEY}; use pg_sys::{ Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, @@ -11,8 +12,8 @@ use pgrx::{check_for_interrupts, prelude::*, PgTupleDesc}; use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CopyOperation, }, type_compat::{ geometry::is_postgis_geometry_type, @@ -23,8 +24,11 @@ use crate::{ }, }; -pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: &PgTupleDesc) -> String { - let arrow_schema = parse_arrow_schema_from_tupledesc(tupledesc); +pub(crate) fn parquet_schema_string_from_tupledesc( + tupledesc: &PgTupleDesc, + copy_operation: CopyOperation, +) -> String { + let arrow_schema = parse_arrow_schema_from_tupledesc(tupledesc, copy_operation); let parquet_schema = arrow_to_parquet_schema(&arrow_schema) .unwrap_or_else(|e| panic!("failed to convert arrow schema to parquet schema: {}", e)); @@ -33,13 +37,15 @@ pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: &PgTupleDesc) -> S String::from_utf8(buf).unwrap_or_else(|e| panic!("failed to convert schema to string: {}", e)) } -pub(crate) fn parse_arrow_schema_from_tupledesc(tupledesc: &PgTupleDesc) -> Schema { +pub(crate) fn parse_arrow_schema_from_tupledesc( + tupledesc: &PgTupleDesc, + copy_operation: CopyOperation, +) -> Schema { let mut field_id = 0; let mut struct_attribute_fields = vec![]; - let include_generated_columns = true; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); + let attributes = collect_attributes_for(copy_operation, tupledesc); for attribute in attributes { let attribute_name = attribute.name(); @@ -48,13 +54,19 @@ pub(crate) fn parse_arrow_schema_from_tupledesc(tupledesc: &PgTupleDesc) -> Sche let field = if is_composite_type(attribute_typoid) { let attribute_tupledesc = tuple_desc(attribute_typoid, attribute_typmod); - parse_struct_schema(attribute_tupledesc, attribute_name, &mut field_id) + parse_struct_schema( + attribute_tupledesc, + attribute_name, + copy_operation, + &mut field_id, + ) } else if is_map_type(attribute_typoid) { let attribute_base_elem_typoid = domain_array_base_elem_typoid(attribute_typoid); parse_map_schema( attribute_base_elem_typoid, attribute_typmod, attribute_name, + copy_operation, &mut field_id, ) } else if is_array_type(attribute_typoid) { @@ -63,6 +75,7 @@ pub(crate) fn parse_arrow_schema_from_tupledesc(tupledesc: &PgTupleDesc) -> Sche attribute_element_typoid, attribute_typmod, attribute_name, + copy_operation, &mut field_id, ) } else { @@ -70,6 +83,7 @@ pub(crate) fn parse_arrow_schema_from_tupledesc(tupledesc: &PgTupleDesc) -> Sche attribute_typoid, attribute_typmod, attribute_name, + copy_operation, &mut field_id, ) }; @@ -80,7 +94,12 @@ pub(crate) fn parse_arrow_schema_from_tupledesc(tupledesc: &PgTupleDesc) -> Sche Schema::new(Fields::from(struct_attribute_fields)) } -fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i32) -> Arc { +fn parse_struct_schema( + tupledesc: PgTupleDesc, + elem_name: &str, + copy_operation: CopyOperation, + field_id: &mut i32, +) -> Arc { check_for_interrupts!(); let metadata = HashMap::from_iter(vec![( @@ -92,8 +111,7 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i let mut child_fields: Vec> = vec![]; - let include_generated_columns = true; - let attributes = collect_valid_attributes(&tupledesc, include_generated_columns); + let attributes = collect_attributes_for(copy_operation, &tupledesc); for attribute in attributes { if attribute.is_dropped() { @@ -106,13 +124,19 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i let child_field = if is_composite_type(attribute_oid) { let attribute_tupledesc = tuple_desc(attribute_oid, attribute_typmod); - parse_struct_schema(attribute_tupledesc, attribute_name, field_id) + parse_struct_schema( + attribute_tupledesc, + attribute_name, + copy_operation, + field_id, + ) } else if is_map_type(attribute_oid) { let attribute_base_elem_typoid = domain_array_base_elem_typoid(attribute_oid); parse_map_schema( attribute_base_elem_typoid, attribute_typmod, attribute_name, + copy_operation, field_id, ) } else if is_array_type(attribute_oid) { @@ -121,25 +145,40 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i attribute_element_typoid, attribute_typmod, attribute_name, + copy_operation, field_id, ) } else { - parse_primitive_schema(attribute_oid, attribute_typmod, attribute_name, field_id) + parse_primitive_schema( + attribute_oid, + attribute_typmod, + attribute_name, + copy_operation, + field_id, + ) }; child_fields.push(child_field); } + let nullable = true; + Field::new( elem_name, arrow::datatypes::DataType::Struct(Fields::from(child_fields)), - true, + nullable, ) .with_metadata(metadata) .into() } -fn parse_list_schema(typoid: Oid, typmod: i32, array_name: &str, field_id: &mut i32) -> Arc { +fn parse_list_schema( + typoid: Oid, + typmod: i32, + array_name: &str, + copy_operation: CopyOperation, + field_id: &mut i32, +) -> Arc { check_for_interrupts!(); let list_metadata = HashMap::from_iter(vec![( @@ -151,24 +190,38 @@ fn parse_list_schema(typoid: Oid, typmod: i32, array_name: &str, field_id: &mut let elem_field = if is_composite_type(typoid) { let tupledesc = tuple_desc(typoid, typmod); - parse_struct_schema(tupledesc, array_name, field_id) + parse_struct_schema(tupledesc, array_name, copy_operation, field_id) } else if is_map_type(typoid) { let base_elem_typoid = domain_array_base_elem_typoid(typoid); - parse_map_schema(base_elem_typoid, typmod, array_name, field_id) + parse_map_schema( + base_elem_typoid, + typmod, + array_name, + copy_operation, + field_id, + ) } else { - parse_primitive_schema(typoid, typmod, array_name, field_id) + parse_primitive_schema(typoid, typmod, array_name, copy_operation, field_id) }; + let nullable = true; + Field::new( array_name, arrow::datatypes::DataType::List(elem_field), - true, + nullable, ) .with_metadata(list_metadata) .into() } -fn parse_map_schema(typoid: Oid, typmod: i32, map_name: &str, field_id: &mut i32) -> Arc { +fn parse_map_schema( + typoid: Oid, + typmod: i32, + map_name: &str, + copy_operation: CopyOperation, + field_id: &mut i32, +) -> Arc { let map_metadata = HashMap::from_iter(vec![( PARQUET_FIELD_ID_META_KEY.into(), field_id.to_string(), @@ -177,13 +230,18 @@ fn parse_map_schema(typoid: Oid, typmod: i32, map_name: &str, field_id: &mut i32 *field_id += 1; let tupledesc = tuple_desc(typoid, typmod); - let entries_field = parse_struct_schema(tupledesc, map_name, field_id); + + let entries_field = parse_struct_schema(tupledesc, map_name, copy_operation, field_id); let entries_field = adjust_map_entries_field(entries_field); + let keys_sorted = false; + + let nullable = true; + Field::new( map_name, - arrow::datatypes::DataType::Map(entries_field, false), - true, + arrow::datatypes::DataType::Map(entries_field, keys_sorted), + nullable, ) .with_metadata(map_metadata) .into() @@ -193,6 +251,7 @@ fn parse_primitive_schema( typoid: Oid, typmod: i32, elem_name: &str, + _copy_operation: CopyOperation, field_id: &mut i32, ) -> Arc { check_for_interrupts!(); @@ -204,31 +263,33 @@ fn parse_primitive_schema( *field_id += 1; + let nullable = true; + let field = match typoid { - FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, true), - FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, true), - BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, true), - INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, true), - INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, true), - INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, true), + FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, nullable), + FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, nullable), + BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, nullable), + INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, nullable), + INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, nullable), + INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, nullable), NUMERICOID => { let (precision, scale) = extract_precision_and_scale_from_numeric_typmod(typmod); if should_write_numeric_as_text(precision) { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } else { Field::new( elem_name, arrow::datatypes::DataType::Decimal128(precision as _, scale as _), - true, + nullable, ) } } - DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, true), + DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, nullable), TIMESTAMPOID => Field::new( elem_name, arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), - true, + nullable, ), TIMESTAMPTZOID => Field::new( elem_name, @@ -236,31 +297,31 @@ fn parse_primitive_schema( arrow::datatypes::TimeUnit::Microsecond, Some("+00:00".into()), ), - true, + nullable, ), TIMEOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ), TIMETZOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ) .with_metadata(HashMap::from_iter(vec![( "adjusted_to_utc".into(), "true".into(), )])), - CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, true), - OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, true), + CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable), + OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, nullable), _ => { if is_postgis_geometry_type(typoid) { - Field::new(elem_name, arrow::datatypes::DataType::Binary, true) + Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable) } else { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } } }; @@ -289,60 +350,94 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { let key_field = fields.find("key").expect("expected key field").1; let value_field = fields.find("val").expect("expected val field").1; - not_nullable_key_field = - Field::new(key_field.name(), key_field.data_type().clone(), false) - .with_metadata(key_field.metadata().clone()); + let key_nullable = false; + + not_nullable_key_field = Field::new( + key_field.name(), + key_field.data_type().clone(), + key_nullable, + ) + .with_metadata(key_field.metadata().clone()); + + let value_nullable = true; - nullable_value_field = - Field::new(value_field.name(), value_field.data_type().clone(), true) - .with_metadata(value_field.metadata().clone()); + nullable_value_field = Field::new( + value_field.name(), + value_field.data_type().clone(), + value_nullable, + ) + .with_metadata(value_field.metadata().clone()); } _ => { panic!("expected struct data type for map entries") } }; + let entries_nullable = false; + let entries_field = Field::new( name, arrow::datatypes::DataType::Struct(Fields::from(vec![ not_nullable_key_field, nullable_value_field, ])), - false, + entries_nullable, ) .with_metadata(metadata); Arc::new(entries_field) } -pub(crate) fn ensure_arrow_schema_match_tupledesc( - file_schema: Arc, - tupledesc: &PgTupleDesc, -) { - let table_schema = parse_arrow_schema_from_tupledesc(tupledesc); +// ensure_arrow_schema_match_tupledesc_schema throws an error if the arrow schema does not match the table schema. +// If the arrow schema is castable to the table schema, it returns a vector of Option to cast to +// for each field. +pub(crate) fn ensure_arrow_schema_match_tupledesc_schema( + arrow_schema: Arc, + tupledesc_schema: Arc, +) -> Vec> { + let mut cast_to_types = Vec::new(); - for table_schema_field in table_schema.fields().iter() { - let table_schema_field_name = table_schema_field.name(); - let table_schema_field_type = table_schema_field.data_type(); + for tupledesc_field in tupledesc_schema.fields().iter() { + let field_name = tupledesc_field.name(); - let file_schema_field = file_schema.column_with_name(table_schema_field_name); + let arrow_field = arrow_schema.column_with_name(field_name); - if let Some(file_schema_field) = file_schema_field { - let file_schema_field_type = file_schema_field.1.data_type(); + if arrow_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } - if file_schema_field_type != table_schema_field_type { - panic!( - "type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but file had \"{}\"", - table_schema_field_name, - table_schema_field_type, - file_schema_field_type, - ); - } - } else { + let (_, arrow_field) = arrow_field.unwrap(); + let arrow_field = Arc::new(arrow_field.clone()); + + let from_type = arrow_field.data_type(); + let to_type = tupledesc_field.data_type(); + + // no cast needed + if from_type == to_type { + cast_to_types.push(None); + continue; + } + + // struct types are not castable since arrow supports casting struct fields by field position + // instead of field name, which is not the intended behavior for pg_parquet + if matches!(from_type, DataType::Struct(_)) || !can_cast_types(from_type, to_type) { panic!( - "column \"{}\" is not found in parquet file", - table_schema_field_name + "type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but parquet file had \"{}\"", + field_name, + to_type, + from_type, ); } + + pgrx::debug2!( + "column \"{}\" is being cast from \"{}\" to \"{}\"", + field_name, + from_type, + to_type + ); + + cast_to_types.push(Some(to_type.clone())); } + + cast_to_types } diff --git a/src/lib.rs b/src/lib.rs index 57584bb..98b299f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,8 +37,11 @@ pub extern "C" fn _PG_init() { #[cfg(any(test, feature = "pg_test"))] #[pg_schema] mod tests { + use std::fs::File; use std::io::Write; use std::marker::PhantomData; + use std::sync::Arc; + use std::vec; use std::{collections::HashMap, fmt::Debug}; use crate::arrow_parquet::compression::PgParquetCompression; @@ -48,6 +51,14 @@ mod tests { use crate::type_compat::pg_arrow_type_conversions::{ DEFAULT_UNBOUNDED_NUMERIC_PRECISION, DEFAULT_UNBOUNDED_NUMERIC_SCALE, }; + use arrow::array::{ + ArrayRef, Float32Array, Int16Array, Int32Array, LargeBinaryArray, LargeStringArray, + ListArray, RecordBatch, StructArray, UInt16Array, UInt32Array, UInt64Array, + }; + use arrow::buffer::{OffsetBuffer, ScalarBuffer}; + use arrow::datatypes::UInt16Type; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use parquet::arrow::ArrowWriter; use pgrx::pg_sys::Oid; use pgrx::{ composite_type, @@ -340,6 +351,14 @@ mod tests { Spi::get_one(&query).unwrap().unwrap() } + fn write_record_batch_to_parquet(schema: SchemaRef, record_batch: RecordBatch) { + let file = File::create("/tmp/test.parquet").unwrap(); + let mut writer = ArrowWriter::try_new(file, schema, None).unwrap(); + + writer.write(&record_batch).unwrap(); + writer.close().unwrap(); + } + #[pg_test] fn test_int2() { let test_table = TestTable::::new("int2".into()); @@ -1391,6 +1410,340 @@ mod tests { Spi::run("DROP TYPE dog;").unwrap(); } + #[pg_test] + fn test_coerce_primitive_types() { + // INT16 => {int, bigint} + let x_nullable = false; + let y_nullable = true; + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int16, x_nullable), + Field::new("y", DataType::Int16, y_nullable), + ])); + + let x = Arc::new(Int16Array::from(vec![1])); + let y = Arc::new(Int16Array::from(vec![2])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int, y bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // INT32 => {bigint} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, true)])); + + let x = Arc::new(Int32Array::from(vec![1])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // FLOAT32 => {double} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, true)])); + + let x = Arc::new(Float32Array::from(vec![1.123])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x double precision)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value as f32, 1.123); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT16 => {smallint, int, bigint} + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::UInt16, true), + Field::new("y", DataType::UInt16, true), + Field::new("z", DataType::UInt16, true), + ])); + + let x = Arc::new(UInt16Array::from(vec![1])); + let y = Arc::new(UInt16Array::from(vec![2])); + let z = Arc::new(UInt16Array::from(vec![3])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y, z]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x smallint, y int, z bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = + Spi::get_three::("SELECT x, y, z FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2), Some(3))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT32 => {int, bigint} + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::UInt32, true), + Field::new("y", DataType::UInt32, true), + ])); + + let x = Arc::new(UInt32Array::from(vec![1])); + let y = Arc::new(UInt32Array::from(vec![2])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int, y bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT64 => {bigint} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::UInt64, true)])); + + let x = Arc::new(UInt64Array::from(vec![1])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // LargeUtf8 => {text} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::LargeUtf8, + true, + )])); + + let x = Arc::new(LargeStringArray::from(vec!["test"])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x text)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, "test"); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // LargeBinary => {bytea} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::LargeBinary, + true, + )])); + + let x = Arc::new(LargeBinaryArray::from(vec!["abc".as_bytes()])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bytea)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::>("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, "abc".as_bytes()); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + } + + #[pg_test] + fn test_coerce_list_types() { + let x_nullable = false; + let field_x = Field::new( + "x", + DataType::List(Field::new("item", DataType::UInt16, false).into()), + x_nullable, + ); + + let x = Arc::new(UInt16Array::from(vec![1, 2])); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 2])); + let x = Arc::new(ListArray::new( + Arc::new(Field::new("item", DataType::UInt16, false)), + offsets, + x, + None, + )); + + let y_nullable = true; + let field_y = Field::new( + "y", + DataType::List(Field::new("item", DataType::UInt16, true).into()), + y_nullable, + ); + + let y = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(3), Some(4)]), + ])); + + let schema = Arc::new(Schema::new(vec![field_x, field_y])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int[], y bigint[])"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::>, Vec>>( + "SELECT x, y FROM test_table LIMIT 1", + ) + .unwrap(); + assert_eq!( + value, + (Some(vec![Some(1), Some(2)]), Some(vec![Some(3), Some(4)])) + ); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + } + + #[pg_test] + #[should_panic(expected = "type mismatch for column \"x\" between table and parquet file.")] + fn test_coerce_struct_types_not_supported() { + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Struct( + vec![ + Field::new("a", DataType::UInt16, false), + Field::new("b", DataType::UInt16, false), + ] + .into(), + ), + false, + )])); + + let a: ArrayRef = Arc::new(UInt16Array::from(vec![Some(1)])); + let b: ArrayRef = Arc::new(UInt16Array::from(vec![Some(2)])); + + let x = Arc::new(StructArray::try_from(vec![("a", a), ("b", b)]).unwrap()); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_type = "CREATE TYPE test_type AS (a int, b bigint)"; + Spi::run(create_type).unwrap(); + + let create_table = "CREATE TABLE test_table (x test_type)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + } + + #[pg_test] + #[should_panic(expected = "violates not-null constraint")] + fn test_copy_not_null_table() { + let create_table = "CREATE TABLE test_table (x int NOT NULL)"; + Spi::run(create_table).unwrap(); + + // first copy non-null value to file + let copy_to = "COPY (SELECT 1 as x) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let result = Spi::get_one::("SELECT x FROM test_table") + .unwrap() + .unwrap(); + assert_eq!(result, 1); + + // then copy null value to file + let copy_to = "COPY (SELECT NULL::int as x) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + // this should panic + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + } + + #[pg_test] + fn test_table_with_different_field_position() { + let copy_to = "COPY (SELECT 1 as x, 'hello' as y) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (y text, x int)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let result = Spi::get_two::<&str, i32>("SELECT y, x FROM test_table LIMIT 1").unwrap(); + assert_eq!(result, (Some("hello"), Some(1))); + } + #[pg_test] fn test_copy_with_empty_options() { let test_table = TestTable::::new("int4".into()) diff --git a/src/parquet_copy_hook/copy_to_dest_receiver.rs b/src/parquet_copy_hook/copy_to_dest_receiver.rs index 2042281..a69b2ce 100644 --- a/src/parquet_copy_hook/copy_to_dest_receiver.rs +++ b/src/parquet_copy_hook/copy_to_dest_receiver.rs @@ -18,7 +18,7 @@ use crate::{ schema_parser::parquet_schema_string_from_tupledesc, uri_utils::parse_uri, }, - pgrx_utils::collect_valid_attributes, + pgrx_utils::{collect_attributes_for, CopyOperation}, }; #[repr(C)] @@ -120,7 +120,7 @@ impl CopyToParquetDestReceiver { pgrx::debug2!( "schema for tuples: {}", - parquet_schema_string_from_tupledesc(&tupledesc) + parquet_schema_string_from_tupledesc(&tupledesc, CopyOperation::CopyTo) ); let current_parquet_writer_context = @@ -179,8 +179,7 @@ extern "C" fn copy_startup(dest: *mut DestReceiver, _operation: i32, tupledesc: let tupledesc = unsafe { BlessTupleDesc(tupledesc) }; let tupledesc = unsafe { PgTupleDesc::from_pg(tupledesc) }; - let include_generated_columns = true; - let attributes = collect_valid_attributes(&tupledesc, include_generated_columns); + let attributes = collect_attributes_for(CopyOperation::CopyTo, &tupledesc); // update the parquet dest receiver's missing fields parquet_dest.tupledesc = tupledesc.as_ptr(); diff --git a/src/pgrx_utils.rs b/src/pgrx_utils.rs index 7d34da0..224dc35 100644 --- a/src/pgrx_utils.rs +++ b/src/pgrx_utils.rs @@ -8,12 +8,23 @@ use pgrx::{ PgTupleDesc, }; -// collect_valid_attributes collects not-dropped attributes from the tuple descriptor. -// If include_generated_columns is false, it will skip generated columns. -pub(crate) fn collect_valid_attributes( +#[derive(Debug, Clone, Copy)] +pub(crate) enum CopyOperation { + CopyFrom, + CopyTo, +} + +// collect_attributes_for collects not-dropped attributes from the tuple descriptor. +// If copy_operation is CopyTo, it also collects generated columns. Otherwise, it does not. +pub(crate) fn collect_attributes_for( + copy_operation: CopyOperation, tupdesc: &PgTupleDesc, - include_generated_columns: bool, ) -> Vec { + let include_generated_columns = match copy_operation { + CopyOperation::CopyFrom => false, + CopyOperation::CopyTo => true, + }; + let mut attributes = vec![]; let mut attributes_set = HashSet::<&str>::new();