diff --git a/sql/bootstrap.sql b/sql/bootstrap.sql new file mode 100644 index 0000000..97c85ca --- /dev/null +++ b/sql/bootstrap.sql @@ -0,0 +1,16 @@ +-- create roles for parquet object store read and write if they do not exist +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'parquet_object_store_read') THEN + CREATE ROLE parquet_object_store_read; + END IF; + IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'parquet_object_store_write') THEN + CREATE ROLE parquet_object_store_write; + END IF; +END $$; + +-- create parquet schema if it does not exist +CREATE SCHEMA IF NOT EXISTS parquet; +REVOKE ALL ON SCHEMA parquet FROM public; +GRANT USAGE ON SCHEMA parquet TO public; +GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA parquet TO public; diff --git a/src/arrow_parquet/uri_utils.rs b/src/arrow_parquet/uri_utils.rs index 4c9315e..ef1c0eb 100644 --- a/src/arrow_parquet/uri_utils.rs +++ b/src/arrow_parquet/uri_utils.rs @@ -21,9 +21,13 @@ use parquet::{ }, file::properties::WriterProperties, }; +use pgrx::pg_sys::AsPgCStr; use crate::arrow_parquet::parquet_writer::DEFAULT_ROW_GROUP_SIZE; +const PARQUET_OBJECT_STORE_READ_ROLE: &str = "parquet_object_store_read"; +const PARQUET_OBJECT_STORE_WRITE_ROLE: &str = "parquet_object_store_write"; + #[derive(Debug, PartialEq)] enum UriFormat { File, @@ -57,18 +61,20 @@ fn parse_bucket_and_key(uri: &str) -> (String, String) { (bucket.to_string(), key.to_string()) } -async fn object_store_with_location(uri: &str) -> (Arc, Path) { +async fn object_store_with_location(uri: &str, read_only: bool) -> (Arc, Path) { let uri_format = UriFormat::from_str(uri).unwrap_or_else(|e| panic!("{}", e)); match uri_format { UriFormat::File => { - // create or overwrite the local file - std::fs::OpenOptions::new() - .write(true) - .truncate(true) - .create(true) - .open(uri) - .unwrap_or_else(|e| panic!("{}", e)); + if !read_only { + // create or overwrite the local file + std::fs::OpenOptions::new() + .write(true) + .truncate(true) + .create(true) + .open(uri) + .unwrap_or_else(|e| panic!("{}", e)); + } let storage_container = Arc::new(LocalFileSystem::new()); @@ -77,6 +83,8 @@ async fn object_store_with_location(uri: &str) -> (Arc, Path) { (storage_container, location) } UriFormat::S3 => { + ensure_object_store_access(read_only); + let (bucket_name, key) = parse_bucket_and_key(uri); let storage_container = Arc::new(get_s3_object_store(&bucket_name).await); @@ -97,7 +105,8 @@ pub(crate) async fn parquet_schema_from_uri(uri: &str) -> SchemaDescriptor { } pub(crate) async fn parquet_metadata_from_uri(uri: &str) -> Arc { - let (parquet_object_store, location) = object_store_with_location(uri).await; + let read_only = true; + let (parquet_object_store, location) = object_store_with_location(uri, read_only).await; let object_store_meta = parquet_object_store .head(&location) @@ -116,7 +125,8 @@ pub(crate) async fn parquet_metadata_from_uri(uri: &str) -> Arc pub(crate) async fn parquet_reader_from_uri( uri: &str, ) -> ParquetRecordBatchStream { - let (parquet_object_store, location) = object_store_with_location(uri).await; + let read_only = true; + let (parquet_object_store, location) = object_store_with_location(uri, read_only).await; let object_store_meta = parquet_object_store .head(&location) @@ -142,7 +152,8 @@ pub(crate) async fn parquet_writer_from_uri( arrow_schema: SchemaRef, writer_props: WriterProperties, ) -> AsyncArrowWriter { - let (parquet_object_store, location) = object_store_with_location(uri).await; + let read_only = false; + let (parquet_object_store, location) = object_store_with_location(uri, read_only).await; let parquet_object_writer = ParquetObjectWriter::new(parquet_object_store, location); @@ -204,3 +215,29 @@ pub async fn get_s3_object_store(bucket_name: &str) -> AmazonS3 { aws_s3_builder.build().unwrap_or_else(|e| panic!("{}", e)) } + +fn ensure_object_store_access(read_only: bool) { + if unsafe { pgrx::pg_sys::superuser() } { + return; + } + + let user_id = unsafe { pgrx::pg_sys::GetUserId() }; + + let required_role_name = if read_only { + PARQUET_OBJECT_STORE_READ_ROLE + } else { + PARQUET_OBJECT_STORE_WRITE_ROLE + }; + + let required_role_id = + unsafe { pgrx::pg_sys::get_role_oid(required_role_name.to_string().as_pg_cstr(), false) }; + + let operation_str = if read_only { "read" } else { "write" }; + + if !unsafe { pgrx::pg_sys::has_privs_of_role(user_id, required_role_id) } { + panic!( + "current user does not have the role, named {}, to {} the bucket", + required_role_name, operation_str + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index c03fa87..c9c8dc7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ use parquet_copy_hook::hook::{init_parquet_copy_hook, ENABLE_PARQUET_COPY_HOOK}; -use pg_sys::MarkGUCPrefixReserved; +use pg_sys::{AsPgCStr, MarkGUCPrefixReserved}; use pgrx::{prelude::*, GucContext, GucFlags, GucRegistry}; mod arrow_parquet; @@ -17,7 +17,8 @@ pub use crate::parquet_copy_hook::copy_to_dest_receiver::create_copy_to_parquet_ pgrx::pg_module_magic!(); -#[allow(static_mut_refs)] +extension_sql_file!("../sql/bootstrap.sql", name = "role_setup", bootstrap); + #[pg_guard] pub extern "C" fn _PG_init() { GucRegistry::define_bool_guc( @@ -29,7 +30,7 @@ pub extern "C" fn _PG_init() { GucFlags::default(), ); - unsafe { MarkGUCPrefixReserved("pg_parquet".as_ptr() as _) }; + unsafe { MarkGUCPrefixReserved("pg_parquet".as_pg_cstr()) }; init_parquet_copy_hook(); } @@ -1173,6 +1174,94 @@ mod tests { test_helper(test_table); } + #[pg_test] + #[should_panic( + expected = "current user does not have the role, named parquet_object_store_read" + )] + fn test_s3_no_read_access() { + // create regular user + Spi::run("CREATE USER regular_user;").unwrap(); + + // grant write access to the regular user but not read access + Spi::run("GRANT parquet_object_store_write TO regular_user;").unwrap(); + + // grant all permissions for public schema + Spi::run("GRANT ALL ON SCHEMA public TO regular_user;").unwrap(); + + // set the current user to the regular user + Spi::run("SET SESSION AUTHORIZATION regular_user;").unwrap(); + + dotenvy::from_path("/tmp/.env").unwrap(); + + let test_bucket_name: String = + std::env::var("AWS_S3_TEST_BUCKET").expect("AWS_S3_TEST_BUCKET not found"); + + let s3_uri = format!("s3://{}/pg_parquet_test.parquet", test_bucket_name); + + let test_table = TestTable::::new("int4".into()).with_uri(s3_uri.clone()); + + test_table.insert("INSERT INTO test_expected (a) VALUES (1), (2), (null);"); + + // can write to s3 + let copy_to_command = format!( + "COPY (SELECT a FROM generate_series(1,10) a) TO '{}';", + s3_uri + ); + Spi::run(copy_to_command.as_str()).unwrap(); + + // cannot read from s3 + let copy_from_command = format!("COPY test_expected FROM '{}';", s3_uri); + Spi::run(copy_from_command.as_str()).unwrap(); + } + + #[pg_test] + #[should_panic( + expected = "current user does not have the role, named parquet_object_store_write" + )] + fn test_s3_no_write_access() { + // create regular user + Spi::run("CREATE USER regular_user;").unwrap(); + + // grant read access to the regular user but not write access + Spi::run("GRANT parquet_object_store_read TO regular_user;").unwrap(); + + // grant usage access to parquet schema and its udfs + Spi::run("GRANT USAGE ON SCHEMA parquet TO regular_user;").unwrap(); + Spi::run("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA parquet TO regular_user;").unwrap(); + + // grant all permissions for public schema + Spi::run("GRANT ALL ON SCHEMA public TO regular_user;").unwrap(); + + // set the current user to the regular user + Spi::run("SET SESSION AUTHORIZATION regular_user;").unwrap(); + + dotenvy::from_path("/tmp/.env").unwrap(); + + let test_bucket_name: String = + std::env::var("AWS_S3_TEST_BUCKET").expect("AWS_S3_TEST_BUCKET not found"); + + let s3_uri = format!("s3://{}/pg_parquet_test.parquet", test_bucket_name); + + // can call metadata udf (requires read access) + let metadata_query = format!("SELECT parquet.metadata('{}');", s3_uri.clone()); + Spi::run(&metadata_query).unwrap(); + + let test_table = TestTable::::new("int4".into()).with_uri(s3_uri.clone()); + + test_table.insert("INSERT INTO test_expected (a) VALUES (1), (2), (null);"); + + // can read from s3 + let copy_from_command = format!("COPY test_expected FROM '{}';", s3_uri); + Spi::run(copy_from_command.as_str()).unwrap(); + + // cannot write to s3 + let copy_to_command = format!( + "COPY (SELECT a FROM generate_series(1,10) a) TO '{}';", + s3_uri + ); + Spi::run(copy_to_command.as_str()).unwrap(); + } + #[pg_test] #[should_panic(expected = "404 Not Found")] fn test_s3_object_store_write_invalid_uri() {