From 4241caef4129cc7983ea0f5c1c1f6dc864e82880 Mon Sep 17 00:00:00 2001 From: Aykut Bozkurt Date: Tue, 26 Nov 2024 23:07:41 +0300 Subject: [PATCH 1/2] Match fields by name via option We add an option for `COPY FROM` called `match_by_name` which matches Parquet file fields to PostgreSQL table columns `by their names` rather than `by their order` in the schema. By default, the option is `false`. The option is useful when field order differs between the Parquet file and the table, but their names match. **!!IMPORTANT!!**: This is a breaking change. Before the PR, we match always by name. This is a bit strict and not common way to match schemas. (e.g. COPY FROM csv at postgres or COPY FROM of duckdb match by field position by default) This is why we match by position by default and have a COPY FROM option `match_by_name` that can be set to true for the old behaviour. Closes #39. --- README.md | 1 + src/arrow_parquet/arrow_to_pg.rs | 2 +- src/arrow_parquet/parquet_reader.rs | 36 ++++++++++++++---- src/arrow_parquet/pg_to_arrow.rs | 2 +- src/arrow_parquet/schema_parser.rs | 59 ++++++++++++++++++++++++----- src/parquet_copy_hook/copy_from.rs | 8 ++-- src/parquet_copy_hook/copy_utils.rs | 23 ++++++++--- src/pgrx_tests/copy_from_coerce.rs | 40 +++++++++++++++++-- src/pgrx_tests/copy_pg_rules.rs | 17 ++++++++- src/pgrx_utils.rs | 10 +++-- 10 files changed, 163 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index cd4c8c9..6e93a20 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,7 @@ Alternatively, you can use the following environment variables when starting pos `pg_parquet` supports the following options in the `COPY FROM` command: - `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension, +- `match_by_name `: matches Parquet file fields to PostgreSQL table columns by their name rather than by their position in the schema (default). By default, the option is `false`. The option is useful when field order differs between the Parquet file and the table, but their names match. ## Configuration There is currently only one GUC parameter to enable/disable the `pg_parquet`: diff --git a/src/arrow_parquet/arrow_to_pg.rs b/src/arrow_parquet/arrow_to_pg.rs index 079ee08..2aa8d0e 100644 --- a/src/arrow_parquet/arrow_to_pg.rs +++ b/src/arrow_parquet/arrow_to_pg.rs @@ -163,7 +163,7 @@ impl ArrowToPgAttributeContext { }; let attributes = - collect_attributes_for(CollectAttributesFor::Struct, attribute_tupledesc); + collect_attributes_for(CollectAttributesFor::Other, attribute_tupledesc); // we only cast the top-level attributes, which already covers the nested attributes let cast_to_types = None; diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index cbeff07..e3366ae 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -16,7 +16,11 @@ use url::Url; use crate::{ arrow_parquet::{ - arrow_to_pg::to_pg_datum, schema_parser::parquet_schema_string_from_attributes, + arrow_to_pg::to_pg_datum, + schema_parser::{ + error_if_copy_from_match_by_position_with_generated_columns, + parquet_schema_string_from_attributes, + }, }, pgrx_utils::{collect_attributes_for, CollectAttributesFor}, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, @@ -38,15 +42,18 @@ pub(crate) struct ParquetReaderContext { parquet_reader: ParquetRecordBatchStream, attribute_contexts: Vec, binary_out_funcs: Vec>, + match_by_name: bool, } impl ParquetReaderContext { - pub(crate) fn new(uri: Url, tupledesc: &PgTupleDesc) -> Self { + pub(crate) fn new(uri: Url, match_by_name: bool, tupledesc: &PgTupleDesc) -> Self { // Postgis and Map contexts are used throughout reading the parquet file. // We need to reset them to avoid reading the stale data. (e.g. extension could be dropped) reset_postgis_context(); reset_map_context(); + error_if_copy_from_match_by_position_with_generated_columns(tupledesc, match_by_name); + let parquet_reader = parquet_reader_from_uri(&uri); let parquet_file_schema = parquet_reader.schema(); @@ -69,6 +76,7 @@ impl ParquetReaderContext { parquet_file_schema.clone(), tupledesc_schema.clone(), &attributes, + match_by_name, ); let attribute_contexts = collect_arrow_to_pg_attribute_contexts( @@ -85,6 +93,7 @@ impl ParquetReaderContext { attribute_contexts, parquet_reader, binary_out_funcs, + match_by_name, started: false, finished: false, } @@ -116,15 +125,23 @@ impl ParquetReaderContext { fn record_batch_to_tuple_datums( record_batch: RecordBatch, attribute_contexts: &[ArrowToPgAttributeContext], + match_by_name: bool, ) -> Vec> { let mut datums = vec![]; - for attribute_context in attribute_contexts { + for (attribute_idx, attribute_context) in attribute_contexts.iter().enumerate() { let name = attribute_context.name(); - let column_array = record_batch - .column_by_name(name) - .unwrap_or_else(|| panic!("column {} not found", name)); + let column_array = if match_by_name { + record_batch + .column_by_name(name) + .unwrap_or_else(|| panic!("column {} not found", name)) + } else { + record_batch + .columns() + .get(attribute_idx) + .unwrap_or_else(|| panic!("column {} not found", name)) + }; let datum = if attribute_context.needs_cast() { // should fail instead of returning None if the cast fails at runtime @@ -181,8 +198,11 @@ impl ParquetReaderContext { self.buffer.extend_from_slice(&attnum_len_bytes); // convert the columnar arrays in record batch to tuple datums - let tuple_datums = - Self::record_batch_to_tuple_datums(record_batch, &self.attribute_contexts); + let tuple_datums = Self::record_batch_to_tuple_datums( + record_batch, + &self.attribute_contexts, + self.match_by_name, + ); // write the tuple datums to the ParquetReader's internal buffer in PG copy format for (datum, out_func) in tuple_datums.into_iter().zip(self.binary_out_funcs.iter()) diff --git a/src/arrow_parquet/pg_to_arrow.rs b/src/arrow_parquet/pg_to_arrow.rs index 530c7f7..17d774b 100644 --- a/src/arrow_parquet/pg_to_arrow.rs +++ b/src/arrow_parquet/pg_to_arrow.rs @@ -148,7 +148,7 @@ impl PgToArrowAttributeContext { }; let attributes = - collect_attributes_for(CollectAttributesFor::Struct, &attribute_tupledesc); + collect_attributes_for(CollectAttributesFor::Other, &attribute_tupledesc); collect_pg_to_arrow_attribute_contexts(&attributes, &fields) }); diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index b76ee70..b4c5798 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -16,7 +16,7 @@ use pgrx::{check_for_interrupts, prelude::*, PgTupleDesc}; use crate::{ pgrx_utils::{ array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, - is_composite_type, tuple_desc, CollectAttributesFor, + is_composite_type, is_generated_attribute, tuple_desc, CollectAttributesFor, }, type_compat::{ geometry::is_postgis_geometry_type, @@ -95,7 +95,7 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i let mut child_fields: Vec> = vec![]; - let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); + let attributes = collect_attributes_for(CollectAttributesFor::Other, &tupledesc); for attribute in attributes { if attribute.is_dropped() { @@ -342,6 +342,30 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { Arc::new(entries_field) } +pub(crate) fn error_if_copy_from_match_by_position_with_generated_columns( + tupledesc: &PgTupleDesc, + match_by_name: bool, +) { + // match_by_name can handle generated columns + if match_by_name { + return; + } + + let attributes = collect_attributes_for(CollectAttributesFor::Other, tupledesc); + + for attribute in attributes { + if is_generated_attribute(&attribute) { + ereport!( + PgLogLevel::ERROR, + PgSqlErrorCode::ERRCODE_FEATURE_NOT_SUPPORTED, + "COPY FROM parquet with generated columns is not supported", + "Try COPY FROM parquet WITH (match_by_name true). \" + It works only if the column names match with parquet file's.", + ); + } + } +} + // ensure_file_schema_match_tupledesc_schema throws an error if the file's schema does not match the table schema. // If the file's arrow schema is castable to the table's arrow schema, it returns a vector of Option // to cast to for each field. @@ -349,21 +373,38 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( file_schema: Arc, tupledesc_schema: Arc, attributes: &[FormData_pg_attribute], + match_by_name: bool, ) -> Vec> { let mut cast_to_types = Vec::new(); + if !match_by_name && tupledesc_schema.fields().len() != file_schema.fields().len() { + panic!( + "column count mismatch between table and parquet file. \ + parquet file has {} columns, but table has {} columns", + file_schema.fields().len(), + tupledesc_schema.fields().len() + ); + } + for (tupledesc_schema_field, attribute) in tupledesc_schema.fields().iter().zip(attributes.iter()) { let field_name = tupledesc_schema_field.name(); - let file_schema_field = file_schema.column_with_name(field_name); + let file_schema_field = if match_by_name { + let file_schema_field = file_schema.column_with_name(field_name); - if file_schema_field.is_none() { - panic!("column \"{}\" is not found in parquet file", field_name); - } + if file_schema_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } + + let (_, file_schema_field) = file_schema_field.unwrap(); + + file_schema_field + } else { + file_schema.field(attribute.attnum as usize - 1) + }; - let (_, file_schema_field) = file_schema_field.unwrap(); let file_schema_field = Arc::new(file_schema_field.clone()); let from_type = file_schema_field.data_type(); @@ -378,7 +419,7 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( if !is_coercible(from_type, to_type, attribute.atttypid, attribute.atttypmod) { panic!( "type mismatch for column \"{}\" between table and parquet file.\n\n\ - table has \"{}\"\n\nparquet file has \"{}\"", + table has \"{}\"\n\nparquet file has \"{}\"", field_name, to_type, from_type ); } @@ -413,7 +454,7 @@ fn is_coercible(from_type: &DataType, to_type: &DataType, to_typoid: Oid, to_typ let tupledesc = tuple_desc(to_typoid, to_typmod); - let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); + let attributes = collect_attributes_for(CollectAttributesFor::Other, &tupledesc); for (from_field, (to_field, to_attribute)) in from_fields .iter() diff --git a/src/parquet_copy_hook/copy_from.rs b/src/parquet_copy_hook/copy_from.rs index bf3a878..6a44446 100644 --- a/src/parquet_copy_hook/copy_from.rs +++ b/src/parquet_copy_hook/copy_from.rs @@ -20,8 +20,8 @@ use crate::{ }; use super::copy_utils::{ - copy_stmt_attribute_list, copy_stmt_create_namespace_item, copy_stmt_create_parse_state, - create_filtered_tupledesc_for_relation, + copy_from_stmt_match_by_name, copy_stmt_attribute_list, copy_stmt_create_namespace_item, + copy_stmt_create_parse_state, create_filtered_tupledesc_for_relation, }; // stack to store parquet reader contexts for COPY FROM. @@ -131,9 +131,11 @@ pub(crate) fn execute_copy_from( let tupledesc = create_filtered_tupledesc_for_relation(p_stmt, &relation); + let match_by_name = copy_from_stmt_match_by_name(p_stmt); + unsafe { // parquet reader context is used throughout the COPY FROM operation. - let parquet_reader_context = ParquetReaderContext::new(uri, &tupledesc); + let parquet_reader_context = ParquetReaderContext::new(uri, match_by_name, &tupledesc); push_parquet_reader_context(parquet_reader_context); // makes sure to set binary format diff --git a/src/parquet_copy_hook/copy_utils.rs b/src/parquet_copy_hook/copy_utils.rs index 068e95e..02c3c5d 100644 --- a/src/parquet_copy_hook/copy_utils.rs +++ b/src/parquet_copy_hook/copy_utils.rs @@ -3,11 +3,12 @@ use std::{ffi::CStr, str::FromStr}; use pgrx::{ is_a, pg_sys::{ - addRangeTableEntryForRelation, defGetInt32, defGetInt64, defGetString, get_namespace_name, - get_rel_namespace, makeDefElem, makeString, make_parsestate, quote_qualified_identifier, - AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, DefElem, List, NoLock, Node, - NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, PlannedStmt, QueryEnvironment, - RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, TupleDescInitEntry, + addRangeTableEntryForRelation, defGetBoolean, defGetInt32, defGetInt64, defGetString, + get_namespace_name, get_rel_namespace, makeDefElem, makeString, make_parsestate, + quote_qualified_identifier, AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, + DefElem, List, NoLock, Node, NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, + PlannedStmt, QueryEnvironment, RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, + TupleDescInitEntry, }, PgBox, PgList, PgRelation, PgTupleDesc, }; @@ -109,7 +110,7 @@ pub(crate) fn validate_copy_to_options(p_stmt: &PgBox, uri: &Url) { } pub(crate) fn validate_copy_from_options(p_stmt: &PgBox) { - validate_copy_option_names(p_stmt, &["format", "freeze"]); + validate_copy_option_names(p_stmt, &["format", "match_by_name", "freeze"]); let format_option = copy_stmt_get_option(p_stmt, "format"); @@ -253,6 +254,16 @@ pub(crate) fn copy_from_stmt_create_option_list(p_stmt: &PgBox) -> new_copy_options } +pub(crate) fn copy_from_stmt_match_by_name(p_stmt: &PgBox) -> bool { + let match_by_name_option = copy_stmt_get_option(p_stmt, "match_by_name"); + + if match_by_name_option.is_null() { + false + } else { + unsafe { defGetBoolean(match_by_name_option.as_ptr()) } + } +} + pub(crate) fn copy_stmt_get_option( p_stmt: &PgBox, option_name: &str, diff --git a/src/pgrx_tests/copy_from_coerce.rs b/src/pgrx_tests/copy_from_coerce.rs index 75c7af8..ed0ebb9 100644 --- a/src/pgrx_tests/copy_from_coerce.rs +++ b/src/pgrx_tests/copy_from_coerce.rs @@ -966,7 +966,7 @@ mod tests { } #[pg_test] - fn test_table_with_different_field_position() { + fn test_table_with_different_position_match_by_name() { let copy_to = format!( "COPY (SELECT 1 as x, 'hello' as y) TO '{}'", LOCAL_TEST_FILE_PATH @@ -976,13 +976,44 @@ mod tests { let create_table = "CREATE TABLE test_table (y text, x int)"; Spi::run(create_table).unwrap(); - let copy_from = format!("COPY test_table FROM '{}'", LOCAL_TEST_FILE_PATH); + let copy_from = format!( + "COPY test_table FROM '{}' WITH (match_by_name true)", + LOCAL_TEST_FILE_PATH + ); Spi::run(©_from).unwrap(); let result = Spi::get_two::<&str, i32>("SELECT y, x FROM test_table LIMIT 1").unwrap(); assert_eq!(result, (Some("hello"), Some(1))); } + #[pg_test] + fn test_table_with_different_name_match_by_position() { + let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (x bigint, y varchar)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let result = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(result, (Some(1), Some("hello"))); + } + + #[pg_test] + #[should_panic(expected = "column count mismatch between table and parquet file")] + fn test_table_with_different_name_match_by_position_fail() { + let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (x bigint, y varchar, z int)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + } + #[pg_test] #[should_panic(expected = "column \"name\" is not found in parquet file")] fn test_missing_column_in_parquet() { @@ -992,7 +1023,10 @@ mod tests { let copy_to_parquet = format!("copy (select 100 as id) to '{}';", LOCAL_TEST_FILE_PATH); Spi::run(©_to_parquet).unwrap(); - let copy_from = format!("COPY test_table FROM '{}'", LOCAL_TEST_FILE_PATH); + let copy_from = format!( + "COPY test_table FROM '{}' with (match_by_name true)", + LOCAL_TEST_FILE_PATH + ); Spi::run(©_from).unwrap(); } diff --git a/src/pgrx_tests/copy_pg_rules.rs b/src/pgrx_tests/copy_pg_rules.rs index b0347e2..35c44d6 100644 --- a/src/pgrx_tests/copy_pg_rules.rs +++ b/src/pgrx_tests/copy_pg_rules.rs @@ -101,6 +101,21 @@ mod tests { Spi::run(©_from).unwrap(); } + #[pg_test] + #[should_panic(expected = "COPY FROM parquet with generated columns is not supported")] + fn test_copy_from_by_position_with_generated_columns_not_supported() { + Spi::run("DROP TABLE IF EXISTS test_table;").unwrap(); + + Spi::run("CREATE TABLE test_table (a int, b int generated always as (10) stored, c text);") + .unwrap(); + + let copy_from_query = format!( + "COPY test_table FROM '{}' WITH (format parquet);", + LOCAL_TEST_FILE_PATH + ); + Spi::run(copy_from_query.as_str()).unwrap(); + } + #[pg_test] fn test_with_generated_and_dropped_columns() { Spi::run("DROP TABLE IF EXISTS test_table;").unwrap(); @@ -123,7 +138,7 @@ mod tests { Spi::run("TRUNCATE test_table;").unwrap(); let copy_from_query = format!( - "COPY test_table FROM '{}' WITH (format parquet);", + "COPY test_table FROM '{}' WITH (format parquet, match_by_name true);", LOCAL_TEST_FILE_PATH ); Spi::run(copy_from_query.as_str()).unwrap(); diff --git a/src/pgrx_utils.rs b/src/pgrx_utils.rs index cb8f9fa..0793def 100644 --- a/src/pgrx_utils.rs +++ b/src/pgrx_utils.rs @@ -12,7 +12,7 @@ use pgrx::{ pub(crate) enum CollectAttributesFor { CopyFrom, CopyTo, - Struct, + Other, } // collect_attributes_for collects not-dropped attributes from the tuple descriptor. @@ -23,7 +23,7 @@ pub(crate) fn collect_attributes_for( ) -> Vec { let include_generated_columns = match copy_operation { CollectAttributesFor::CopyFrom => false, - CollectAttributesFor::CopyTo | CollectAttributesFor::Struct => true, + CollectAttributesFor::CopyTo | CollectAttributesFor::Other => true, }; let mut attributes = vec![]; @@ -35,7 +35,7 @@ pub(crate) fn collect_attributes_for( continue; } - if !include_generated_columns && attribute.attgenerated != 0 { + if !include_generated_columns && is_generated_attribute(attribute) { continue; } @@ -55,6 +55,10 @@ pub(crate) fn collect_attributes_for( attributes } +pub(crate) fn is_generated_attribute(attribute: &FormData_pg_attribute) -> bool { + attribute.attgenerated != 0 +} + pub(crate) fn tuple_desc(typoid: Oid, typmod: i32) -> PgTupleDesc<'static> { let tupledesc = unsafe { lookup_rowtype_tupdesc(typoid, typmod) }; unsafe { PgTupleDesc::from_pg(tupledesc) } From 95416434302fa4b914a788e0e91e4ad75578b34f Mon Sep 17 00:00:00 2001 From: Aykut Bozkurt Date: Wed, 27 Nov 2024 22:54:30 +0300 Subject: [PATCH 2/2] rename option to match_by --- README.md | 2 +- src/arrow_parquet.rs | 1 + src/arrow_parquet/match_by.rs | 20 ++++++++++++++++ src/arrow_parquet/parquet_reader.rs | 29 ++++++++++++----------- src/arrow_parquet/schema_parser.rs | 36 +++++++++++++++++------------ src/parquet_copy_hook/copy_from.rs | 6 ++--- src/parquet_copy_hook/copy_utils.rs | 32 +++++++++++++++---------- src/pgrx_tests/copy_from_coerce.rs | 6 ++--- src/pgrx_tests/copy_options.rs | 15 ++++++++++++ src/pgrx_tests/copy_pg_rules.rs | 2 +- 10 files changed, 100 insertions(+), 49 deletions(-) create mode 100644 src/arrow_parquet/match_by.rs diff --git a/README.md b/README.md index 6e93a20..353b01f 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ Alternatively, you can use the following environment variables when starting pos `pg_parquet` supports the following options in the `COPY FROM` command: - `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension, -- `match_by_name `: matches Parquet file fields to PostgreSQL table columns by their name rather than by their position in the schema (default). By default, the option is `false`. The option is useful when field order differs between the Parquet file and the table, but their names match. +- `match_by `: method to match Parquet file fields to PostgreSQL table columns. The available methods are `position` and `name`. The default method is `position`. You can set it to `name` to match the columns by their name rather than by their position in the schema (default). Match by `name` is useful when field order differs between the Parquet file and the table, but their names match. ## Configuration There is currently only one GUC parameter to enable/disable the `pg_parquet`: diff --git a/src/arrow_parquet.rs b/src/arrow_parquet.rs index 7e445a4..e6ca8b0 100644 --- a/src/arrow_parquet.rs +++ b/src/arrow_parquet.rs @@ -1,6 +1,7 @@ pub(crate) mod arrow_to_pg; pub(crate) mod arrow_utils; pub(crate) mod compression; +pub(crate) mod match_by; pub(crate) mod parquet_reader; pub(crate) mod parquet_writer; pub(crate) mod pg_to_arrow; diff --git a/src/arrow_parquet/match_by.rs b/src/arrow_parquet/match_by.rs new file mode 100644 index 0000000..f115b56 --- /dev/null +++ b/src/arrow_parquet/match_by.rs @@ -0,0 +1,20 @@ +use std::str::FromStr; + +#[derive(Debug, Clone, Copy, Default, PartialEq)] +pub(crate) enum MatchBy { + #[default] + Position, + Name, +} + +impl FromStr for MatchBy { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "position" => Ok(MatchBy::Position), + "name" => Ok(MatchBy::Name), + _ => Err(format!("unrecognized match_by method: {}", s)), + } + } +} diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index e3366ae..9a2cf1d 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -28,6 +28,7 @@ use crate::{ use super::{ arrow_to_pg::{collect_arrow_to_pg_attribute_contexts, ArrowToPgAttributeContext}, + match_by::MatchBy, schema_parser::{ ensure_file_schema_match_tupledesc_schema, parse_arrow_schema_from_attributes, }, @@ -42,17 +43,17 @@ pub(crate) struct ParquetReaderContext { parquet_reader: ParquetRecordBatchStream, attribute_contexts: Vec, binary_out_funcs: Vec>, - match_by_name: bool, + match_by: MatchBy, } impl ParquetReaderContext { - pub(crate) fn new(uri: Url, match_by_name: bool, tupledesc: &PgTupleDesc) -> Self { + pub(crate) fn new(uri: Url, match_by: MatchBy, tupledesc: &PgTupleDesc) -> Self { // Postgis and Map contexts are used throughout reading the parquet file. // We need to reset them to avoid reading the stale data. (e.g. extension could be dropped) reset_postgis_context(); reset_map_context(); - error_if_copy_from_match_by_position_with_generated_columns(tupledesc, match_by_name); + error_if_copy_from_match_by_position_with_generated_columns(tupledesc, match_by); let parquet_reader = parquet_reader_from_uri(&uri); @@ -76,7 +77,7 @@ impl ParquetReaderContext { parquet_file_schema.clone(), tupledesc_schema.clone(), &attributes, - match_by_name, + match_by, ); let attribute_contexts = collect_arrow_to_pg_attribute_contexts( @@ -93,7 +94,7 @@ impl ParquetReaderContext { attribute_contexts, parquet_reader, binary_out_funcs, - match_by_name, + match_by, started: false, finished: false, } @@ -125,22 +126,22 @@ impl ParquetReaderContext { fn record_batch_to_tuple_datums( record_batch: RecordBatch, attribute_contexts: &[ArrowToPgAttributeContext], - match_by_name: bool, + match_by: MatchBy, ) -> Vec> { let mut datums = vec![]; for (attribute_idx, attribute_context) in attribute_contexts.iter().enumerate() { let name = attribute_context.name(); - let column_array = if match_by_name { - record_batch - .column_by_name(name) - .unwrap_or_else(|| panic!("column {} not found", name)) - } else { - record_batch + let column_array = match match_by { + MatchBy::Position => record_batch .columns() .get(attribute_idx) - .unwrap_or_else(|| panic!("column {} not found", name)) + .unwrap_or_else(|| panic!("column {} not found", name)), + + MatchBy::Name => record_batch + .column_by_name(name) + .unwrap_or_else(|| panic!("column {} not found", name)), }; let datum = if attribute_context.needs_cast() { @@ -201,7 +202,7 @@ impl ParquetReaderContext { let tuple_datums = Self::record_batch_to_tuple_datums( record_batch, &self.attribute_contexts, - self.match_by_name, + self.match_by, ); // write the tuple datums to the ParquetReader's internal buffer in PG copy format diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index b4c5798..f100e4e 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -27,6 +27,8 @@ use crate::{ }, }; +use super::match_by::MatchBy; + pub(crate) fn parquet_schema_string_from_attributes( attributes: &[FormData_pg_attribute], ) -> String { @@ -344,10 +346,10 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { pub(crate) fn error_if_copy_from_match_by_position_with_generated_columns( tupledesc: &PgTupleDesc, - match_by_name: bool, + match_by: MatchBy, ) { - // match_by_name can handle generated columns - if match_by_name { + // match_by 'name' can handle generated columns + if let MatchBy::Name = match_by { return; } @@ -359,7 +361,7 @@ pub(crate) fn error_if_copy_from_match_by_position_with_generated_columns( PgLogLevel::ERROR, PgSqlErrorCode::ERRCODE_FEATURE_NOT_SUPPORTED, "COPY FROM parquet with generated columns is not supported", - "Try COPY FROM parquet WITH (match_by_name true). \" + "Try COPY FROM parquet WITH (match_by 'name'). \" It works only if the column names match with parquet file's.", ); } @@ -373,11 +375,13 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( file_schema: Arc, tupledesc_schema: Arc, attributes: &[FormData_pg_attribute], - match_by_name: bool, + match_by: MatchBy, ) -> Vec> { let mut cast_to_types = Vec::new(); - if !match_by_name && tupledesc_schema.fields().len() != file_schema.fields().len() { + if match_by == MatchBy::Position + && tupledesc_schema.fields().len() != file_schema.fields().len() + { panic!( "column count mismatch between table and parquet file. \ parquet file has {} columns, but table has {} columns", @@ -391,18 +395,20 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( { let field_name = tupledesc_schema_field.name(); - let file_schema_field = if match_by_name { - let file_schema_field = file_schema.column_with_name(field_name); + let file_schema_field = match match_by { + MatchBy::Position => file_schema.field(attribute.attnum as usize - 1), - if file_schema_field.is_none() { - panic!("column \"{}\" is not found in parquet file", field_name); - } + MatchBy::Name => { + let file_schema_field = file_schema.column_with_name(field_name); - let (_, file_schema_field) = file_schema_field.unwrap(); + if file_schema_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } - file_schema_field - } else { - file_schema.field(attribute.attnum as usize - 1) + let (_, file_schema_field) = file_schema_field.unwrap(); + + file_schema_field + } }; let file_schema_field = Arc::new(file_schema_field.clone()); diff --git a/src/parquet_copy_hook/copy_from.rs b/src/parquet_copy_hook/copy_from.rs index 6a44446..ae86e07 100644 --- a/src/parquet_copy_hook/copy_from.rs +++ b/src/parquet_copy_hook/copy_from.rs @@ -20,7 +20,7 @@ use crate::{ }; use super::copy_utils::{ - copy_from_stmt_match_by_name, copy_stmt_attribute_list, copy_stmt_create_namespace_item, + copy_from_stmt_match_by, copy_stmt_attribute_list, copy_stmt_create_namespace_item, copy_stmt_create_parse_state, create_filtered_tupledesc_for_relation, }; @@ -131,11 +131,11 @@ pub(crate) fn execute_copy_from( let tupledesc = create_filtered_tupledesc_for_relation(p_stmt, &relation); - let match_by_name = copy_from_stmt_match_by_name(p_stmt); + let match_by = copy_from_stmt_match_by(p_stmt); unsafe { // parquet reader context is used throughout the COPY FROM operation. - let parquet_reader_context = ParquetReaderContext::new(uri, match_by_name, &tupledesc); + let parquet_reader_context = ParquetReaderContext::new(uri, match_by, &tupledesc); push_parquet_reader_context(parquet_reader_context); // makes sure to set binary format diff --git a/src/parquet_copy_hook/copy_utils.rs b/src/parquet_copy_hook/copy_utils.rs index 02c3c5d..dde7848 100644 --- a/src/parquet_copy_hook/copy_utils.rs +++ b/src/parquet_copy_hook/copy_utils.rs @@ -3,12 +3,11 @@ use std::{ffi::CStr, str::FromStr}; use pgrx::{ is_a, pg_sys::{ - addRangeTableEntryForRelation, defGetBoolean, defGetInt32, defGetInt64, defGetString, - get_namespace_name, get_rel_namespace, makeDefElem, makeString, make_parsestate, - quote_qualified_identifier, AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, - DefElem, List, NoLock, Node, NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, - PlannedStmt, QueryEnvironment, RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, - TupleDescInitEntry, + addRangeTableEntryForRelation, defGetInt32, defGetInt64, defGetString, get_namespace_name, + get_rel_namespace, makeDefElem, makeString, make_parsestate, quote_qualified_identifier, + AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, DefElem, List, NoLock, Node, + NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, PlannedStmt, QueryEnvironment, + RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, TupleDescInitEntry, }, PgBox, PgList, PgRelation, PgTupleDesc, }; @@ -16,6 +15,7 @@ use url::Url; use crate::arrow_parquet::{ compression::{all_supported_compressions, PgParquetCompression}, + match_by::MatchBy, parquet_writer::{DEFAULT_ROW_GROUP_SIZE, DEFAULT_ROW_GROUP_SIZE_BYTES}, uri_utils::parse_uri, }; @@ -110,7 +110,7 @@ pub(crate) fn validate_copy_to_options(p_stmt: &PgBox, uri: &Url) { } pub(crate) fn validate_copy_from_options(p_stmt: &PgBox) { - validate_copy_option_names(p_stmt, &["format", "match_by_name", "freeze"]); + validate_copy_option_names(p_stmt, &["format", "match_by", "freeze"]); let format_option = copy_stmt_get_option(p_stmt, "format"); @@ -254,13 +254,21 @@ pub(crate) fn copy_from_stmt_create_option_list(p_stmt: &PgBox) -> new_copy_options } -pub(crate) fn copy_from_stmt_match_by_name(p_stmt: &PgBox) -> bool { - let match_by_name_option = copy_stmt_get_option(p_stmt, "match_by_name"); +pub(crate) fn copy_from_stmt_match_by(p_stmt: &PgBox) -> MatchBy { + let match_by_option = copy_stmt_get_option(p_stmt, "match_by"); - if match_by_name_option.is_null() { - false + if match_by_option.is_null() { + MatchBy::default() } else { - unsafe { defGetBoolean(match_by_name_option.as_ptr()) } + let match_by = unsafe { defGetString(match_by_option.as_ptr()) }; + + let match_by = unsafe { + CStr::from_ptr(match_by) + .to_str() + .expect("match_by option is not a valid CString") + }; + + MatchBy::from_str(match_by).unwrap_or_else(|e| panic!("{}", e)) } } diff --git a/src/pgrx_tests/copy_from_coerce.rs b/src/pgrx_tests/copy_from_coerce.rs index ed0ebb9..03d557f 100644 --- a/src/pgrx_tests/copy_from_coerce.rs +++ b/src/pgrx_tests/copy_from_coerce.rs @@ -977,7 +977,7 @@ mod tests { Spi::run(create_table).unwrap(); let copy_from = format!( - "COPY test_table FROM '{}' WITH (match_by_name true)", + "COPY test_table FROM '{}' WITH (match_by 'name')", LOCAL_TEST_FILE_PATH ); Spi::run(©_from).unwrap(); @@ -994,7 +994,7 @@ mod tests { let create_table = "CREATE TABLE test_table (x bigint, y varchar)"; Spi::run(create_table).unwrap(); - let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + let copy_from = "COPY test_table FROM '/tmp/test.parquet' WITH (match_by 'position')"; Spi::run(copy_from).unwrap(); let result = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); @@ -1024,7 +1024,7 @@ mod tests { Spi::run(©_to_parquet).unwrap(); let copy_from = format!( - "COPY test_table FROM '{}' with (match_by_name true)", + "COPY test_table FROM '{}' with (match_by 'name')", LOCAL_TEST_FILE_PATH ); Spi::run(©_from).unwrap(); diff --git a/src/pgrx_tests/copy_options.rs b/src/pgrx_tests/copy_options.rs index 0e223ce..eaccd20 100644 --- a/src/pgrx_tests/copy_options.rs +++ b/src/pgrx_tests/copy_options.rs @@ -392,4 +392,19 @@ mod tests { assert_eq!(result_metadata, vec![10]); } + + #[pg_test] + #[should_panic(expected = "unrecognized match_by method: invalid_match_by")] + fn test_invalid_match_by() { + let mut copy_from_options = HashMap::new(); + copy_from_options.insert( + "match_by".to_string(), + CopyOptionValue::StringOption("invalid_match_by".to_string()), + ); + + let test_table = + TestTable::::new("int4".into()).with_copy_from_options(copy_from_options); + test_table.insert("INSERT INTO test_expected (a) VALUES (1), (2), (null);"); + test_table.assert_expected_and_result_rows(); + } } diff --git a/src/pgrx_tests/copy_pg_rules.rs b/src/pgrx_tests/copy_pg_rules.rs index 35c44d6..c4a2816 100644 --- a/src/pgrx_tests/copy_pg_rules.rs +++ b/src/pgrx_tests/copy_pg_rules.rs @@ -138,7 +138,7 @@ mod tests { Spi::run("TRUNCATE test_table;").unwrap(); let copy_from_query = format!( - "COPY test_table FROM '{}' WITH (format parquet, match_by_name true);", + "COPY test_table FROM '{}' WITH (format parquet, match_by 'name');", LOCAL_TEST_FILE_PATH ); Spi::run(copy_from_query.as_str()).unwrap();