diff --git a/README.md b/README.md index cd4c8c9..353b01f 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,7 @@ Alternatively, you can use the following environment variables when starting pos `pg_parquet` supports the following options in the `COPY FROM` command: - `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension, +- `match_by `: method to match Parquet file fields to PostgreSQL table columns. The available methods are `position` and `name`. The default method is `position`. You can set it to `name` to match the columns by their name rather than by their position in the schema (default). Match by `name` is useful when field order differs between the Parquet file and the table, but their names match. ## Configuration There is currently only one GUC parameter to enable/disable the `pg_parquet`: diff --git a/src/arrow_parquet.rs b/src/arrow_parquet.rs index 7e445a4..e6ca8b0 100644 --- a/src/arrow_parquet.rs +++ b/src/arrow_parquet.rs @@ -1,6 +1,7 @@ pub(crate) mod arrow_to_pg; pub(crate) mod arrow_utils; pub(crate) mod compression; +pub(crate) mod match_by; pub(crate) mod parquet_reader; pub(crate) mod parquet_writer; pub(crate) mod pg_to_arrow; diff --git a/src/arrow_parquet/arrow_to_pg.rs b/src/arrow_parquet/arrow_to_pg.rs index 079ee08..2aa8d0e 100644 --- a/src/arrow_parquet/arrow_to_pg.rs +++ b/src/arrow_parquet/arrow_to_pg.rs @@ -163,7 +163,7 @@ impl ArrowToPgAttributeContext { }; let attributes = - collect_attributes_for(CollectAttributesFor::Struct, attribute_tupledesc); + collect_attributes_for(CollectAttributesFor::Other, attribute_tupledesc); // we only cast the top-level attributes, which already covers the nested attributes let cast_to_types = None; diff --git a/src/arrow_parquet/match_by.rs b/src/arrow_parquet/match_by.rs new file mode 100644 index 0000000..f115b56 --- /dev/null +++ b/src/arrow_parquet/match_by.rs @@ -0,0 +1,20 @@ +use std::str::FromStr; + +#[derive(Debug, Clone, Copy, Default, PartialEq)] +pub(crate) enum MatchBy { + #[default] + Position, + Name, +} + +impl FromStr for MatchBy { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "position" => Ok(MatchBy::Position), + "name" => Ok(MatchBy::Name), + _ => Err(format!("unrecognized match_by method: {}", s)), + } + } +} diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index cbeff07..9a2cf1d 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -16,7 +16,11 @@ use url::Url; use crate::{ arrow_parquet::{ - arrow_to_pg::to_pg_datum, schema_parser::parquet_schema_string_from_attributes, + arrow_to_pg::to_pg_datum, + schema_parser::{ + error_if_copy_from_match_by_position_with_generated_columns, + parquet_schema_string_from_attributes, + }, }, pgrx_utils::{collect_attributes_for, CollectAttributesFor}, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, @@ -24,6 +28,7 @@ use crate::{ use super::{ arrow_to_pg::{collect_arrow_to_pg_attribute_contexts, ArrowToPgAttributeContext}, + match_by::MatchBy, schema_parser::{ ensure_file_schema_match_tupledesc_schema, parse_arrow_schema_from_attributes, }, @@ -38,15 +43,18 @@ pub(crate) struct ParquetReaderContext { parquet_reader: ParquetRecordBatchStream, attribute_contexts: Vec, binary_out_funcs: Vec>, + match_by: MatchBy, } impl ParquetReaderContext { - pub(crate) fn new(uri: Url, tupledesc: &PgTupleDesc) -> Self { + pub(crate) fn new(uri: Url, match_by: MatchBy, tupledesc: &PgTupleDesc) -> Self { // Postgis and Map contexts are used throughout reading the parquet file. // We need to reset them to avoid reading the stale data. (e.g. extension could be dropped) reset_postgis_context(); reset_map_context(); + error_if_copy_from_match_by_position_with_generated_columns(tupledesc, match_by); + let parquet_reader = parquet_reader_from_uri(&uri); let parquet_file_schema = parquet_reader.schema(); @@ -69,6 +77,7 @@ impl ParquetReaderContext { parquet_file_schema.clone(), tupledesc_schema.clone(), &attributes, + match_by, ); let attribute_contexts = collect_arrow_to_pg_attribute_contexts( @@ -85,6 +94,7 @@ impl ParquetReaderContext { attribute_contexts, parquet_reader, binary_out_funcs, + match_by, started: false, finished: false, } @@ -116,15 +126,23 @@ impl ParquetReaderContext { fn record_batch_to_tuple_datums( record_batch: RecordBatch, attribute_contexts: &[ArrowToPgAttributeContext], + match_by: MatchBy, ) -> Vec> { let mut datums = vec![]; - for attribute_context in attribute_contexts { + for (attribute_idx, attribute_context) in attribute_contexts.iter().enumerate() { let name = attribute_context.name(); - let column_array = record_batch - .column_by_name(name) - .unwrap_or_else(|| panic!("column {} not found", name)); + let column_array = match match_by { + MatchBy::Position => record_batch + .columns() + .get(attribute_idx) + .unwrap_or_else(|| panic!("column {} not found", name)), + + MatchBy::Name => record_batch + .column_by_name(name) + .unwrap_or_else(|| panic!("column {} not found", name)), + }; let datum = if attribute_context.needs_cast() { // should fail instead of returning None if the cast fails at runtime @@ -181,8 +199,11 @@ impl ParquetReaderContext { self.buffer.extend_from_slice(&attnum_len_bytes); // convert the columnar arrays in record batch to tuple datums - let tuple_datums = - Self::record_batch_to_tuple_datums(record_batch, &self.attribute_contexts); + let tuple_datums = Self::record_batch_to_tuple_datums( + record_batch, + &self.attribute_contexts, + self.match_by, + ); // write the tuple datums to the ParquetReader's internal buffer in PG copy format for (datum, out_func) in tuple_datums.into_iter().zip(self.binary_out_funcs.iter()) diff --git a/src/arrow_parquet/pg_to_arrow.rs b/src/arrow_parquet/pg_to_arrow.rs index 530c7f7..17d774b 100644 --- a/src/arrow_parquet/pg_to_arrow.rs +++ b/src/arrow_parquet/pg_to_arrow.rs @@ -148,7 +148,7 @@ impl PgToArrowAttributeContext { }; let attributes = - collect_attributes_for(CollectAttributesFor::Struct, &attribute_tupledesc); + collect_attributes_for(CollectAttributesFor::Other, &attribute_tupledesc); collect_pg_to_arrow_attribute_contexts(&attributes, &fields) }); diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index b76ee70..f100e4e 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -16,7 +16,7 @@ use pgrx::{check_for_interrupts, prelude::*, PgTupleDesc}; use crate::{ pgrx_utils::{ array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, - is_composite_type, tuple_desc, CollectAttributesFor, + is_composite_type, is_generated_attribute, tuple_desc, CollectAttributesFor, }, type_compat::{ geometry::is_postgis_geometry_type, @@ -27,6 +27,8 @@ use crate::{ }, }; +use super::match_by::MatchBy; + pub(crate) fn parquet_schema_string_from_attributes( attributes: &[FormData_pg_attribute], ) -> String { @@ -95,7 +97,7 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i let mut child_fields: Vec> = vec![]; - let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); + let attributes = collect_attributes_for(CollectAttributesFor::Other, &tupledesc); for attribute in attributes { if attribute.is_dropped() { @@ -342,6 +344,30 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { Arc::new(entries_field) } +pub(crate) fn error_if_copy_from_match_by_position_with_generated_columns( + tupledesc: &PgTupleDesc, + match_by: MatchBy, +) { + // match_by 'name' can handle generated columns + if let MatchBy::Name = match_by { + return; + } + + let attributes = collect_attributes_for(CollectAttributesFor::Other, tupledesc); + + for attribute in attributes { + if is_generated_attribute(&attribute) { + ereport!( + PgLogLevel::ERROR, + PgSqlErrorCode::ERRCODE_FEATURE_NOT_SUPPORTED, + "COPY FROM parquet with generated columns is not supported", + "Try COPY FROM parquet WITH (match_by 'name'). \" + It works only if the column names match with parquet file's.", + ); + } + } +} + // ensure_file_schema_match_tupledesc_schema throws an error if the file's schema does not match the table schema. // If the file's arrow schema is castable to the table's arrow schema, it returns a vector of Option // to cast to for each field. @@ -349,21 +375,42 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( file_schema: Arc, tupledesc_schema: Arc, attributes: &[FormData_pg_attribute], + match_by: MatchBy, ) -> Vec> { let mut cast_to_types = Vec::new(); + if match_by == MatchBy::Position + && tupledesc_schema.fields().len() != file_schema.fields().len() + { + panic!( + "column count mismatch between table and parquet file. \ + parquet file has {} columns, but table has {} columns", + file_schema.fields().len(), + tupledesc_schema.fields().len() + ); + } + for (tupledesc_schema_field, attribute) in tupledesc_schema.fields().iter().zip(attributes.iter()) { let field_name = tupledesc_schema_field.name(); - let file_schema_field = file_schema.column_with_name(field_name); + let file_schema_field = match match_by { + MatchBy::Position => file_schema.field(attribute.attnum as usize - 1), - if file_schema_field.is_none() { - panic!("column \"{}\" is not found in parquet file", field_name); - } + MatchBy::Name => { + let file_schema_field = file_schema.column_with_name(field_name); + + if file_schema_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } + + let (_, file_schema_field) = file_schema_field.unwrap(); + + file_schema_field + } + }; - let (_, file_schema_field) = file_schema_field.unwrap(); let file_schema_field = Arc::new(file_schema_field.clone()); let from_type = file_schema_field.data_type(); @@ -378,7 +425,7 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( if !is_coercible(from_type, to_type, attribute.atttypid, attribute.atttypmod) { panic!( "type mismatch for column \"{}\" between table and parquet file.\n\n\ - table has \"{}\"\n\nparquet file has \"{}\"", + table has \"{}\"\n\nparquet file has \"{}\"", field_name, to_type, from_type ); } @@ -413,7 +460,7 @@ fn is_coercible(from_type: &DataType, to_type: &DataType, to_typoid: Oid, to_typ let tupledesc = tuple_desc(to_typoid, to_typmod); - let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); + let attributes = collect_attributes_for(CollectAttributesFor::Other, &tupledesc); for (from_field, (to_field, to_attribute)) in from_fields .iter() diff --git a/src/parquet_copy_hook/copy_from.rs b/src/parquet_copy_hook/copy_from.rs index bf3a878..ae86e07 100644 --- a/src/parquet_copy_hook/copy_from.rs +++ b/src/parquet_copy_hook/copy_from.rs @@ -20,8 +20,8 @@ use crate::{ }; use super::copy_utils::{ - copy_stmt_attribute_list, copy_stmt_create_namespace_item, copy_stmt_create_parse_state, - create_filtered_tupledesc_for_relation, + copy_from_stmt_match_by, copy_stmt_attribute_list, copy_stmt_create_namespace_item, + copy_stmt_create_parse_state, create_filtered_tupledesc_for_relation, }; // stack to store parquet reader contexts for COPY FROM. @@ -131,9 +131,11 @@ pub(crate) fn execute_copy_from( let tupledesc = create_filtered_tupledesc_for_relation(p_stmt, &relation); + let match_by = copy_from_stmt_match_by(p_stmt); + unsafe { // parquet reader context is used throughout the COPY FROM operation. - let parquet_reader_context = ParquetReaderContext::new(uri, &tupledesc); + let parquet_reader_context = ParquetReaderContext::new(uri, match_by, &tupledesc); push_parquet_reader_context(parquet_reader_context); // makes sure to set binary format diff --git a/src/parquet_copy_hook/copy_utils.rs b/src/parquet_copy_hook/copy_utils.rs index 068e95e..dde7848 100644 --- a/src/parquet_copy_hook/copy_utils.rs +++ b/src/parquet_copy_hook/copy_utils.rs @@ -15,6 +15,7 @@ use url::Url; use crate::arrow_parquet::{ compression::{all_supported_compressions, PgParquetCompression}, + match_by::MatchBy, parquet_writer::{DEFAULT_ROW_GROUP_SIZE, DEFAULT_ROW_GROUP_SIZE_BYTES}, uri_utils::parse_uri, }; @@ -109,7 +110,7 @@ pub(crate) fn validate_copy_to_options(p_stmt: &PgBox, uri: &Url) { } pub(crate) fn validate_copy_from_options(p_stmt: &PgBox) { - validate_copy_option_names(p_stmt, &["format", "freeze"]); + validate_copy_option_names(p_stmt, &["format", "match_by", "freeze"]); let format_option = copy_stmt_get_option(p_stmt, "format"); @@ -253,6 +254,24 @@ pub(crate) fn copy_from_stmt_create_option_list(p_stmt: &PgBox) -> new_copy_options } +pub(crate) fn copy_from_stmt_match_by(p_stmt: &PgBox) -> MatchBy { + let match_by_option = copy_stmt_get_option(p_stmt, "match_by"); + + if match_by_option.is_null() { + MatchBy::default() + } else { + let match_by = unsafe { defGetString(match_by_option.as_ptr()) }; + + let match_by = unsafe { + CStr::from_ptr(match_by) + .to_str() + .expect("match_by option is not a valid CString") + }; + + MatchBy::from_str(match_by).unwrap_or_else(|e| panic!("{}", e)) + } +} + pub(crate) fn copy_stmt_get_option( p_stmt: &PgBox, option_name: &str, diff --git a/src/pgrx_tests/copy_from_coerce.rs b/src/pgrx_tests/copy_from_coerce.rs index 75c7af8..03d557f 100644 --- a/src/pgrx_tests/copy_from_coerce.rs +++ b/src/pgrx_tests/copy_from_coerce.rs @@ -966,7 +966,7 @@ mod tests { } #[pg_test] - fn test_table_with_different_field_position() { + fn test_table_with_different_position_match_by_name() { let copy_to = format!( "COPY (SELECT 1 as x, 'hello' as y) TO '{}'", LOCAL_TEST_FILE_PATH @@ -976,13 +976,44 @@ mod tests { let create_table = "CREATE TABLE test_table (y text, x int)"; Spi::run(create_table).unwrap(); - let copy_from = format!("COPY test_table FROM '{}'", LOCAL_TEST_FILE_PATH); + let copy_from = format!( + "COPY test_table FROM '{}' WITH (match_by 'name')", + LOCAL_TEST_FILE_PATH + ); Spi::run(©_from).unwrap(); let result = Spi::get_two::<&str, i32>("SELECT y, x FROM test_table LIMIT 1").unwrap(); assert_eq!(result, (Some("hello"), Some(1))); } + #[pg_test] + fn test_table_with_different_name_match_by_position() { + let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (x bigint, y varchar)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet' WITH (match_by 'position')"; + Spi::run(copy_from).unwrap(); + + let result = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(result, (Some(1), Some("hello"))); + } + + #[pg_test] + #[should_panic(expected = "column count mismatch between table and parquet file")] + fn test_table_with_different_name_match_by_position_fail() { + let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (x bigint, y varchar, z int)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + } + #[pg_test] #[should_panic(expected = "column \"name\" is not found in parquet file")] fn test_missing_column_in_parquet() { @@ -992,7 +1023,10 @@ mod tests { let copy_to_parquet = format!("copy (select 100 as id) to '{}';", LOCAL_TEST_FILE_PATH); Spi::run(©_to_parquet).unwrap(); - let copy_from = format!("COPY test_table FROM '{}'", LOCAL_TEST_FILE_PATH); + let copy_from = format!( + "COPY test_table FROM '{}' with (match_by 'name')", + LOCAL_TEST_FILE_PATH + ); Spi::run(©_from).unwrap(); } diff --git a/src/pgrx_tests/copy_options.rs b/src/pgrx_tests/copy_options.rs index 0e223ce..eaccd20 100644 --- a/src/pgrx_tests/copy_options.rs +++ b/src/pgrx_tests/copy_options.rs @@ -392,4 +392,19 @@ mod tests { assert_eq!(result_metadata, vec![10]); } + + #[pg_test] + #[should_panic(expected = "unrecognized match_by method: invalid_match_by")] + fn test_invalid_match_by() { + let mut copy_from_options = HashMap::new(); + copy_from_options.insert( + "match_by".to_string(), + CopyOptionValue::StringOption("invalid_match_by".to_string()), + ); + + let test_table = + TestTable::::new("int4".into()).with_copy_from_options(copy_from_options); + test_table.insert("INSERT INTO test_expected (a) VALUES (1), (2), (null);"); + test_table.assert_expected_and_result_rows(); + } } diff --git a/src/pgrx_tests/copy_pg_rules.rs b/src/pgrx_tests/copy_pg_rules.rs index b0347e2..c4a2816 100644 --- a/src/pgrx_tests/copy_pg_rules.rs +++ b/src/pgrx_tests/copy_pg_rules.rs @@ -101,6 +101,21 @@ mod tests { Spi::run(©_from).unwrap(); } + #[pg_test] + #[should_panic(expected = "COPY FROM parquet with generated columns is not supported")] + fn test_copy_from_by_position_with_generated_columns_not_supported() { + Spi::run("DROP TABLE IF EXISTS test_table;").unwrap(); + + Spi::run("CREATE TABLE test_table (a int, b int generated always as (10) stored, c text);") + .unwrap(); + + let copy_from_query = format!( + "COPY test_table FROM '{}' WITH (format parquet);", + LOCAL_TEST_FILE_PATH + ); + Spi::run(copy_from_query.as_str()).unwrap(); + } + #[pg_test] fn test_with_generated_and_dropped_columns() { Spi::run("DROP TABLE IF EXISTS test_table;").unwrap(); @@ -123,7 +138,7 @@ mod tests { Spi::run("TRUNCATE test_table;").unwrap(); let copy_from_query = format!( - "COPY test_table FROM '{}' WITH (format parquet);", + "COPY test_table FROM '{}' WITH (format parquet, match_by 'name');", LOCAL_TEST_FILE_PATH ); Spi::run(copy_from_query.as_str()).unwrap(); diff --git a/src/pgrx_utils.rs b/src/pgrx_utils.rs index cb8f9fa..0793def 100644 --- a/src/pgrx_utils.rs +++ b/src/pgrx_utils.rs @@ -12,7 +12,7 @@ use pgrx::{ pub(crate) enum CollectAttributesFor { CopyFrom, CopyTo, - Struct, + Other, } // collect_attributes_for collects not-dropped attributes from the tuple descriptor. @@ -23,7 +23,7 @@ pub(crate) fn collect_attributes_for( ) -> Vec { let include_generated_columns = match copy_operation { CollectAttributesFor::CopyFrom => false, - CollectAttributesFor::CopyTo | CollectAttributesFor::Struct => true, + CollectAttributesFor::CopyTo | CollectAttributesFor::Other => true, }; let mut attributes = vec![]; @@ -35,7 +35,7 @@ pub(crate) fn collect_attributes_for( continue; } - if !include_generated_columns && attribute.attgenerated != 0 { + if !include_generated_columns && is_generated_attribute(attribute) { continue; } @@ -55,6 +55,10 @@ pub(crate) fn collect_attributes_for( attributes } +pub(crate) fn is_generated_attribute(attribute: &FormData_pg_attribute) -> bool { + attribute.attgenerated != 0 +} + pub(crate) fn tuple_desc(typoid: Oid, typmod: i32) -> PgTupleDesc<'static> { let tupledesc = unsafe { lookup_rowtype_tupdesc(typoid, typmod) }; unsafe { PgTupleDesc::from_pg(tupledesc) }