diff --git a/src/arrow_parquet/parquet_writer.rs b/src/arrow_parquet/parquet_writer.rs index a008f43..f525c7c 100644 --- a/src/arrow_parquet/parquet_writer.rs +++ b/src/arrow_parquet/parquet_writer.rs @@ -108,6 +108,7 @@ impl<'a> ParquetWriterContext<'a> { } pub(crate) fn close(self) { - self.runtime.block_on(self.parquet_writer.close()).unwrap(); + // should not panic as we can call from try catch block + self.runtime.block_on(self.parquet_writer.close()).ok(); } } diff --git a/src/lib.rs b/src/lib.rs index a6c35ed..8a66cf3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ -use parquet_copy_hook::hook::init_parquet_copy_hook; -use pgrx::prelude::*; +use parquet_copy_hook::hook::{init_parquet_copy_hook, ENABLE_PARQUET_COPY_HOOK}; +use pg_sys::MarkGUCPrefixReserved; +use pgrx::{prelude::*, GucContext, GucFlags, GucRegistry}; mod arrow_parquet; mod parquet_copy_hook; @@ -19,6 +20,17 @@ pgrx::pg_module_magic!(); #[allow(static_mut_refs)] #[pg_guard] pub extern "C" fn _PG_init() { + GucRegistry::define_bool_guc( + "pg_parquet.enable_copy_hooks", + "Enable parquet copy hooks", + "Enable parquet copy hooks", + &ENABLE_PARQUET_COPY_HOOK, + GucContext::Userset, + GucFlags::default(), + ); + + unsafe { MarkGUCPrefixReserved("pg_parquet".as_ptr() as _) }; + init_parquet_copy_hook(); } @@ -2004,6 +2016,13 @@ mod tests { Spi::run("DROP TABLE workers; DROP TYPE worker, person;").unwrap(); } + + #[pg_test] + #[should_panic(expected = "relative path not allowed for COPY to file")] + fn test_disabled_hooks() { + Spi::run("SET pg_parquet.enable_copy_hooks TO false;").unwrap(); + Spi::run("COPY (SELECT 1 as id) TO 'file:///tmp/test.parquet'").unwrap(); + } } /// This module is required by `cargo pgrx test` invocations. diff --git a/src/parquet_copy_hook/copy_utils.rs b/src/parquet_copy_hook/copy_utils.rs index fb927be..18a678b 100644 --- a/src/parquet_copy_hook/copy_utils.rs +++ b/src/parquet_copy_hook/copy_utils.rs @@ -220,11 +220,7 @@ pub(crate) fn is_copy_to_parquet_stmt(pstmt: &PgBox) -> boo return false; } - if is_parquet_format(©_stmt) { - return true; - } - - is_parquet_file(©_stmt) + is_parquet_format(©_stmt) || is_parquet_file(©_stmt) } pub(crate) fn is_copy_from_parquet_stmt(pstmt: &PgBox) -> bool { @@ -242,11 +238,7 @@ pub(crate) fn is_copy_from_parquet_stmt(pstmt: &PgBox) -> b return false; } - if is_parquet_format(©_stmt) { - return true; - } - - is_parquet_file(©_stmt) + is_parquet_format(©_stmt) || is_parquet_file(©_stmt) } pub(crate) fn copy_has_relation(pstmt: &PgBox) -> bool { diff --git a/src/parquet_copy_hook/hook.rs b/src/parquet_copy_hook/hook.rs index 69074a6..ddebbec 100644 --- a/src/parquet_copy_hook/hook.rs +++ b/src/parquet_copy_hook/hook.rs @@ -1,7 +1,7 @@ use std::ffi::CStr; use pg_sys::{standard_ProcessUtility, CommandTag, ProcessUtility_hook, ProcessUtility_hook_type}; -use pgrx::prelude::*; +use pgrx::{prelude::*, GucSetting}; use crate::parquet_copy_hook::{ copy_to_dest_receiver::create_copy_to_parquet_dest_receiver, @@ -18,6 +18,8 @@ use super::{ copy_utils::{copy_stmt_codec, validate_copy_from_options, validate_copy_to_options}, }; +pub(crate) static ENABLE_PARQUET_COPY_HOOK: GucSetting = GucSetting::::new(true); + static mut PREV_PROCESS_UTILITY_HOOK: ProcessUtility_hook_type = None; #[pg_guard] @@ -49,7 +51,7 @@ extern "C" fn parquet_copy_hook( let params = unsafe { PgBox::from_pg(params) }; let query_env = unsafe { PgBox::from_pg(query_env) }; - if is_copy_to_parquet_stmt(&pstmt) { + if ENABLE_PARQUET_COPY_HOOK.get() && is_copy_to_parquet_stmt(&pstmt) { validate_copy_to_options(&pstmt); let filename = copy_stmt_filename(&pstmt); @@ -86,7 +88,7 @@ extern "C" fn parquet_copy_hook( .execute(); return; - } else if is_copy_from_parquet_stmt(&pstmt) { + } else if ENABLE_PARQUET_COPY_HOOK.get() && is_copy_from_parquet_stmt(&pstmt) { validate_copy_from_options(&pstmt); PgTryBuilder::new(|| {