From 5bd39d5948e066ce628056cf71ce00865c1dbb18 Mon Sep 17 00:00:00 2001 From: Aykut Bozkurt Date: Mon, 11 Nov 2024 14:50:37 +0300 Subject: [PATCH] Coerce types on read `COPY FROM parquet` is too strict when matching Postgres tupledesc schema to the schema from parquet file. e.g. `INT32` type in the parquet schema cannot be read into a Postgres column with `int64` type. We can avoid this situation by adding a `is_coercible(from_type, to_type)` check while matching the expected schema from the parquet file. With that we can coerce as shown below from parquet source type to Postgres destination types: - INT16 => {int32, int64} - INT32 => {int64} - UINT16 => {int16, int32, int64} - UINT32 => {int32, int64} - UINT64 => {int64} - FLOAT32 => {double} As we use arrow as intermediate format, it might be the case that `LargeUtf8` or `LargeBinary` types are used by the external writer instead of `Utf8` and `Binary`. That is why we also need to support below coercions for arrow source types: - `Utf8 | LargeUtf8` => {text} - `Binary | LargeBinary` => {bytea} Closes #67. --- src/arrow_parquet/arrow_to_pg.rs | 475 ++++++++++++------ src/arrow_parquet/arrow_to_pg/bytea.rs | 27 +- src/arrow_parquet/arrow_to_pg/char.rs | 28 +- .../arrow_to_pg/fallback_to_text.rs | 28 +- src/arrow_parquet/arrow_to_pg/float4.rs | 21 + src/arrow_parquet/arrow_to_pg/geometry.rs | 27 +- src/arrow_parquet/arrow_to_pg/int2.rs | 107 +++- src/arrow_parquet/arrow_to_pg/int4.rs | 65 ++- src/arrow_parquet/arrow_to_pg/int8.rs | 23 +- src/arrow_parquet/arrow_to_pg/text.rs | 24 +- src/arrow_parquet/schema_parser.rs | 180 +++++-- src/lib.rs | 332 ++++++++++++ 12 files changed, 1116 insertions(+), 221 deletions(-) diff --git a/src/arrow_parquet/arrow_to_pg.rs b/src/arrow_parquet/arrow_to_pg.rs index ec7c9ce..43592df 100644 --- a/src/arrow_parquet/arrow_to_pg.rs +++ b/src/arrow_parquet/arrow_to_pg.rs @@ -1,14 +1,17 @@ +use std::ops::Deref; + use arrow::array::{ Array, ArrayData, BinaryArray, BooleanArray, Date32Array, Decimal128Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, ListArray, MapArray, StringArray, - StructArray, Time64MicrosecondArray, TimestampMicrosecondArray, UInt32Array, + Float64Array, Int16Array, Int32Array, Int64Array, LargeBinaryArray, LargeStringArray, + ListArray, MapArray, StringArray, StructArray, Time64MicrosecondArray, + TimestampMicrosecondArray, UInt16Array, UInt32Array, UInt64Array, }; -use arrow_schema::Fields; +use arrow_schema::{DataType, FieldRef, Fields, TimeUnit}; use pgrx::{ datum::{Date, Time, TimeWithTimeZone, Timestamp, TimestampWithTimeZone}, pg_sys::{ - Datum, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, - INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, + Datum, Oid, CHAROID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, NUMERICOID, OIDOID, + TEXTOID, TIMEOID, }, prelude::PgHeapTuple, AllocatedByRust, AnyNumeric, IntoDatum, PgTupleDesc, @@ -23,9 +26,7 @@ use crate::{ fallback_to_text::{reset_fallback_to_text_context, FallbackToText}, geometry::{is_postgis_geometry_type, Geometry}, map::{is_map_type, Map}, - pg_arrow_type_conversions::{ - extract_precision_and_scale_from_numeric_typmod, should_write_numeric_as_text, - }, + pg_arrow_type_conversions::extract_precision_and_scale_from_numeric_typmod, }, }; @@ -57,12 +58,10 @@ pub(crate) trait ArrowArrayToPgType: From { #[derive(Clone)] pub(crate) struct ArrowToPgAttributeContext { name: String, + field: FieldRef, typoid: Oid, typmod: i32, - is_array: bool, - is_composite: bool, is_geometry: bool, - is_map: bool, attribute_contexts: Option>, attribute_tupledesc: Option>, precision: Option, @@ -157,12 +156,10 @@ impl ArrowToPgAttributeContext { Self { name: name.to_string(), + field, typoid: attribute_typoid, typmod, - is_array, - is_composite, is_geometry, - is_map, attribute_contexts, attribute_tupledesc, scale, @@ -206,7 +203,7 @@ pub(crate) fn to_pg_datum( attribute_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - if attribute_context.is_array { + if matches!(attribute_array.data_type(), DataType::List(_)) { to_pg_array_datum(attribute_array, attribute_context) } else { to_pg_nonarray_datum(attribute_array, attribute_context) @@ -227,43 +224,71 @@ fn to_pg_nonarray_datum( primitive_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - match attribute_context.typoid { - FLOAT4OID => { - to_pg_datum!(Float32Array, f32, primitive_array, attribute_context) + match attribute_context.field.data_type() { + DataType::Float32 => { + if attribute_context.typoid == FLOAT4OID { + to_pg_datum!(Float32Array, f32, primitive_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == FLOAT8OID); + to_pg_datum!(Float32Array, f64, primitive_array, attribute_context) + } } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!(Float64Array, f64, primitive_array, attribute_context) } - INT2OID => { - to_pg_datum!(Int16Array, i16, primitive_array, attribute_context) - } - INT4OID => { - to_pg_datum!(Int32Array, i32, primitive_array, attribute_context) + DataType::Int16 => { + if attribute_context.typoid == INT2OID { + to_pg_datum!(Int16Array, i16, primitive_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(Int16Array, i32, primitive_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(Int16Array, i64, primitive_array, attribute_context) + } } - INT8OID => { - to_pg_datum!(Int64Array, i64, primitive_array, attribute_context) + DataType::UInt16 => { + if attribute_context.typoid == INT2OID { + to_pg_datum!(UInt16Array, i16, primitive_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(UInt16Array, i32, primitive_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(UInt16Array, i64, primitive_array, attribute_context) + } } - BOOLOID => { - to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) + DataType::Int32 => { + if attribute_context.typoid == INT4OID { + to_pg_datum!(Int32Array, i32, primitive_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(Int32Array, i64, primitive_array, attribute_context) + } } - CHAROID => { - to_pg_datum!(StringArray, i8, primitive_array, attribute_context) + DataType::UInt32 => { + if attribute_context.typoid == OIDOID { + to_pg_datum!(UInt32Array, Oid, primitive_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(UInt32Array, i32, primitive_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(UInt32Array, i64, primitive_array, attribute_context) + } } - TEXTOID => { - to_pg_datum!(StringArray, String, primitive_array, attribute_context) + DataType::Int64 => { + to_pg_datum!(Int64Array, i64, primitive_array, attribute_context) } - BYTEAOID => { - to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) + DataType::UInt64 => { + to_pg_datum!(UInt64Array, i64, primitive_array, attribute_context) } - OIDOID => { - to_pg_datum!(UInt32Array, Oid, primitive_array, attribute_context) + DataType::Boolean => { + to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, i8, primitive_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!(StringArray, String, primitive_array, attribute_context) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -272,72 +297,110 @@ fn to_pg_nonarray_datum( primitive_array, attribute_context ) + } + } + DataType::LargeUtf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(LargeStringArray, i8, primitive_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!(LargeStringArray, String, primitive_array, attribute_context) } else { + reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); + to_pg_datum!( - Decimal128Array, - AnyNumeric, + LargeStringArray, + FallbackToText, primitive_array, attribute_context ) } } - DATEOID => { - to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) + } else { + to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) + } + } + DataType::LargeBinary => { + if attribute_context.is_geometry { + to_pg_datum!( + LargeBinaryArray, + Geometry, + primitive_array, + attribute_context + ) + } else { + to_pg_datum!( + LargeBinaryArray, + Vec, + primitive_array, + attribute_context + ) + } } - TIMEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Time64MicrosecondArray, - Time, + Decimal128Array, + AnyNumeric, primitive_array, attribute_context ) } - TIMETZOID => { + DataType::Date32 => { + to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + } + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Time, + primitive_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + TimeWithTimeZone, + primitive_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - TimeWithTimeZone, + TimestampMicrosecondArray, + Timestamp, primitive_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone_str)) + if timezone_str.deref() == "+00:00" => + { to_pg_datum!( TimestampMicrosecondArray, - Timestamp, + TimestampWithTimeZone, primitive_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - TimestampWithTimeZone, + StructArray, + PgHeapTuple, primitive_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Map, primitive_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - PgHeapTuple, - primitive_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Map, primitive_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - FallbackToText, - primitive_array, - attribute_context - ) - } + panic!( + "unsupported data type: {:?}", + attribute_context.field.data_type() + ); } } } @@ -354,16 +417,31 @@ fn to_pg_array_datum( let list_array = list_array.value(0).to_data(); - match attribute_context.typoid { - FLOAT4OID => { - to_pg_datum!( - Float32Array, - Vec>, - list_array, - attribute_context - ) + let element_field = match attribute_context.field.data_type() { + DataType::List(field) => field, + _ => unreachable!(), + }; + + match element_field.data_type() { + DataType::Float32 => { + if attribute_context.typoid == FLOAT4OID { + to_pg_datum!( + Float32Array, + Vec>, + list_array, + attribute_context + ) + } else { + debug_assert!(attribute_context.typoid == FLOAT8OID); + to_pg_datum!( + Float32Array, + Vec>, + list_array, + attribute_context + ) + } } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!( Float64Array, Vec>, @@ -371,51 +449,69 @@ fn to_pg_array_datum( attribute_context ) } - INT2OID => { - to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) + DataType::Int16 => { + if attribute_context.typoid == INT2OID { + to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) + } } - INT4OID => { - to_pg_datum!(Int32Array, Vec>, list_array, attribute_context) + DataType::UInt16 => { + if attribute_context.typoid == INT2OID { + to_pg_datum!(UInt16Array, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(UInt16Array, Vec>, list_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(UInt16Array, Vec>, list_array, attribute_context) + } } - INT8OID => { - to_pg_datum!(Int64Array, Vec>, list_array, attribute_context) + DataType::Int32 => { + if attribute_context.typoid == INT4OID { + to_pg_datum!(Int32Array, Vec>, list_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(Int32Array, Vec>, list_array, attribute_context) + } } - BOOLOID => { - to_pg_datum!( - BooleanArray, - Vec>, - list_array, - attribute_context - ) + DataType::UInt32 => { + if attribute_context.typoid == OIDOID { + to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) + } } - CHAROID => { - to_pg_datum!(StringArray, Vec>, list_array, attribute_context) + DataType::Int64 => { + to_pg_datum!(Int64Array, Vec>, list_array, attribute_context) } - TEXTOID => { - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) + DataType::UInt64 => { + to_pg_datum!(UInt64Array, Vec>, list_array, attribute_context) } - BYTEAOID => { + DataType::Boolean => { to_pg_datum!( - BinaryArray, - Vec>>, + BooleanArray, + Vec>, list_array, attribute_context ) } - OIDOID => { - to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) - } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!( + StringArray, + Vec>, + list_array, + attribute_context + ) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -424,82 +520,135 @@ fn to_pg_array_datum( list_array, attribute_context ) + } + } + DataType::LargeUtf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!( + LargeStringArray, + Vec>, + list_array, + attribute_context + ) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!( + LargeStringArray, + Vec>, + list_array, + attribute_context + ) + } else { + reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); + + to_pg_datum!( + LargeStringArray, + Vec>, + list_array, + attribute_context + ) + } + } + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!( + BinaryArray, + Vec>, + list_array, + attribute_context + ) } else { to_pg_datum!( - Decimal128Array, - Vec>, + BinaryArray, + Vec>>, list_array, attribute_context ) } } - DATEOID => { + DataType::LargeBinary => { + if attribute_context.is_geometry { + to_pg_datum!( + LargeBinaryArray, + Vec>, + list_array, + attribute_context + ) + } else { + to_pg_datum!( + LargeBinaryArray, + Vec>>, + list_array, + attribute_context + ) + } + } + DataType::Decimal128(_, _) => { to_pg_datum!( - Date32Array, - Vec>, + Decimal128Array, + Vec>, list_array, attribute_context ) } - TIMEOID => { + DataType::Date32 => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + Date32Array, + Vec>, list_array, attribute_context ) } - TIMETZOID => { + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + TimestampMicrosecondArray, + Vec>, list_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone_str)) + if timezone_str.deref() == "+00:00" => + { to_pg_datum!( TimestampMicrosecondArray, - Vec>, + Vec>, list_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - Vec>, + StructArray, + Vec>>, list_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Vec>, list_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - Vec>>, - list_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Vec>, list_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!( - BinaryArray, - Vec>, - list_array, - attribute_context - ) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } + panic!( + "unsupported data type: {:?}", + attribute_context.field.data_type() + ); } } } diff --git a/src/arrow_parquet/arrow_to_pg/bytea.rs b/src/arrow_parquet/arrow_to_pg/bytea.rs index fc67c2c..d17262b 100644 --- a/src/arrow_parquet/arrow_to_pg/bytea.rs +++ b/src/arrow_parquet/arrow_to_pg/bytea.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, BinaryArray}; +use arrow::array::{Array, BinaryArray, LargeBinaryArray}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -13,6 +13,16 @@ impl ArrowArrayToPgType> for BinaryArray { } } +impl ArrowArrayToPgType> for LargeBinaryArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option> { + if self.is_null(0) { + None + } else { + Some(self.value(0).to_vec()) + } + } +} + // Bytea[] impl ArrowArrayToPgType>>> for BinaryArray { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>>> { @@ -28,3 +38,18 @@ impl ArrowArrayToPgType>>> for BinaryArray { Some(vals) } } + +impl ArrowArrayToPgType>>> for LargeBinaryArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>>> { + let mut vals = vec![]; + for val in self.iter() { + if let Some(val) = val { + vals.push(Some(val.to_vec())); + } else { + vals.push(None); + } + } + + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/char.rs b/src/arrow_parquet/arrow_to_pg/char.rs index 2a23187..7b55d23 100644 --- a/src/arrow_parquet/arrow_to_pg/char.rs +++ b/src/arrow_parquet/arrow_to_pg/char.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, StringArray}; +use arrow::array::{Array, LargeStringArray, StringArray}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -15,6 +15,18 @@ impl ArrowArrayToPgType for StringArray { } } +impl ArrowArrayToPgType for LargeStringArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0); + let val: i8 = val.chars().next().expect("unexpected ascii char") as i8; + Some(val) + } + } +} + // Char[] impl ArrowArrayToPgType>> for StringArray { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -29,3 +41,17 @@ impl ArrowArrayToPgType>> for StringArray { Some(vals) } } + +impl ArrowArrayToPgType>> for LargeStringArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + let val = val.map(|val| { + let val: i8 = val.chars().next().expect("unexpected ascii char") as i8; + val + }); + vals.push(val); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/fallback_to_text.rs b/src/arrow_parquet/arrow_to_pg/fallback_to_text.rs index 5144787..a07bd08 100644 --- a/src/arrow_parquet/arrow_to_pg/fallback_to_text.rs +++ b/src/arrow_parquet/arrow_to_pg/fallback_to_text.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, StringArray}; +use arrow::array::{Array, LargeStringArray, StringArray}; use crate::type_compat::fallback_to_text::FallbackToText; @@ -17,6 +17,18 @@ impl ArrowArrayToPgType for StringArray { } } +impl ArrowArrayToPgType for LargeStringArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let text_repr = self.value(0).to_string(); + let val = FallbackToText(text_repr); + Some(val) + } + } +} + // Text[] representation of any type impl ArrowArrayToPgType>> for StringArray { fn to_pg_type( @@ -31,3 +43,17 @@ impl ArrowArrayToPgType>> for StringArray { Some(vals) } } + +impl ArrowArrayToPgType>> for LargeStringArray { + fn to_pg_type( + self, + _context: &ArrowToPgAttributeContext, + ) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + let val = val.map(|val| FallbackToText(val.to_string())); + vals.push(val); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/float4.rs b/src/arrow_parquet/arrow_to_pg/float4.rs index 48f36e2..19ffb6a 100644 --- a/src/arrow_parquet/arrow_to_pg/float4.rs +++ b/src/arrow_parquet/arrow_to_pg/float4.rs @@ -14,6 +14,17 @@ impl ArrowArrayToPgType for Float32Array { } } +impl ArrowArrayToPgType for Float32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + // Float4[] impl ArrowArrayToPgType>> for Float32Array { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -24,3 +35,13 @@ impl ArrowArrayToPgType>> for Float32Array { Some(vals) } } + +impl ArrowArrayToPgType>> for Float32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/geometry.rs b/src/arrow_parquet/arrow_to_pg/geometry.rs index eea86af..6b8e3c8 100644 --- a/src/arrow_parquet/arrow_to_pg/geometry.rs +++ b/src/arrow_parquet/arrow_to_pg/geometry.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, BinaryArray}; +use arrow::array::{Array, BinaryArray, LargeBinaryArray}; use crate::type_compat::geometry::Geometry; @@ -15,6 +15,16 @@ impl ArrowArrayToPgType for BinaryArray { } } +impl ArrowArrayToPgType for LargeBinaryArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + Some(self.value(0).to_vec().into()) + } + } +} + // Geometry[] impl ArrowArrayToPgType>> for BinaryArray { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -30,3 +40,18 @@ impl ArrowArrayToPgType>> for BinaryArray { Some(vals) } } + +impl ArrowArrayToPgType>> for LargeBinaryArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + if let Some(val) = val { + vals.push(Some(val.to_vec().into())); + } else { + vals.push(None); + } + } + + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/int2.rs b/src/arrow_parquet/arrow_to_pg/int2.rs index 6f814db..d1c4e73 100644 --- a/src/arrow_parquet/arrow_to_pg/int2.rs +++ b/src/arrow_parquet/arrow_to_pg/int2.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, Int16Array}; +use arrow::array::{Array, Int16Array, UInt16Array}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -14,6 +14,61 @@ impl ArrowArrayToPgType for Int16Array { } } +impl ArrowArrayToPgType for Int16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for Int16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + // Int2[] impl ArrowArrayToPgType>> for Int16Array { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -24,3 +79,53 @@ impl ArrowArrayToPgType>> for Int16Array { Some(vals) } } + +impl ArrowArrayToPgType>> for Int16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for Int16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/int4.rs b/src/arrow_parquet/arrow_to_pg/int4.rs index 87a06e4..ecae4d4 100644 --- a/src/arrow_parquet/arrow_to_pg/int4.rs +++ b/src/arrow_parquet/arrow_to_pg/int4.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, Int32Array}; +use arrow::array::{Array, Int32Array, UInt32Array}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -14,6 +14,39 @@ impl ArrowArrayToPgType for Int32Array { } } +impl ArrowArrayToPgType for Int32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for UInt32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for UInt32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + // Int4[] impl ArrowArrayToPgType>> for Int32Array { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -24,3 +57,33 @@ impl ArrowArrayToPgType>> for Int32Array { Some(vals) } } + +impl ArrowArrayToPgType>> for Int32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for UInt32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for UInt32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/int8.rs b/src/arrow_parquet/arrow_to_pg/int8.rs index 151b99e..978f70b 100644 --- a/src/arrow_parquet/arrow_to_pg/int8.rs +++ b/src/arrow_parquet/arrow_to_pg/int8.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, Int64Array}; +use arrow::array::{Array, Int64Array, UInt64Array}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -14,6 +14,17 @@ impl ArrowArrayToPgType for Int64Array { } } +impl ArrowArrayToPgType for UInt64Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + // Int8[] impl ArrowArrayToPgType>> for Int64Array { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -24,3 +35,13 @@ impl ArrowArrayToPgType>> for Int64Array { Some(vals) } } + +impl ArrowArrayToPgType>> for UInt64Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|v| v as _)); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/text.rs b/src/arrow_parquet/arrow_to_pg/text.rs index ba784e0..b4190a1 100644 --- a/src/arrow_parquet/arrow_to_pg/text.rs +++ b/src/arrow_parquet/arrow_to_pg/text.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, StringArray}; +use arrow::array::{Array, LargeStringArray, StringArray}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -14,6 +14,17 @@ impl ArrowArrayToPgType for StringArray { } } +impl ArrowArrayToPgType for LargeStringArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0); + Some(val.to_string()) + } + } +} + // Text[] impl ArrowArrayToPgType>> for StringArray { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -25,3 +36,14 @@ impl ArrowArrayToPgType>> for StringArray { Some(vals) } } + +impl ArrowArrayToPgType>> for LargeStringArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + let val = val.map(|val| val.to_string()); + vals.push(val); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index 8dd79cf..f18b0ce 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, ops::Deref, sync::Arc}; use arrow::datatypes::{Field, Fields, Schema}; -use arrow_schema::FieldRef; +use arrow_schema::{DataType, FieldRef}; use parquet::arrow::{arrow_to_parquet_schema, PARQUET_FIELD_ID_META_KEY}; use pg_sys::{ Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, @@ -130,10 +130,12 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i child_fields.push(child_field); } + let nullable = true; + Field::new( elem_name, arrow::datatypes::DataType::Struct(Fields::from(child_fields)), - true, + nullable, ) .with_metadata(metadata) .into() @@ -159,10 +161,12 @@ fn parse_list_schema(typoid: Oid, typmod: i32, array_name: &str, field_id: &mut parse_primitive_schema(typoid, typmod, array_name, field_id) }; + let nullable = true; + Field::new( array_name, arrow::datatypes::DataType::List(elem_field), - true, + nullable, ) .with_metadata(list_metadata) .into() @@ -177,13 +181,18 @@ fn parse_map_schema(typoid: Oid, typmod: i32, map_name: &str, field_id: &mut i32 *field_id += 1; let tupledesc = tuple_desc(typoid, typmod); + let entries_field = parse_struct_schema(tupledesc, map_name, field_id); let entries_field = adjust_map_entries_field(entries_field); + let keys_sorted = false; + + let nullable = true; + Field::new( map_name, - arrow::datatypes::DataType::Map(entries_field, false), - true, + arrow::datatypes::DataType::Map(entries_field, keys_sorted), + nullable, ) .with_metadata(map_metadata) .into() @@ -204,31 +213,33 @@ fn parse_primitive_schema( *field_id += 1; + let nullable = true; + let field = match typoid { - FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, true), - FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, true), - BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, true), - INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, true), - INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, true), - INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, true), + FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, nullable), + FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, nullable), + BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, nullable), + INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, nullable), + INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, nullable), + INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, nullable), NUMERICOID => { let (precision, scale) = extract_precision_and_scale_from_numeric_typmod(typmod); if should_write_numeric_as_text(precision) { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } else { Field::new( elem_name, arrow::datatypes::DataType::Decimal128(precision as _, scale as _), - true, + nullable, ) } } - DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, true), + DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, nullable), TIMESTAMPOID => Field::new( elem_name, arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), - true, + nullable, ), TIMESTAMPTZOID => Field::new( elem_name, @@ -236,31 +247,31 @@ fn parse_primitive_schema( arrow::datatypes::TimeUnit::Microsecond, Some("+00:00".into()), ), - true, + nullable, ), TIMEOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ), TIMETZOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ) .with_metadata(HashMap::from_iter(vec![( "adjusted_to_utc".into(), "true".into(), )])), - CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, true), - OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, true), + CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable), + OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, nullable), _ => { if is_postgis_geometry_type(typoid) { - Field::new(elem_name, arrow::datatypes::DataType::Binary, true) + Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable) } else { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } } }; @@ -289,26 +300,38 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { let key_field = fields.find("key").expect("expected key field").1; let value_field = fields.find("val").expect("expected val field").1; - not_nullable_key_field = - Field::new(key_field.name(), key_field.data_type().clone(), false) - .with_metadata(key_field.metadata().clone()); + let key_nullable = false; + + not_nullable_key_field = Field::new( + key_field.name(), + key_field.data_type().clone(), + key_nullable, + ) + .with_metadata(key_field.metadata().clone()); + + let value_nullable = true; - nullable_value_field = - Field::new(value_field.name(), value_field.data_type().clone(), true) - .with_metadata(value_field.metadata().clone()); + nullable_value_field = Field::new( + value_field.name(), + value_field.data_type().clone(), + value_nullable, + ) + .with_metadata(value_field.metadata().clone()); } _ => { panic!("expected struct data type for map entries") } }; + let entries_nullable = false; + let entries_field = Field::new( name, arrow::datatypes::DataType::Struct(Fields::from(vec![ not_nullable_key_field, nullable_value_field, ])), - false, + entries_nullable, ) .with_metadata(metadata); @@ -316,33 +339,90 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { } pub(crate) fn ensure_arrow_schema_match_tupledesc( - file_schema: Arc, + arrow_schema: Arc, tupledesc: &PgTupleDesc, ) { - let table_schema = parse_arrow_schema_from_tupledesc(tupledesc); + let tupledesc_schema = parse_arrow_schema_from_tupledesc(tupledesc); - for table_schema_field in table_schema.fields().iter() { - let table_schema_field_name = table_schema_field.name(); - let table_schema_field_type = table_schema_field.data_type(); + for tupledesc_field in tupledesc_schema.fields().iter() { + let field_name = tupledesc_field.name(); - let file_schema_field = file_schema.column_with_name(table_schema_field_name); + let arrow_field = arrow_schema.column_with_name(field_name); - if let Some(file_schema_field) = file_schema_field { - let file_schema_field_type = file_schema_field.1.data_type(); + if arrow_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } - if file_schema_field_type != table_schema_field_type { - panic!( - "type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but file had \"{}\"", - table_schema_field_name, - table_schema_field_type, - file_schema_field_type, - ); - } - } else { + let (_, arrow_field) = arrow_field.unwrap(); + + let arrow_field_type = arrow_field.data_type(); + + let tupledesc_field_type = tupledesc_field.data_type(); + + if !is_coercible_types(arrow_field_type, tupledesc_field_type) { panic!( - "column \"{}\" is not found in parquet file", - table_schema_field_name + "type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but parquet file had \"{}\"", + field_name, + tupledesc_field_type, + arrow_field_type, ); } } } + +fn is_coercible_types(from_type: &DataType, to_type: &DataType) -> bool { + if from_type == to_type { + return true; + } + + if matches!( + (from_type, to_type), + (DataType::Float32, DataType::Float64) + | (DataType::Int16, DataType::Int32) + | (DataType::Int16, DataType::Int64) + | (DataType::Int32, DataType::Int64) + | (DataType::UInt16, DataType::Int16) + | (DataType::UInt16, DataType::Int32) + | (DataType::UInt16, DataType::Int64) + | (DataType::UInt32, DataType::Int32) + | (DataType::UInt32, DataType::Int64) + | (DataType::UInt64, DataType::Int64) + | (DataType::LargeUtf8, DataType::Utf8) + | (DataType::LargeBinary, DataType::Binary) + ) { + return true; + } + + if let (DataType::List(from_elem_field), DataType::List(to_elem_field)) = (from_type, to_type) { + return is_coercible_types(from_elem_field.data_type(), to_elem_field.data_type()); + } else if let (DataType::Struct(from_fields), DataType::Struct(to_fields)) = + (from_type, to_type) + { + if from_fields.len() != to_fields.len() { + return false; + } + + for (from_field, to_field) in from_fields.iter().zip(to_fields.iter()) { + if from_field.name() != to_field.name() { + return false; + } + + if !is_coercible_types(from_field.data_type(), to_field.data_type()) { + return false; + } + } + + return true; + } else if let (DataType::Map(from_entries_field, _), DataType::Map(to_entries_field, _)) = + (from_type, to_type) + { + // crunchy_map entries are not allowed to be null + if from_entries_field.is_nullable() { + return false; + } + + return is_coercible_types(from_entries_field.data_type(), to_entries_field.data_type()); + } + + false +} diff --git a/src/lib.rs b/src/lib.rs index 57584bb..cb4fe3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,8 +37,11 @@ pub extern "C" fn _PG_init() { #[cfg(any(test, feature = "pg_test"))] #[pg_schema] mod tests { + use std::fs::File; use std::io::Write; use std::marker::PhantomData; + use std::sync::Arc; + use std::vec; use std::{collections::HashMap, fmt::Debug}; use crate::arrow_parquet::compression::PgParquetCompression; @@ -48,6 +51,13 @@ mod tests { use crate::type_compat::pg_arrow_type_conversions::{ DEFAULT_UNBOUNDED_NUMERIC_PRECISION, DEFAULT_UNBOUNDED_NUMERIC_SCALE, }; + use arrow::array::{ + ArrayRef, Float32Array, Int16Array, Int32Array, LargeBinaryArray, LargeStringArray, + ListArray, RecordBatch, StructArray, UInt16Array, UInt32Array, UInt64Array, + }; + use arrow::datatypes::UInt16Type; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use parquet::arrow::ArrowWriter; use pgrx::pg_sys::Oid; use pgrx::{ composite_type, @@ -340,6 +350,14 @@ mod tests { Spi::get_one(&query).unwrap().unwrap() } + fn write_record_batch_to_parquet(schema: SchemaRef, record_batch: RecordBatch) { + let file = File::create("/tmp/test.parquet").unwrap(); + let mut writer = ArrowWriter::try_new(file, schema, None).unwrap(); + + writer.write(&record_batch).unwrap(); + writer.close().unwrap(); + } + #[pg_test] fn test_int2() { let test_table = TestTable::::new("int2".into()); @@ -1391,6 +1409,320 @@ mod tests { Spi::run("DROP TYPE dog;").unwrap(); } + #[pg_test] + fn test_coerce_primitive_types() { + // INT16 => {int, bigint} + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int16, true), + Field::new("y", DataType::Int16, true), + ])); + + let x = Arc::new(Int16Array::from(vec![1])); + let y = Arc::new(Int16Array::from(vec![2])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int, y bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // INT32 => {bigint} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, true)])); + + let x = Arc::new(Int32Array::from(vec![1])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // FLOAT32 => {double} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, true)])); + + let x = Arc::new(Float32Array::from(vec![1.123])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x double precision)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value as f32, 1.123); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT16 => {smallint, int, bigint} + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::UInt16, true), + Field::new("y", DataType::UInt16, true), + Field::new("z", DataType::UInt16, true), + ])); + + let x = Arc::new(UInt16Array::from(vec![1])); + let y = Arc::new(UInt16Array::from(vec![2])); + let z = Arc::new(UInt16Array::from(vec![3])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y, z]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x smallint, y int, z bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = + Spi::get_three::("SELECT x, y, z FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2), Some(3))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT32 => {int, bigint} + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::UInt32, true), + Field::new("y", DataType::UInt32, true), + ])); + + let x = Arc::new(UInt32Array::from(vec![1])); + let y = Arc::new(UInt32Array::from(vec![2])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int, y bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT64 => {bigint} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::UInt64, true)])); + + let x = Arc::new(UInt64Array::from(vec![1])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // LargeUtf8 => {text} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::LargeUtf8, + true, + )])); + + let x = Arc::new(LargeStringArray::from(vec!["test"])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x text)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, "test"); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // LargeBinary => {bytea} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::LargeBinary, + true, + )])); + + let x = Arc::new(LargeBinaryArray::from(vec!["abc".as_bytes()])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bytea)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::>("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, "abc".as_bytes()); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + } + + #[pg_test] + fn test_coerce_list_types() { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "x", + DataType::List(Field::new("item", DataType::UInt16, true).into()), + true, + ), + Field::new( + "y", + DataType::List(Field::new("item", DataType::UInt16, true).into()), + true, + ), + ])); + + let x = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + ])); + + let y = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(3), Some(4)]), + ])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int[], y bigint[])"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::>, Vec>>( + "SELECT x, y FROM test_table LIMIT 1", + ) + .unwrap(); + assert_eq!( + value, + (Some(vec![Some(1), Some(2)]), Some(vec![Some(3), Some(4)])) + ); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + } + + #[pg_test] + fn test_coerce_struct_types() { + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Struct( + vec![ + Field::new("a", DataType::UInt16, false), + Field::new("b", DataType::UInt16, false), + ] + .into(), + ), + false, + )])); + + let a: ArrayRef = Arc::new(UInt16Array::from(vec![Some(1)])); + let b: ArrayRef = Arc::new(UInt16Array::from(vec![Some(2)])); + + let x = Arc::new(StructArray::try_from(vec![("a", a), ("b", b)]).unwrap()); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_type = "CREATE TYPE test_type AS (a int, b bigint)"; + Spi::run(create_type).unwrap(); + + let create_table = "CREATE TABLE test_table (x test_type)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = + Spi::get_two::("SELECT (x).a, (x).b FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + } + + #[pg_test] + #[should_panic(expected = "violates not-null constraint")] + fn test_copy_not_null_table() { + let create_table = "CREATE TABLE test_table (x int NOT NULL)"; + Spi::run(create_table).unwrap(); + + // first copy non-null value to file + let copy_to = "COPY (SELECT 1 as x) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let result = Spi::get_one::("SELECT x FROM test_table") + .unwrap() + .unwrap(); + assert_eq!(result, 1); + + // then copy null value to file + let copy_to = "COPY (SELECT NULL::int as x) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + // this should panic + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + } + #[pg_test] fn test_copy_with_empty_options() { let test_table = TestTable::::new("int4".into())