Skip to content

Commit

Permalink
verify parquet file schema before copying the file to Postgres table,…
Browse files Browse the repository at this point in the history
… which both prevents unnecessary work and gives more friendly errors
  • Loading branch information
aykut-bozkurt committed Sep 26, 2024
1 parent ef8a258 commit 6a2be26
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 8 deletions.
18 changes: 13 additions & 5 deletions src/arrow_parquet/parquet_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use crate::{
type_compat::{geometry::reset_postgis_context, map::reset_crunchy_map_context},
};

use super::uri_utils::parquet_reader_from_uri;
use super::{
schema_visitor::ensure_arrow_schema_match_tupledesc, uri_utils::parquet_reader_from_uri,
};

pub(crate) struct ParquetReaderContext {
buffer: Vec<u8>,
Expand All @@ -35,14 +37,19 @@ impl ParquetReaderContext {
reset_postgis_context();
reset_crunchy_map_context();

let pgtupledesc = unsafe { PgTupleDesc::from_pg_copy(tupledesc) };

let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap_or_else(|e| panic!("failed to create tokio runtime: {}", e));

let parquet_reader = runtime.block_on(parquet_reader_from_uri(&uri));

let binary_out_funcs = Self::collect_binary_out_funcs(tupledesc);
let file_schema = parquet_reader.schema();
ensure_arrow_schema_match_tupledesc(file_schema.clone(), &pgtupledesc);

let binary_out_funcs = Self::collect_binary_out_funcs(&pgtupledesc);

ParquetReaderContext {
buffer: Vec::new(),
Expand All @@ -56,13 +63,14 @@ impl ParquetReaderContext {
}
}

fn collect_binary_out_funcs(tupledesc: TupleDesc) -> Vec<PgBox<FmgrInfo, AllocatedByPostgres>> {
fn collect_binary_out_funcs(
tupledesc: &PgTupleDesc,
) -> Vec<PgBox<FmgrInfo, AllocatedByPostgres>> {
unsafe {
let mut binary_out_funcs = vec![];

let include_generated_columns = false;
let tupledesc = PgTupleDesc::from_pg_copy(tupledesc);
let attributes = collect_valid_attributes(&tupledesc, include_generated_columns);
let attributes = collect_valid_attributes(tupledesc, include_generated_columns);

for att in attributes.iter() {
let typoid = att.type_oid();
Expand Down
36 changes: 33 additions & 3 deletions src/arrow_parquet/schema_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use arrow_schema::FieldRef;
use parquet::arrow::{arrow_to_parquet_schema, PARQUET_FIELD_ID_META_KEY};
use pg_sys::{
Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID,
NUMERICOID, OIDOID, RECORDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID,
NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID,
};
use pgrx::{prelude::*, PgTupleDesc};

Expand Down Expand Up @@ -35,8 +35,6 @@ pub(crate) fn parquet_schema_string_from_tupledesc(tupledesc: PgTupleDesc) -> St
}

pub(crate) fn parse_arrow_schema_from_tupledesc(tupledesc: PgTupleDesc) -> Schema {
debug_assert!(tupledesc.oid() == RECORDOID);

let mut field_id = 0;

let mut struct_attribute_fields = vec![];
Expand Down Expand Up @@ -288,3 +286,35 @@ fn to_not_nullable_field(field: FieldRef) -> FieldRef {
let field = Field::new(name, data_type.clone(), false).with_metadata(metadata);
Arc::new(field)
}

pub(crate) fn ensure_arrow_schema_match_tupledesc(
file_schema: Arc<Schema>,
tupledesc: &PgTupleDesc,
) {
let table_schema = parse_arrow_schema_from_tupledesc(tupledesc.clone());

for table_schema_field in table_schema.fields().iter() {
let table_schema_field_name = table_schema_field.name();
let table_schema_field_type = table_schema_field.data_type();

let file_schema_field = file_schema.column_with_name(table_schema_field_name);

if let Some(file_schema_field) = file_schema_field {
let file_schema_field_type = file_schema_field.1.data_type();

if file_schema_field_type != table_schema_field_type {
panic!(
"type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but file had \"{}\"",
table_schema_field_name,
table_schema_field_type,
file_schema_field_type,
);
}
} else {
panic!(
"column \"{}\" is not found in parquet file",
table_schema_field_name
);
}
}
}
49 changes: 49 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2184,6 +2184,55 @@ mod tests {
let parquet_metadata_command = "select * from parquet.metadata('/tmp/test.parquet');";
Spi::run(parquet_metadata_command).unwrap();
}

#[pg_test]
#[should_panic(
expected = "type mismatch for column \"location\" between table and parquet file"
)]
fn test_type_mismatch_between_parquet_and_table() {
let create_types = "create type dog as (name text, age int);
create type person as (id bigint, name text, dogs dog[]);
create type address as (loc text);";
Spi::run(create_types).unwrap();

let create_correct_table =
"create table factory_correct (id bigint, workers person[], name text, location address);";
Spi::run(create_correct_table).unwrap();

let create_wrong_table =
"create table factory_wrong (id bigint, workers person[], name text, location int);";
Spi::run(create_wrong_table).unwrap();

let copy_to_parquet = "copy (select 1::int8 as id,
array[
row(1, 'ali', array[row('lady', 4), NULL]::dog[])
]::person[] as workers,
'Microsoft' as name,
row('istanbul')::address as location
from generate_series(1,10) i) to '/tmp/test.parquet';";
Spi::run(copy_to_parquet).unwrap();

// copy to correct table which matches the parquet schema
let copy_to_correct_table = "copy factory_correct from '/tmp/test.parquet';";
Spi::run(copy_to_correct_table).unwrap();

// copy to wrong table which does not match the parquet schema
let copy_to_wrong_table = "copy factory_wrong from '/tmp/test.parquet';";
Spi::run(copy_to_wrong_table).unwrap();
}

#[pg_test]
#[should_panic(expected = "column \"name\" is not found in parquet file")]
fn test_missing_column_in_parquet() {
let create_table = "create table test(id int, name text);";
Spi::run(create_table).unwrap();

let copy_to_parquet = "copy (select 100 as id) to '/tmp/test.parquet';";
Spi::run(copy_to_parquet).unwrap();

let copy_to_table = "copy test from '/tmp/test.parquet';";
Spi::run(copy_to_table).unwrap();
}
}

/// This module is required by `cargo pgrx test` invocations.
Expand Down

0 comments on commit 6a2be26

Please sign in to comment.