From 6a2be26e7b1450796593b910b622eb79433c258e Mon Sep 17 00:00:00 2001 From: Aykut Bozkurt Date: Thu, 26 Sep 2024 14:24:12 +0300 Subject: [PATCH] verify parquet file schema before copying the file to Postgres table, which both prevents unnecessary work and gives more friendly errors --- src/arrow_parquet/parquet_reader.rs | 18 ++++++++--- src/arrow_parquet/schema_visitor.rs | 36 +++++++++++++++++++-- src/lib.rs | 49 +++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 8 deletions(-) diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index 42bef79..754d032 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -17,7 +17,9 @@ use crate::{ type_compat::{geometry::reset_postgis_context, map::reset_crunchy_map_context}, }; -use super::uri_utils::parquet_reader_from_uri; +use super::{ + schema_visitor::ensure_arrow_schema_match_tupledesc, uri_utils::parquet_reader_from_uri, +}; pub(crate) struct ParquetReaderContext { buffer: Vec, @@ -35,6 +37,8 @@ impl ParquetReaderContext { reset_postgis_context(); reset_crunchy_map_context(); + let pgtupledesc = unsafe { PgTupleDesc::from_pg_copy(tupledesc) }; + let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() @@ -42,7 +46,10 @@ impl ParquetReaderContext { let parquet_reader = runtime.block_on(parquet_reader_from_uri(&uri)); - let binary_out_funcs = Self::collect_binary_out_funcs(tupledesc); + let file_schema = parquet_reader.schema(); + ensure_arrow_schema_match_tupledesc(file_schema.clone(), &pgtupledesc); + + let binary_out_funcs = Self::collect_binary_out_funcs(&pgtupledesc); ParquetReaderContext { buffer: Vec::new(), @@ -56,13 +63,14 @@ impl ParquetReaderContext { } } - fn collect_binary_out_funcs(tupledesc: TupleDesc) -> Vec> { + fn collect_binary_out_funcs( + tupledesc: &PgTupleDesc, + ) -> Vec> { unsafe { let mut binary_out_funcs = vec![]; let include_generated_columns = false; - let tupledesc = PgTupleDesc::from_pg_copy(tupledesc); - let attributes = collect_valid_attributes(&tupledesc, include_generated_columns); + let attributes = collect_valid_attributes(tupledesc, include_generated_columns); for att in attributes.iter() { let typoid = att.type_oid(); diff --git a/src/arrow_parquet/schema_visitor.rs b/src/arrow_parquet/schema_visitor.rs index 6cbe6fa..8321217 100644 --- a/src/arrow_parquet/schema_visitor.rs +++ b/src/arrow_parquet/schema_visitor.rs @@ -5,7 +5,7 @@ use arrow_schema::FieldRef; use parquet::arrow::{arrow_to_parquet_schema, PARQUET_FIELD_ID_META_KEY}; use pg_sys::{ Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, - NUMERICOID, OIDOID, RECORDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, + NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, }; use pgrx::{prelude::*, PgTupleDesc}; @@ -35,8 +35,6 @@ pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: PgTupleDesc) -> St } pub(crate) fn parse_arrow_schema_from_tupledesc(tupledesc: PgTupleDesc) -> Schema { - debug_assert!(tupledesc.oid() == RECORDOID); - let mut field_id = 0; let mut struct_attribute_fields = vec![]; @@ -288,3 +286,35 @@ fn to_not_nullable_field(field: FieldRef) -> FieldRef { let field = Field::new(name, data_type.clone(), false).with_metadata(metadata); Arc::new(field) } + +pub(crate) fn ensure_arrow_schema_match_tupledesc( + file_schema: Arc, + tupledesc: &PgTupleDesc, +) { + let table_schema = parse_arrow_schema_from_tupledesc(tupledesc.clone()); + + for table_schema_field in table_schema.fields().iter() { + let table_schema_field_name = table_schema_field.name(); + let table_schema_field_type = table_schema_field.data_type(); + + let file_schema_field = file_schema.column_with_name(table_schema_field_name); + + if let Some(file_schema_field) = file_schema_field { + let file_schema_field_type = file_schema_field.1.data_type(); + + if file_schema_field_type != table_schema_field_type { + panic!( + "type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but file had \"{}\"", + table_schema_field_name, + table_schema_field_type, + file_schema_field_type, + ); + } + } else { + panic!( + "column \"{}\" is not found in parquet file", + table_schema_field_name + ); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 452ab44..bb27e10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2184,6 +2184,55 @@ mod tests { let parquet_metadata_command = "select * from parquet.metadata('/tmp/test.parquet');"; Spi::run(parquet_metadata_command).unwrap(); } + + #[pg_test] + #[should_panic( + expected = "type mismatch for column \"location\" between table and parquet file" + )] + fn test_type_mismatch_between_parquet_and_table() { + let create_types = "create type dog as (name text, age int); + create type person as (id bigint, name text, dogs dog[]); + create type address as (loc text);"; + Spi::run(create_types).unwrap(); + + let create_correct_table = + "create table factory_correct (id bigint, workers person[], name text, location address);"; + Spi::run(create_correct_table).unwrap(); + + let create_wrong_table = + "create table factory_wrong (id bigint, workers person[], name text, location int);"; + Spi::run(create_wrong_table).unwrap(); + + let copy_to_parquet = "copy (select 1::int8 as id, + array[ + row(1, 'ali', array[row('lady', 4), NULL]::dog[]) + ]::person[] as workers, + 'Microsoft' as name, + row('istanbul')::address as location + from generate_series(1,10) i) to '/tmp/test.parquet';"; + Spi::run(copy_to_parquet).unwrap(); + + // copy to correct table which matches the parquet schema + let copy_to_correct_table = "copy factory_correct from '/tmp/test.parquet';"; + Spi::run(copy_to_correct_table).unwrap(); + + // copy to wrong table which does not match the parquet schema + let copy_to_wrong_table = "copy factory_wrong from '/tmp/test.parquet';"; + Spi::run(copy_to_wrong_table).unwrap(); + } + + #[pg_test] + #[should_panic(expected = "column \"name\" is not found in parquet file")] + fn test_missing_column_in_parquet() { + let create_table = "create table test(id int, name text);"; + Spi::run(create_table).unwrap(); + + let copy_to_parquet = "copy (select 100 as id) to '/tmp/test.parquet';"; + Spi::run(copy_to_parquet).unwrap(); + + let copy_to_table = "copy test from '/tmp/test.parquet';"; + Spi::run(copy_to_table).unwrap(); + } } /// This module is required by `cargo pgrx test` invocations.