Skip to content

Commit

Permalink
Check extension at hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
aykut-bozkurt committed Dec 16, 2024
1 parent 8f812ed commit 222dc23
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 12 deletions.
41 changes: 33 additions & 8 deletions src/parquet_copy_hook/copy_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ use pgrx::{
};
use url::Url;

use crate::arrow_parquet::{
compression::{all_supported_compressions, PgParquetCompression},
match_by::MatchBy,
parquet_writer::{DEFAULT_ROW_GROUP_SIZE, DEFAULT_ROW_GROUP_SIZE_BYTES},
uri_utils::parse_uri,
use crate::{
arrow_parquet::{
compression::{all_supported_compressions, PgParquetCompression},
match_by::MatchBy,
parquet_writer::{DEFAULT_ROW_GROUP_SIZE, DEFAULT_ROW_GROUP_SIZE_BYTES},
uri_utils::parse_uri,
},
pgrx_utils::extension_exists,
};

use super::pg_compat::strVal;
use super::{hook::ENABLE_PARQUET_COPY_HOOK, pg_compat::strVal};

pub(crate) fn validate_copy_to_options(p_stmt: &PgBox<PlannedStmt>, uri: &Url) {
validate_copy_option_names(
Expand Down Expand Up @@ -298,6 +301,11 @@ pub(crate) fn copy_stmt_get_option(
}

pub(crate) fn is_copy_to_parquet_stmt(p_stmt: &PgBox<PlannedStmt>) -> bool {
// the GUC pg_parquet.enable_copy_hook must be set to true
if !ENABLE_PARQUET_COPY_HOOK.get() {
return false;
}

let is_copy_stmt = unsafe { is_a(p_stmt.utilityStmt, T_CopyStmt) };

if !is_copy_stmt {
Expand All @@ -320,10 +328,21 @@ pub(crate) fn is_copy_to_parquet_stmt(p_stmt: &PgBox<PlannedStmt>) -> bool {

let uri = copy_stmt_uri(p_stmt).expect("uri is None");

is_parquet_format_option(p_stmt) || is_parquet_uri(uri)
if !is_parquet_format_option(p_stmt) && !is_parquet_uri(uri) {
return false;
}

// extension checks are done via catalog (not yet searched via cache by postgres till pg18)
// this is why we check them after the uri checks
extension_exists("pg_parquet") && !extension_exists("crunchy_query_engine")
}

pub(crate) fn is_copy_from_parquet_stmt(p_stmt: &PgBox<PlannedStmt>) -> bool {
// the GUC pg_parquet.enable_copy_hook must be set to true
if !ENABLE_PARQUET_COPY_HOOK.get() {
return false;
}

let is_copy_stmt = unsafe { is_a(p_stmt.utilityStmt, T_CopyStmt) };

if !is_copy_stmt {
Expand All @@ -346,7 +365,13 @@ pub(crate) fn is_copy_from_parquet_stmt(p_stmt: &PgBox<PlannedStmt>) -> bool {

let uri = copy_stmt_uri(p_stmt).expect("uri is None");

is_parquet_format_option(p_stmt) || is_parquet_uri(uri)
if !is_parquet_format_option(p_stmt) && !is_parquet_uri(uri) {
return false;
}

// extension checks are done via catalog (not yet searched via cache by postgres till pg18)
// this is why we check them after the uri checks
extension_exists("pg_parquet") && !extension_exists("crunchy_query_engine")
}

fn is_parquet_format_option(p_stmt: &PgBox<PlannedStmt>) -> bool {
Expand Down
4 changes: 2 additions & 2 deletions src/parquet_copy_hook/hook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,15 @@ extern "C" fn parquet_copy_hook(
let query_env = unsafe { PgBox::from_pg(query_env) };
let mut completion_tag = unsafe { PgBox::from_pg(completion_tag) };

if ENABLE_PARQUET_COPY_HOOK.get() && is_copy_to_parquet_stmt(&p_stmt) {
if is_copy_to_parquet_stmt(&p_stmt) {
let nprocessed = process_copy_to_parquet(&p_stmt, query_string, &params, &query_env);

if !completion_tag.is_null() {
completion_tag.nprocessed = nprocessed;
completion_tag.commandTag = CommandTag::CMDTAG_COPY;
}
return;
} else if ENABLE_PARQUET_COPY_HOOK.get() && is_copy_from_parquet_stmt(&p_stmt) {
} else if is_copy_from_parquet_stmt(&p_stmt) {
let nprocessed = process_copy_from_parquet(&p_stmt, query_string, &query_env);

if !completion_tag.is_null() {
Expand Down
10 changes: 8 additions & 2 deletions src/pgrx_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::collections::HashSet;

use pgrx::{
pg_sys::{
getBaseType, get_element_type, lookup_rowtype_tupdesc, type_is_array, type_is_rowtype,
FormData_pg_attribute, InvalidOid, Oid,
getBaseType, get_element_type, get_extension_oid, lookup_rowtype_tupdesc, type_is_array,
type_is_rowtype, AsPgCStr, FormData_pg_attribute, InvalidOid, Oid,
},
PgTupleDesc,
};
Expand Down Expand Up @@ -99,3 +99,9 @@ pub(crate) fn domain_array_base_elem_typoid(domain_typoid: Oid) -> Oid {

array_element_typoid(base_array_typoid)
}

pub(crate) fn extension_exists(extension_name: &str) -> bool {
let extension_name = extension_name.as_pg_cstr();
let extension_oid = unsafe { get_extension_oid(extension_name, true) };
extension_oid != InvalidOid
}

0 comments on commit 222dc23

Please sign in to comment.