diff --git a/src/parquet_copy_hook.rs b/src/parquet_copy_hook.rs index 6d48c12..9dc0c8e 100644 --- a/src/parquet_copy_hook.rs +++ b/src/parquet_copy_hook.rs @@ -1 +1,3 @@ +pub(crate) mod copy_to; +pub(crate) mod copy_to_dest_receiver; pub(crate) mod copy_utils; diff --git a/src/parquet_copy_hook/copy_to.rs b/src/parquet_copy_hook/copy_to.rs new file mode 100644 index 0000000..82c14f1 --- /dev/null +++ b/src/parquet_copy_hook/copy_to.rs @@ -0,0 +1,184 @@ +use pgrx::{ + is_a, + pg_sys::{ + self, makeRangeVar, pg_analyze_and_rewrite_fixedparams, pg_plan_query, CommandTag, + CopyStmt, CreateNewPortal, DestReceiver, GetActiveSnapshot, NodeTag::T_CopyStmt, + PlannedStmt, PortalDefineQuery, PortalDrop, PortalRun, PortalStart, QueryCompletion, + RawStmt, CURSOR_OPT_PARALLEL_OK, + }, + AllocatedByRust, PgBox, PgList, PgRelation, +}; + +use crate::parquet_copy_hook::copy_utils::{ + copy_stmt_has_relation, copy_stmt_lock_mode, copy_stmt_relation_oid, +}; + +// execute_copy_to_with_dest_receiver executes a COPY TO statement with our custom DestReceiver +// for writing to Parquet files. +// - converts the table relation to a SELECT statement if necessary +// - analyzes and rewrites the raw query +// - plans the rewritten query +// - creates a portal for the planned query by using the custom DestReceiver +// - executes the query with the portal +pub(crate) fn execute_copy_to_with_dest_receiver( + pstmt: &PgBox, + query_string: &core::ffi::CStr, + params: PgBox, + query_env: PgBox, + parquet_dest: PgBox, +) -> u64 { + unsafe { + debug_assert!(is_a(pstmt.utilityStmt, T_CopyStmt)); + let copy_stmt = PgBox::::from_pg(pstmt.utilityStmt as _); + + let mut relation = PgRelation::from_pg(std::ptr::null_mut()); + + if copy_stmt_has_relation(pstmt) { + let rel_oid = copy_stmt_relation_oid(pstmt); + let lock_mode = copy_stmt_lock_mode(pstmt); + relation = PgRelation::with_lock(rel_oid, lock_mode); + } + + let raw_query = prepare_copy_to_raw_stmt(pstmt, ©_stmt, &relation); + + let rewritten_queries = pg_analyze_and_rewrite_fixedparams( + raw_query.as_ptr(), + query_string.as_ptr(), + std::ptr::null_mut(), + 0, + query_env.as_ptr(), + ); + + let query = PgList::from_pg(rewritten_queries) + .pop() + .expect("rewritten query is empty"); + + let plan = pg_plan_query( + query, + std::ptr::null(), + CURSOR_OPT_PARALLEL_OK as _, + params.as_ptr(), + ); + + let portal = CreateNewPortal(); + let mut portal = PgBox::from_pg(portal); + portal.visible = false; + + let mut plans = PgList::::new(); + plans.push(plan); + + PortalDefineQuery( + portal.as_ptr(), + std::ptr::null(), + query_string.as_ptr(), + CommandTag::CMDTAG_COPY, + plans.as_ptr(), + std::ptr::null_mut(), + ); + + PortalStart(portal.as_ptr(), params.as_ptr(), 0, GetActiveSnapshot()); + + let mut completion_tag = QueryCompletion { + commandTag: CommandTag::CMDTAG_COPY, + nprocessed: 0, + }; + + PortalRun( + portal.as_ptr(), + i64::MAX, + false, + true, + parquet_dest.as_ptr(), + parquet_dest.as_ptr(), + &mut completion_tag as _, + ); + + PortalDrop(portal.as_ptr(), false); + + completion_tag.nprocessed + } +} + +fn prepare_copy_to_raw_stmt( + pstmt: &PgBox, + copy_stmt: &PgBox, + relation: &PgRelation, +) -> PgBox { + let mut raw_query = unsafe { PgBox::::alloc_node(pg_sys::NodeTag::T_RawStmt) }; + raw_query.stmt_location = pstmt.stmt_location; + raw_query.stmt_len = pstmt.stmt_len; + + if relation.is_null() { + raw_query.stmt = copy_stmt.query; + } else { + // convert relation to query + let mut target_list = PgList::new(); + + if copy_stmt.attlist.is_null() { + // SELECT * FROM relation + let mut col_ref = + unsafe { PgBox::::alloc_node(pg_sys::NodeTag::T_ColumnRef) }; + let a_star = unsafe { PgBox::::alloc_node(pg_sys::NodeTag::T_A_Star) }; + + let mut field_list = PgList::new(); + field_list.push(a_star.into_pg()); + + col_ref.fields = field_list.into_pg(); + col_ref.location = -1; + + let mut target = + unsafe { PgBox::::alloc_node(pg_sys::NodeTag::T_ResTarget) }; + target.name = std::ptr::null_mut(); + target.indirection = std::ptr::null_mut(); + target.val = col_ref.into_pg() as _; + target.location = -1; + + target_list.push(target.into_pg()); + } else { + // SELECT a,b,... FROM relation + let attlist = unsafe { PgList::<*mut i8>::from_pg(copy_stmt.attlist) }; + for attname in attlist.iter_ptr() { + let mut col_ref = + unsafe { PgBox::::alloc_node(pg_sys::NodeTag::T_ColumnRef) }; + + let mut field_list = PgList::new(); + field_list.push(unsafe { *attname }); + + col_ref.fields = field_list.into_pg(); + col_ref.location = -1; + + let mut target = + unsafe { PgBox::::alloc_node(pg_sys::NodeTag::T_ResTarget) }; + target.name = std::ptr::null_mut(); + target.indirection = std::ptr::null_mut(); + target.val = col_ref.into_pg() as _; + target.location = -1; + + target_list.push(target.into_pg()); + } + } + + let from = unsafe { + makeRangeVar( + relation.namespace().as_ptr() as _, + relation.name().as_ptr() as _, + -1, + ) + }; + let mut from = unsafe { PgBox::from_pg(from) }; + from.inh = false; + + let mut select_stmt = + unsafe { PgBox::::alloc_node(pg_sys::NodeTag::T_SelectStmt) }; + + select_stmt.targetList = target_list.into_pg(); + + let mut from_list = PgList::new(); + from_list.push(from.into_pg()); + select_stmt.fromClause = from_list.into_pg(); + + raw_query.stmt = select_stmt.into_pg() as _; + } + + raw_query +} diff --git a/src/parquet_copy_hook/copy_to_dest_receiver.rs b/src/parquet_copy_hook/copy_to_dest_receiver.rs new file mode 100644 index 0000000..2c38f5f --- /dev/null +++ b/src/parquet_copy_hook/copy_to_dest_receiver.rs @@ -0,0 +1,293 @@ +use pg_sys::{ + slot_getallattrs, AsPgCStr, BlessTupleDesc, CommandDest, CurrentMemoryContext, Datum, + DestReceiver, HeapTupleData, List, MemoryContext, TupleDesc, TupleTableSlot, +}; +use pgrx::{prelude::*, PgList, PgMemoryContexts, PgTupleDesc}; + +use crate::{ + arrow_parquet::{ + codec::ParquetCodecOption, parquet_writer::ParquetWriterContext, + schema_visitor::parquet_schema_string_from_tupledesc, uri_utils::parse_uri, + }, + parquet_copy_hook::copy_utils::tuple_column_sizes, + pgrx_utils::collect_valid_attributes, +}; + +#[repr(C)] +struct CopyToParquetDestReceiver { + dest: DestReceiver, + uri: *mut i8, + tupledesc: TupleDesc, + tuple_count: i64, + tuples: *mut List, + natts: usize, + column_sizes: *mut i64, + codec: ParquetCodecOption, + row_group_size: i64, + per_copy_context: MemoryContext, +} + +impl CopyToParquetDestReceiver { + fn collect_tuple(&mut self, tuple: PgHeapTuple, tuple_column_sizes: Vec) { + let mut tuples = unsafe { PgList::from_pg(self.tuples) }; + tuples.push(tuple.into_pg()); + self.tuples = tuples.into_pg(); + + let column_sizes = unsafe { std::slice::from_raw_parts_mut(self.column_sizes, self.natts) }; + column_sizes + .iter_mut() + .zip(tuple_column_sizes.iter()) + .for_each(|(a, b)| *a += *b as i64); + + self.tuple_count += 1; + } + + fn reset_tuples(&mut self) { + unsafe { pg_sys::MemoryContextReset(self.per_copy_context) }; + + self.tuple_count = 0; + self.tuples = PgList::::new().into_pg(); + self.column_sizes = unsafe { + pg_sys::MemoryContextAllocZero( + self.per_copy_context, + std::mem::size_of::() * self.natts, + ) as *mut i64 + }; + } + + fn exceeds_row_group_size(&self) -> bool { + self.tuple_count >= self.row_group_size + } + + fn exceeds_max_col_size(&self, tuple_column_sizes: &[i32]) -> bool { + let column_sizes = unsafe { std::slice::from_raw_parts(self.column_sizes, self.natts) }; + column_sizes + .iter() + .zip(tuple_column_sizes) + .map(|(a, b)| *a + *b as i64) + .any(|size| size > i32::MAX as i64) + } + + fn write_tuples_to_parquet(&mut self) { + debug_assert!(!self.tupledesc.is_null()); + + let tupledesc = unsafe { PgTupleDesc::from_pg(self.tupledesc) }; + + let tuples = unsafe { PgList::from_pg(self.tuples) }; + let tuples = tuples + .iter_ptr() + .map(|tup_ptr: *mut HeapTupleData| unsafe { + if tup_ptr.is_null() { + None + } else { + let tup = PgHeapTuple::from_heap_tuple(tupledesc.clone(), tup_ptr).into_owned(); + Some(tup) + } + }) + .collect::>(); + + pgrx::debug2!( + "schema for tuples: {}", + parquet_schema_string_from_tupledesc(&tupledesc) + ); + + let current_parquet_writer_context = + peek_parquet_writer_context().expect("parquet writer context is not found"); + current_parquet_writer_context.write_new_row_group(tuples); + + self.reset_tuples(); + } + + fn cleanup(&mut self) { + unsafe { pg_sys::MemoryContextDelete(self.per_copy_context) }; + } +} + +// stack to store parquet writer contexts for COPY TO. +// This needs to be a stack since COPY command can be nested. +static mut PARQUET_WRITER_CONTEXT_STACK: Vec = vec![]; + +pub(crate) fn peek_parquet_writer_context() -> Option<&'static mut ParquetWriterContext> { + unsafe { PARQUET_WRITER_CONTEXT_STACK.last_mut() } +} + +pub(crate) fn pop_parquet_writer_context(throw_error: bool) -> Option { + let mut current_parquet_writer_context = unsafe { PARQUET_WRITER_CONTEXT_STACK.pop() }; + + if current_parquet_writer_context.is_none() { + let level = if throw_error { + PgLogLevel::ERROR + } else { + PgLogLevel::DEBUG2 + }; + + ereport!( + level, + PgSqlErrorCode::ERRCODE_INTERNAL_ERROR, + "parquet writer context stack is already empty" + ); + + None + } else { + current_parquet_writer_context.take() + } +} + +pub(crate) fn push_parquet_writer_context(writer_ctx: ParquetWriterContext) { + unsafe { PARQUET_WRITER_CONTEXT_STACK.push(writer_ctx) }; +} + +#[pg_guard] +extern "C" fn copy_startup(dest: *mut DestReceiver, _operation: i32, tupledesc: TupleDesc) { + let parquet_dest = unsafe { + (dest as *mut CopyToParquetDestReceiver) + .as_mut() + .expect("invalid parquet dest receiver ptr") + }; + + // bless tupledesc, otherwise lookup_row_tupledesc would fail for row types + let tupledesc = unsafe { BlessTupleDesc(tupledesc) }; + let tupledesc = unsafe { PgTupleDesc::from_pg(tupledesc) }; + + let include_generated_columns = true; + let attributes = collect_valid_attributes(&tupledesc, include_generated_columns); + + // update the parquet dest receiver's missing fields + parquet_dest.tupledesc = tupledesc.as_ptr(); + parquet_dest.tuples = PgList::::new().into_pg(); + parquet_dest.column_sizes = unsafe { + pg_sys::MemoryContextAllocZero( + parquet_dest.per_copy_context, + std::mem::size_of::() * attributes.len(), + ) as *mut i64 + }; + parquet_dest.natts = attributes.len(); + + let uri = unsafe { std::ffi::CStr::from_ptr(parquet_dest.uri) } + .to_str() + .expect("uri is not a valid C string"); + + let uri = parse_uri(uri); + + let codec = parquet_dest.codec; + + // parquet writer context is used throughout the COPY TO operation. + // This might be put into ParquetCopyDestReceiver, but it's hard to preserve repr(C). + let parquet_writer_context = ParquetWriterContext::new(uri, codec, &tupledesc); + push_parquet_writer_context(parquet_writer_context); +} + +#[pg_guard] +extern "C" fn copy_receive(slot: *mut TupleTableSlot, dest: *mut DestReceiver) -> bool { + let parquet_dest = unsafe { + (dest as *mut CopyToParquetDestReceiver) + .as_mut() + .expect("invalid parquet dest receiver ptr") + }; + + unsafe { + let mut per_copy_ctx = PgMemoryContexts::For(parquet_dest.per_copy_context); + + per_copy_ctx.switch_to(|_context| { + // extracts all attributes in statement "SELECT * FROM table" + slot_getallattrs(slot); + + let slot = PgBox::from_pg(slot); + + let natts = parquet_dest.natts; + + let datums = slot.tts_values; + let datums = std::slice::from_raw_parts(datums, natts); + + let nulls = slot.tts_isnull; + let nulls = std::slice::from_raw_parts(nulls, natts); + + let datums: Vec> = datums + .iter() + .zip(nulls) + .map(|(datum, is_null)| if *is_null { None } else { Some(*datum) }) + .collect(); + + let tupledesc = PgTupleDesc::from_pg(parquet_dest.tupledesc); + + let column_sizes = tuple_column_sizes(&datums, &tupledesc); + + if parquet_dest.exceeds_max_col_size(&column_sizes) { + parquet_dest.write_tuples_to_parquet(); + } + + let heap_tuple = PgHeapTuple::from_datums(tupledesc, datums) + .unwrap_or_else(|e| panic!("failed to create heap tuple from datums: {}", e)); + + parquet_dest.collect_tuple(heap_tuple, column_sizes); + + if parquet_dest.exceeds_row_group_size() { + parquet_dest.write_tuples_to_parquet(); + } + }); + }; + + true +} + +#[pg_guard] +extern "C" fn copy_shutdown(dest: *mut DestReceiver) { + let parquet_dest = unsafe { + (dest as *mut CopyToParquetDestReceiver) + .as_mut() + .expect("invalid parquet dest receiver ptr") + }; + + if parquet_dest.tuple_count > 0 { + parquet_dest.write_tuples_to_parquet(); + } + + parquet_dest.cleanup(); + + let throw_error = true; + let current_parquet_writer_context = pop_parquet_writer_context(throw_error); + current_parquet_writer_context + .expect("current parquet writer context is not found") + .close(); +} + +#[pg_guard] +extern "C" fn copy_destroy(_dest: *mut DestReceiver) {} + +#[pg_guard] +#[no_mangle] +pub extern "C" fn create_copy_to_parquet_dest_receiver( + uri: *mut i8, + row_group_size: i64, + codec: ParquetCodecOption, +) -> *mut DestReceiver { + let per_copy_context = unsafe { + pg_sys::AllocSetContextCreateExtended( + CurrentMemoryContext as _, + "ParquetCopyDestReceiver".as_pg_cstr(), + pg_sys::ALLOCSET_DEFAULT_MINSIZE as _, + pg_sys::ALLOCSET_DEFAULT_INITSIZE as _, + pg_sys::ALLOCSET_DEFAULT_MAXSIZE as _, + ) + }; + + let mut parquet_dest = + unsafe { PgBox::::alloc0() }; + + parquet_dest.dest.receiveSlot = Some(copy_receive); + parquet_dest.dest.rStartup = Some(copy_startup); + parquet_dest.dest.rShutdown = Some(copy_shutdown); + parquet_dest.dest.rDestroy = Some(copy_destroy); + parquet_dest.dest.mydest = CommandDest::DestCopyOut; + parquet_dest.uri = uri; + parquet_dest.tupledesc = std::ptr::null_mut(); + parquet_dest.natts = 0; + parquet_dest.tuple_count = 0; + parquet_dest.tuples = std::ptr::null_mut(); + parquet_dest.column_sizes = std::ptr::null_mut(); + parquet_dest.row_group_size = row_group_size; + parquet_dest.codec = codec; + parquet_dest.per_copy_context = per_copy_context; + + unsafe { std::mem::transmute(parquet_dest) } +}