diff --git a/Cargo.lock b/Cargo.lock index d594d73..7150e6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2231,6 +2231,7 @@ name = "pg_parquet" version = "0.1.0" dependencies = [ "arrow", + "arrow-cast", "arrow-schema", "aws-config", "aws-credential-types", diff --git a/Cargo.toml b/Cargo.toml index b5a372b..76a9e1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ pg_test = [] [dependencies] arrow = {version = "53", default-features = false} +arrow-cast = {version = "53", default-features = false} arrow-schema = {version = "53", default-features = false} aws-config = { version = "1.5", default-features = false, features = ["rustls"]} aws-credential-types = {version = "1.2", default-features = false} diff --git a/src/arrow_parquet/arrow_to_pg.rs b/src/arrow_parquet/arrow_to_pg.rs index ec7c9ce..963634a 100644 --- a/src/arrow_parquet/arrow_to_pg.rs +++ b/src/arrow_parquet/arrow_to_pg.rs @@ -1,15 +1,14 @@ +use std::ops::Deref; + use arrow::array::{ Array, ArrayData, BinaryArray, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, ListArray, MapArray, StringArray, StructArray, Time64MicrosecondArray, TimestampMicrosecondArray, UInt32Array, }; -use arrow_schema::Fields; +use arrow_schema::{DataType, FieldRef, Fields, TimeUnit}; use pgrx::{ datum::{Date, Time, TimeWithTimeZone, Timestamp, TimestampWithTimeZone}, - pg_sys::{ - Datum, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, - INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, - }, + pg_sys::{Datum, Oid, CHAROID, NUMERICOID, TEXTOID, TIMEOID}, prelude::PgHeapTuple, AllocatedByRust, AnyNumeric, IntoDatum, PgTupleDesc, }; @@ -23,9 +22,7 @@ use crate::{ fallback_to_text::{reset_fallback_to_text_context, FallbackToText}, geometry::{is_postgis_geometry_type, Geometry}, map::{is_map_type, Map}, - pg_arrow_type_conversions::{ - extract_precision_and_scale_from_numeric_typmod, should_write_numeric_as_text, - }, + pg_arrow_type_conversions::extract_precision_and_scale_from_numeric_typmod, }, }; @@ -57,12 +54,11 @@ pub(crate) trait ArrowArrayToPgType: From { #[derive(Clone)] pub(crate) struct ArrowToPgAttributeContext { name: String, + cast_to: Option, + field: FieldRef, typoid: Oid, typmod: i32, - is_array: bool, - is_composite: bool, is_geometry: bool, - is_map: bool, attribute_contexts: Option>, attribute_tupledesc: Option>, precision: Option, @@ -70,7 +66,13 @@ pub(crate) struct ArrowToPgAttributeContext { } impl ArrowToPgAttributeContext { - pub(crate) fn new(name: &str, typoid: Oid, typmod: i32, fields: Fields) -> Self { + pub(crate) fn new( + name: &str, + typoid: Oid, + typmod: i32, + fields: Fields, + cast_to: Option, + ) -> Self { let field = fields .iter() .find(|field| field.name() == name) @@ -147,9 +149,11 @@ impl ArrowToPgAttributeContext { _ => unreachable!(), }; + // we only cast the top-level attributes, which already covers the nested attributes Some(collect_arrow_to_pg_attribute_contexts( attribute_tupledesc, &fields, + None, )) } else { None @@ -157,12 +161,11 @@ impl ArrowToPgAttributeContext { Self { name: name.to_string(), + cast_to, + field, typoid: attribute_typoid, typmod, - is_array, - is_composite, is_geometry, - is_map, attribute_contexts, attribute_tupledesc, scale, @@ -173,27 +176,48 @@ impl ArrowToPgAttributeContext { pub(crate) fn name(&self) -> &str { &self.name } + + pub(crate) fn cast_to(&self) -> &Option { + &self.cast_to + } + + pub(crate) fn data_type(&self) -> &DataType { + if let Some(cast_to) = &self.cast_to { + cast_to + } else { + self.field.data_type() + } + } } pub(crate) fn collect_arrow_to_pg_attribute_contexts( tupledesc: &PgTupleDesc, fields: &Fields, + cast_to_types: Option>>, ) -> Vec { // parquet file does not contain generated columns. PG will handle them. let include_generated_columns = false; let attributes = collect_valid_attributes(tupledesc, include_generated_columns); let mut attribute_contexts = vec![]; - for attribute in attributes { + for (idx, attribute) in attributes.iter().enumerate() { let attribute_name = attribute.name(); let attribute_typoid = attribute.type_oid().value(); let attribute_typmod = attribute.type_mod(); + let cast_to = if let Some(cast_to_types) = cast_to_types.as_ref() { + debug_assert!(cast_to_types.len() == attributes.len()); + cast_to_types.get(idx).cloned().expect("cast_to null") + } else { + None + }; + let attribute_context = ArrowToPgAttributeContext::new( attribute_name, attribute_typoid, attribute_typmod, fields.clone(), + cast_to, ); attribute_contexts.push(attribute_context); @@ -206,7 +230,7 @@ pub(crate) fn to_pg_datum( attribute_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - if attribute_context.is_array { + if matches!(attribute_array.data_type(), DataType::List(_)) { to_pg_array_datum(attribute_array, attribute_context) } else { to_pg_nonarray_datum(attribute_array, attribute_context) @@ -227,43 +251,34 @@ fn to_pg_nonarray_datum( primitive_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - match attribute_context.typoid { - FLOAT4OID => { + match attribute_context.data_type() { + DataType::Float32 => { to_pg_datum!(Float32Array, f32, primitive_array, attribute_context) } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!(Float64Array, f64, primitive_array, attribute_context) } - INT2OID => { + DataType::Int16 => { to_pg_datum!(Int16Array, i16, primitive_array, attribute_context) } - INT4OID => { + DataType::Int32 => { to_pg_datum!(Int32Array, i32, primitive_array, attribute_context) } - INT8OID => { + DataType::Int64 => { to_pg_datum!(Int64Array, i64, primitive_array, attribute_context) } - BOOLOID => { - to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) - } - CHAROID => { - to_pg_datum!(StringArray, i8, primitive_array, attribute_context) - } - TEXTOID => { - to_pg_datum!(StringArray, String, primitive_array, attribute_context) - } - BYTEAOID => { - to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) - } - OIDOID => { + DataType::UInt32 => { to_pg_datum!(UInt32Array, Oid, primitive_array, attribute_context) } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Boolean => { + to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) + } + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, i8, primitive_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!(StringArray, String, primitive_array, attribute_context) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -272,72 +287,74 @@ fn to_pg_nonarray_datum( primitive_array, attribute_context ) - } else { - to_pg_datum!( - Decimal128Array, - AnyNumeric, - primitive_array, - attribute_context - ) } } - DATEOID => { - to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) + } else { + to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) + } } - TIMEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Time64MicrosecondArray, - Time, + Decimal128Array, + AnyNumeric, primitive_array, attribute_context ) } - TIMETZOID => { + DataType::Date32 => { + to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + } + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Time, + primitive_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + TimeWithTimeZone, + primitive_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - TimeWithTimeZone, + TimestampMicrosecondArray, + Timestamp, primitive_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone_str)) + if timezone_str.deref() == "+00:00" => + { to_pg_datum!( TimestampMicrosecondArray, - Timestamp, + TimestampWithTimeZone, primitive_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - TimestampWithTimeZone, + StructArray, + PgHeapTuple, primitive_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Map, primitive_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - PgHeapTuple, - primitive_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Map, primitive_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - FallbackToText, - primitive_array, - attribute_context - ) - } + panic!("unsupported data type: {:?}", attribute_context.data_type()); } } } @@ -354,8 +371,13 @@ fn to_pg_array_datum( let list_array = list_array.value(0).to_data(); - match attribute_context.typoid { - FLOAT4OID => { + let element_field = match attribute_context.data_type() { + DataType::List(field) => field, + _ => unreachable!(), + }; + + match element_field.data_type() { + DataType::Float32 => { to_pg_datum!( Float32Array, Vec>, @@ -363,7 +385,7 @@ fn to_pg_array_datum( attribute_context ) } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!( Float64Array, Vec>, @@ -371,16 +393,19 @@ fn to_pg_array_datum( attribute_context ) } - INT2OID => { + DataType::Int16 => { to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) } - INT4OID => { + DataType::Int32 => { to_pg_datum!(Int32Array, Vec>, list_array, attribute_context) } - INT8OID => { + DataType::Int64 => { to_pg_datum!(Int64Array, Vec>, list_array, attribute_context) } - BOOLOID => { + DataType::UInt32 => { + to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) + } + DataType::Boolean => { to_pg_datum!( BooleanArray, Vec>, @@ -388,34 +413,17 @@ fn to_pg_array_datum( attribute_context ) } - CHAROID => { - to_pg_datum!(StringArray, Vec>, list_array, attribute_context) - } - TEXTOID => { - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } - BYTEAOID => { - to_pg_datum!( - BinaryArray, - Vec>>, - list_array, - attribute_context - ) - } - OIDOID => { - to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) - } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!( + StringArray, + Vec>, + list_array, + attribute_context + ) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -424,82 +432,89 @@ fn to_pg_array_datum( list_array, attribute_context ) + } + } + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!( + BinaryArray, + Vec>, + list_array, + attribute_context + ) } else { to_pg_datum!( - Decimal128Array, - Vec>, + BinaryArray, + Vec>>, list_array, attribute_context ) } } - DATEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Date32Array, - Vec>, + Decimal128Array, + Vec>, list_array, attribute_context ) } - TIMEOID => { + DataType::Date32 => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + Date32Array, + Vec>, list_array, attribute_context ) } - TIMETZOID => { + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + TimestampMicrosecondArray, + Vec>, list_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone_str)) + if timezone_str.deref() == "+00:00" => + { to_pg_datum!( TimestampMicrosecondArray, - Vec>, + Vec>, list_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - Vec>, + StructArray, + Vec>>, list_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Vec>, list_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - Vec>>, - list_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Vec>, list_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!( - BinaryArray, - Vec>, - list_array, - attribute_context - ) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } + panic!("unsupported data type: {:?}", attribute_context.data_type()); } } } diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index a3cd53b..dbd02f7 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -1,4 +1,7 @@ +use std::sync::Arc; + use arrow::array::RecordBatch; +use arrow_cast::{cast_with_options, CastOptions}; use futures::StreamExt; use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStream}; use pgrx::{ @@ -18,7 +21,9 @@ use crate::{ use super::{ arrow_to_pg::{collect_arrow_to_pg_attribute_contexts, ArrowToPgAttributeContext}, - schema_parser::ensure_arrow_schema_match_tupledesc, + schema_parser::{ + ensure_arrow_schema_match_tupledesc_schema, parse_arrow_schema_from_tupledesc, + }, uri_utils::{parquet_reader_from_uri, PG_BACKEND_TOKIO_RUNTIME}, }; @@ -41,12 +46,29 @@ impl ParquetReaderContext { let parquet_reader = parquet_reader_from_uri(&uri); - let schema = parquet_reader.schema(); - ensure_arrow_schema_match_tupledesc(schema.clone(), tupledesc); + let parquet_file_schema = parquet_reader.schema(); - let binary_out_funcs = Self::collect_binary_out_funcs(tupledesc); + let include_generated_columns = false; + let tupledesc_schema = + parse_arrow_schema_from_tupledesc(tupledesc, include_generated_columns); + + let tupledesc_schema = Arc::new(tupledesc_schema); + + // Ensure that the arrow schema matches the tupledesc. + // Gets cast_to types for each attribute if a cast is needed for the attribute's columnar array + // to match the expected columnar array for its tupledesc type. + let cast_to_types = ensure_arrow_schema_match_tupledesc_schema( + parquet_file_schema.clone(), + tupledesc_schema.clone(), + ); - let attribute_contexts = collect_arrow_to_pg_attribute_contexts(tupledesc, &schema.fields); + let attribute_contexts = collect_arrow_to_pg_attribute_contexts( + tupledesc, + &tupledesc_schema.fields, + Some(cast_to_types), + ); + + let binary_out_funcs = Self::collect_binary_out_funcs(tupledesc); ParquetReaderContext { buffer: Vec::new(), @@ -93,12 +115,26 @@ impl ParquetReaderContext { for attribute_context in attribute_contexts { let name = attribute_context.name(); + let cast_to = attribute_context.cast_to(); - let column = record_batch + let column_array = record_batch .column_by_name(name) .unwrap_or_else(|| panic!("column {} not found", name)); - let datum = to_pg_datum(column.to_data(), attribute_context); + let datum = if let Some(cast_to) = cast_to { + // should fail instead of returning None if the cast fails at runtime + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + + let casted_column_array = cast_with_options(&column_array, cast_to, &cast_options) + .unwrap_or_else(|e| panic!("failed to cast column {}: {}", name, e)); + + to_pg_datum(casted_column_array.to_data(), attribute_context) + } else { + to_pg_datum(column_array.to_data(), attribute_context) + }; datums.push(datum); } diff --git a/src/arrow_parquet/parquet_writer.rs b/src/arrow_parquet/parquet_writer.rs index 7c12009..21d867c 100644 --- a/src/arrow_parquet/parquet_writer.rs +++ b/src/arrow_parquet/parquet_writer.rs @@ -57,7 +57,8 @@ impl ParquetWriterContext { .set_created_by("pg_parquet".to_string()) .build(); - let schema = parse_arrow_schema_from_tupledesc(tupledesc); + let include_generated_columns = true; + let schema = parse_arrow_schema_from_tupledesc(tupledesc, include_generated_columns); let schema = Arc::new(schema); let parquet_writer = parquet_writer_from_uri(&uri, schema.clone(), writer_props); diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index 8dd79cf..9a94048 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -1,7 +1,8 @@ use std::{collections::HashMap, ops::Deref, sync::Arc}; use arrow::datatypes::{Field, Fields, Schema}; -use arrow_schema::FieldRef; +use arrow_cast::can_cast_types; +use arrow_schema::{DataType, FieldRef}; use parquet::arrow::{arrow_to_parquet_schema, PARQUET_FIELD_ID_META_KEY}; use pg_sys::{ Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, @@ -24,7 +25,8 @@ use crate::{ }; pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: &PgTupleDesc) -> String { - let arrow_schema = parse_arrow_schema_from_tupledesc(tupledesc); + let include_generated_columns = true; + let arrow_schema = parse_arrow_schema_from_tupledesc(tupledesc, include_generated_columns); let parquet_schema = arrow_to_parquet_schema(&arrow_schema) .unwrap_or_else(|e| panic!("failed to convert arrow schema to parquet schema: {}", e)); @@ -33,12 +35,14 @@ pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: &PgTupleDesc) -> S String::from_utf8(buf).unwrap_or_else(|e| panic!("failed to convert schema to string: {}", e)) } -pub(crate) fn parse_arrow_schema_from_tupledesc(tupledesc: &PgTupleDesc) -> Schema { +pub(crate) fn parse_arrow_schema_from_tupledesc( + tupledesc: &PgTupleDesc, + include_generated_columns: bool, +) -> Schema { let mut field_id = 0; let mut struct_attribute_fields = vec![]; - let include_generated_columns = true; let attributes = collect_valid_attributes(tupledesc, include_generated_columns); for attribute in attributes { @@ -130,10 +134,12 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i child_fields.push(child_field); } + let nullable = true; + Field::new( elem_name, arrow::datatypes::DataType::Struct(Fields::from(child_fields)), - true, + nullable, ) .with_metadata(metadata) .into() @@ -159,10 +165,12 @@ fn parse_list_schema(typoid: Oid, typmod: i32, array_name: &str, field_id: &mut parse_primitive_schema(typoid, typmod, array_name, field_id) }; + let nullable = true; + Field::new( array_name, arrow::datatypes::DataType::List(elem_field), - true, + nullable, ) .with_metadata(list_metadata) .into() @@ -177,13 +185,18 @@ fn parse_map_schema(typoid: Oid, typmod: i32, map_name: &str, field_id: &mut i32 *field_id += 1; let tupledesc = tuple_desc(typoid, typmod); + let entries_field = parse_struct_schema(tupledesc, map_name, field_id); let entries_field = adjust_map_entries_field(entries_field); + let keys_sorted = false; + + let nullable = true; + Field::new( map_name, - arrow::datatypes::DataType::Map(entries_field, false), - true, + arrow::datatypes::DataType::Map(entries_field, keys_sorted), + nullable, ) .with_metadata(map_metadata) .into() @@ -204,31 +217,33 @@ fn parse_primitive_schema( *field_id += 1; + let nullable = true; + let field = match typoid { - FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, true), - FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, true), - BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, true), - INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, true), - INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, true), - INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, true), + FLOAT4OID => Field::new(elem_name, arrow::datatypes::DataType::Float32, nullable), + FLOAT8OID => Field::new(elem_name, arrow::datatypes::DataType::Float64, nullable), + BOOLOID => Field::new(elem_name, arrow::datatypes::DataType::Boolean, nullable), + INT2OID => Field::new(elem_name, arrow::datatypes::DataType::Int16, nullable), + INT4OID => Field::new(elem_name, arrow::datatypes::DataType::Int32, nullable), + INT8OID => Field::new(elem_name, arrow::datatypes::DataType::Int64, nullable), NUMERICOID => { let (precision, scale) = extract_precision_and_scale_from_numeric_typmod(typmod); if should_write_numeric_as_text(precision) { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } else { Field::new( elem_name, arrow::datatypes::DataType::Decimal128(precision as _, scale as _), - true, + nullable, ) } } - DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, true), + DATEOID => Field::new(elem_name, arrow::datatypes::DataType::Date32, nullable), TIMESTAMPOID => Field::new( elem_name, arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), - true, + nullable, ), TIMESTAMPTZOID => Field::new( elem_name, @@ -236,31 +251,31 @@ fn parse_primitive_schema( arrow::datatypes::TimeUnit::Microsecond, Some("+00:00".into()), ), - true, + nullable, ), TIMEOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ), TIMETZOID => Field::new( elem_name, arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), - true, + nullable, ) .with_metadata(HashMap::from_iter(vec![( "adjusted_to_utc".into(), "true".into(), )])), - CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, true), - BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, true), - OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, true), + CHAROID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + TEXTOID => Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable), + BYTEAOID => Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable), + OIDOID => Field::new(elem_name, arrow::datatypes::DataType::UInt32, nullable), _ => { if is_postgis_geometry_type(typoid) { - Field::new(elem_name, arrow::datatypes::DataType::Binary, true) + Field::new(elem_name, arrow::datatypes::DataType::Binary, nullable) } else { - Field::new(elem_name, arrow::datatypes::DataType::Utf8, true) + Field::new(elem_name, arrow::datatypes::DataType::Utf8, nullable) } } }; @@ -289,60 +304,92 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { let key_field = fields.find("key").expect("expected key field").1; let value_field = fields.find("val").expect("expected val field").1; - not_nullable_key_field = - Field::new(key_field.name(), key_field.data_type().clone(), false) - .with_metadata(key_field.metadata().clone()); + let key_nullable = false; + + not_nullable_key_field = Field::new( + key_field.name(), + key_field.data_type().clone(), + key_nullable, + ) + .with_metadata(key_field.metadata().clone()); + + let value_nullable = true; - nullable_value_field = - Field::new(value_field.name(), value_field.data_type().clone(), true) - .with_metadata(value_field.metadata().clone()); + nullable_value_field = Field::new( + value_field.name(), + value_field.data_type().clone(), + value_nullable, + ) + .with_metadata(value_field.metadata().clone()); } _ => { panic!("expected struct data type for map entries") } }; + let entries_nullable = false; + let entries_field = Field::new( name, arrow::datatypes::DataType::Struct(Fields::from(vec![ not_nullable_key_field, nullable_value_field, ])), - false, + entries_nullable, ) .with_metadata(metadata); Arc::new(entries_field) } -pub(crate) fn ensure_arrow_schema_match_tupledesc( - file_schema: Arc, - tupledesc: &PgTupleDesc, -) { - let table_schema = parse_arrow_schema_from_tupledesc(tupledesc); +// ensure_arrow_schema_match_tupledesc_schema throws an error if the arrow schema does not match the table schema. +// If the arrow schema is castable to the table schema, it returns a vector of Option to cast to +// for each field. +pub(crate) fn ensure_arrow_schema_match_tupledesc_schema( + arrow_schema: Arc, + tupledesc_schema: Arc, +) -> Vec> { + let mut cast_to = Vec::new(); - for table_schema_field in table_schema.fields().iter() { - let table_schema_field_name = table_schema_field.name(); - let table_schema_field_type = table_schema_field.data_type(); + for tupledesc_field in tupledesc_schema.fields().iter() { + let field_name = tupledesc_field.name(); - let file_schema_field = file_schema.column_with_name(table_schema_field_name); + let arrow_field = arrow_schema.column_with_name(field_name); - if let Some(file_schema_field) = file_schema_field { - let file_schema_field_type = file_schema_field.1.data_type(); + if arrow_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } - if file_schema_field_type != table_schema_field_type { - panic!( - "type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but file had \"{}\"", - table_schema_field_name, - table_schema_field_type, - file_schema_field_type, - ); - } - } else { + let (_, arrow_field) = arrow_field.unwrap(); + let arrow_field = Arc::new(arrow_field.clone()); + + let from_type = arrow_field.data_type(); + let to_type = tupledesc_field.data_type(); + + // no cast needed + if from_type == to_type { + cast_to.push(None); + continue; + } + + if !can_cast_types(from_type, to_type) { panic!( - "column \"{}\" is not found in parquet file", - table_schema_field_name + "type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but parquet file had \"{}\"", + field_name, + to_type, + from_type, ); } + + pgrx::debug2!( + "column \"{}\" is being cast from \"{}\" to \"{}\"", + field_name, + from_type, + to_type + ); + + cast_to.push(Some(to_type.clone())); } + + cast_to } diff --git a/src/lib.rs b/src/lib.rs index 57584bb..14c294a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,8 +37,11 @@ pub extern "C" fn _PG_init() { #[cfg(any(test, feature = "pg_test"))] #[pg_schema] mod tests { + use std::fs::File; use std::io::Write; use std::marker::PhantomData; + use std::sync::Arc; + use std::vec; use std::{collections::HashMap, fmt::Debug}; use crate::arrow_parquet::compression::PgParquetCompression; @@ -48,6 +51,14 @@ mod tests { use crate::type_compat::pg_arrow_type_conversions::{ DEFAULT_UNBOUNDED_NUMERIC_PRECISION, DEFAULT_UNBOUNDED_NUMERIC_SCALE, }; + use arrow::array::{ + ArrayRef, Float32Array, Int16Array, Int32Array, LargeBinaryArray, LargeStringArray, + ListArray, RecordBatch, StructArray, UInt16Array, UInt32Array, UInt64Array, + }; + use arrow::buffer::{OffsetBuffer, ScalarBuffer}; + use arrow::datatypes::UInt16Type; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use parquet::arrow::ArrowWriter; use pgrx::pg_sys::Oid; use pgrx::{ composite_type, @@ -340,6 +351,14 @@ mod tests { Spi::get_one(&query).unwrap().unwrap() } + fn write_record_batch_to_parquet(schema: SchemaRef, record_batch: RecordBatch) { + let file = File::create("/tmp/test.parquet").unwrap(); + let mut writer = ArrowWriter::try_new(file, schema, None).unwrap(); + + writer.write(&record_batch).unwrap(); + writer.close().unwrap(); + } + #[pg_test] fn test_int2() { let test_table = TestTable::::new("int2".into()); @@ -1391,6 +1410,403 @@ mod tests { Spi::run("DROP TYPE dog;").unwrap(); } + #[pg_test] + fn test_coerce_primitive_types() { + // INT16 => {int, bigint} + let x_nullable = false; + let y_nullable = true; + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int16, x_nullable), + Field::new("y", DataType::Int16, y_nullable), + ])); + + let x = Arc::new(Int16Array::from(vec![1])); + let y = Arc::new(Int16Array::from(vec![2])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int, y bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // INT32 => {bigint} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, true)])); + + let x = Arc::new(Int32Array::from(vec![1])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // FLOAT32 => {double} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, true)])); + + let x = Arc::new(Float32Array::from(vec![1.123])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x double precision)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value as f32, 1.123); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT16 => {smallint, int, bigint} + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::UInt16, true), + Field::new("y", DataType::UInt16, true), + Field::new("z", DataType::UInt16, true), + ])); + + let x = Arc::new(UInt16Array::from(vec![1])); + let y = Arc::new(UInt16Array::from(vec![2])); + let z = Arc::new(UInt16Array::from(vec![3])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y, z]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x smallint, y int, z bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = + Spi::get_three::("SELECT x, y, z FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2), Some(3))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT32 => {int, bigint} + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::UInt32, true), + Field::new("y", DataType::UInt32, true), + ])); + + let x = Arc::new(UInt32Array::from(vec![1])); + let y = Arc::new(UInt32Array::from(vec![2])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int, y bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT64 => {bigint} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::UInt64, true)])); + + let x = Arc::new(UInt64Array::from(vec![1])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // LargeUtf8 => {text} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::LargeUtf8, + true, + )])); + + let x = Arc::new(LargeStringArray::from(vec!["test"])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x text)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, "test"); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // LargeBinary => {bytea} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::LargeBinary, + true, + )])); + + let x = Arc::new(LargeBinaryArray::from(vec!["abc".as_bytes()])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bytea)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::>("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, "abc".as_bytes()); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + } + + #[pg_test] + fn test_coerce_list_types() { + let x_nullable = false; + let field_x = Field::new( + "x", + DataType::List(Field::new("item", DataType::UInt16, false).into()), + x_nullable, + ); + + let x = Arc::new(UInt16Array::from(vec![1, 2])); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 2])); + let x = Arc::new(ListArray::new( + Arc::new(Field::new("item", DataType::UInt16, false)), + offsets, + x, + None, + )); + + let y_nullable = true; + let field_y = Field::new( + "y", + DataType::List(Field::new("item", DataType::UInt16, true).into()), + y_nullable, + ); + + let y = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(3), Some(4)]), + ])); + + let schema = Arc::new(Schema::new(vec![field_x, field_y])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int[], y bigint[])"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::>, Vec>>( + "SELECT x, y FROM test_table LIMIT 1", + ) + .unwrap(); + assert_eq!( + value, + (Some(vec![Some(1), Some(2)]), Some(vec![Some(3), Some(4)])) + ); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + } + + #[pg_test] + fn test_coerce_struct_types() { + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Struct( + vec![ + Field::new("a", DataType::UInt16, false), + Field::new("b", DataType::UInt16, false), + ] + .into(), + ), + false, + )])); + + let a: ArrayRef = Arc::new(UInt16Array::from(vec![Some(1)])); + let b: ArrayRef = Arc::new(UInt16Array::from(vec![Some(2)])); + + let x = Arc::new(StructArray::try_from(vec![("a", a), ("b", b)]).unwrap()); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_type = "CREATE TYPE test_type AS (a int, b bigint)"; + Spi::run(create_type).unwrap(); + + let create_table = "CREATE TABLE test_table (x test_type)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = + Spi::get_two::("SELECT (x).a, (x).b FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + } + + #[pg_test] + #[should_panic(expected = "violates not-null constraint")] + fn test_copy_not_null_table() { + let create_table = "CREATE TABLE test_table (x int NOT NULL)"; + Spi::run(create_table).unwrap(); + + // first copy non-null value to file + let copy_to = "COPY (SELECT 1 as x) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let result = Spi::get_one::("SELECT x FROM test_table") + .unwrap() + .unwrap(); + assert_eq!(result, 1); + + // then copy null value to file + let copy_to = "COPY (SELECT NULL::int as x) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + // this should panic + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + } + + #[pg_test] + fn test_table_with_different_field_position() { + let copy_to = "COPY (SELECT 1 as x, 'hello' as y) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (y text, x int)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let result = Spi::get_two::<&str, i32>("SELECT y, x FROM test_table LIMIT 1").unwrap(); + assert_eq!(result, (Some("hello"), Some(1))); + } + + #[pg_test] + #[should_panic(expected = "Cannot cast string 'hello' to value of Int32 type")] + fn test_copy_composite_with_different_field_position() { + let create_type = "CREATE TYPE test_type AS (y int, x text)"; + Spi::run(create_type).unwrap(); + + let copy_to = "COPY (SELECT ROW(1, 'hello')::test_type as x) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_another_type = "CREATE TYPE another_test_type AS (x text, y int)"; + Spi::run(create_another_type).unwrap(); + + let create_table = "CREATE TABLE test_table (x another_test_type)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + } + + #[pg_test] + #[should_panic(expected = "type mismatch for column \"x\" between table and parquet file.")] + fn test_copy_composite_with_less_fields() { + let create_type = "CREATE TYPE test_type AS (x text)"; + Spi::run(create_type).unwrap(); + + let copy_to = "COPY (SELECT ROW('hello')::test_type as x) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_another_type = "CREATE TYPE another_test_type AS (x text, y int)"; + Spi::run(create_another_type).unwrap(); + + let create_table = "CREATE TABLE test_table (x another_test_type)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + } + + #[pg_test] + #[should_panic(expected = "type mismatch for column \"x\" between table and parquet file.")] + fn test_copy_composite_with_extra_field() { + let create_type = "CREATE TYPE test_type AS (x text, y int)"; + Spi::run(create_type).unwrap(); + + let copy_to = "COPY (SELECT ROW('hello', 1)::test_type as x) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_another_type = "CREATE TYPE another_test_type AS (x text)"; + Spi::run(create_another_type).unwrap(); + + let create_table = "CREATE TABLE test_table (x another_test_type)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + } + #[pg_test] fn test_copy_with_empty_options() { let test_table = TestTable::::new("int4".into())