From b408fac33455fa0aa971d172bf0e271d385b540a Mon Sep 17 00:00:00 2001 From: Aykut Bozkurt Date: Mon, 11 Nov 2024 14:50:37 +0300 Subject: [PATCH 1/3] Cast types on read `COPY FROM parquet` is too strict when matching Postgres tupledesc schema to the parquet file schema. e.g. `INT32` type in the parquet schema cannot be read into a Postgres column with `int64` type. We can avoid this situation by casting arrow array to the array that is expected by the tupledesc schema, if the cast is possible. We can make use of `arrow-cast` crate, which is in the same project with `arrow`. Its public api lets us check if a cast possible between 2 arrow types and perform the cast. To make sure the cast is possible, we need to do 2 checks: 1. arrow-cast allows the cast from "arrow type at the parquet file" to "arrow type at the schema that is generated for tupledesc", 2. the cast is meaningful at Postgres. We check if there is an explicit cast from "Postgres type that corresponds for the arrow type at Parquet file" to "Postgres type at tupledesc". With that we can cast between many castable types as shown below: - INT16 => INT32 - UINT32 => INT64 - FLOAT32 => FLOAT64 - LargeUtf8 => UTF8 - LargeBinary => Binary - Struct, Array, and Map with castable fields, e.g. [UINT16] => [INT64] or struct {'x': UINT16} => struct {'x': INT64} **NOTE**: Struct fields must match by name and position to be cast. Closes #67. --- Cargo.lock | 1 + Cargo.toml | 1 + README.md | 8 +- src/arrow_parquet.rs | 1 + src/arrow_parquet/arrow_to_pg.rs | 395 ++++--- src/arrow_parquet/arrow_to_pg/timestamptz.rs | 19 +- src/arrow_parquet/cast_mode.rs | 22 + src/arrow_parquet/parquet_reader.rs | 74 +- src/arrow_parquet/parquet_writer.rs | 17 +- src/arrow_parquet/pg_to_arrow.rs | 20 +- src/arrow_parquet/schema_parser.rs | 361 ++++-- src/lib.rs | 1010 +++++++++++++++++ src/parquet_copy_hook/copy_from.rs | 8 +- .../copy_to_dest_receiver.rs | 11 +- src/parquet_copy_hook/copy_utils.rs | 40 +- src/pgrx_utils.rs | 20 +- src/type_compat/pg_arrow_type_conversions.rs | 8 +- 17 files changed, 1715 insertions(+), 301 deletions(-) create mode 100644 src/arrow_parquet/cast_mode.rs diff --git a/Cargo.lock b/Cargo.lock index d594d73..7150e6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2231,6 +2231,7 @@ name = "pg_parquet" version = "0.1.0" dependencies = [ "arrow", + "arrow-cast", "arrow-schema", "aws-config", "aws-credential-types", diff --git a/Cargo.toml b/Cargo.toml index b5a372b..76a9e1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ pg_test = [] [dependencies] arrow = {version = "53", default-features = false} +arrow-cast = {version = "53", default-features = false} arrow-schema = {version = "53", default-features = false} aws-config = { version = "1.5", default-features = false, features = ["rustls"]} aws-credential-types = {version = "1.2", default-features = false} diff --git a/README.md b/README.md index 1ebd4d0..1cd1ee3 100644 --- a/README.md +++ b/README.md @@ -185,12 +185,16 @@ Alternatively, you can use the following environment variables when starting pos ## Copy Options `pg_parquet` supports the following options in the `COPY TO` command: -- `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension. (This is the only option that `COPY FROM` command supports.), +- `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension, - `row_group_size `: the number of rows in each row group while writing Parquet files. The default row group size is `122880`, - `row_group_size_bytes `: the total byte size of rows in each row group while writing Parquet files. The default row group size bytes is `row_group_size * 1024`, -- `compression `: the compression format to use while writing Parquet files. The supported compression formats are `uncompressed`, `snappy`, `gzip`, `brotli`, `lz4`, `lz4raw` and `zstd`. The default compression format is `snappy`. If not specified, the compression format is determined by the file extension. +- `compression `: the compression format to use while writing Parquet files. The supported compression formats are `uncompressed`, `snappy`, `gzip`, `brotli`, `lz4`, `lz4raw` and `zstd`. The default compression format is `snappy`. If not specified, the compression format is determined by the file extension, - `compression_level `: the compression level to use while writing Parquet files. The supported compression levels are only supported for `gzip`, `zstd` and `brotli` compression formats. The default compression level is `6` for `gzip (0-10)`, `1` for `zstd (1-22)` and `1` for `brotli (0-11)`. +`pg_parquet` supports the following options in the `COPY FROM` command: +- `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension, +- `cast_mode `: Specifies the casting behavior, which can be set to either `strict` or `relaxed`. This determines whether lossy conversions are allowed. By default, the mode is `strict`, which does not permit lossy conversions (e.g., `bigint => int` causes a schema mismatch error during schema validation). When set to `relaxed`, lossy conversions are allowed, and errors will only be raised at runtime if a value cannot be properly converted. This option provides flexibility to handle schema mismatches by deferring error checks to runtime. + ## Configuration There is currently only one GUC parameter to enable/disable the `pg_parquet`: - `pg_parquet.enable_copy_hooks`: you can set this parameter to `on` or `off` to enable or disable the `pg_parquet` extension. The default value is `on`. diff --git a/src/arrow_parquet.rs b/src/arrow_parquet.rs index 7e445a4..fb5cb0d 100644 --- a/src/arrow_parquet.rs +++ b/src/arrow_parquet.rs @@ -1,5 +1,6 @@ pub(crate) mod arrow_to_pg; pub(crate) mod arrow_utils; +pub(crate) mod cast_mode; pub(crate) mod compression; pub(crate) mod parquet_reader; pub(crate) mod parquet_writer; diff --git a/src/arrow_parquet/arrow_to_pg.rs b/src/arrow_parquet/arrow_to_pg.rs index ec7c9ce..079ee08 100644 --- a/src/arrow_parquet/arrow_to_pg.rs +++ b/src/arrow_parquet/arrow_to_pg.rs @@ -3,29 +3,23 @@ use arrow::array::{ Float64Array, Int16Array, Int32Array, Int64Array, ListArray, MapArray, StringArray, StructArray, Time64MicrosecondArray, TimestampMicrosecondArray, UInt32Array, }; -use arrow_schema::Fields; +use arrow_schema::{DataType, FieldRef, Fields, TimeUnit}; use pgrx::{ datum::{Date, Time, TimeWithTimeZone, Timestamp, TimestampWithTimeZone}, - pg_sys::{ - Datum, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, - INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, - }, + pg_sys::{Datum, FormData_pg_attribute, Oid, CHAROID, TEXTOID, TIMEOID}, prelude::PgHeapTuple, AllocatedByRust, AnyNumeric, IntoDatum, PgTupleDesc, }; use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CollectAttributesFor, }, type_compat::{ fallback_to_text::{reset_fallback_to_text_context, FallbackToText}, geometry::{is_postgis_geometry_type, Geometry}, map::{is_map_type, Map}, - pg_arrow_type_conversions::{ - extract_precision_and_scale_from_numeric_typmod, should_write_numeric_as_text, - }, }, }; @@ -57,25 +51,33 @@ pub(crate) trait ArrowArrayToPgType: From { #[derive(Clone)] pub(crate) struct ArrowToPgAttributeContext { name: String, + data_type: DataType, + needs_cast: bool, typoid: Oid, typmod: i32, - is_array: bool, - is_composite: bool, is_geometry: bool, - is_map: bool, attribute_contexts: Option>, attribute_tupledesc: Option>, precision: Option, scale: Option, + timezone: Option, } impl ArrowToPgAttributeContext { - pub(crate) fn new(name: &str, typoid: Oid, typmod: i32, fields: Fields) -> Self { - let field = fields - .iter() - .find(|field| field.name() == name) - .unwrap_or_else(|| panic!("failed to find field {}", name)) - .clone(); + pub(crate) fn new( + name: &str, + typoid: Oid, + typmod: i32, + field: FieldRef, + cast_to_type: Option, + ) -> Self { + let needs_cast = cast_to_type.is_some(); + + let data_type = if let Some(cast_to_type) = &cast_to_type { + cast_to_type.clone() + } else { + field.data_type().clone() + }; let is_array = is_array_type(typoid); let is_composite; @@ -123,16 +125,29 @@ impl ArrowToPgAttributeContext { None }; - let precision; - let scale; - if attribute_typoid == NUMERICOID { - let (p, s) = extract_precision_and_scale_from_numeric_typmod(typmod); - precision = Some(p); - scale = Some(s); - } else { - precision = None; - scale = None; - } + let (precision, scale) = match &data_type { + DataType::Decimal128(p, s) => (Some(*p as _), Some(*s as _)), + DataType::List(field) => { + if let DataType::Decimal128(p, s) = field.data_type() { + (Some(*p as _), Some(*s as _)) + } else { + (None, None) + } + } + _ => (None, None), + }; + + let timezone = match &data_type { + DataType::Timestamp(_, Some(timezone)) => Some(timezone.to_string()), + DataType::List(field) => { + if let DataType::Timestamp(_, Some(timezone)) = field.data_type() { + Some(timezone.to_string()) + } else { + None + } + } + _ => None, + }; // for composite and map types, recursively collect attribute contexts let attribute_contexts = if let Some(attribute_tupledesc) = &attribute_tupledesc { @@ -147,9 +162,16 @@ impl ArrowToPgAttributeContext { _ => unreachable!(), }; + let attributes = + collect_attributes_for(CollectAttributesFor::Struct, attribute_tupledesc); + + // we only cast the top-level attributes, which already covers the nested attributes + let cast_to_types = None; + Some(collect_arrow_to_pg_attribute_contexts( - attribute_tupledesc, + &attributes, &fields, + cast_to_types, )) } else { None @@ -157,43 +179,63 @@ impl ArrowToPgAttributeContext { Self { name: name.to_string(), + data_type, + needs_cast, typoid: attribute_typoid, typmod, - is_array, - is_composite, is_geometry, - is_map, attribute_contexts, attribute_tupledesc, scale, precision, + timezone, } } pub(crate) fn name(&self) -> &str { &self.name } + + pub(crate) fn needs_cast(&self) -> bool { + self.needs_cast + } + + pub(crate) fn data_type(&self) -> &DataType { + &self.data_type + } } pub(crate) fn collect_arrow_to_pg_attribute_contexts( - tupledesc: &PgTupleDesc, + attributes: &[FormData_pg_attribute], fields: &Fields, + cast_to_types: Option>>, ) -> Vec { - // parquet file does not contain generated columns. PG will handle them. - let include_generated_columns = false; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); let mut attribute_contexts = vec![]; - for attribute in attributes { + for (idx, attribute) in attributes.iter().enumerate() { let attribute_name = attribute.name(); let attribute_typoid = attribute.type_oid().value(); let attribute_typmod = attribute.type_mod(); + let field = fields + .iter() + .find(|field| field.name() == attribute_name) + .unwrap_or_else(|| panic!("failed to find field {}", attribute_name)) + .clone(); + + let cast_to_type = if let Some(cast_to_types) = cast_to_types.as_ref() { + debug_assert!(cast_to_types.len() == attributes.len()); + cast_to_types.get(idx).cloned().expect("cast_to_type null") + } else { + None + }; + let attribute_context = ArrowToPgAttributeContext::new( attribute_name, attribute_typoid, attribute_typmod, - fields.clone(), + field, + cast_to_type, ); attribute_contexts.push(attribute_context); @@ -206,7 +248,7 @@ pub(crate) fn to_pg_datum( attribute_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - if attribute_context.is_array { + if matches!(attribute_array.data_type(), DataType::List(_)) { to_pg_array_datum(attribute_array, attribute_context) } else { to_pg_nonarray_datum(attribute_array, attribute_context) @@ -227,43 +269,34 @@ fn to_pg_nonarray_datum( primitive_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - match attribute_context.typoid { - FLOAT4OID => { + match attribute_context.data_type() { + DataType::Float32 => { to_pg_datum!(Float32Array, f32, primitive_array, attribute_context) } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!(Float64Array, f64, primitive_array, attribute_context) } - INT2OID => { + DataType::Int16 => { to_pg_datum!(Int16Array, i16, primitive_array, attribute_context) } - INT4OID => { + DataType::Int32 => { to_pg_datum!(Int32Array, i32, primitive_array, attribute_context) } - INT8OID => { + DataType::Int64 => { to_pg_datum!(Int64Array, i64, primitive_array, attribute_context) } - BOOLOID => { - to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) - } - CHAROID => { - to_pg_datum!(StringArray, i8, primitive_array, attribute_context) - } - TEXTOID => { - to_pg_datum!(StringArray, String, primitive_array, attribute_context) - } - BYTEAOID => { - to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) - } - OIDOID => { + DataType::UInt32 => { to_pg_datum!(UInt32Array, Oid, primitive_array, attribute_context) } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Boolean => { + to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) + } + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, i8, primitive_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!(StringArray, String, primitive_array, attribute_context) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -272,72 +305,72 @@ fn to_pg_nonarray_datum( primitive_array, attribute_context ) - } else { - to_pg_datum!( - Decimal128Array, - AnyNumeric, - primitive_array, - attribute_context - ) } } - DATEOID => { - to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) + } else { + to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) + } } - TIMEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Time64MicrosecondArray, - Time, + Decimal128Array, + AnyNumeric, primitive_array, attribute_context ) } - TIMETZOID => { + DataType::Date32 => { + to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + } + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Time, + primitive_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + TimeWithTimeZone, + primitive_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - TimeWithTimeZone, + TimestampMicrosecondArray, + Timestamp, primitive_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => { to_pg_datum!( TimestampMicrosecondArray, - Timestamp, + TimestampWithTimeZone, primitive_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - TimestampWithTimeZone, + StructArray, + PgHeapTuple, primitive_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Map, primitive_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - PgHeapTuple, - primitive_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Map, primitive_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - FallbackToText, - primitive_array, - attribute_context - ) - } + panic!("unsupported data type: {:?}", attribute_context.data_type()); } } } @@ -354,8 +387,13 @@ fn to_pg_array_datum( let list_array = list_array.value(0).to_data(); - match attribute_context.typoid { - FLOAT4OID => { + let element_field = match attribute_context.data_type() { + DataType::List(field) => field, + _ => unreachable!(), + }; + + match element_field.data_type() { + DataType::Float32 => { to_pg_datum!( Float32Array, Vec>, @@ -363,7 +401,7 @@ fn to_pg_array_datum( attribute_context ) } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!( Float64Array, Vec>, @@ -371,16 +409,19 @@ fn to_pg_array_datum( attribute_context ) } - INT2OID => { + DataType::Int16 => { to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) } - INT4OID => { + DataType::Int32 => { to_pg_datum!(Int32Array, Vec>, list_array, attribute_context) } - INT8OID => { + DataType::Int64 => { to_pg_datum!(Int64Array, Vec>, list_array, attribute_context) } - BOOLOID => { + DataType::UInt32 => { + to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) + } + DataType::Boolean => { to_pg_datum!( BooleanArray, Vec>, @@ -388,34 +429,17 @@ fn to_pg_array_datum( attribute_context ) } - CHAROID => { - to_pg_datum!(StringArray, Vec>, list_array, attribute_context) - } - TEXTOID => { - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } - BYTEAOID => { - to_pg_datum!( - BinaryArray, - Vec>>, - list_array, - attribute_context - ) - } - OIDOID => { - to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) - } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!( + StringArray, + Vec>, + list_array, + attribute_context + ) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -424,82 +448,87 @@ fn to_pg_array_datum( list_array, attribute_context ) + } + } + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!( + BinaryArray, + Vec>, + list_array, + attribute_context + ) } else { to_pg_datum!( - Decimal128Array, - Vec>, + BinaryArray, + Vec>>, list_array, attribute_context ) } } - DATEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Date32Array, - Vec>, + Decimal128Array, + Vec>, list_array, attribute_context ) } - TIMEOID => { + DataType::Date32 => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + Date32Array, + Vec>, list_array, attribute_context ) } - TIMETZOID => { + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + TimestampMicrosecondArray, + Vec>, list_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => { to_pg_datum!( TimestampMicrosecondArray, - Vec>, + Vec>, list_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - Vec>, + StructArray, + Vec>>, list_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Vec>, list_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - Vec>>, - list_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Vec>, list_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!( - BinaryArray, - Vec>, - list_array, - attribute_context - ) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } + panic!("unsupported data type: {:?}", attribute_context.data_type()); } } } diff --git a/src/arrow_parquet/arrow_to_pg/timestamptz.rs b/src/arrow_parquet/arrow_to_pg/timestamptz.rs index 81bf5f9..a5a4736 100644 --- a/src/arrow_parquet/arrow_to_pg/timestamptz.rs +++ b/src/arrow_parquet/arrow_to_pg/timestamptz.rs @@ -7,11 +7,16 @@ use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; // Timestamptz impl ArrowArrayToPgType for TimestampMicrosecondArray { - fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + fn to_pg_type(self, context: &ArrowToPgAttributeContext) -> Option { if self.is_null(0) { None } else { - Some(i64_to_timestamptz(self.value(0))) + let timezone = context + .timezone + .as_ref() + .expect("timezone is required for timestamptz"); + + Some(i64_to_timestamptz(self.value(0), timezone)) } } } @@ -20,11 +25,17 @@ impl ArrowArrayToPgType for TimestampMicrosecondArray { impl ArrowArrayToPgType>> for TimestampMicrosecondArray { fn to_pg_type( self, - _context: &ArrowToPgAttributeContext, + context: &ArrowToPgAttributeContext, ) -> Option>> { let mut vals = vec![]; + + let timezone = context + .timezone + .as_ref() + .expect("timezone is required for timestamptz[]"); + for val in self.iter() { - let val = val.map(i64_to_timestamptz); + let val = val.map(|v| i64_to_timestamptz(v, timezone)); vals.push(val); } Some(vals) diff --git a/src/arrow_parquet/cast_mode.rs b/src/arrow_parquet/cast_mode.rs new file mode 100644 index 0000000..43ee6ac --- /dev/null +++ b/src/arrow_parquet/cast_mode.rs @@ -0,0 +1,22 @@ +use std::str::FromStr; + +#[derive(Debug, Copy, Clone)] +pub(crate) enum CastMode { + Strict, + Relaxed, +} + +impl FromStr for CastMode { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "strict" => Ok(Self::Strict), + "relaxed" => Ok(Self::Relaxed), + _ => Err(format!( + "{} is not a valid cast_mode. Set it to either 'strict' or 'relaxed'.", + s + )), + } + } +} diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index a3cd53b..0e9d78c 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -1,24 +1,33 @@ +use std::sync::Arc; + use arrow::array::RecordBatch; +use arrow_cast::{cast_with_options, CastOptions}; use futures::StreamExt; use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStream}; use pgrx::{ check_for_interrupts, pg_sys::{ - fmgr_info, getTypeBinaryOutputInfo, varlena, Datum, FmgrInfo, InvalidOid, SendFunctionCall, + fmgr_info, getTypeBinaryOutputInfo, varlena, Datum, FmgrInfo, FormData_pg_attribute, + InvalidOid, SendFunctionCall, }, vardata_any, varsize_any_exhdr, void_mut_ptr, AllocatedByPostgres, PgBox, PgTupleDesc, }; use url::Url; use crate::{ - arrow_parquet::arrow_to_pg::to_pg_datum, - pgrx_utils::collect_valid_attributes, + arrow_parquet::{ + arrow_to_pg::to_pg_datum, schema_parser::parquet_schema_string_from_attributes, + }, + pgrx_utils::{collect_attributes_for, CollectAttributesFor}, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, }; use super::{ arrow_to_pg::{collect_arrow_to_pg_attribute_contexts, ArrowToPgAttributeContext}, - schema_parser::ensure_arrow_schema_match_tupledesc, + cast_mode::CastMode, + schema_parser::{ + ensure_arrow_schema_match_tupledesc_schema, parse_arrow_schema_from_attributes, + }, uri_utils::{parquet_reader_from_uri, PG_BACKEND_TOKIO_RUNTIME}, }; @@ -33,7 +42,7 @@ pub(crate) struct ParquetReaderContext { } impl ParquetReaderContext { - pub(crate) fn new(uri: Url, tupledesc: &PgTupleDesc) -> Self { + pub(crate) fn new(uri: Url, cast_mode: CastMode, tupledesc: &PgTupleDesc) -> Self { // Postgis and Map contexts are used throughout reading the parquet file. // We need to reset them to avoid reading the stale data. (e.g. extension could be dropped) reset_postgis_context(); @@ -41,12 +50,36 @@ impl ParquetReaderContext { let parquet_reader = parquet_reader_from_uri(&uri); - let schema = parquet_reader.schema(); - ensure_arrow_schema_match_tupledesc(schema.clone(), tupledesc); + let parquet_file_schema = parquet_reader.schema(); + + let attributes = collect_attributes_for(CollectAttributesFor::CopyFrom, tupledesc); + + pgrx::debug2!( + "schema for tuples: {}", + parquet_schema_string_from_attributes(&attributes) + ); - let binary_out_funcs = Self::collect_binary_out_funcs(tupledesc); + let tupledesc_schema = parse_arrow_schema_from_attributes(&attributes); - let attribute_contexts = collect_arrow_to_pg_attribute_contexts(tupledesc, &schema.fields); + let tupledesc_schema = Arc::new(tupledesc_schema); + + // Ensure that the arrow schema matches the tupledesc. + // Gets cast_to_types for each attribute if a cast is needed for the attribute's columnar array + // to match the expected columnar array for its tupledesc type. + let cast_to_types = ensure_arrow_schema_match_tupledesc_schema( + parquet_file_schema.clone(), + tupledesc_schema.clone(), + &attributes, + cast_mode, + ); + + let attribute_contexts = collect_arrow_to_pg_attribute_contexts( + &attributes, + &tupledesc_schema.fields, + Some(cast_to_types), + ); + + let binary_out_funcs = Self::collect_binary_out_funcs(&attributes); ParquetReaderContext { buffer: Vec::new(), @@ -60,14 +93,11 @@ impl ParquetReaderContext { } fn collect_binary_out_funcs( - tupledesc: &PgTupleDesc, + attributes: &[FormData_pg_attribute], ) -> Vec> { unsafe { let mut binary_out_funcs = vec![]; - let include_generated_columns = false; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); - for att in attributes.iter() { let typoid = att.type_oid(); @@ -94,11 +124,25 @@ impl ParquetReaderContext { for attribute_context in attribute_contexts { let name = attribute_context.name(); - let column = record_batch + let column_array = record_batch .column_by_name(name) .unwrap_or_else(|| panic!("column {} not found", name)); - let datum = to_pg_datum(column.to_data(), attribute_context); + let datum = if attribute_context.needs_cast() { + // should fail instead of returning None if the cast fails at runtime + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + + let casted_column_array = + cast_with_options(&column_array, attribute_context.data_type(), &cast_options) + .unwrap_or_else(|e| panic!("failed to cast column {}: {}", name, e)); + + to_pg_datum(casted_column_array.to_data(), attribute_context) + } else { + to_pg_datum(column_array.to_data(), attribute_context) + }; datums.push(datum); } diff --git a/src/arrow_parquet/parquet_writer.rs b/src/arrow_parquet/parquet_writer.rs index 7c12009..e93ea8b 100644 --- a/src/arrow_parquet/parquet_writer.rs +++ b/src/arrow_parquet/parquet_writer.rs @@ -12,9 +12,12 @@ use url::Url; use crate::{ arrow_parquet::{ compression::{PgParquetCompression, PgParquetCompressionWithLevel}, - schema_parser::parse_arrow_schema_from_tupledesc, + schema_parser::{ + parquet_schema_string_from_attributes, parse_arrow_schema_from_attributes, + }, uri_utils::{parquet_writer_from_uri, PG_BACKEND_TOKIO_RUNTIME}, }, + pgrx_utils::{collect_attributes_for, CollectAttributesFor}, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, }; @@ -57,12 +60,20 @@ impl ParquetWriterContext { .set_created_by("pg_parquet".to_string()) .build(); - let schema = parse_arrow_schema_from_tupledesc(tupledesc); + let attributes = collect_attributes_for(CollectAttributesFor::CopyTo, tupledesc); + + pgrx::debug2!( + "schema for tuples: {}", + parquet_schema_string_from_attributes(&attributes) + ); + + let schema = parse_arrow_schema_from_attributes(&attributes); let schema = Arc::new(schema); let parquet_writer = parquet_writer_from_uri(&uri, schema.clone(), writer_props); - let attribute_contexts = collect_pg_to_arrow_attribute_contexts(tupledesc, &schema.fields); + let attribute_contexts = + collect_pg_to_arrow_attribute_contexts(&attributes, &schema.fields); ParquetWriterContext { parquet_writer, diff --git a/src/arrow_parquet/pg_to_arrow.rs b/src/arrow_parquet/pg_to_arrow.rs index 40cc03c..530c7f7 100644 --- a/src/arrow_parquet/pg_to_arrow.rs +++ b/src/arrow_parquet/pg_to_arrow.rs @@ -7,16 +7,17 @@ use pgrx::{ datum::{Date, Time, TimeWithTimeZone, Timestamp, TimestampWithTimeZone, UnboxDatum}, heap_tuple::PgHeapTuple, pg_sys::{ - Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, - NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, + FormData_pg_attribute, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, + INT2OID, INT4OID, INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, + TIMESTAMPTZOID, TIMETZOID, }, - AllocatedByRust, AnyNumeric, FromDatum, PgTupleDesc, + AllocatedByRust, AnyNumeric, FromDatum, }; use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CollectAttributesFor, }, type_compat::{ fallback_to_text::{reset_fallback_to_text_context, FallbackToText}, @@ -146,7 +147,10 @@ impl PgToArrowAttributeContext { _ => unreachable!(), }; - collect_pg_to_arrow_attribute_contexts(&attribute_tupledesc, &fields) + let attributes = + collect_attributes_for(CollectAttributesFor::Struct, &attribute_tupledesc); + + collect_pg_to_arrow_attribute_contexts(&attributes, &fields) }); Self { @@ -166,11 +170,9 @@ impl PgToArrowAttributeContext { } pub(crate) fn collect_pg_to_arrow_attribute_contexts( - tupledesc: &PgTupleDesc, + attributes: &[FormData_pg_attribute], fields: &Fields, ) -> Vec { - let include_generated_columns = true; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); let mut attribute_contexts = vec![]; for attribute in attributes { diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index 8dd79cf..ddcc18d 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -1,18 +1,22 @@ use std::{collections::HashMap, ops::Deref, sync::Arc}; use arrow::datatypes::{Field, Fields, Schema}; -use arrow_schema::FieldRef; +use arrow_cast::can_cast_types; +use arrow_schema::{DataType, FieldRef}; use parquet::arrow::{arrow_to_parquet_schema, PARQUET_FIELD_ID_META_KEY}; use pg_sys::{ - Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, - NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, + can_coerce_type, + CoercionContext::{self, COERCION_EXPLICIT, COERCION_IMPLICIT}, + FormData_pg_attribute, InvalidOid, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, + FLOAT8OID, INT2OID, INT4OID, INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, + TIMESTAMPTZOID, TIMETZOID, }; use pgrx::{check_for_interrupts, prelude::*, PgTupleDesc}; use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CollectAttributesFor, }, type_compat::{ geometry::is_postgis_geometry_type, @@ -23,8 +27,12 @@ use crate::{ }, }; -pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: &PgTupleDesc) -> String { - let arrow_schema = parse_arrow_schema_from_tupledesc(tupledesc); +use super::cast_mode::CastMode; + +pub(crate) fn parquet_schema_string_from_attributes( + attributes: &[FormData_pg_attribute], +) -> String { + let arrow_schema = parse_arrow_schema_from_attributes(attributes); let parquet_schema = arrow_to_parquet_schema(&arrow_schema) .unwrap_or_else(|e| panic!("failed to convert arrow schema to parquet schema: {}", e)); @@ -33,14 +41,11 @@ pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: &PgTupleDesc) -> S String::from_utf8(buf).unwrap_or_else(|e| panic!("failed to convert schema to string: {}", e)) } -pub(crate) fn parse_arrow_schema_from_tupledesc(tupledesc: &PgTupleDesc) -> Schema { +pub(crate) fn parse_arrow_schema_from_attributes(attributes: &[FormData_pg_attribute]) -> Schema { let mut field_id = 0; let mut struct_attribute_fields = vec![]; - let include_generated_columns = true; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); - for attribute in attributes { let attribute_name = attribute.name(); let attribute_typoid = attribute.type_oid().value(); @@ -92,8 +97,7 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i let mut child_fields: Vec> = vec![]; - let include_generated_columns = true; - let attributes = collect_valid_attributes(&tupledesc, include_generated_columns); + let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); for attribute in attributes { if attribute.is_dropped() { @@ -130,10 +134,12 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i child_fields.push(child_field); } + let nullable = true; + Field::new( elem_name, arrow::datatypes::DataType::Struct(Fields::from(child_fields)), - true, + nullable, ) .with_metadata(metadata) .into() @@ -159,10 +165,12 @@ fn parse_list_schema(typoid: Oid, typmod: i32, array_name: &str, field_id: &mut parse_primitive_schema(typoid, typmod, array_name, field_id) }; + let nullable = true; + Field::new( array_name, arrow::datatypes::DataType::List(elem_field), - true, + nullable, ) .with_metadata(list_metadata) .into() @@ -177,13 +185,18 @@ fn parse_map_schema(typoid: Oid, typmod: i32, map_name: &str, field_id: &mut i32 *field_id += 1; let tupledesc = tuple_desc(typoid, typmod); + let entries_field = parse_struct_schema(tupledesc, map_name, field_id); let entries_field = adjust_map_entries_field(entries_field); + let keys_sorted = false; + + let nullable = true; + Field::new( map_name, - arrow::datatypes::DataType::Map(entries_field, false), - true, + arrow::datatypes::DataType::Map(entries_field, keys_sorted), + nullable, ) .with_metadata(map_metadata) .into() @@ -204,31 +217,33 @@ fn parse_primitive_schema( *field_id += 1; + let nullable = true; + let field = match typoid { - FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, true), - FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, true), - BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, true), - INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, true), - INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, true), - INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, true), + FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, nullable), + FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, nullable), + BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, nullable), + INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, nullable), + INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, nullable), + INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, nullable), NUMERICOID => { let (precision, scale) = extract_precision_and_scale_from_numeric_typmod(typmod); if should_write_numeric_as_text(precision) { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } else { Field::new( elem_name, arrow::datatypes::DataType::Decimal128(precision as _, scale as _), - true, + nullable, ) } } - DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, true), + DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, nullable), TIMESTAMPOID => Field::new( elem_name, arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), - true, + nullable, ), TIMESTAMPTZOID => Field::new( elem_name, @@ -236,31 +251,31 @@ fn parse_primitive_schema( arrow::datatypes::TimeUnit::Microsecond, Some("+00:00".into()), ), - true, + nullable, ), TIMEOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ), TIMETZOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ) .with_metadata(HashMap::from_iter(vec![( "adjusted_to_utc".into(), "true".into(), )])), - CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, true), - OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, true), + CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable), + OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, nullable), _ => { if is_postgis_geometry_type(typoid) { - Field::new(elem_name, arrow::datatypes::DataType::Binary, true) + Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable) } else { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } } }; @@ -289,60 +304,274 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { let key_field = fields.find("key").expect("expected key field").1; let value_field = fields.find("val").expect("expected val field").1; - not_nullable_key_field = - Field::new(key_field.name(), key_field.data_type().clone(), false) - .with_metadata(key_field.metadata().clone()); + let key_nullable = false; - nullable_value_field = - Field::new(value_field.name(), value_field.data_type().clone(), true) - .with_metadata(value_field.metadata().clone()); + not_nullable_key_field = Field::new( + key_field.name(), + key_field.data_type().clone(), + key_nullable, + ) + .with_metadata(key_field.metadata().clone()); + + let value_nullable = true; + + nullable_value_field = Field::new( + value_field.name(), + value_field.data_type().clone(), + value_nullable, + ) + .with_metadata(value_field.metadata().clone()); } _ => { panic!("expected struct data type for map entries") } }; + let entries_nullable = false; + let entries_field = Field::new( name, arrow::datatypes::DataType::Struct(Fields::from(vec![ not_nullable_key_field, nullable_value_field, ])), - false, + entries_nullable, ) .with_metadata(metadata); Arc::new(entries_field) } -pub(crate) fn ensure_arrow_schema_match_tupledesc( - file_schema: Arc, - tupledesc: &PgTupleDesc, -) { - let table_schema = parse_arrow_schema_from_tupledesc(tupledesc); +// ensure_arrow_schema_match_tupledesc_schema throws an error if the arrow schema does not match the table schema. +// If the arrow schema is castable to the table schema, it returns a vector of Option to cast to +// for each field. +pub(crate) fn ensure_arrow_schema_match_tupledesc_schema( + arrow_schema: Arc, + tupledesc_schema: Arc, + attributes: &[FormData_pg_attribute], + cast_mode: CastMode, +) -> Vec> { + let mut cast_to_types = Vec::new(); - for table_schema_field in table_schema.fields().iter() { - let table_schema_field_name = table_schema_field.name(); - let table_schema_field_type = table_schema_field.data_type(); + for (tupledesc_field, attribute) in tupledesc_schema.fields().iter().zip(attributes.iter()) { + let field_name = tupledesc_field.name(); - let file_schema_field = file_schema.column_with_name(table_schema_field_name); + let arrow_field = arrow_schema.column_with_name(field_name); - if let Some(file_schema_field) = file_schema_field { - let file_schema_field_type = file_schema_field.1.data_type(); + if arrow_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } - if file_schema_field_type != table_schema_field_type { - panic!( - "type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but file had \"{}\"", - table_schema_field_name, - table_schema_field_type, - file_schema_field_type, - ); - } - } else { - panic!( - "column \"{}\" is not found in parquet file", - table_schema_field_name + let (_, arrow_field) = arrow_field.unwrap(); + let arrow_field = Arc::new(arrow_field.clone()); + + let from_type = arrow_field.data_type(); + let to_type = tupledesc_field.data_type(); + + // no cast needed + if from_type == to_type { + cast_to_types.push(None); + continue; + } + + if let Err(coercion_error) = is_coercible( + from_type, + to_type, + attribute.atttypid, + attribute.atttypmod, + cast_mode, + ) { + let type_mismatch_message = format!( + "type mismatch for column \"{}\" between table and parquet file.\n\n\ + table has \"{}\"\n\nparquet file has \"{}\"", + field_name, to_type, from_type ); + + match coercion_error { + CoercionError::NoStrictCoercionPath => ereport!( + pgrx::PgLogLevel::ERROR, + PgSqlErrorCode::ERRCODE_CANNOT_COERCE, + type_mismatch_message, + "Try COPY FROM '..' WITH (cast_mode = 'relaxed') to allow lossy casts with runtime checks." + ), + CoercionError::NoCoercionPath => ereport!( + pgrx::PgLogLevel::ERROR, + PgSqlErrorCode::ERRCODE_CANNOT_COERCE, + type_mismatch_message + ), + CoercionError::MapEntriesNullable => ereport!( + pgrx::PgLogLevel::ERROR, + PgSqlErrorCode::ERRCODE_CANNOT_COERCE, + format!("entries field in map type cannot be nullable for column \"{}\"", field_name) + ), + } } + + pgrx::debug2!( + "column \"{}\" is being cast from \"{}\" to \"{}\"", + field_name, + from_type, + to_type + ); + + cast_to_types.push(Some(to_type.clone())); + } + + cast_to_types +} + +enum CoercionError { + NoStrictCoercionPath, + NoCoercionPath, + MapEntriesNullable, +} + +// is_coercible first checks if "from_type" can be cast to "to_type" by arrow-cast. +// Then, it checks if the cast is meaningful at Postgres by seeing if there is +// an explicit coercion from "from_typoid" to "to_typoid". +// +// Additionaly, we need to be careful about struct rules for the cast: +// Arrow supports casting struct fields by field position instead of field name, +// which is not the intended behavior for pg_parquet. Hence, we make sure the field names +// match for structs. +fn is_coercible( + from_type: &DataType, + to_type: &DataType, + to_typoid: Oid, + to_typmod: i32, + cast_mode: CastMode, +) -> Result<(), CoercionError> { + match (from_type, to_type) { + (DataType::Struct(from_fields), DataType::Struct(to_fields)) => { + if from_fields.len() != to_fields.len() { + return Err(CoercionError::NoCoercionPath); + } + + let tupledesc = tuple_desc(to_typoid, to_typmod); + + let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); + + for (from_field, (to_field, to_attribute)) in from_fields + .iter() + .zip(to_fields.iter().zip(attributes.iter())) + { + if from_field.name() != to_field.name() { + return Err(CoercionError::NoCoercionPath); + } + + is_coercible( + from_field.data_type(), + to_field.data_type(), + to_attribute.type_oid().value(), + to_attribute.type_mod(), + cast_mode, + )?; + } + + Ok(()) + } + (DataType::List(from_field), DataType::List(to_field)) => { + let element_oid = array_element_typoid(to_typoid); + let element_typmod = to_typmod; + + is_coercible( + from_field.data_type(), + to_field.data_type(), + element_oid, + element_typmod, + cast_mode, + ) + } + (DataType::Map(from_entries_field, _), DataType::Map(to_entries_field, _)) => { + // entries field cannot be null + if from_entries_field.is_nullable() { + return Err(CoercionError::MapEntriesNullable); + } + + let entries_typoid = domain_array_base_elem_typoid(to_typoid); + + is_coercible( + from_entries_field.data_type(), + to_entries_field.data_type(), + entries_typoid, + to_typmod, + cast_mode, + ) + } + _ => { + // check if arrow-cast can cast the types + if !can_cast_types(from_type, to_type) { + return Err(CoercionError::NoCoercionPath); + } + + let from_typoid = pg_type_for_arrow_primitive_type(from_type); + + // pg_parquet could not recognize that arrow type + if from_typoid == InvalidOid { + return Err(CoercionError::NoCoercionPath); + } + + let can_coerce_via_relaxed_mode = + can_pg_coerce_types(from_typoid, to_typoid, COERCION_EXPLICIT); + + // check if coercion is meaningful at Postgres (it has a coercion path) + match cast_mode { + CastMode::Strict => { + let can_coerce_via_strict_mode = + can_pg_coerce_types(from_typoid, to_typoid, COERCION_IMPLICIT); + + if !can_coerce_via_strict_mode && can_coerce_via_relaxed_mode { + Err(CoercionError::NoStrictCoercionPath) + } else if !can_coerce_via_strict_mode { + Err(CoercionError::NoCoercionPath) + } else { + Ok(()) + } + } + CastMode::Relaxed => { + if !can_coerce_via_relaxed_mode { + Err(CoercionError::NoCoercionPath) + } else { + Ok(()) + } + } + } + } + } +} + +fn can_pg_coerce_types(from_typoid: Oid, to_typoid: Oid, ccontext: CoercionContext::Type) -> bool { + let n_args = 1; + let input_typeids = [from_typoid]; + let target_typeids = [to_typoid]; + + unsafe { + can_coerce_type( + n_args, + input_typeids.as_ptr(), + target_typeids.as_ptr(), + ccontext, + ) + } +} + +// pg_type_for_arrow_primitive_type returns Postgres type for given +// primitive arrow type. It returns InvalidOid if the arrow type is not recognized. +fn pg_type_for_arrow_primitive_type(data_type: &DataType) -> Oid { + match data_type { + DataType::Float32 | DataType::Float16 => FLOAT4OID, + DataType::Float64 => FLOAT8OID, + DataType::Int16 | DataType::UInt16 | DataType::Int8 | DataType::UInt8 => INT2OID, + DataType::Int32 | DataType::UInt32 => INT4OID, + DataType::Int64 | DataType::UInt64 => INT8OID, + DataType::Decimal128(_, _) => NUMERICOID, + DataType::Boolean => BOOLOID, + DataType::Date32 => DATEOID, + DataType::Time64(_) => TIMEOID, + DataType::Timestamp(_, None) => TIMESTAMPOID, + DataType::Timestamp(_, Some(_)) => TIMESTAMPTZOID, + DataType::Utf8 | DataType::LargeUtf8 => TEXTOID, + DataType::Binary | DataType::LargeBinary => BYTEAOID, + _ => InvalidOid, } } diff --git a/src/lib.rs b/src/lib.rs index 57584bb..77e95e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,8 +37,11 @@ pub extern "C" fn _PG_init() { #[cfg(any(test, feature = "pg_test"))] #[pg_schema] mod tests { + use std::fs::File; use std::io::Write; use std::marker::PhantomData; + use std::sync::Arc; + use std::vec; use std::{collections::HashMap, fmt::Debug}; use crate::arrow_parquet::compression::PgParquetCompression; @@ -46,8 +49,19 @@ mod tests { use crate::type_compat::geometry::Geometry; use crate::type_compat::map::Map; use crate::type_compat::pg_arrow_type_conversions::{ + date_to_i32, time_to_i64, timestamp_to_i64, timestamptz_to_i64, timetz_to_i64, DEFAULT_UNBOUNDED_NUMERIC_PRECISION, DEFAULT_UNBOUNDED_NUMERIC_SCALE, }; + use arrow::array::{ + ArrayRef, BinaryArray, BooleanArray, Date32Array, Decimal128Array, Float32Array, + Float64Array, Int16Array, Int32Array, Int8Array, LargeBinaryArray, LargeStringArray, + ListArray, MapArray, RecordBatch, StringArray, StructArray, Time64MicrosecondArray, + TimestampMicrosecondArray, UInt16Array, UInt32Array, UInt64Array, + }; + use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow::datatypes::UInt16Type; + use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit}; + use parquet::arrow::ArrowWriter; use pgrx::pg_sys::Oid; use pgrx::{ composite_type, @@ -340,6 +354,14 @@ mod tests { Spi::get_one(&query).unwrap().unwrap() } + fn write_record_batch_to_parquet(schema: SchemaRef, record_batch: RecordBatch) { + let file = File::create("/tmp/test.parquet").unwrap(); + let mut writer = ArrowWriter::try_new(file, schema, None).unwrap(); + + writer.write(&record_batch).unwrap(); + writer.close().unwrap(); + } + #[pg_test] fn test_int2() { let test_table = TestTable::::new("int2".into()); @@ -1391,6 +1413,980 @@ mod tests { Spi::run("DROP TYPE dog;").unwrap(); } + #[pg_test] + fn test_coerce_primitive_types() { + // INT16 => {int, bigint} + let x_nullable = false; + let y_nullable = true; + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int16, x_nullable), + Field::new("y", DataType::Int16, y_nullable), + ])); + + let x = Arc::new(Int16Array::from(vec![1])); + let y = Arc::new(Int16Array::from(vec![2])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int, y bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // INT32 => {bigint} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, true)])); + + let x = Arc::new(Int32Array::from(vec![1])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // FLOAT32 => {double} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, true)])); + + let x = Arc::new(Float32Array::from(vec![1.123])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x double precision)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value as f32, 1.123); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // FLOAT64 => {float} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float64, true)])); + + let x = Arc::new(Float64Array::from(vec![1.123])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x real)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet' WITH (cast_mode 'relaxed')"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1.123); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // DATE32 => {timestamp} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Date32, true)])); + + let date = Date::new(2022, 5, 5).unwrap(); + + let x = Arc::new(Date32Array::from(vec![date_to_i32(date)])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timestamp)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, Timestamp::from(date)); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIMESTAMP => {timestamptz} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )])); + + let timestamp = Timestamp::from(Date::new(2022, 5, 5).unwrap()); + + let x = Arc::new(TimestampMicrosecondArray::from(vec![timestamp_to_i64( + timestamp, + )])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timestamptz)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value.at_timezone("UTC").unwrap(), timestamp); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIMESTAMPTZ => {timestamp} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Timestamp(TimeUnit::Microsecond, Some("Europe/Paris".into())), + true, + )])); + + let timestamptz = + TimestampWithTimeZone::with_timezone(2022, 5, 5, 0, 0, 0.0, "Europe/Paris").unwrap(); + + let x = Arc::new( + TimestampMicrosecondArray::from(vec![timestamptz_to_i64(timestamptz)]) + .with_timezone("Europe/Paris"), + ); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timestamp)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet' WITH (cast_mode 'relaxed')"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, timestamptz.at_timezone("UTC").unwrap()); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIME64 => {timetz} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Time64(TimeUnit::Microsecond), + true, + )])); + + let time = Time::new(13, 0, 0.0).unwrap(); + + let x = Arc::new(Time64MicrosecondArray::from(vec![time_to_i64(time)])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timetz)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, time.into()); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIME64 => {time} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Time64(TimeUnit::Microsecond), + true, + ) + .with_metadata(HashMap::from_iter(vec![( + "adjusted_to_utc".into(), + "true".into(), + )]))])); + + let timetz = TimeWithTimeZone::with_timezone(13, 0, 0.0, "UTC").unwrap(); + + let x = Arc::new(Time64MicrosecondArray::from(vec![timetz_to_i64(timetz)])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x time)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::