diff --git a/Cargo.lock b/Cargo.lock index d594d73..7150e6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2231,6 +2231,7 @@ name = "pg_parquet" version = "0.1.0" dependencies = [ "arrow", + "arrow-cast", "arrow-schema", "aws-config", "aws-credential-types", diff --git a/Cargo.toml b/Cargo.toml index b5a372b..76a9e1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ pg_test = [] [dependencies] arrow = {version = "53", default-features = false} +arrow-cast = {version = "53", default-features = false} arrow-schema = {version = "53", default-features = false} aws-config = { version = "1.5", default-features = false, features = ["rustls"]} aws-credential-types = {version = "1.2", default-features = false} diff --git a/README.md b/README.md index 1ebd4d0..cd4c8c9 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ SELECT * FROM parquet.schema('/tmp/product_example.parquet') LIMIT 10; /tmp/product_example.parquet | name | BYTE_ARRAY | | OPTIONAL | | UTF8 | | | 3 | STRING /tmp/product_example.parquet | items | | | OPTIONAL | 1 | LIST | | | 4 | LIST /tmp/product_example.parquet | list | | | REPEATED | 1 | | | | | - /tmp/product_example.parquet | items | | | OPTIONAL | 3 | | | | 5 | + /tmp/product_example.parquet | element | | | OPTIONAL | 3 | | | | 5 | /tmp/product_example.parquet | id | INT32 | | OPTIONAL | | | | | 6 | /tmp/product_example.parquet | name | BYTE_ARRAY | | OPTIONAL | | UTF8 | | | 7 | STRING (10 rows) @@ -185,12 +185,15 @@ Alternatively, you can use the following environment variables when starting pos ## Copy Options `pg_parquet` supports the following options in the `COPY TO` command: -- `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension. (This is the only option that `COPY FROM` command supports.), +- `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension, - `row_group_size `: the number of rows in each row group while writing Parquet files. The default row group size is `122880`, - `row_group_size_bytes `: the total byte size of rows in each row group while writing Parquet files. The default row group size bytes is `row_group_size * 1024`, -- `compression `: the compression format to use while writing Parquet files. The supported compression formats are `uncompressed`, `snappy`, `gzip`, `brotli`, `lz4`, `lz4raw` and `zstd`. The default compression format is `snappy`. If not specified, the compression format is determined by the file extension. +- `compression `: the compression format to use while writing Parquet files. The supported compression formats are `uncompressed`, `snappy`, `gzip`, `brotli`, `lz4`, `lz4raw` and `zstd`. The default compression format is `snappy`. If not specified, the compression format is determined by the file extension, - `compression_level `: the compression level to use while writing Parquet files. The supported compression levels are only supported for `gzip`, `zstd` and `brotli` compression formats. The default compression level is `6` for `gzip (0-10)`, `1` for `zstd (1-22)` and `1` for `brotli (0-11)`. +`pg_parquet` supports the following options in the `COPY FROM` command: +- `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension, + ## Configuration There is currently only one GUC parameter to enable/disable the `pg_parquet`: - `pg_parquet.enable_copy_hooks`: you can set this parameter to `on` or `off` to enable or disable the `pg_parquet` extension. The default value is `on`. diff --git a/src/arrow_parquet/arrow_to_pg.rs b/src/arrow_parquet/arrow_to_pg.rs index ec7c9ce..079ee08 100644 --- a/src/arrow_parquet/arrow_to_pg.rs +++ b/src/arrow_parquet/arrow_to_pg.rs @@ -3,29 +3,23 @@ use arrow::array::{ Float64Array, Int16Array, Int32Array, Int64Array, ListArray, MapArray, StringArray, StructArray, Time64MicrosecondArray, TimestampMicrosecondArray, UInt32Array, }; -use arrow_schema::Fields; +use arrow_schema::{DataType, FieldRef, Fields, TimeUnit}; use pgrx::{ datum::{Date, Time, TimeWithTimeZone, Timestamp, TimestampWithTimeZone}, - pg_sys::{ - Datum, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, - INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, - }, + pg_sys::{Datum, FormData_pg_attribute, Oid, CHAROID, TEXTOID, TIMEOID}, prelude::PgHeapTuple, AllocatedByRust, AnyNumeric, IntoDatum, PgTupleDesc, }; use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CollectAttributesFor, }, type_compat::{ fallback_to_text::{reset_fallback_to_text_context, FallbackToText}, geometry::{is_postgis_geometry_type, Geometry}, map::{is_map_type, Map}, - pg_arrow_type_conversions::{ - extract_precision_and_scale_from_numeric_typmod, should_write_numeric_as_text, - }, }, }; @@ -57,25 +51,33 @@ pub(crate) trait ArrowArrayToPgType: From { #[derive(Clone)] pub(crate) struct ArrowToPgAttributeContext { name: String, + data_type: DataType, + needs_cast: bool, typoid: Oid, typmod: i32, - is_array: bool, - is_composite: bool, is_geometry: bool, - is_map: bool, attribute_contexts: Option>, attribute_tupledesc: Option>, precision: Option, scale: Option, + timezone: Option, } impl ArrowToPgAttributeContext { - pub(crate) fn new(name: &str, typoid: Oid, typmod: i32, fields: Fields) -> Self { - let field = fields - .iter() - .find(|field| field.name() == name) - .unwrap_or_else(|| panic!("failed to find field {}", name)) - .clone(); + pub(crate) fn new( + name: &str, + typoid: Oid, + typmod: i32, + field: FieldRef, + cast_to_type: Option, + ) -> Self { + let needs_cast = cast_to_type.is_some(); + + let data_type = if let Some(cast_to_type) = &cast_to_type { + cast_to_type.clone() + } else { + field.data_type().clone() + }; let is_array = is_array_type(typoid); let is_composite; @@ -123,16 +125,29 @@ impl ArrowToPgAttributeContext { None }; - let precision; - let scale; - if attribute_typoid == NUMERICOID { - let (p, s) = extract_precision_and_scale_from_numeric_typmod(typmod); - precision = Some(p); - scale = Some(s); - } else { - precision = None; - scale = None; - } + let (precision, scale) = match &data_type { + DataType::Decimal128(p, s) => (Some(*p as _), Some(*s as _)), + DataType::List(field) => { + if let DataType::Decimal128(p, s) = field.data_type() { + (Some(*p as _), Some(*s as _)) + } else { + (None, None) + } + } + _ => (None, None), + }; + + let timezone = match &data_type { + DataType::Timestamp(_, Some(timezone)) => Some(timezone.to_string()), + DataType::List(field) => { + if let DataType::Timestamp(_, Some(timezone)) = field.data_type() { + Some(timezone.to_string()) + } else { + None + } + } + _ => None, + }; // for composite and map types, recursively collect attribute contexts let attribute_contexts = if let Some(attribute_tupledesc) = &attribute_tupledesc { @@ -147,9 +162,16 @@ impl ArrowToPgAttributeContext { _ => unreachable!(), }; + let attributes = + collect_attributes_for(CollectAttributesFor::Struct, attribute_tupledesc); + + // we only cast the top-level attributes, which already covers the nested attributes + let cast_to_types = None; + Some(collect_arrow_to_pg_attribute_contexts( - attribute_tupledesc, + &attributes, &fields, + cast_to_types, )) } else { None @@ -157,43 +179,63 @@ impl ArrowToPgAttributeContext { Self { name: name.to_string(), + data_type, + needs_cast, typoid: attribute_typoid, typmod, - is_array, - is_composite, is_geometry, - is_map, attribute_contexts, attribute_tupledesc, scale, precision, + timezone, } } pub(crate) fn name(&self) -> &str { &self.name } + + pub(crate) fn needs_cast(&self) -> bool { + self.needs_cast + } + + pub(crate) fn data_type(&self) -> &DataType { + &self.data_type + } } pub(crate) fn collect_arrow_to_pg_attribute_contexts( - tupledesc: &PgTupleDesc, + attributes: &[FormData_pg_attribute], fields: &Fields, + cast_to_types: Option>>, ) -> Vec { - // parquet file does not contain generated columns. PG will handle them. - let include_generated_columns = false; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); let mut attribute_contexts = vec![]; - for attribute in attributes { + for (idx, attribute) in attributes.iter().enumerate() { let attribute_name = attribute.name(); let attribute_typoid = attribute.type_oid().value(); let attribute_typmod = attribute.type_mod(); + let field = fields + .iter() + .find(|field| field.name() == attribute_name) + .unwrap_or_else(|| panic!("failed to find field {}", attribute_name)) + .clone(); + + let cast_to_type = if let Some(cast_to_types) = cast_to_types.as_ref() { + debug_assert!(cast_to_types.len() == attributes.len()); + cast_to_types.get(idx).cloned().expect("cast_to_type null") + } else { + None + }; + let attribute_context = ArrowToPgAttributeContext::new( attribute_name, attribute_typoid, attribute_typmod, - fields.clone(), + field, + cast_to_type, ); attribute_contexts.push(attribute_context); @@ -206,7 +248,7 @@ pub(crate) fn to_pg_datum( attribute_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - if attribute_context.is_array { + if matches!(attribute_array.data_type(), DataType::List(_)) { to_pg_array_datum(attribute_array, attribute_context) } else { to_pg_nonarray_datum(attribute_array, attribute_context) @@ -227,43 +269,34 @@ fn to_pg_nonarray_datum( primitive_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - match attribute_context.typoid { - FLOAT4OID => { + match attribute_context.data_type() { + DataType::Float32 => { to_pg_datum!(Float32Array, f32, primitive_array, attribute_context) } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!(Float64Array, f64, primitive_array, attribute_context) } - INT2OID => { + DataType::Int16 => { to_pg_datum!(Int16Array, i16, primitive_array, attribute_context) } - INT4OID => { + DataType::Int32 => { to_pg_datum!(Int32Array, i32, primitive_array, attribute_context) } - INT8OID => { + DataType::Int64 => { to_pg_datum!(Int64Array, i64, primitive_array, attribute_context) } - BOOLOID => { - to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) - } - CHAROID => { - to_pg_datum!(StringArray, i8, primitive_array, attribute_context) - } - TEXTOID => { - to_pg_datum!(StringArray, String, primitive_array, attribute_context) - } - BYTEAOID => { - to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) - } - OIDOID => { + DataType::UInt32 => { to_pg_datum!(UInt32Array, Oid, primitive_array, attribute_context) } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Boolean => { + to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) + } + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, i8, primitive_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!(StringArray, String, primitive_array, attribute_context) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -272,72 +305,72 @@ fn to_pg_nonarray_datum( primitive_array, attribute_context ) - } else { - to_pg_datum!( - Decimal128Array, - AnyNumeric, - primitive_array, - attribute_context - ) } } - DATEOID => { - to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) + } else { + to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) + } } - TIMEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Time64MicrosecondArray, - Time, + Decimal128Array, + AnyNumeric, primitive_array, attribute_context ) } - TIMETZOID => { + DataType::Date32 => { + to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + } + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Time, + primitive_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + TimeWithTimeZone, + primitive_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - TimeWithTimeZone, + TimestampMicrosecondArray, + Timestamp, primitive_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => { to_pg_datum!( TimestampMicrosecondArray, - Timestamp, + TimestampWithTimeZone, primitive_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - TimestampWithTimeZone, + StructArray, + PgHeapTuple, primitive_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Map, primitive_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - PgHeapTuple, - primitive_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Map, primitive_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - FallbackToText, - primitive_array, - attribute_context - ) - } + panic!("unsupported data type: {:?}", attribute_context.data_type()); } } } @@ -354,8 +387,13 @@ fn to_pg_array_datum( let list_array = list_array.value(0).to_data(); - match attribute_context.typoid { - FLOAT4OID => { + let element_field = match attribute_context.data_type() { + DataType::List(field) => field, + _ => unreachable!(), + }; + + match element_field.data_type() { + DataType::Float32 => { to_pg_datum!( Float32Array, Vec>, @@ -363,7 +401,7 @@ fn to_pg_array_datum( attribute_context ) } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!( Float64Array, Vec>, @@ -371,16 +409,19 @@ fn to_pg_array_datum( attribute_context ) } - INT2OID => { + DataType::Int16 => { to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) } - INT4OID => { + DataType::Int32 => { to_pg_datum!(Int32Array, Vec>, list_array, attribute_context) } - INT8OID => { + DataType::Int64 => { to_pg_datum!(Int64Array, Vec>, list_array, attribute_context) } - BOOLOID => { + DataType::UInt32 => { + to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) + } + DataType::Boolean => { to_pg_datum!( BooleanArray, Vec>, @@ -388,34 +429,17 @@ fn to_pg_array_datum( attribute_context ) } - CHAROID => { - to_pg_datum!(StringArray, Vec>, list_array, attribute_context) - } - TEXTOID => { - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } - BYTEAOID => { - to_pg_datum!( - BinaryArray, - Vec>>, - list_array, - attribute_context - ) - } - OIDOID => { - to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) - } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!( + StringArray, + Vec>, + list_array, + attribute_context + ) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -424,82 +448,87 @@ fn to_pg_array_datum( list_array, attribute_context ) + } + } + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!( + BinaryArray, + Vec>, + list_array, + attribute_context + ) } else { to_pg_datum!( - Decimal128Array, - Vec>, + BinaryArray, + Vec>>, list_array, attribute_context ) } } - DATEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Date32Array, - Vec>, + Decimal128Array, + Vec>, list_array, attribute_context ) } - TIMEOID => { + DataType::Date32 => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + Date32Array, + Vec>, list_array, attribute_context ) } - TIMETZOID => { + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + TimestampMicrosecondArray, + Vec>, list_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => { to_pg_datum!( TimestampMicrosecondArray, - Vec>, + Vec>, list_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - Vec>, + StructArray, + Vec>>, list_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Vec>, list_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - Vec>>, - list_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Vec>, list_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!( - BinaryArray, - Vec>, - list_array, - attribute_context - ) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } + panic!("unsupported data type: {:?}", attribute_context.data_type()); } } } diff --git a/src/arrow_parquet/arrow_to_pg/timestamptz.rs b/src/arrow_parquet/arrow_to_pg/timestamptz.rs index 81bf5f9..a5a4736 100644 --- a/src/arrow_parquet/arrow_to_pg/timestamptz.rs +++ b/src/arrow_parquet/arrow_to_pg/timestamptz.rs @@ -7,11 +7,16 @@ use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; // Timestamptz impl ArrowArrayToPgType for TimestampMicrosecondArray { - fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + fn to_pg_type(self, context: &ArrowToPgAttributeContext) -> Option { if self.is_null(0) { None } else { - Some(i64_to_timestamptz(self.value(0))) + let timezone = context + .timezone + .as_ref() + .expect("timezone is required for timestamptz"); + + Some(i64_to_timestamptz(self.value(0), timezone)) } } } @@ -20,11 +25,17 @@ impl ArrowArrayToPgType for TimestampMicrosecondArray { impl ArrowArrayToPgType>> for TimestampMicrosecondArray { fn to_pg_type( self, - _context: &ArrowToPgAttributeContext, + context: &ArrowToPgAttributeContext, ) -> Option>> { let mut vals = vec![]; + + let timezone = context + .timezone + .as_ref() + .expect("timezone is required for timestamptz[]"); + for val in self.iter() { - let val = val.map(i64_to_timestamptz); + let val = val.map(|v| i64_to_timestamptz(v, timezone)); vals.push(val); } Some(vals) diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index a3cd53b..cbeff07 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -1,24 +1,32 @@ +use std::sync::Arc; + use arrow::array::RecordBatch; +use arrow_cast::{cast_with_options, CastOptions}; use futures::StreamExt; use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStream}; use pgrx::{ check_for_interrupts, pg_sys::{ - fmgr_info, getTypeBinaryOutputInfo, varlena, Datum, FmgrInfo, InvalidOid, SendFunctionCall, + fmgr_info, getTypeBinaryOutputInfo, varlena, Datum, FmgrInfo, FormData_pg_attribute, + InvalidOid, SendFunctionCall, }, vardata_any, varsize_any_exhdr, void_mut_ptr, AllocatedByPostgres, PgBox, PgTupleDesc, }; use url::Url; use crate::{ - arrow_parquet::arrow_to_pg::to_pg_datum, - pgrx_utils::collect_valid_attributes, + arrow_parquet::{ + arrow_to_pg::to_pg_datum, schema_parser::parquet_schema_string_from_attributes, + }, + pgrx_utils::{collect_attributes_for, CollectAttributesFor}, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, }; use super::{ arrow_to_pg::{collect_arrow_to_pg_attribute_contexts, ArrowToPgAttributeContext}, - schema_parser::ensure_arrow_schema_match_tupledesc, + schema_parser::{ + ensure_file_schema_match_tupledesc_schema, parse_arrow_schema_from_attributes, + }, uri_utils::{parquet_reader_from_uri, PG_BACKEND_TOKIO_RUNTIME}, }; @@ -41,12 +49,35 @@ impl ParquetReaderContext { let parquet_reader = parquet_reader_from_uri(&uri); - let schema = parquet_reader.schema(); - ensure_arrow_schema_match_tupledesc(schema.clone(), tupledesc); + let parquet_file_schema = parquet_reader.schema(); + + let attributes = collect_attributes_for(CollectAttributesFor::CopyFrom, tupledesc); + + pgrx::debug2!( + "schema for tuples: {}", + parquet_schema_string_from_attributes(&attributes) + ); - let binary_out_funcs = Self::collect_binary_out_funcs(tupledesc); + let tupledesc_schema = parse_arrow_schema_from_attributes(&attributes); - let attribute_contexts = collect_arrow_to_pg_attribute_contexts(tupledesc, &schema.fields); + let tupledesc_schema = Arc::new(tupledesc_schema); + + // Ensure that the file schema matches the tupledesc schema. + // Gets cast_to_types for each attribute if a cast is needed for the attribute's columnar array + // to match the expected columnar array for its tupledesc type. + let cast_to_types = ensure_file_schema_match_tupledesc_schema( + parquet_file_schema.clone(), + tupledesc_schema.clone(), + &attributes, + ); + + let attribute_contexts = collect_arrow_to_pg_attribute_contexts( + &attributes, + &tupledesc_schema.fields, + Some(cast_to_types), + ); + + let binary_out_funcs = Self::collect_binary_out_funcs(&attributes); ParquetReaderContext { buffer: Vec::new(), @@ -60,14 +91,11 @@ impl ParquetReaderContext { } fn collect_binary_out_funcs( - tupledesc: &PgTupleDesc, + attributes: &[FormData_pg_attribute], ) -> Vec> { unsafe { let mut binary_out_funcs = vec![]; - let include_generated_columns = false; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); - for att in attributes.iter() { let typoid = att.type_oid(); @@ -94,11 +122,25 @@ impl ParquetReaderContext { for attribute_context in attribute_contexts { let name = attribute_context.name(); - let column = record_batch + let column_array = record_batch .column_by_name(name) .unwrap_or_else(|| panic!("column {} not found", name)); - let datum = to_pg_datum(column.to_data(), attribute_context); + let datum = if attribute_context.needs_cast() { + // should fail instead of returning None if the cast fails at runtime + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + + let casted_column_array = + cast_with_options(&column_array, attribute_context.data_type(), &cast_options) + .unwrap_or_else(|e| panic!("failed to cast column {}: {}", name, e)); + + to_pg_datum(casted_column_array.to_data(), attribute_context) + } else { + to_pg_datum(column_array.to_data(), attribute_context) + }; datums.push(datum); } diff --git a/src/arrow_parquet/parquet_writer.rs b/src/arrow_parquet/parquet_writer.rs index 7c12009..e93ea8b 100644 --- a/src/arrow_parquet/parquet_writer.rs +++ b/src/arrow_parquet/parquet_writer.rs @@ -12,9 +12,12 @@ use url::Url; use crate::{ arrow_parquet::{ compression::{PgParquetCompression, PgParquetCompressionWithLevel}, - schema_parser::parse_arrow_schema_from_tupledesc, + schema_parser::{ + parquet_schema_string_from_attributes, parse_arrow_schema_from_attributes, + }, uri_utils::{parquet_writer_from_uri, PG_BACKEND_TOKIO_RUNTIME}, }, + pgrx_utils::{collect_attributes_for, CollectAttributesFor}, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, }; @@ -57,12 +60,20 @@ impl ParquetWriterContext { .set_created_by("pg_parquet".to_string()) .build(); - let schema = parse_arrow_schema_from_tupledesc(tupledesc); + let attributes = collect_attributes_for(CollectAttributesFor::CopyTo, tupledesc); + + pgrx::debug2!( + "schema for tuples: {}", + parquet_schema_string_from_attributes(&attributes) + ); + + let schema = parse_arrow_schema_from_attributes(&attributes); let schema = Arc::new(schema); let parquet_writer = parquet_writer_from_uri(&uri, schema.clone(), writer_props); - let attribute_contexts = collect_pg_to_arrow_attribute_contexts(tupledesc, &schema.fields); + let attribute_contexts = + collect_pg_to_arrow_attribute_contexts(&attributes, &schema.fields); ParquetWriterContext { parquet_writer, diff --git a/src/arrow_parquet/pg_to_arrow.rs b/src/arrow_parquet/pg_to_arrow.rs index 40cc03c..530c7f7 100644 --- a/src/arrow_parquet/pg_to_arrow.rs +++ b/src/arrow_parquet/pg_to_arrow.rs @@ -7,16 +7,17 @@ use pgrx::{ datum::{Date, Time, TimeWithTimeZone, Timestamp, TimestampWithTimeZone, UnboxDatum}, heap_tuple::PgHeapTuple, pg_sys::{ - Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, - NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, + FormData_pg_attribute, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, + INT2OID, INT4OID, INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, + TIMESTAMPTZOID, TIMETZOID, }, - AllocatedByRust, AnyNumeric, FromDatum, PgTupleDesc, + AllocatedByRust, AnyNumeric, FromDatum, }; use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CollectAttributesFor, }, type_compat::{ fallback_to_text::{reset_fallback_to_text_context, FallbackToText}, @@ -146,7 +147,10 @@ impl PgToArrowAttributeContext { _ => unreachable!(), }; - collect_pg_to_arrow_attribute_contexts(&attribute_tupledesc, &fields) + let attributes = + collect_attributes_for(CollectAttributesFor::Struct, &attribute_tupledesc); + + collect_pg_to_arrow_attribute_contexts(&attributes, &fields) }); Self { @@ -166,11 +170,9 @@ impl PgToArrowAttributeContext { } pub(crate) fn collect_pg_to_arrow_attribute_contexts( - tupledesc: &PgTupleDesc, + attributes: &[FormData_pg_attribute], fields: &Fields, ) -> Vec { - let include_generated_columns = true; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); let mut attribute_contexts = vec![]; for attribute in attributes { diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index 8dd79cf..b76ee70 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -1,18 +1,22 @@ use std::{collections::HashMap, ops::Deref, sync::Arc}; use arrow::datatypes::{Field, Fields, Schema}; -use arrow_schema::FieldRef; +use arrow_cast::can_cast_types; +use arrow_schema::{DataType, FieldRef}; use parquet::arrow::{arrow_to_parquet_schema, PARQUET_FIELD_ID_META_KEY}; use pg_sys::{ - Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, - NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, + can_coerce_type, + CoercionContext::{self, COERCION_EXPLICIT}, + FormData_pg_attribute, InvalidOid, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, + FLOAT8OID, INT2OID, INT4OID, INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, + TIMESTAMPTZOID, TIMETZOID, }; use pgrx::{check_for_interrupts, prelude::*, PgTupleDesc}; use crate::{ pgrx_utils::{ - array_element_typoid, collect_valid_attributes, domain_array_base_elem_typoid, - is_array_type, is_composite_type, tuple_desc, + array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, + is_composite_type, tuple_desc, CollectAttributesFor, }, type_compat::{ geometry::is_postgis_geometry_type, @@ -23,8 +27,10 @@ use crate::{ }, }; -pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: &PgTupleDesc) -> String { - let arrow_schema = parse_arrow_schema_from_tupledesc(tupledesc); +pub(crate) fn parquet_schema_string_from_attributes( + attributes: &[FormData_pg_attribute], +) -> String { + let arrow_schema = parse_arrow_schema_from_attributes(attributes); let parquet_schema = arrow_to_parquet_schema(&arrow_schema) .unwrap_or_else(|e| panic!("failed to convert arrow schema to parquet schema: {}", e)); @@ -33,14 +39,11 @@ pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: &PgTupleDesc) -> S String::from_utf8(buf).unwrap_or_else(|e| panic!("failed to convert schema to string: {}", e)) } -pub(crate) fn parse_arrow_schema_from_tupledesc(tupledesc: &PgTupleDesc) -> Schema { +pub(crate) fn parse_arrow_schema_from_attributes(attributes: &[FormData_pg_attribute]) -> Schema { let mut field_id = 0; let mut struct_attribute_fields = vec![]; - let include_generated_columns = true; - let attributes = collect_valid_attributes(tupledesc, include_generated_columns); - for attribute in attributes { let attribute_name = attribute.name(); let attribute_typoid = attribute.type_oid().value(); @@ -92,8 +95,7 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i let mut child_fields: Vec> = vec![]; - let include_generated_columns = true; - let attributes = collect_valid_attributes(&tupledesc, include_generated_columns); + let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); for attribute in attributes { if attribute.is_dropped() { @@ -130,10 +132,12 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i child_fields.push(child_field); } + let nullable = true; + Field::new( elem_name, arrow::datatypes::DataType::Struct(Fields::from(child_fields)), - true, + nullable, ) .with_metadata(metadata) .into() @@ -149,20 +153,24 @@ fn parse_list_schema(typoid: Oid, typmod: i32, array_name: &str, field_id: &mut *field_id += 1; + let element_name = "element"; + let elem_field = if is_composite_type(typoid) { let tupledesc = tuple_desc(typoid, typmod); - parse_struct_schema(tupledesc, array_name, field_id) + parse_struct_schema(tupledesc, element_name, field_id) } else if is_map_type(typoid) { let base_elem_typoid = domain_array_base_elem_typoid(typoid); - parse_map_schema(base_elem_typoid, typmod, array_name, field_id) + parse_map_schema(base_elem_typoid, typmod, element_name, field_id) } else { - parse_primitive_schema(typoid, typmod, array_name, field_id) + parse_primitive_schema(typoid, typmod, element_name, field_id) }; + let nullable = true; + Field::new( array_name, arrow::datatypes::DataType::List(elem_field), - true, + nullable, ) .with_metadata(list_metadata) .into() @@ -177,13 +185,18 @@ fn parse_map_schema(typoid: Oid, typmod: i32, map_name: &str, field_id: &mut i32 *field_id += 1; let tupledesc = tuple_desc(typoid, typmod); + let entries_field = parse_struct_schema(tupledesc, map_name, field_id); let entries_field = adjust_map_entries_field(entries_field); + let keys_sorted = false; + + let nullable = true; + Field::new( map_name, - arrow::datatypes::DataType::Map(entries_field, false), - true, + arrow::datatypes::DataType::Map(entries_field, keys_sorted), + nullable, ) .with_metadata(map_metadata) .into() @@ -204,31 +217,33 @@ fn parse_primitive_schema( *field_id += 1; + let nullable = true; + let field = match typoid { - FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, true), - FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, true), - BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, true), - INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, true), - INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, true), - INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, true), + FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, nullable), + FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, nullable), + BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, nullable), + INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, nullable), + INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, nullable), + INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, nullable), NUMERICOID => { let (precision, scale) = extract_precision_and_scale_from_numeric_typmod(typmod); if should_write_numeric_as_text(precision) { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } else { Field::new( elem_name, arrow::datatypes::DataType::Decimal128(precision as _, scale as _), - true, + nullable, ) } } - DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, true), + DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, nullable), TIMESTAMPOID => Field::new( elem_name, arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), - true, + nullable, ), TIMESTAMPTZOID => Field::new( elem_name, @@ -236,31 +251,31 @@ fn parse_primitive_schema( arrow::datatypes::TimeUnit::Microsecond, Some("+00:00".into()), ), - true, + nullable, ), TIMEOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ), TIMETZOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ) .with_metadata(HashMap::from_iter(vec![( "adjusted_to_utc".into(), "true".into(), )])), - CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, true), - OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, true), + CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable), + OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, nullable), _ => { if is_postgis_geometry_type(typoid) { - Field::new(elem_name, arrow::datatypes::DataType::Binary, true) + Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable) } else { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } } }; @@ -277,72 +292,226 @@ fn parse_primitive_schema( } fn adjust_map_entries_field(field: FieldRef) -> FieldRef { - let name = field.deref().name(); - let data_type = field.deref().data_type(); - let metadata = field.deref().metadata().clone(); - let not_nullable_key_field; let nullable_value_field; - match data_type { + match field.deref().data_type() { arrow::datatypes::DataType::Struct(fields) => { let key_field = fields.find("key").expect("expected key field").1; let value_field = fields.find("val").expect("expected val field").1; - not_nullable_key_field = - Field::new(key_field.name(), key_field.data_type().clone(), false) - .with_metadata(key_field.metadata().clone()); + let key_nullable = false; + + not_nullable_key_field = Field::new( + key_field.name(), + key_field.data_type().clone(), + key_nullable, + ) + .with_metadata(key_field.metadata().clone()); + + let value_nullable = true; - nullable_value_field = - Field::new(value_field.name(), value_field.data_type().clone(), true) - .with_metadata(value_field.metadata().clone()); + nullable_value_field = Field::new( + value_field.name(), + value_field.data_type().clone(), + value_nullable, + ) + .with_metadata(value_field.metadata().clone()); } _ => { - panic!("expected struct data type for map entries") + panic!("expected struct data type for map key_value field") } }; + let entries_nullable = false; + + let entries_name = "key_value"; + + let metadata = field.deref().metadata().clone(); + let entries_field = Field::new( - name, + entries_name, arrow::datatypes::DataType::Struct(Fields::from(vec![ not_nullable_key_field, nullable_value_field, ])), - false, + entries_nullable, ) .with_metadata(metadata); Arc::new(entries_field) } -pub(crate) fn ensure_arrow_schema_match_tupledesc( +// ensure_file_schema_match_tupledesc_schema throws an error if the file's schema does not match the table schema. +// If the file's arrow schema is castable to the table's arrow schema, it returns a vector of Option +// to cast to for each field. +pub(crate) fn ensure_file_schema_match_tupledesc_schema( file_schema: Arc, - tupledesc: &PgTupleDesc, -) { - let table_schema = parse_arrow_schema_from_tupledesc(tupledesc); - - for table_schema_field in table_schema.fields().iter() { - let table_schema_field_name = table_schema_field.name(); - let table_schema_field_type = table_schema_field.data_type(); - - let file_schema_field = file_schema.column_with_name(table_schema_field_name); - - if let Some(file_schema_field) = file_schema_field { - let file_schema_field_type = file_schema_field.1.data_type(); - - if file_schema_field_type != table_schema_field_type { - panic!( - "type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but file had \"{}\"", - table_schema_field_name, - table_schema_field_type, - file_schema_field_type, - ); - } - } else { + tupledesc_schema: Arc, + attributes: &[FormData_pg_attribute], +) -> Vec> { + let mut cast_to_types = Vec::new(); + + for (tupledesc_schema_field, attribute) in + tupledesc_schema.fields().iter().zip(attributes.iter()) + { + let field_name = tupledesc_schema_field.name(); + + let file_schema_field = file_schema.column_with_name(field_name); + + if file_schema_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } + + let (_, file_schema_field) = file_schema_field.unwrap(); + let file_schema_field = Arc::new(file_schema_field.clone()); + + let from_type = file_schema_field.data_type(); + let to_type = tupledesc_schema_field.data_type(); + + // no cast needed + if from_type == to_type { + cast_to_types.push(None); + continue; + } + + if !is_coercible(from_type, to_type, attribute.atttypid, attribute.atttypmod) { panic!( - "column \"{}\" is not found in parquet file", - table_schema_field_name + "type mismatch for column \"{}\" between table and parquet file.\n\n\ + table has \"{}\"\n\nparquet file has \"{}\"", + field_name, to_type, from_type ); } + + pgrx::debug2!( + "column \"{}\" is being cast from \"{}\" to \"{}\"", + field_name, + from_type, + to_type + ); + + cast_to_types.push(Some(to_type.clone())); + } + + cast_to_types +} + +// is_coercible first checks if "from_type" can be cast to "to_type" by arrow-cast. +// Then, it checks if the cast is meaningful at Postgres by seeing if there is +// an explicit coercion from "from_typoid" to "to_typoid". +// +// Additionaly, we need to be careful about struct rules for the cast: +// Arrow supports casting struct fields by field position instead of field name, +// which is not the intended behavior for pg_parquet. Hence, we make sure the field names +// match for structs. +fn is_coercible(from_type: &DataType, to_type: &DataType, to_typoid: Oid, to_typmod: i32) -> bool { + match (from_type, to_type) { + (DataType::Struct(from_fields), DataType::Struct(to_fields)) => { + if from_fields.len() != to_fields.len() { + return false; + } + + let tupledesc = tuple_desc(to_typoid, to_typmod); + + let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); + + for (from_field, (to_field, to_attribute)) in from_fields + .iter() + .zip(to_fields.iter().zip(attributes.iter())) + { + if from_field.name() != to_field.name() { + return false; + } + + if !is_coercible( + from_field.data_type(), + to_field.data_type(), + to_attribute.type_oid().value(), + to_attribute.type_mod(), + ) { + return false; + } + } + + true + } + (DataType::List(from_field), DataType::List(to_field)) => { + let element_oid = array_element_typoid(to_typoid); + let element_typmod = to_typmod; + + is_coercible( + from_field.data_type(), + to_field.data_type(), + element_oid, + element_typmod, + ) + } + (DataType::Map(from_entries_field, _), DataType::Map(to_entries_field, _)) => { + // entries field cannot be null + if from_entries_field.is_nullable() { + return false; + } + + let entries_typoid = domain_array_base_elem_typoid(to_typoid); + + is_coercible( + from_entries_field.data_type(), + to_entries_field.data_type(), + entries_typoid, + to_typmod, + ) + } + _ => { + // check if arrow-cast can cast the types + if !can_cast_types(from_type, to_type) { + return false; + } + + let from_typoid = pg_type_for_arrow_primitive_type(from_type); + + // pg_parquet could not recognize that arrow type + if from_typoid == InvalidOid { + return false; + } + + // check if coercion is meaningful at Postgres (it has a coercion path) + can_pg_coerce_types(from_typoid, to_typoid, COERCION_EXPLICIT) + } + } +} + +fn can_pg_coerce_types(from_typoid: Oid, to_typoid: Oid, ccontext: CoercionContext::Type) -> bool { + let n_args = 1; + let input_typeids = [from_typoid]; + let target_typeids = [to_typoid]; + + unsafe { + can_coerce_type( + n_args, + input_typeids.as_ptr(), + target_typeids.as_ptr(), + ccontext, + ) + } +} + +// pg_type_for_arrow_primitive_type returns Postgres type for given +// primitive arrow type. It returns InvalidOid if the arrow type is not recognized. +fn pg_type_for_arrow_primitive_type(data_type: &DataType) -> Oid { + match data_type { + DataType::Float32 | DataType::Float16 => FLOAT4OID, + DataType::Float64 => FLOAT8OID, + DataType::Int16 | DataType::UInt16 | DataType::Int8 | DataType::UInt8 => INT2OID, + DataType::Int32 | DataType::UInt32 => INT4OID, + DataType::Int64 | DataType::UInt64 => INT8OID, + DataType::Decimal128(_, _) => NUMERICOID, + DataType::Boolean => BOOLOID, + DataType::Date32 => DATEOID, + DataType::Time64(_) => TIMEOID, + DataType::Timestamp(_, None) => TIMESTAMPOID, + DataType::Timestamp(_, Some(_)) => TIMESTAMPTZOID, + DataType::Utf8 | DataType::LargeUtf8 => TEXTOID, + DataType::Binary | DataType::LargeBinary => BYTEAOID, + _ => InvalidOid, } } diff --git a/src/lib.rs b/src/lib.rs index 57584bb..809a565 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,8 +37,11 @@ pub extern "C" fn _PG_init() { #[cfg(any(test, feature = "pg_test"))] #[pg_schema] mod tests { + use std::fs::File; use std::io::Write; use std::marker::PhantomData; + use std::sync::Arc; + use std::vec; use std::{collections::HashMap, fmt::Debug}; use crate::arrow_parquet::compression::PgParquetCompression; @@ -46,8 +49,19 @@ mod tests { use crate::type_compat::geometry::Geometry; use crate::type_compat::map::Map; use crate::type_compat::pg_arrow_type_conversions::{ + date_to_i32, time_to_i64, timestamp_to_i64, timestamptz_to_i64, timetz_to_i64, DEFAULT_UNBOUNDED_NUMERIC_PRECISION, DEFAULT_UNBOUNDED_NUMERIC_SCALE, }; + use arrow::array::{ + ArrayRef, BinaryArray, BooleanArray, Date32Array, Decimal128Array, Float32Array, + Float64Array, Int16Array, Int32Array, Int8Array, LargeBinaryArray, LargeStringArray, + ListArray, MapArray, RecordBatch, StringArray, StructArray, Time64MicrosecondArray, + TimestampMicrosecondArray, UInt16Array, UInt32Array, UInt64Array, + }; + use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow::datatypes::UInt16Type; + use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit}; + use parquet::arrow::ArrowWriter; use pgrx::pg_sys::Oid; use pgrx::{ composite_type, @@ -340,6 +354,14 @@ mod tests { Spi::get_one(&query).unwrap().unwrap() } + fn write_record_batch_to_parquet(schema: SchemaRef, record_batch: RecordBatch) { + let file = File::create("/tmp/test.parquet").unwrap(); + let mut writer = ArrowWriter::try_new(file, schema, None).unwrap(); + + writer.write(&record_batch).unwrap(); + writer.close().unwrap(); + } + #[pg_test] fn test_int2() { let test_table = TestTable::::new("int2".into()); @@ -1140,7 +1162,7 @@ mod tests { test_helper(test_table); let parquet_schema_command = - "select precision, scale, logical_type, type_name from parquet.schema('/tmp/test.parquet') WHERE name = 'a' ORDER BY logical_type;"; + "select precision, scale, logical_type, type_name from parquet.schema('/tmp/test.parquet') WHERE name = 'element' ORDER BY logical_type;"; let attribute_schema = Spi::connect(|client| { let tup_table = client.select(parquet_schema_command, None, None).unwrap(); @@ -1158,7 +1180,7 @@ mod tests { results }); - assert_eq!(attribute_schema.len(), 2); + assert_eq!(attribute_schema.len(), 1); assert_eq!( attribute_schema[0], ( @@ -1168,7 +1190,6 @@ mod tests { Some("FIXED_LEN_BYTE_ARRAY".to_string()) ) ); - assert_eq!(attribute_schema[1], (None, None, "LIST".to_string(), None)); } #[pg_test] @@ -1391,6 +1412,973 @@ mod tests { Spi::run("DROP TYPE dog;").unwrap(); } + #[pg_test] + fn test_coerce_primitive_types() { + // INT16 => {int, bigint} + let x_nullable = false; + let y_nullable = true; + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int16, x_nullable), + Field::new("y", DataType::Int16, y_nullable), + ])); + + let x = Arc::new(Int16Array::from(vec![1])); + let y = Arc::new(Int16Array::from(vec![2])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int, y bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // INT32 => {bigint} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, true)])); + + let x = Arc::new(Int32Array::from(vec![1])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // FLOAT32 => {double} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, true)])); + + let x = Arc::new(Float32Array::from(vec![1.123])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x double precision)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value as f32, 1.123); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // FLOAT64 => {float} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float64, true)])); + + let x = Arc::new(Float64Array::from(vec![1.123])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x real)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1.123); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // DATE32 => {timestamp} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Date32, true)])); + + let date = Date::new(2022, 5, 5).unwrap(); + + let x = Arc::new(Date32Array::from(vec![date_to_i32(date)])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timestamp)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, Timestamp::from(date)); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIMESTAMP => {timestamptz} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )])); + + let timestamp = Timestamp::from(Date::new(2022, 5, 5).unwrap()); + + let x = Arc::new(TimestampMicrosecondArray::from(vec![timestamp_to_i64( + timestamp, + )])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timestamptz)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value.at_timezone("UTC").unwrap(), timestamp); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIMESTAMPTZ => {timestamp} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Timestamp(TimeUnit::Microsecond, Some("Europe/Paris".into())), + true, + )])); + + let timestamptz = + TimestampWithTimeZone::with_timezone(2022, 5, 5, 0, 0, 0.0, "Europe/Paris").unwrap(); + + let x = Arc::new( + TimestampMicrosecondArray::from(vec![timestamptz_to_i64(timestamptz)]) + .with_timezone("Europe/Paris"), + ); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timestamp)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, timestamptz.at_timezone("UTC").unwrap()); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIME64 => {timetz} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Time64(TimeUnit::Microsecond), + true, + )])); + + let time = Time::new(13, 0, 0.0).unwrap(); + + let x = Arc::new(Time64MicrosecondArray::from(vec![time_to_i64(time)])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x timetz)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, time.into()); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // TIME64 => {time} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Time64(TimeUnit::Microsecond), + true, + ) + .with_metadata(HashMap::from_iter(vec![( + "adjusted_to_utc".into(), + "true".into(), + )]))])); + + let timetz = TimeWithTimeZone::with_timezone(13, 0, 0.0, "UTC").unwrap(); + + let x = Arc::new(Time64MicrosecondArray::from(vec![timetz_to_i64(timetz)])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x time)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::