From c8af064b9839ce88de1de9b4f61cb8c62dbafa36 Mon Sep 17 00:00:00 2001 From: Aykut Bozkurt Date: Mon, 11 Nov 2024 14:50:37 +0300 Subject: [PATCH] Cast types on read `COPY FROM parquet` is too strict when matching Postgres tupledesc schema to the parquet file schema. e.g. `INT32` type in the parquet schema cannot be read into a Postgres column with `int64` type. We can avoid this situation by casting arrow array to the array that is expected by the tupledesc schema, if the cast is possible. We can make use of `arrow-cast` crate, which is in the same project with `arrow`. Its public api lets us check if a cast possible between 2 arrow types and perform the cast. To make sure the cast is possible, we need to do 2 checks: 1. arrow-cast allows the cast from "arrow type at the parquet file" to "arrow type at the schema that is generated for tupledesc", 2. the cast is meaningful at Postgres. We check if there is an explicit cast from "Postgres type that corresponds for the arrow type at Parquet file" to "Postgres type at tupledesc". With that we can cast between many castable types as shown below: - INT16 => INT32 - UINT32 => INT64 - FLOAT32 => FLOAT64 - LargeUtf8 => UTF8 - LargeBinary => Binary - Struct, Array, and Map with castable fields, e.g. [UINT16] => [INT64] or struct {'x': UINT16} => struct {'x': INT64} **NOTE**: Struct fields must match by name and position to be cast. Closes #67. --- Cargo.lock | 1 + Cargo.toml | 1 + src/arrow_parquet/arrow_to_pg.rs | 395 ++++---- src/arrow_parquet/arrow_to_pg/timestamptz.rs | 19 +- src/arrow_parquet/parquet_reader.rs | 70 +- src/arrow_parquet/parquet_writer.rs | 17 +- src/arrow_parquet/pg_to_arrow.rs | 20 +- src/arrow_parquet/schema_parser.rs | 289 ++++-- src/lib.rs | 959 ++++++++++++++++++ .../copy_to_dest_receiver.rs | 11 +- src/pgrx_utils.rs | 20 +- src/type_compat/pg_arrow_type_conversions.rs | 8 +- 12 files changed, 1518 insertions(+), 292 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d594d73..7150e6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2231,6 +2231,7 @@ name = "pg_parquet" version = "0.1.0" dependencies = [ "arrow", + "arrow-cast", "arrow-schema", "aws-config", "aws-credential-types", diff --git a/Cargo.toml b/Cargo.toml index b5a372b..76a9e1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ pg_test = [] [dependencies] arrow = {version = "53", default-features = false} +arrow-cast = {version = "53", default-features = false} arrow-schema = {version = "53", default-features = false} aws-config = { version = "1.5", default-features = false, features = ["rustls"]} aws-credential-types = {version = "1.2", default-features = false} diff --git a/src/arrow_parquet/arrow_to_pg.rs b/src/arrow_parquet/arrow_to_pg.rs index ec7c9ce..079ee08 100644 --- a/src/arrow_parquet/arrow_to_pg.rs +++ b/src/arrow_parquet/arrow_to_pg.rs @@ -3,29 +3,23 @@ use arrow::array::{ Float64Array, Int16Array, Int32Array, Int64Array, ListArray, MapArray, StringArray, StructArray, Time64MicrosecondArray, TimestampMicrosecondArray, UInt32Array, }; -use arrow_schema::Fields; +use arrow_schema::{DataType, FieldRef, Fields, TimeUnit}; use pgrx::{ datum::{Date, Time, TimeWithTimeZone, Timestamp, TimestampWithTimeZone}, - pg_sys::{ - Datum, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, - INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, - }, + pg_sys::{Datum, FormData_pg_attribute, Oid, CHAROID, TEXTOID, TIMEOID}, prelude::PgHeapTuple, AllocatedByRust, AnyNumeric, IntoDatum, PgTupleDesc, }; use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CollectAttributesFor, }, type_compat::{ fallback_to_text::{reset_fallback_to_text_context, FallbackToText}, geometry::{is_postgis_geometry_type, Geometry}, map::{is_map_type, Map}, - pg_arrow_type_conversions::{ - extract_precision_and_scale_from_numeric_typmod, should_write_numeric_as_text, - }, }, }; @@ -57,25 +51,33 @@ pub(crate) trait ArrowArrayToPgType: From { #[derive(Clone)] pub(crate) struct ArrowToPgAttributeContext { name: String, + data_type: DataType, + needs_cast: bool, typoid: Oid, typmod: i32, - is_array: bool, - is_composite: bool, is_geometry: bool, - is_map: bool, attribute_contexts: Option>, attribute_tupledesc: Option>, precision: Option, scale: Option, + timezone: Option, } impl ArrowToPgAttributeContext { - pub(crate) fn new(name: &str, typoid: Oid, typmod: i32, fields: Fields) -> Self { - let field = fields - .iter() - .find(|field| field.name() == name) - .unwrap_or_else(|| panic!("failed to find field {}", name)) - .clone(); + pub(crate) fn new( + name: &str, + typoid: Oid, + typmod: i32, + field: FieldRef, + cast_to_type: Option, + ) -> Self { + let needs_cast = cast_to_type.is_some(); + + let data_type = if let Some(cast_to_type) = &cast_to_type { + cast_to_type.clone() + } else { + field.data_type().clone() + }; let is_array = is_array_type(typoid); let is_composite; @@ -123,16 +125,29 @@ impl ArrowToPgAttributeContext { None }; - let precision; - let scale; - if attribute_typoid == NUMERICOID { - let (p, s) = extract_precision_and_scale_from_numeric_typmod(typmod); - precision = Some(p); - scale = Some(s); - } else { - precision = None; - scale = None; - } + let (precision, scale) = match &data_type { + DataType::Decimal128(p, s) => (Some(*p as _), Some(*s as _)), + DataType::List(field) => { + if let DataType::Decimal128(p, s) = field.data_type() { + (Some(*p as _), Some(*s as _)) + } else { + (None, None) + } + } + _ => (None, None), + }; + + let timezone = match &data_type { + DataType::Timestamp(_, Some(timezone)) => Some(timezone.to_string()), + DataType::List(field) => { + if let DataType::Timestamp(_, Some(timezone)) = field.data_type() { + Some(timezone.to_string()) + } else { + None + } + } + _ => None, + }; // for composite and map types, recursively collect attribute contexts let attribute_contexts = if let Some(attribute_tupledesc) = &attribute_tupledesc { @@ -147,9 +162,16 @@ impl ArrowToPgAttributeContext { _ => unreachable!(), }; + let attributes = + collect_attributes_for(CollectAttributesFor::Struct, attribute_tupledesc); + + // we only cast the top-level attributes, which already covers the nested attributes + let cast_to_types = None; + Some(collect_arrow_to_pg_attribute_contexts( - attribute_tupledesc, + &attributes, &fields, + cast_to_types, )) } else { None @@ -157,43 +179,63 @@ impl ArrowToPgAttributeContext { Self { name: name.to_string(), + data_type, + needs_cast, typoid: attribute_typoid, typmod, - is_array, - is_composite, is_geometry, - is_map, attribute_contexts, attribute_tupledesc, scale, precision, + timezone, } } pub(crate) fn name(&self) -> &str { &self.name } + + pub(crate) fn needs_cast(&self) -> bool { + self.needs_cast + } + + pub(crate) fn data_type(&self) -> &DataType { + &self.data_type + } } pub(crate) fn collect_arrow_to_pg_attribute_contexts( - tupledesc: &PgTupleDesc, + attributes: &[FormData_pg_attribute], fields: &Fields, + cast_to_types: Option>>, ) -> Vec { - // parquet file does not contain generated columns. PG will handle them. - let include_generated_columns = false; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); let mut attribute_contexts = vec![]; - for attribute in attributes { + for (idx, attribute) in attributes.iter().enumerate() { let attribute_name = attribute.name(); let attribute_typoid = attribute.type_oid().value(); let attribute_typmod = attribute.type_mod(); + let field = fields + .iter() + .find(|field| field.name() == attribute_name) + .unwrap_or_else(|| panic!("failed to find field {}", attribute_name)) + .clone(); + + let cast_to_type = if let Some(cast_to_types) = cast_to_types.as_ref() { + debug_assert!(cast_to_types.len() == attributes.len()); + cast_to_types.get(idx).cloned().expect("cast_to_type null") + } else { + None + }; + let attribute_context = ArrowToPgAttributeContext::new( attribute_name, attribute_typoid, attribute_typmod, - fields.clone(), + field, + cast_to_type, ); attribute_contexts.push(attribute_context); @@ -206,7 +248,7 @@ pub(crate) fn to_pg_datum( attribute_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - if attribute_context.is_array { + if matches!(attribute_array.data_type(), DataType::List(_)) { to_pg_array_datum(attribute_array, attribute_context) } else { to_pg_nonarray_datum(attribute_array, attribute_context) @@ -227,43 +269,34 @@ fn to_pg_nonarray_datum( primitive_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - match attribute_context.typoid { - FLOAT4OID => { + match attribute_context.data_type() { + DataType::Float32 => { to_pg_datum!(Float32Array, f32, primitive_array, attribute_context) } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!(Float64Array, f64, primitive_array, attribute_context) } - INT2OID => { + DataType::Int16 => { to_pg_datum!(Int16Array, i16, primitive_array, attribute_context) } - INT4OID => { + DataType::Int32 => { to_pg_datum!(Int32Array, i32, primitive_array, attribute_context) } - INT8OID => { + DataType::Int64 => { to_pg_datum!(Int64Array, i64, primitive_array, attribute_context) } - BOOLOID => { - to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) - } - CHAROID => { - to_pg_datum!(StringArray, i8, primitive_array, attribute_context) - } - TEXTOID => { - to_pg_datum!(StringArray, String, primitive_array, attribute_context) - } - BYTEAOID => { - to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) - } - OIDOID => { + DataType::UInt32 => { to_pg_datum!(UInt32Array, Oid, primitive_array, attribute_context) } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Boolean => { + to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) + } + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, i8, primitive_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!(StringArray, String, primitive_array, attribute_context) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -272,72 +305,72 @@ fn to_pg_nonarray_datum( primitive_array, attribute_context ) - } else { - to_pg_datum!( - Decimal128Array, - AnyNumeric, - primitive_array, - attribute_context - ) } } - DATEOID => { - to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) + } else { + to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) + } } - TIMEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Time64MicrosecondArray, - Time, + Decimal128Array, + AnyNumeric, primitive_array, attribute_context ) } - TIMETZOID => { + DataType::Date32 => { + to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + } + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Time, + primitive_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + TimeWithTimeZone, + primitive_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - TimeWithTimeZone, + TimestampMicrosecondArray, + Timestamp, primitive_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => { to_pg_datum!( TimestampMicrosecondArray, - Timestamp, + TimestampWithTimeZone, primitive_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - TimestampWithTimeZone, + StructArray, + PgHeapTuple, primitive_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Map, primitive_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - PgHeapTuple, - primitive_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Map, primitive_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - FallbackToText, - primitive_array, - attribute_context - ) - } + panic!("unsupported data type: {:?}", attribute_context.data_type()); } } } @@ -354,8 +387,13 @@ fn to_pg_array_datum( let list_array = list_array.value(0).to_data(); - match attribute_context.typoid { - FLOAT4OID => { + let element_field = match attribute_context.data_type() { + DataType::List(field) => field, + _ => unreachable!(), + }; + + match element_field.data_type() { + DataType::Float32 => { to_pg_datum!( Float32Array, Vec>, @@ -363,7 +401,7 @@ fn to_pg_array_datum( attribute_context ) } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!( Float64Array, Vec>, @@ -371,16 +409,19 @@ fn to_pg_array_datum( attribute_context ) } - INT2OID => { + DataType::Int16 => { to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) } - INT4OID => { + DataType::Int32 => { to_pg_datum!(Int32Array, Vec>, list_array, attribute_context) } - INT8OID => { + DataType::Int64 => { to_pg_datum!(Int64Array, Vec>, list_array, attribute_context) } - BOOLOID => { + DataType::UInt32 => { + to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) + } + DataType::Boolean => { to_pg_datum!( BooleanArray, Vec>, @@ -388,34 +429,17 @@ fn to_pg_array_datum( attribute_context ) } - CHAROID => { - to_pg_datum!(StringArray, Vec>, list_array, attribute_context) - } - TEXTOID => { - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } - BYTEAOID => { - to_pg_datum!( - BinaryArray, - Vec>>, - list_array, - attribute_context - ) - } - OIDOID => { - to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) - } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!( + StringArray, + Vec>, + list_array, + attribute_context + ) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -424,82 +448,87 @@ fn to_pg_array_datum( list_array, attribute_context ) + } + } + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!( + BinaryArray, + Vec>, + list_array, + attribute_context + ) } else { to_pg_datum!( - Decimal128Array, - Vec>, + BinaryArray, + Vec>>, list_array, attribute_context ) } } - DATEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Date32Array, - Vec>, + Decimal128Array, + Vec>, list_array, attribute_context ) } - TIMEOID => { + DataType::Date32 => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + Date32Array, + Vec>, list_array, attribute_context ) } - TIMETZOID => { + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + TimestampMicrosecondArray, + Vec>, list_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => { to_pg_datum!( TimestampMicrosecondArray, - Vec>, + Vec>, list_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - Vec>, + StructArray, + Vec>>, list_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Vec>, list_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - Vec>>, - list_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Vec>, list_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!( - BinaryArray, - Vec>, - list_array, - attribute_context - ) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } + panic!("unsupported data type: {:?}", attribute_context.data_type()); } } } diff --git a/src/arrow_parquet/arrow_to_pg/timestamptz.rs b/src/arrow_parquet/arrow_to_pg/timestamptz.rs index 81bf5f9..a5a4736 100644 --- a/src/arrow_parquet/arrow_to_pg/timestamptz.rs +++ b/src/arrow_parquet/arrow_to_pg/timestamptz.rs @@ -7,11 +7,16 @@ use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; // Timestamptz impl ArrowArrayToPgType for TimestampMicrosecondArray { - fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + fn to_pg_type(self, context: &ArrowToPgAttributeContext) -> Option { if self.is_null(0) { None } else { - Some(i64_to_timestamptz(self.value(0))) + let timezone = context + .timezone + .as_ref() + .expect("timezone is required for timestamptz"); + + Some(i64_to_timestamptz(self.value(0), timezone)) } } } @@ -20,11 +25,17 @@ impl ArrowArrayToPgType for TimestampMicrosecondArray { impl ArrowArrayToPgType>> for TimestampMicrosecondArray { fn to_pg_type( self, - _context: &ArrowToPgAttributeContext, + context: &ArrowToPgAttributeContext, ) -> Option>> { let mut vals = vec![]; + + let timezone = context + .timezone + .as_ref() + .expect("timezone is required for timestamptz[]"); + for val in self.iter() { - let val = val.map(i64_to_timestamptz); + let val = val.map(|v| i64_to_timestamptz(v, timezone)); vals.push(val); } Some(vals) diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index a3cd53b..6db958a 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -1,24 +1,32 @@ +use std::sync::Arc; + use arrow::array::RecordBatch; +use arrow_cast::{cast_with_options, CastOptions}; use futures::StreamExt; use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStream}; use pgrx::{ check_for_interrupts, pg_sys::{ - fmgr_info, getTypeBinaryOutputInfo, varlena, Datum, FmgrInfo, InvalidOid, SendFunctionCall, + fmgr_info, getTypeBinaryOutputInfo, varlena, Datum, FmgrInfo, FormData_pg_attribute, + InvalidOid, SendFunctionCall, }, vardata_any, varsize_any_exhdr, void_mut_ptr, AllocatedByPostgres, PgBox, PgTupleDesc, }; use url::Url; use crate::{ - arrow_parquet::arrow_to_pg::to_pg_datum, - pgrx_utils::collect_valid_attributes, + arrow_parquet::{ + arrow_to_pg::to_pg_datum, schema_parser::parquet_schema_string_from_attributes, + }, + pgrx_utils::{collect_attributes_for, CollectAttributesFor}, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, }; use super::{ arrow_to_pg::{collect_arrow_to_pg_attribute_contexts, ArrowToPgAttributeContext}, - schema_parser::ensure_arrow_schema_match_tupledesc, + schema_parser::{ + ensure_arrow_schema_match_tupledesc_schema, parse_arrow_schema_from_attributes, + }, uri_utils::{parquet_reader_from_uri, PG_BACKEND_TOKIO_RUNTIME}, }; @@ -41,12 +49,35 @@ impl ParquetReaderContext { let parquet_reader = parquet_reader_from_uri(&uri); - let schema = parquet_reader.schema(); - ensure_arrow_schema_match_tupledesc(schema.clone(), tupledesc); + let parquet_file_schema = parquet_reader.schema(); + + let attributes = collect_attributes_for(CollectAttributesFor::CopyFrom, tupledesc); + + pgrx::debug2!( + "schema for tuples: {}", + parquet_schema_string_from_attributes(&attributes) + ); - let binary_out_funcs = Self::collect_binary_out_funcs(tupledesc); + let tupledesc_schema = parse_arrow_schema_from_attributes(&attributes); - let attribute_contexts = collect_arrow_to_pg_attribute_contexts(tupledesc, &schema.fields); + let tupledesc_schema = Arc::new(tupledesc_schema); + + // Ensure that the arrow schema matches the tupledesc. + // Gets cast_to_types for each attribute if a cast is needed for the attribute's columnar array + // to match the expected columnar array for its tupledesc type. + let cast_to_types = ensure_arrow_schema_match_tupledesc_schema( + parquet_file_schema.clone(), + tupledesc_schema.clone(), + &attributes, + ); + + let attribute_contexts = collect_arrow_to_pg_attribute_contexts( + &attributes, + &tupledesc_schema.fields, + Some(cast_to_types), + ); + + let binary_out_funcs = Self::collect_binary_out_funcs(&attributes); ParquetReaderContext { buffer: Vec::new(), @@ -60,14 +91,11 @@ impl ParquetReaderContext { } fn collect_binary_out_funcs( - tupledesc: &PgTupleDesc, + attributes: &[FormData_pg_attribute], ) -> Vec> { unsafe { let mut binary_out_funcs = vec![]; - let include_generated_columns = false; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); - for att in attributes.iter() { let typoid = att.type_oid(); @@ -94,11 +122,25 @@ impl ParquetReaderContext { for attribute_context in attribute_contexts { let name = attribute_context.name(); - let column = record_batch + let column_array = record_batch .column_by_name(name) .unwrap_or_else(|| panic!("column {} not found", name)); - let datum = to_pg_datum(column.to_data(), attribute_context); + let datum = if attribute_context.needs_cast() { + // should fail instead of returning None if the cast fails at runtime + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + + let casted_column_array = + cast_with_options(&column_array, attribute_context.data_type(), &cast_options) + .unwrap_or_else(|e| panic!("failed to cast column {}: {}", name, e)); + + to_pg_datum(casted_column_array.to_data(), attribute_context) + } else { + to_pg_datum(column_array.to_data(), attribute_context) + }; datums.push(datum); } diff --git a/src/arrow_parquet/parquet_writer.rs b/src/arrow_parquet/parquet_writer.rs index 7c12009..e93ea8b 100644 --- a/src/arrow_parquet/parquet_writer.rs +++ b/src/arrow_parquet/parquet_writer.rs @@ -12,9 +12,12 @@ use url::Url; use crate::{ arrow_parquet::{ compression::{PgParquetCompression, PgParquetCompressionWithLevel}, - schema_parser::parse_arrow_schema_from_tupledesc, + schema_parser::{ + parquet_schema_string_from_attributes, parse_arrow_schema_from_attributes, + }, uri_utils::{parquet_writer_from_uri, PG_BACKEND_TOKIO_RUNTIME}, }, + pgrx_utils::{collect_attributes_for, CollectAttributesFor}, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, }; @@ -57,12 +60,20 @@ impl ParquetWriterContext { .set_created_by("pg_parquet".to_string()) .build(); - let schema = parse_arrow_schema_from_tupledesc(tupledesc); + let attributes = collect_attributes_for(CollectAttributesFor::CopyTo, tupledesc); + + pgrx::debug2!( + "schema for tuples: {}", + parquet_schema_string_from_attributes(&attributes) + ); + + let schema = parse_arrow_schema_from_attributes(&attributes); let schema = Arc::new(schema); let parquet_writer = parquet_writer_from_uri(&uri, schema.clone(), writer_props); - let attribute_contexts = collect_pg_to_arrow_attribute_contexts(tupledesc, &schema.fields); + let attribute_contexts = + collect_pg_to_arrow_attribute_contexts(&attributes, &schema.fields); ParquetWriterContext { parquet_writer, diff --git a/src/arrow_parquet/pg_to_arrow.rs b/src/arrow_parquet/pg_to_arrow.rs index 40cc03c..530c7f7 100644 --- a/src/arrow_parquet/pg_to_arrow.rs +++ b/src/arrow_parquet/pg_to_arrow.rs @@ -7,16 +7,17 @@ use pgrx::{ datum::{Date, Time, TimeWithTimeZone, Timestamp, TimestampWithTimeZone, UnboxDatum}, heap_tuple::PgHeapTuple, pg_sys::{ - Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, - NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, + FormData_pg_attribute, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, + INT2OID, INT4OID, INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, + TIMESTAMPTZOID, TIMETZOID, }, - AllocatedByRust, AnyNumeric, FromDatum, PgTupleDesc, + AllocatedByRust, AnyNumeric, FromDatum, }; use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CollectAttributesFor, }, type_compat::{ fallback_to_text::{reset_fallback_to_text_context, FallbackToText}, @@ -146,7 +147,10 @@ impl PgToArrowAttributeContext { _ => unreachable!(), }; - collect_pg_to_arrow_attribute_contexts(&attribute_tupledesc, &fields) + let attributes = + collect_attributes_for(CollectAttributesFor::Struct, &attribute_tupledesc); + + collect_pg_to_arrow_attribute_contexts(&attributes, &fields) }); Self { @@ -166,11 +170,9 @@ impl PgToArrowAttributeContext { } pub(crate) fn collect_pg_to_arrow_attribute_contexts( - tupledesc: &PgTupleDesc, + attributes: &[FormData_pg_attribute], fields: &Fields, ) -> Vec { - let include_generated_columns = true; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); let mut attribute_contexts = vec![]; for attribute in attributes { diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index 8dd79cf..e3b367a 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -1,18 +1,20 @@ use std::{collections::HashMap, ops::Deref, sync::Arc}; use arrow::datatypes::{Field, Fields, Schema}; -use arrow_schema::FieldRef; +use arrow_cast::can_cast_types; +use arrow_schema::{DataType, FieldRef}; use parquet::arrow::{arrow_to_parquet_schema, PARQUET_FIELD_ID_META_KEY}; use pg_sys::{ - Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, + can_coerce_type, CoercionContext::COERCION_EXPLICIT, FormData_pg_attribute, InvalidOid, Oid, + BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, }; use pgrx::{check_for_interrupts, prelude::*, PgTupleDesc}; use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CollectAttributesFor, }, type_compat::{ geometry::is_postgis_geometry_type, @@ -23,8 +25,10 @@ use crate::{ }, }; -pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: &PgTupleDesc) -> String { - let arrow_schema = parse_arrow_schema_from_tupledesc(tupledesc); +pub(crate) fn parquet_schema_string_from_attributes( + attributes: &[FormData_pg_attribute], +) -> String { + let arrow_schema = parse_arrow_schema_from_attributes(attributes); let parquet_schema = arrow_to_parquet_schema(&arrow_schema) .unwrap_or_else(|e| panic!("failed to convert arrow schema to parquet schema: {}", e)); @@ -33,14 +37,11 @@ pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: &PgTupleDesc) -> S String::from_utf8(buf).unwrap_or_else(|e| panic!("failed to convert schema to string: {}", e)) } -pub(crate) fn parse_arrow_schema_from_tupledesc(tupledesc: &PgTupleDesc) -> Schema { +pub(crate) fn parse_arrow_schema_from_attributes(attributes: &[FormData_pg_attribute]) -> Schema { let mut field_id = 0; let mut struct_attribute_fields = vec![]; - let include_generated_columns = true; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); - for attribute in attributes { let attribute_name = attribute.name(); let attribute_typoid = attribute.type_oid().value(); @@ -92,8 +93,7 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i let mut child_fields: Vec> = vec![]; - let include_generated_columns = true; - let attributes = collect_valid_attributes(&tupledesc, include_generated_columns); + let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); for attribute in attributes { if attribute.is_dropped() { @@ -130,10 +130,12 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i child_fields.push(child_field); } + let nullable = true; + Field::new( elem_name, arrow::datatypes::DataType::Struct(Fields::from(child_fields)), - true, + nullable, ) .with_metadata(metadata) .into() @@ -159,10 +161,12 @@ fn parse_list_schema(typoid: Oid, typmod: i32, array_name: &str, field_id: &mut parse_primitive_schema(typoid, typmod, array_name, field_id) }; + let nullable = true; + Field::new( array_name, arrow::datatypes::DataType::List(elem_field), - true, + nullable, ) .with_metadata(list_metadata) .into() @@ -177,13 +181,18 @@ fn parse_map_schema(typoid: Oid, typmod: i32, map_name: &str, field_id: &mut i32 *field_id += 1; let tupledesc = tuple_desc(typoid, typmod); + let entries_field = parse_struct_schema(tupledesc, map_name, field_id); let entries_field = adjust_map_entries_field(entries_field); + let keys_sorted = false; + + let nullable = true; + Field::new( map_name, - arrow::datatypes::DataType::Map(entries_field, false), - true, + arrow::datatypes::DataType::Map(entries_field, keys_sorted), + nullable, ) .with_metadata(map_metadata) .into() @@ -204,31 +213,33 @@ fn parse_primitive_schema( *field_id += 1; + let nullable = true; + let field = match typoid { - FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, true), - FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, true), - BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, true), - INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, true), - INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, true), - INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, true), + FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, nullable), + FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, nullable), + BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, nullable), + INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, nullable), + INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, nullable), + INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, nullable), NUMERICOID => { let (precision, scale) = extract_precision_and_scale_from_numeric_typmod(typmod); if should_write_numeric_as_text(precision) { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } else { Field::new( elem_name, arrow::datatypes::DataType::Decimal128(precision as _, scale as _), - true, + nullable, ) } } - DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, true), + DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, nullable), TIMESTAMPOID => Field::new( elem_name, arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), - true, + nullable, ), TIMESTAMPTZOID => Field::new( elem_name, @@ -236,31 +247,31 @@ fn parse_primitive_schema( arrow::datatypes::TimeUnit::Microsecond, Some("+00:00".into()), ), - true, + nullable, ), TIMEOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ), TIMETZOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ) .with_metadata(HashMap::from_iter(vec![( "adjusted_to_utc".into(), "true".into(), )])), - CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, true), - OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, true), + CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable), + OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, nullable), _ => { if is_postgis_geometry_type(typoid) { - Field::new(elem_name, arrow::datatypes::DataType::Binary, true) + Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable) } else { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } } }; @@ -289,60 +300,210 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { let key_field = fields.find("key").expect("expected key field").1; let value_field = fields.find("val").expect("expected val field").1; - not_nullable_key_field = - Field::new(key_field.name(), key_field.data_type().clone(), false) - .with_metadata(key_field.metadata().clone()); + let key_nullable = false; - nullable_value_field = - Field::new(value_field.name(), value_field.data_type().clone(), true) - .with_metadata(value_field.metadata().clone()); + not_nullable_key_field = Field::new( + key_field.name(), + key_field.data_type().clone(), + key_nullable, + ) + .with_metadata(key_field.metadata().clone()); + + let value_nullable = true; + + nullable_value_field = Field::new( + value_field.name(), + value_field.data_type().clone(), + value_nullable, + ) + .with_metadata(value_field.metadata().clone()); } _ => { panic!("expected struct data type for map entries") } }; + let entries_nullable = false; + let entries_field = Field::new( name, arrow::datatypes::DataType::Struct(Fields::from(vec![ not_nullable_key_field, nullable_value_field, ])), - false, + entries_nullable, ) .with_metadata(metadata); Arc::new(entries_field) } -pub(crate) fn ensure_arrow_schema_match_tupledesc( - file_schema: Arc, - tupledesc: &PgTupleDesc, -) { - let table_schema = parse_arrow_schema_from_tupledesc(tupledesc); +// ensure_arrow_schema_match_tupledesc_schema throws an error if the arrow schema does not match the table schema. +// If the arrow schema is castable to the table schema, it returns a vector of Option to cast to +// for each field. +pub(crate) fn ensure_arrow_schema_match_tupledesc_schema( + arrow_schema: Arc, + tupledesc_schema: Arc, + attributes: &[FormData_pg_attribute], +) -> Vec> { + let mut cast_to_types = Vec::new(); + + for (tupledesc_field, attribute) in tupledesc_schema.fields().iter().zip(attributes.iter()) { + let field_name = tupledesc_field.name(); - for table_schema_field in table_schema.fields().iter() { - let table_schema_field_name = table_schema_field.name(); - let table_schema_field_type = table_schema_field.data_type(); + let arrow_field = arrow_schema.column_with_name(field_name); - let file_schema_field = file_schema.column_with_name(table_schema_field_name); + if arrow_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } - if let Some(file_schema_field) = file_schema_field { - let file_schema_field_type = file_schema_field.1.data_type(); + let (_, arrow_field) = arrow_field.unwrap(); + let arrow_field = Arc::new(arrow_field.clone()); - if file_schema_field_type != table_schema_field_type { - panic!( - "type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but file had \"{}\"", - table_schema_field_name, - table_schema_field_type, - file_schema_field_type, - ); - } - } else { + let from_type = arrow_field.data_type(); + let to_type = tupledesc_field.data_type(); + + // no cast needed + if from_type == to_type { + cast_to_types.push(None); + continue; + } + + if !is_coercible(from_type, to_type, attribute.atttypid, attribute.atttypmod) { panic!( - "column \"{}\" is not found in parquet file", - table_schema_field_name + "type mismatch for column \"{}\" between table and parquet file.\n\n\ + table has \"{}\".\n\n\ + parquet file has \"{}\"", + field_name, to_type, from_type, ); } + + pgrx::debug2!( + "column \"{}\" is being cast from \"{}\" to \"{}\"", + field_name, + from_type, + to_type + ); + + cast_to_types.push(Some(to_type.clone())); + } + + cast_to_types +} + +// is_coercible first checks if "from_type" can be cast to "to_type" by arrow-cast. +// Then, it checks if the cast is meaningful at Postgres by seeing if there is +// an explicit coercion from "from_typoid" to "to_typoid". +// +// Additionaly, we need to be careful about struct rules for the cast: +// Arrow supports casting struct fields by field position instead of field name, +// which is not the intended behavior for pg_parquet. Hence, we make sure the field names +// match for structs. +fn is_coercible(from_type: &DataType, to_type: &DataType, to_typoid: Oid, to_typmod: i32) -> bool { + match (from_type, to_type) { + (DataType::Struct(from_fields), DataType::Struct(to_fields)) => { + if from_fields.len() != to_fields.len() { + return false; + } + + let tupledesc = tuple_desc(to_typoid, to_typmod); + + let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); + + for (from_field, (to_field, to_attribute)) in from_fields + .iter() + .zip(to_fields.iter().zip(attributes.iter())) + { + if from_field.name() != to_field.name() { + return false; + } + + if !is_coercible( + from_field.data_type(), + to_field.data_type(), + to_attribute.type_oid().value(), + to_attribute.type_mod(), + ) { + return false; + } + } + + true + } + (DataType::List(from_field), DataType::List(to_field)) => { + let element_oid = array_element_typoid(to_typoid); + let element_typmod = to_typmod; + + is_coercible( + from_field.data_type(), + to_field.data_type(), + element_oid, + element_typmod, + ) + } + (DataType::Map(from_entries_field, _), DataType::Map(to_entries_field, _)) => { + // entries field cannot be null + if from_entries_field.is_nullable() { + return false; + } + + let entries_typoid = domain_array_base_elem_typoid(to_typoid); + + is_coercible( + from_entries_field.data_type(), + to_entries_field.data_type(), + entries_typoid, + to_typmod, + ) + } + _ => { + // check if arrow-cast can cast the types + if !can_cast_types(from_type, to_type) { + return false; + } + + let from_typoid = pg_type_for_arrow_primitive_type(from_type); + + // pg_parquet could not recognize that arrow type + if from_typoid == InvalidOid { + return false; + } + + let n_args = 1; + let ccontext = COERCION_EXPLICIT; + let input_typeids = [from_typoid]; + let target_typeids = [to_typoid]; + + // check if coercion is meaningful at Postgres (it has a coercion path) + unsafe { + can_coerce_type( + n_args, + input_typeids.as_ptr(), + target_typeids.as_ptr(), + ccontext, + ) + } + } + } +} + +// pg_type_for_arrow_primitive_type returns Postgres type for given +// primitive arrow type. It returns InvalidOid if the arrow type is not recognized. +fn pg_type_for_arrow_primitive_type(data_type: &DataType) -> Oid { + match data_type { + DataType::Float32 | DataType::Float16 => FLOAT4OID, + DataType::Float64 => FLOAT8OID, + DataType::Int16 | DataType::UInt16 | DataType::Int8 | DataType::UInt8 => INT2OID, + DataType::Int32 | DataType::UInt32 => INT4OID, + DataType::Int64 | DataType::UInt64 => INT8OID, + DataType::Decimal128(_, _) => NUMERICOID, + DataType::Boolean => BOOLOID, + DataType::Date32 => DATEOID, + DataType::Time64(_) => TIMEOID, + DataType::Timestamp(_, None) => TIMESTAMPOID, + DataType::Timestamp(_, Some(_)) => TIMESTAMPTZOID, + DataType::Utf8 | DataType::LargeUtf8 => TEXTOID, + DataType::Binary | DataType::LargeBinary => BYTEAOID, + _ => InvalidOid, } } diff --git a/src/lib.rs b/src/lib.rs index 57584bb..fa644b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,8 +37,11 @@ pub extern "C" fn _PG_init() { #[cfg(any(test, feature = "pg_test"))] #[pg_schema] mod tests { + use std::fs::File; use std::io::Write; use std::marker::PhantomData; + use std::sync::Arc; + use std::vec; use std::{collections::HashMap, fmt::Debug}; use crate::arrow_parquet::compression::PgParquetCompression; @@ -46,8 +49,19 @@ mod tests { use crate::type_compat::geometry::Geometry; use crate::type_compat::map::Map; use crate::type_compat::pg_arrow_type_conversions::{ + date_to_i32, time_to_i64, timestamp_to_i64, timestamptz_to_i64, timetz_to_i64, DEFAULT_UNBOUNDED_NUMERIC_PRECISION, DEFAULT_UNBOUNDED_NUMERIC_SCALE, }; + use arrow::array::{ + ArrayRef, BinaryArray, BooleanArray, Date32Array, Decimal128Array, Float32Array, + Float64Array, Int16Array, Int32Array, Int8Array, LargeBinaryArray, LargeStringArray, + ListArray, MapArray, RecordBatch, StringArray, StructArray, Time64MicrosecondArray, + TimestampMicrosecondArray, UInt16Array, UInt32Array, UInt64Array, + }; + use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow::datatypes::UInt16Type; + use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit}; + use parquet::arrow::ArrowWriter; use pgrx::pg_sys::Oid; use pgrx::{ composite_type, @@ -340,6 +354,14 @@ mod tests { Spi::get_one(&query).unwrap().unwrap() } + fn write_record_batch_to_parquet(schema: SchemaRef, record_batch: RecordBatch) { + let file = File::create("/tmp/test.parquet").unwrap(); + let mut writer = ArrowWriter::try_new(file, schema, None).unwrap(); + + writer.write(&record_batch).unwrap(); + writer.close().unwrap(); + } + #[pg_test] fn test_int2() { let test_table = TestTable::::new("int2".into()); @@ -1391,6 +1413,943 @@ mod tests { Spi::run("DROP TYPE dog;").unwrap(); } + #[pg_test] + fn test_coerce_primitive_types() { + // INT16 => {int, bigint} + let x_nullable = false; + let y_nullable = true; + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int16, x_nullable), + Field::new("y", DataType::Int16, y_nullable), + ])); + + let x = Arc::new(Int16Array::from(vec![1])); + let y = Arc::new(Int16Array::from(vec![2])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int, y bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // INT32 => {bigint} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, true)])); + + let x = Arc::new(Int32Array::from(vec![1])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // FLOAT32 => {double} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, true)])); + + let x = Arc::new(Float32Array::from(vec![1.123])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x double precision)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value as f32, 1.123); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // FLOAT64 => {float} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float64, true)])); + + let x = Arc::new(Float64Array::from(vec![1.123])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x real)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1.123); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // DATE32 => {timestamp} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Date32, true)])); + + let date = Date::new(2022, 5, 5).unwrap(); + + let x = Arc::new(Date32Array::from(vec![date_to_i32(date)])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timestamp)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, Timestamp::from(date)); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIMESTAMP => {timestamptz} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )])); + + let timestamp = Timestamp::from(Date::new(2022, 5, 5).unwrap()); + + let x = Arc::new(TimestampMicrosecondArray::from(vec![timestamp_to_i64( + timestamp, + )])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timestamptz)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value.at_timezone("UTC").unwrap(), timestamp); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIMESTAMPTZ => {timestamp} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Timestamp(TimeUnit::Microsecond, Some("Europe/Paris".into())), + true, + )])); + + let timestamptz = + TimestampWithTimeZone::with_timezone(2022, 5, 5, 0, 0, 0.0, "Europe/Paris").unwrap(); + + let x = Arc::new( + TimestampMicrosecondArray::from(vec![timestamptz_to_i64(timestamptz)]) + .with_timezone("Europe/Paris"), + ); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timestamp)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, timestamptz.at_timezone("UTC").unwrap()); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIME64 => {timetz} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Time64(TimeUnit::Microsecond), + true, + )])); + + let time = Time::new(13, 0, 0.0).unwrap(); + + let x = Arc::new(Time64MicrosecondArray::from(vec![time_to_i64(time)])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timetz)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, time.into()); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIME64 => {time} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Time64(TimeUnit::Microsecond), + true, + ) + .with_metadata(HashMap::from_iter(vec![( + "adjusted_to_utc".into(), + "true".into(), + )]))])); + + let timetz = TimeWithTimeZone::with_timezone(13, 0, 0.0, "UTC").unwrap(); + + let x = Arc::new(Time64MicrosecondArray::from(vec![timetz_to_i64(timetz)])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x time)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::