diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index a3cd53b..3d48f78 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -33,13 +33,13 @@ pub(crate) struct ParquetReaderContext { } impl ParquetReaderContext { - pub(crate) fn new(uri: Url, tupledesc: &PgTupleDesc) -> Self { + pub(crate) fn new(uri: Url, tupledesc: &PgTupleDesc, check_perms: bool) -> Self { // Postgis and Map contexts are used throughout reading the parquet file. // We need to reset them to avoid reading the stale data. (e.g. extension could be dropped) reset_postgis_context(); reset_map_context(); - let parquet_reader = parquet_reader_from_uri(&uri); + let parquet_reader = parquet_reader_from_uri(&uri, check_perms); let schema = parquet_reader.schema(); ensure_arrow_schema_match_tupledesc(schema.clone(), tupledesc); diff --git a/src/arrow_parquet/parquet_writer.rs b/src/arrow_parquet/parquet_writer.rs index a7b9ef4..845b785 100644 --- a/src/arrow_parquet/parquet_writer.rs +++ b/src/arrow_parquet/parquet_writer.rs @@ -37,6 +37,7 @@ impl ParquetWriterContext { compression: PgParquetCompression, compression_level: i32, tupledesc: &PgTupleDesc, + check_perms: bool, ) -> ParquetWriterContext { debug_assert!(tupledesc.oid() == RECORDOID); @@ -60,7 +61,8 @@ impl ParquetWriterContext { let schema = parse_arrow_schema_from_tupledesc(tupledesc); let schema = Arc::new(schema); - let parquet_writer = parquet_writer_from_uri(&uri, schema.clone(), writer_props); + let parquet_writer = + parquet_writer_from_uri(&uri, schema.clone(), writer_props, check_perms); let attribute_contexts = collect_pg_to_arrow_attribute_contexts(tupledesc, &schema.fields); diff --git a/src/arrow_parquet/uri_utils.rs b/src/arrow_parquet/uri_utils.rs index 745058f..232ccad 100644 --- a/src/arrow_parquet/uri_utils.rs +++ b/src/arrow_parquet/uri_utils.rs @@ -56,10 +56,16 @@ fn parse_bucket_and_key(uri: &Url) -> (String, String) { (bucket.to_string(), key.to_string()) } -fn object_store_with_location(uri: &Url, copy_from: bool) -> (Arc, Path) { - if uri.scheme() == "s3" { - ensure_object_store_access_privilege(copy_from); +fn object_store_with_location( + uri: &Url, + copy_from: bool, + check_perms: bool, +) -> (Arc, Path) { + if check_perms { + ensure_access_privilege_to_uri(&uri, copy_from); + } + if uri.scheme() == "s3" { let (bucket_name, key) = parse_bucket_and_key(uri); let storage_container = PG_BACKEND_TOKIO_RUNTIME @@ -71,8 +77,6 @@ fn object_store_with_location(uri: &Url, copy_from: bool) -> (Arc String { uri.to_string() } -pub(crate) fn parquet_schema_from_uri(uri: &Url) -> SchemaDescriptor { - let parquet_reader = parquet_reader_from_uri(uri); +pub(crate) fn parquet_schema_from_uri(uri: &Url, check_perms: bool) -> SchemaDescriptor { + let parquet_reader = parquet_reader_from_uri(uri, check_perms); let arrow_schema = parquet_reader.schema(); arrow_to_parquet_schema(arrow_schema).unwrap_or_else(|e| panic!("{}", e)) } -pub(crate) fn parquet_metadata_from_uri(uri: &Url) -> Arc { +pub(crate) fn parquet_metadata_from_uri(uri: &Url, check_perms: bool) -> Arc { let copy_from = true; - let (parquet_object_store, location) = object_store_with_location(uri, copy_from); + let (parquet_object_store, location) = object_store_with_location(uri, copy_from, check_perms); PG_BACKEND_TOKIO_RUNTIME.block_on(async { let object_store_meta = parquet_object_store @@ -209,9 +213,12 @@ pub(crate) fn parquet_metadata_from_uri(uri: &Url) -> Arc { }) } -pub(crate) fn parquet_reader_from_uri(uri: &Url) -> ParquetRecordBatchStream { +pub(crate) fn parquet_reader_from_uri( + uri: &Url, + check_perms: bool, +) -> ParquetRecordBatchStream { let copy_from = true; - let (parquet_object_store, location) = object_store_with_location(uri, copy_from); + let (parquet_object_store, location) = object_store_with_location(uri, copy_from, check_perms); PG_BACKEND_TOKIO_RUNTIME.block_on(async { let object_store_meta = parquet_object_store @@ -241,9 +248,10 @@ pub(crate) fn parquet_writer_from_uri( uri: &Url, arrow_schema: SchemaRef, writer_props: WriterProperties, + check_perms: bool, ) -> AsyncArrowWriter { let copy_from = false; - let (parquet_object_store, location) = object_store_with_location(uri, copy_from); + let (parquet_object_store, location) = object_store_with_location(uri, copy_from, check_perms); let parquet_object_writer = ParquetObjectWriter::new(parquet_object_store, location); @@ -251,14 +259,21 @@ pub(crate) fn parquet_writer_from_uri( .unwrap_or_else(|e| panic!("failed to create parquet writer for uri {}: {}", uri, e)) } -fn ensure_object_store_access_privilege(copy_from: bool) { +pub(crate) fn ensure_access_privilege_to_uri(uri: &Url, copy_from: bool) { if unsafe { superuser() } { return; } let user_id = unsafe { GetUserId() }; + let is_file = uri.scheme() == "file"; - let required_role_name = if copy_from { + let required_role_name = if is_file { + if copy_from { + "pg_read_server_files" + } else { + "pg_write_server_files" + } + } else if copy_from { PARQUET_OBJECT_STORE_READ_ROLE } else { PARQUET_OBJECT_STORE_WRITE_ROLE @@ -268,46 +283,19 @@ fn ensure_object_store_access_privilege(copy_from: bool) { unsafe { get_role_oid(required_role_name.to_string().as_pg_cstr(), false) }; let operation_str = if copy_from { "from" } else { "to" }; + let object_type = if is_file { "file" } else { "remote uri" }; if !unsafe { has_privs_of_role(user_id, required_role_id) } { ereport!( pgrx::PgLogLevel::ERROR, pgrx::PgSqlErrorCode::ERRCODE_INSUFFICIENT_PRIVILEGE, - format!("permission denied to COPY {} a remote uri", operation_str), format!( - "Only roles with privileges of the \"{}\" role may COPY {} a remote uri.", - required_role_name, operation_str + "permission denied to COPY {} a {}", + operation_str, object_type ), - ); - } -} - -fn ensure_local_file_access_privilege(copy_from: bool) { - if unsafe { superuser() } { - return; - } - - let user_id = unsafe { GetUserId() }; - - let required_role_name = if copy_from { - "pg_read_server_files" - } else { - "pg_write_server_files" - }; - - let required_role_id = - unsafe { get_role_oid(required_role_name.to_string().as_pg_cstr(), false) }; - - let operation_str = if copy_from { "from" } else { "to" }; - - if !unsafe { has_privs_of_role(user_id, required_role_id) } { - ereport!( - pgrx::PgLogLevel::ERROR, - pgrx::PgSqlErrorCode::ERRCODE_INSUFFICIENT_PRIVILEGE, - format!("permission denied to COPY {} a file", operation_str), format!( - "Only roles with privileges of the \"{}\" role may COPY {} a file.", - required_role_name, operation_str + "Only roles with privileges of the \"{}\" role may COPY {} a {}.", + required_role_name, operation_str, object_type ), ); } diff --git a/src/lib.rs b/src/lib.rs index fdba58c..ce61540 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2672,6 +2672,9 @@ mod tests { let create_role = "create role test_role;"; Spi::run(create_role).unwrap(); + let grant_role = "grant pg_write_server_files TO test_role;"; + Spi::run(grant_role).unwrap(); + let set_role = "set role test_role;"; Spi::run(set_role).unwrap(); diff --git a/src/parquet_copy_hook/copy_from.rs b/src/parquet_copy_hook/copy_from.rs index f1fdab9..66cae17 100644 --- a/src/parquet_copy_hook/copy_from.rs +++ b/src/parquet_copy_hook/copy_from.rs @@ -133,7 +133,7 @@ pub(crate) fn execute_copy_from( unsafe { // parquet reader context is used throughout the COPY FROM operation. - let parquet_reader_context = ParquetReaderContext::new(uri, &tupledesc); + let parquet_reader_context = ParquetReaderContext::new(uri, &tupledesc, true); push_parquet_reader_context(parquet_reader_context); // makes sure to set binary format diff --git a/src/parquet_copy_hook/copy_to_dest_receiver.rs b/src/parquet_copy_hook/copy_to_dest_receiver.rs index 07fe504..e03aed2 100644 --- a/src/parquet_copy_hook/copy_to_dest_receiver.rs +++ b/src/parquet_copy_hook/copy_to_dest_receiver.rs @@ -203,7 +203,7 @@ extern "C" fn copy_startup(dest: *mut DestReceiver, _operation: i32, tupledesc: // parquet writer context is used throughout the COPY TO operation. // This might be put into ParquetCopyDestReceiver, but it's hard to preserve repr(C). let parquet_writer_context = - ParquetWriterContext::new(uri, compression, compression_level, &tupledesc); + ParquetWriterContext::new(uri, compression, compression_level, &tupledesc, false); push_parquet_writer_context(parquet_writer_context); } diff --git a/src/parquet_copy_hook/hook.rs b/src/parquet_copy_hook/hook.rs index 8d4cea0..1329817 100644 --- a/src/parquet_copy_hook/hook.rs +++ b/src/parquet_copy_hook/hook.rs @@ -7,7 +7,10 @@ use pg_sys::{ use pgrx::{prelude::*, GucSetting}; use crate::{ - arrow_parquet::{compression::INVALID_COMPRESSION_LEVEL, uri_utils::uri_as_string}, + arrow_parquet::{ + compression::INVALID_COMPRESSION_LEVEL, + uri_utils::{ensure_access_privilege_to_uri, uri_as_string}, + }, parquet_copy_hook::{ copy_to_dest_receiver::create_copy_to_parquet_dest_receiver, copy_utils::{ @@ -59,6 +62,9 @@ extern "C" fn parquet_copy_hook( if ENABLE_PARQUET_COPY_HOOK.get() && is_copy_to_parquet_stmt(&p_stmt) { let uri = copy_stmt_uri(&p_stmt).expect("uri is None"); + let copy_from = false; + + ensure_access_privilege_to_uri(&uri, copy_from); validate_copy_to_options(&p_stmt, &uri); @@ -106,6 +112,9 @@ extern "C" fn parquet_copy_hook( return; } else if ENABLE_PARQUET_COPY_HOOK.get() && is_copy_from_parquet_stmt(&p_stmt) { let uri = copy_stmt_uri(&p_stmt).expect("uri is None"); + let copy_from = true; + + ensure_access_privilege_to_uri(&uri, copy_from); validate_copy_from_options(&p_stmt); diff --git a/src/parquet_udfs/metadata.rs b/src/parquet_udfs/metadata.rs index 4d167c6..3bd1aad 100644 --- a/src/parquet_udfs/metadata.rs +++ b/src/parquet_udfs/metadata.rs @@ -39,7 +39,7 @@ mod parquet { > { let uri = parse_uri(&uri); - let parquet_metadata = parquet_metadata_from_uri(&uri); + let parquet_metadata = parquet_metadata_from_uri(&uri, true); let mut rows = vec![]; @@ -137,7 +137,7 @@ mod parquet { > { let uri = parse_uri(&uri); - let parquet_metadata = parquet_metadata_from_uri(&uri); + let parquet_metadata = parquet_metadata_from_uri(&uri, true); let created_by = parquet_metadata .file_metadata() @@ -174,7 +174,7 @@ mod parquet { > { let uri = parse_uri(&uri); - let parquet_metadata = parquet_metadata_from_uri(&uri); + let parquet_metadata = parquet_metadata_from_uri(&uri, true); let kv_metadata = parquet_metadata.file_metadata().key_value_metadata(); if kv_metadata.is_none() { diff --git a/src/parquet_udfs/schema.rs b/src/parquet_udfs/schema.rs index 7e5c68f..8039aa2 100644 --- a/src/parquet_udfs/schema.rs +++ b/src/parquet_udfs/schema.rs @@ -32,7 +32,7 @@ mod parquet { > { let uri = parse_uri(&uri); - let parquet_schema = parquet_schema_from_uri(&uri); + let parquet_schema = parquet_schema_from_uri(&uri, true); let root_type = parquet_schema.root_schema(); let thrift_schema_elements = to_thrift(root_type).unwrap_or_else(|e| {