From 014bb7f50f4c4a2214f3db32dcae9122aa1ba54e Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Wed, 27 Dec 2023 22:05:40 +0800 Subject: [PATCH] fix unexpected off-heap memory overflow: 1. use smaller batch size for parquet scaning. 2. reduce native-to-spark ffi channel buffer size. 3. shorten batch lifetime in project-filtering and batch coalescing. 4. other minor code refection. --- native-engine/blaze-jni-bridge/Cargo.toml | 2 +- .../blaze-jni-bridge/src/jni_bridge.rs | 4 +- native-engine/blaze/src/alloc.rs | 76 ++++++++++++ native-engine/blaze/src/exec.rs | 13 +- native-engine/blaze/src/lib.rs | 14 +-- native-engine/blaze/src/rt.rs | 94 +++++++-------- .../datafusion-ext-commons/src/cast.rs | 8 +- .../src/io/batch_serde.rs | 25 ++-- .../datafusion-ext-commons/src/lib.rs | 19 +++ .../datafusion-ext-commons/src/spark_hash.rs | 53 +++----- .../src/streams/coalesce_stream.rs | 35 ++++-- .../src/get_indexed_field.rs | 38 +++--- .../datafusion-ext-exprs/src/get_map_value.rs | 113 ++++++++++++------ .../datafusion-ext-exprs/src/named_struct.rs | 8 +- .../src/spark_udf_wrapper.rs | 10 +- .../src/string_contains.rs | 8 +- .../src/string_ends_with.rs | 8 +- .../src/string_starts_with.rs | 8 +- .../datafusion-ext-functions/src/lib.rs | 12 +- .../src/spark_make_array.rs | 12 +- .../src/spark_null_if_zero.rs | 8 +- .../src/spark_strings.rs | 43 ++----- .../src/agg/agg_tables.rs | 8 +- .../datafusion-ext-plans/src/agg/avg.rs | 7 +- .../datafusion-ext-plans/src/agg/maxmin.rs | 24 +--- .../datafusion-ext-plans/src/agg/mod.rs | 7 +- .../datafusion-ext-plans/src/agg/sum.rs | 24 +--- .../src/broadcast_join_exec.rs | 7 +- .../datafusion-ext-plans/src/common/output.rs | 62 +++++----- .../datafusion-ext-plans/src/debug_exec.rs | 9 +- .../src/empty_partitions_exec.rs | 9 +- .../datafusion-ext-plans/src/expand_exec.rs | 14 +-- .../src/ffi_reader_exec.rs | 15 ++- .../datafusion-ext-plans/src/filter_exec.rs | 12 +- .../datafusion-ext-plans/src/generate/mod.rs | 13 +- .../src/ipc_reader_exec.rs | 11 +- .../src/ipc_writer_exec.rs | 9 +- .../datafusion-ext-plans/src/limit_exec.rs | 9 +- .../datafusion-ext-plans/src/parquet_exec.rs | 26 ++-- .../src/parquet_sink_exec.rs | 16 +-- .../datafusion-ext-plans/src/project_exec.rs | 2 + .../src/rename_columns_exec.rs | 15 +-- .../src/shuffle/bucket_repartitioner.rs | 3 +- .../src/shuffle/sort_repartitioner.rs | 4 +- .../src/shuffle_writer_exec.rs | 7 +- .../datafusion-ext-plans/src/sort_exec.rs | 7 +- .../src/sort_merge_join_exec.rs | 35 +++--- .../sql/blaze/BlazeCallNativeWrapper.scala | 76 +++++++----- .../ArrowFFIStreamImportIterator.scala | 107 ----------------- 49 files changed, 522 insertions(+), 627 deletions(-) create mode 100644 native-engine/blaze/src/alloc.rs delete mode 100644 spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIStreamImportIterator.scala diff --git a/native-engine/blaze-jni-bridge/Cargo.toml b/native-engine/blaze-jni-bridge/Cargo.toml index 4834a207b..7f2f42557 100644 --- a/native-engine/blaze-jni-bridge/Cargo.toml +++ b/native-engine/blaze-jni-bridge/Cargo.toml @@ -8,5 +8,5 @@ resolver = "1" datafusion = { workspace = true } jni = "0.20.0" log = "0.4.14" -once_cell = "1.19.0" +once_cell = "1.11.0" paste = "1.0.7" diff --git a/native-engine/blaze-jni-bridge/src/jni_bridge.rs b/native-engine/blaze-jni-bridge/src/jni_bridge.rs index 9c48cbc60..3917d3ee6 100644 --- a/native-engine/blaze-jni-bridge/src/jni_bridge.rs +++ b/native-engine/blaze-jni-bridge/src/jni_bridge.rs @@ -358,8 +358,8 @@ macro_rules! jni_throw { #[macro_export] macro_rules! jni_fatal_error { - ($value:expr) => {{ - $crate::jni_bridge::THREAD_JNIENV.with(|env| env.fatal_error($value)) + ($($arg:tt)*) => {{ + $crate::jni_bridge::THREAD_JNIENV.with(|env| env.fatal_error(format!($($arg)*))) }}; } diff --git a/native-engine/blaze/src/alloc.rs b/native-engine/blaze/src/alloc.rs new file mode 100644 index 000000000..aa5772a4a --- /dev/null +++ b/native-engine/blaze/src/alloc.rs @@ -0,0 +1,76 @@ +// #[global_allocator] +// static GLOBAL: jemallocator::Jemalloc = jemallocator::Jemalloc; + +use std::{ + alloc::{GlobalAlloc, Layout}, + sync::{ + atomic::{AtomicUsize, Ordering::SeqCst}, + Mutex, + }, +}; + +#[global_allocator] +static GLOBAL: jemallocator::Jemalloc = jemallocator::Jemalloc; + +// only used for debugging +// +// #[global_allocator] +// static GLOBAL: DebugAlloc = +// DebugAlloc::new(jemallocator::Jemalloc); + +#[allow(unused)] +struct DebugAlloc { + inner: T, + last_updated: AtomicUsize, + current: AtomicUsize, + mutex: Mutex<()>, +} + +#[allow(unused)] +impl DebugAlloc { + pub const fn new(inner: T) -> Self { + Self { + inner, + last_updated: AtomicUsize::new(0), + current: AtomicUsize::new(0), + mutex: Mutex::new(()), + } + } + + fn update(&self) { + let _lock = self.mutex.lock().unwrap(); + let current = self.current.load(SeqCst); + let last_updated = self.last_updated.load(SeqCst); + let delta = (current as isize - last_updated as isize).abs(); + if delta > 104857600 { + eprintln!(" * ALLOC {} -> {}", last_updated, current); + self.last_updated.store(current, SeqCst); + } + } +} + +unsafe impl GlobalAlloc for DebugAlloc { + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + self.current.fetch_add(layout.size(), SeqCst); + self.update(); + self.inner.alloc(layout) + } + + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + self.current.fetch_sub(layout.size(), SeqCst); + self.update(); + self.inner.dealloc(ptr, layout) + } + + unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { + self.current.fetch_add(layout.size(), SeqCst); + self.update(); + self.inner.alloc_zeroed(layout) + } + + unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { + self.current.fetch_add(new_size - layout.size(), SeqCst); + self.update(); + self.inner.realloc(ptr, layout, new_size) + } +} diff --git a/native-engine/blaze/src/exec.rs b/native-engine/blaze/src/exec.rs index cab2bcc32..07ecc9131 100644 --- a/native-engine/blaze/src/exec.rs +++ b/native-engine/blaze/src/exec.rs @@ -30,6 +30,7 @@ use datafusion::{ physical_plan::{displayable, ExecutionPlan}, prelude::{SessionConfig, SessionContext}, }; +use datafusion_ext_commons::df_execution_err; use datafusion_ext_plans::memmgr::MemManager; use jni::{ objects::{JClass, JObject}, @@ -87,22 +88,22 @@ pub extern "system" fn Java_org_apache_spark_sql_blaze_JniBridge_callNative( let task_definition = TaskDefinition::decode( jni_convert_byte_array!(raw_task_definition.as_obj())?.as_slice(), ) - .map_err(|err| DataFusionError::Plan(format!("cannot decode execution plan: {:?}", err)))?; + .or_else(|err| df_execution_err!("cannot decode execution plan: {err:?}"))?; let task_id = &task_definition.task_id.expect("task_id is empty"); let plan = &task_definition.plan.expect("plan is empty"); drop(raw_task_definition); // get execution plan - let execution_plan: Arc = plan.try_into().map_err(|err| { - DataFusionError::Plan(format!("cannot create execution plan: {:?}", err)) - })?; + let execution_plan: Arc = plan + .try_into() + .or_else(|err| df_execution_err!("cannot create execution plan: {err:?}"))?; let execution_plan_displayable = displayable(execution_plan.as_ref()) .indent(true) .to_string(); log::info!("Creating native execution plan succeeded"); - log::info!(" task_id={:?}", task_id); - log::info!(" execution plan:\n{}", execution_plan_displayable); + log::info!(" task_id={task_id:?}"); + log::info!(" execution plan:\n{execution_plan_displayable}"); // execute to stream let runtime = Box::new(NativeExecutionRuntime::start( diff --git a/native-engine/blaze/src/lib.rs b/native-engine/blaze/src/lib.rs index 0f204fe46..0a8707026 100644 --- a/native-engine/blaze/src/lib.rs +++ b/native-engine/blaze/src/lib.rs @@ -17,14 +17,12 @@ use std::{any::Any, error::Error, fmt::Debug, panic::AssertUnwindSafe}; use blaze_jni_bridge::*; use jni::objects::{JObject, JThrowable}; +mod alloc; mod exec; mod logging; mod metrics; mod rt; -#[global_allocator] -static GLOBAL: jemallocator::Jemalloc = jemallocator::Jemalloc; - fn handle_unwinded(err: Box) { // default handling: // * caused by Interrupted/TaskKilled: do nothing but just print a message. @@ -48,10 +46,7 @@ fn handle_unwinded(err: Box) { Ok(()) }; recover().unwrap_or_else(|err: Box| { - jni_fatal_error!(format!( - "Error recovering from panic, cannot resume: {:?}", - err - )); + jni_fatal_error!("Error recovering from panic, cannot resume: {err:?}"); }); } @@ -70,10 +65,7 @@ fn throw_runtime_exception(msg: &str, cause: JObject) -> datafusion::error::Resu let e = jni_new_object!(JavaRuntimeException(msg.as_obj(), cause))?; if let Err(err) = jni_throw!(JThrowable::from(e.as_obj())) { - jni_fatal_error!(format!( - "Error throwing RuntimeException, cannot result: {:?}", - err - )); + jni_fatal_error!("Error throwing RuntimeException, cannot result: {err:?}"); } Ok(()) } diff --git a/native-engine/blaze/src/rt.rs b/native-engine/blaze/src/rt.rs index d2c3fb07f..09cb0d743 100644 --- a/native-engine/blaze/src/rt.rs +++ b/native-engine/blaze/src/rt.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::{ + error::Error, panic::AssertUnwindSafe, sync::{mpsc::Receiver, Arc}, }; @@ -35,8 +36,8 @@ use datafusion::{ ExecutionPlan, }, }; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceInput; -use datafusion_ext_plans::common::output::WrappedRecordBatchSender; +use datafusion_ext_commons::{df_execution_err, streams::coalesce_stream::CoalesceInput}; +use datafusion_ext_plans::common::output::TaskOutputter; use futures::{FutureExt, StreamExt}; use jni::objects::{GlobalRef, JObject}; use tokio::runtime::Runtime; @@ -48,7 +49,6 @@ pub struct NativeExecutionRuntime { plan: Arc, task_context: Arc, partition: usize, - ffi_schema: Arc, batch_receiver: Receiver>>, rt: Runtime, } @@ -71,9 +71,9 @@ impl NativeExecutionRuntime { )?; // init ffi schema - let ffi_schema = Arc::new(FFI_ArrowSchema::try_from(schema.as_ref())?); + let ffi_schema = FFI_ArrowSchema::try_from(schema.as_ref())?; jni_call!(BlazeCallNativeWrapper(native_wrapper.as_obj()) - .importSchema(ffi_schema.as_ref() as *const FFI_ArrowSchema as i64) -> () + .importSchema(&ffi_schema as *const FFI_ArrowSchema as i64) -> () )?; // create tokio runtime @@ -92,60 +92,53 @@ impl NativeExecutionRuntime { }) .build()?; - let (batch_sender, batch_receiver) = std::sync::mpsc::sync_channel(2); + let (batch_sender, batch_receiver) = std::sync::mpsc::sync_channel(0); let nrt = Self { native_wrapper: native_wrapper.clone(), plan, partition, rt, - ffi_schema, batch_receiver, task_context: context, }; // spawn batch producer - let batch_sender_cloned = batch_sender.clone(); - let consume_stream = move || async move { + let err_sender = batch_sender.clone(); + let consume_stream = async move { while let Some(batch) = AssertUnwindSafe(stream.next()) .catch_unwind() .await .unwrap_or_else(|err| { let panic_message = panic_message::get_panic_message(&err).unwrap_or("unknown error"); - Some(Err(DataFusionError::Execution(panic_message.to_owned()))) + Some(df_execution_err!("{}", panic_message)) }) .transpose() - .map_err(|err| DataFusionError::Execution(format!("{}", err)))? + .or_else(|err| df_execution_err!("{err}"))? { - batch_sender.send(Ok(Some(batch))).map_err(|err| { - DataFusionError::Execution(format!("send batch error: {err}")) - })?; + batch_sender + .send(Ok(Some(batch))) + .or_else(|err| df_execution_err!("send batch error: {err}"))?; } batch_sender .send(Ok(None)) - .map_err(|err| DataFusionError::Execution(format!("send batch error: {err}")))?; - + .or_else(|err| df_execution_err!("send batch error: {err}"))?; log::info!("[partition={partition}] finished"); Ok::<_, DataFusionError>(()) }; nrt.rt.spawn(async move { - let result = consume_stream().await; - result.unwrap_or_else(|err| handle_unwinded_scope(|| -> Result<()> { - batch_sender_cloned.send( - Err(DataFusionError::Execution(format!("execution aborted"))) - ).map_err(|err| { - DataFusionError::Execution(format!("send batch error: {err}")) - })?; - let task_running = is_task_running(); - if !task_running { - log::warn!( - "[partition={partition}] task completed/interrupted before native execution done", - ); - return Ok(()); - } - - let cause = - if jni_exception_check!()? { + consume_stream.await.unwrap_or_else(|err| { + handle_unwinded_scope(|| { + let task_running = is_task_running(); + if !task_running { + log::warn!( + "[partition={partition}] task completed before native execution done" + ); + return Ok(()); + } + + err_sender.send(df_execution_err!("execution aborted"))?; + let cause = if jni_exception_check!()? { log::error!("[partition={partition}] panics with an java exception: {err}"); Some(jni_exception_occurred!()?) } else { @@ -153,23 +146,26 @@ impl NativeExecutionRuntime { None }; - set_error( - &native_wrapper, - &format!("[partition={partition}] panics: {err}"), - cause.map(|e| e.as_obj()), - )?; - log::info!("[partition={partition}] exited abnormally."); - Ok::<_, DataFusionError>(()) - })); + set_error( + &native_wrapper, + &format!("[partition={partition}] panics: {err}"), + cause.map(|e| e.as_obj()), + )?; + log::info!("[partition={partition}] exited abnormally."); + Ok::<_, Box>(()) + }) + }); }); Ok(nrt) } pub fn next_batch(&self) -> bool { let next_batch = || -> Result { - match self.batch_receiver.recv().map_err(|err| { - DataFusionError::Execution(format!("receive batch error: {err}")) - })?? { + match self + .batch_receiver + .recv() + .or_else(|err| df_execution_err!("receive batch error: {err}"))?? + { Some(batch) => { let ffi_array = FFI_ArrowArray::new(&StructArray::from(batch).into_data()); jni_call!(BlazeCallNativeWrapper(self.native_wrapper.as_obj()) @@ -197,14 +193,10 @@ impl NativeExecutionRuntime { pub fn finalize(self) { log::info!("native execution [partition={}] finalizing", self.partition); - let _ = self.update_metrics(); - log::info!("native execution [partition={}] 1", self.partition); - drop(self.ffi_schema); - log::info!("native execution [partition={}] 2", self.partition); + self.update_metrics().unwrap_or_default(); drop(self.plan); - log::info!("native execution [partition={}] 3", self.partition); - WrappedRecordBatchSender::cancel_task(&self.task_context); // cancel all pending streams - log::info!("native execution [partition={}] 4", self.partition); + + self.task_context.cancel_task(); // cancel all pending streams self.rt.shutdown_background(); log::info!("native execution [partition={}] finalized", self.partition); } diff --git a/native-engine/datafusion-ext-commons/src/cast.rs b/native-engine/datafusion-ext-commons/src/cast.rs index e347539a6..2a32d78ab 100644 --- a/native-engine/datafusion-ext-commons/src/cast.rs +++ b/native-engine/datafusion-ext-commons/src/cast.rs @@ -18,11 +18,13 @@ use arrow::{array::*, datatypes::*}; use bigdecimal::{FromPrimitive, ToPrimitive}; use datafusion::common::{ cast::{as_float32_array, as_float64_array}, - DataFusionError, Result, + Result, }; use num::{cast::AsPrimitive, Bounded, Integer, Signed}; use paste::paste; +use crate::df_execution_err; + pub fn cast(array: &dyn Array, cast_type: &DataType) -> Result { return cast_impl(array, cast_type, false); } @@ -109,9 +111,7 @@ pub fn cast_impl( if !match_struct_fields { if to_fields.len() != struct_.num_columns() { - return Err(DataFusionError::Execution( - "cannot cast structs with different numbers of fields".to_string(), - )); + df_execution_err!("cannot cast structs with different numbers of fields")?; } let casted_arrays = struct_ diff --git a/native-engine/datafusion-ext-commons/src/io/batch_serde.rs b/native-engine/datafusion-ext-commons/src/io/batch_serde.rs index aa85ccb22..5957deee3 100644 --- a/native-engine/datafusion-ext-commons/src/io/batch_serde.rs +++ b/native-engine/datafusion-ext-commons/src/io/batch_serde.rs @@ -27,9 +27,12 @@ use arrow::{ record_batch::{RecordBatch, RecordBatchOptions}, }; use bitvec::prelude::BitVec; -use datafusion::common::{DataFusionError, Result}; +use datafusion::common::Result; -use crate::io::{read_bytes_slice, read_len, write_len}; +use crate::{ + df_execution_err, df_unimplemented_err, + io::{read_bytes_slice, read_len, write_len}, +}; pub fn write_batch( batch: &RecordBatch, @@ -201,12 +204,7 @@ pub fn write_array(array: &dyn Array, output: &mut W) -> Result<()> { DataType::List(_field) => write_list_array(as_list_array(array), output)?, DataType::Map(..) => write_map_array(as_map_array(array), output)?, DataType::Struct(_) => write_struct_array(as_struct_array(array), output)?, - other => { - return Err(DataFusionError::NotImplemented(format!( - "unsupported data type: {}", - other - ))); - } + other => df_unimplemented_err!("unsupported data type: {other}")?, } Ok(()) } @@ -252,12 +250,7 @@ pub fn read_array( read_map_array(num_rows, input, map_field, *is_sorted)? } DataType::Struct(fields) => read_struct_array(num_rows, input, fields)?, - other => { - return Err(DataFusionError::NotImplemented(format!( - "unsupported data type: {}", - other - ))); - } + other => df_unimplemented_err!("unsupported data type: {other}")?, }) } @@ -309,7 +302,7 @@ fn nameless_data_type(data_type: &DataType) -> DataType { pub fn write_data_type(data_type: &DataType, output: &mut W) -> Result<()> { let buf = postcard::to_allocvec(&nameless_data_type(data_type)) - .map_err(|err| DataFusionError::Execution(format!("serialize data type error: {err}")))?; + .or_else(|err| df_execution_err!("serialize data type error: {err}"))?; write_len(buf.len(), output)?; output.write_all(&buf)?; Ok(()) @@ -319,7 +312,7 @@ pub fn read_data_type(input: &mut R) -> Result { let buf_len = read_len(input)?; let buf = read_bytes_slice(input, buf_len)?; let data_type = postcard::from_bytes(&buf) - .map_err(|err| DataFusionError::Execution(format!("deserialize data type error: {err}")))?; + .or_else(|err| df_execution_err!("deserialize data type error: {err}"))?; Ok(data_type) } diff --git a/native-engine/datafusion-ext-commons/src/lib.rs b/native-engine/datafusion-ext-commons/src/lib.rs index 0c94cf42c..3229382d3 100644 --- a/native-engine/datafusion-ext-commons/src/lib.rs +++ b/native-engine/datafusion-ext-commons/src/lib.rs @@ -27,3 +27,22 @@ pub mod slim_bytes; pub mod spark_hash; pub mod streams; pub mod uda; + +#[macro_export] +macro_rules! df_execution_err { + ($($arg:tt)*) => { + Err(datafusion::common::DataFusionError::Execution(format!($($arg)*))) + } +} +#[macro_export] +macro_rules! df_unimplemented_err { + ($($arg:tt)*) => { + Err(datafusion::common::DataFusionError::NotImplemented(format!($($arg)*))) + } +} +#[macro_export] +macro_rules! df_external_err { + ($($arg:tt)*) => { + Err(datafusion::common::DataFusionError::External(format!($($arg)*))) + } +} diff --git a/native-engine/datafusion-ext-commons/src/spark_hash.rs b/native-engine/datafusion-ext-commons/src/spark_hash.rs index af5626e54..6a76bb953 100644 --- a/native-engine/datafusion-ext-commons/src/spark_hash.rs +++ b/native-engine/datafusion-ext-commons/src/spark_hash.rs @@ -23,7 +23,9 @@ use arrow::{ Int8Type, TimeUnit, }, }; -use datafusion::error::{DataFusionError, Result}; +use datafusion::error::Result; + +use crate::df_execution_err; #[inline] fn spark_compatible_murmur3_hash>(data: T, seed: u32) -> u32 { @@ -170,14 +172,14 @@ fn create_hashes_dictionary( for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { if let Some(key) = key { - let idx = key.to_usize().ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert key value {:?} to usize in dictionary of type {:?}", - key, - dict_array.data_type() - )) - })?; - *hash = dict_hashes[idx] + if let Some(idx) = key.to_usize() { + *hash = dict_hashes[idx]; + } else { + let dt = dict_array.data_type(); + df_execution_err!( + "Can not convert key value {key:?} to usize in dictionary of type {dt:?}" + )?; + } } // no update for Null, consistent with other hashes } Ok(()) @@ -269,25 +271,12 @@ fn hash_array(array: &ArrayRef, hashes_buffer: &mut [u32]) -> Result<()> { DataType::Decimal128(..) => { hash_array_decimal!(Decimal128Array, array, hashes_buffer); } - DataType::Dictionary(index_type, _) => match **index_type { - DataType::Int8 => { - create_hashes_dictionary::(array, hashes_buffer)?; - } - DataType::Int16 => { - create_hashes_dictionary::(array, hashes_buffer)?; - } - DataType::Int32 => { - create_hashes_dictionary::(array, hashes_buffer)?; - } - DataType::Int64 => { - create_hashes_dictionary::(array, hashes_buffer)?; - } - _ => { - return Err(DataFusionError::Internal(format!( - "Unsupported dictionary type in hasher hashing: {}", - array.data_type(), - ))) - } + DataType::Dictionary(index_type, _) => match &**index_type { + DataType::Int8 => create_hashes_dictionary::(array, hashes_buffer)?, + DataType::Int16 => create_hashes_dictionary::(array, hashes_buffer)?, + DataType::Int32 => create_hashes_dictionary::(array, hashes_buffer)?, + DataType::Int64 => create_hashes_dictionary::(array, hashes_buffer)?, + other => df_execution_err!("Unsupported dictionary type in hasher hashing: {other}")?, }, _ => { for idx in 0..array.len() { @@ -407,13 +396,7 @@ fn hash_one(col: &ArrayRef, idx: usize, hash: &mut u32) -> Result<()> { hash_one(col, idx, hash)?; } } - _ => { - // This is internal because we should have caught this before. - return Err(DataFusionError::Internal(format!( - "Unsupported data type in hasher: {}", - col.data_type() - ))); - } + other => df_execution_err!("Unsupported data type in hasher: {other}")?, } } Ok(()) diff --git a/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs b/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs index 2b612db18..d96f55554 100644 --- a/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs +++ b/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs @@ -18,19 +18,21 @@ use std::{ task::{ready, Context, Poll}, }; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::{ + datatypes::SchemaRef, + record_batch::{RecordBatch, RecordBatchOptions}, +}; use datafusion::{ common::Result, execution::TaskContext, physical_plan::{ - coalesce_batches::concat_batches, metrics::{BaselineMetrics, Time}, RecordBatchStream, SendableRecordBatchStream, }, }; use futures::{Stream, StreamExt}; -const STAGING_BATCHES_MEM_SIZE_LIMIT: usize = 1 << 26; // limit output batch size to 64MB +const STAGING_BATCHES_MEM_SIZE_LIMIT: usize = 1 << 25; // limit output batch size to 32MB pub trait CoalesceInput { fn coalesce_input( @@ -92,14 +94,31 @@ impl CoalesceStream { } fn coalesce(&mut self) -> Result { - let coalesced = concat_batches( - &self.schema(), - &std::mem::take(&mut self.staging_batches), - self.staging_rows, + // better concat_batches() implementation that releases old batch columns asap. + let schema = self.input.schema(); + + // collect all columns + let mut all_cols = schema.fields().iter().map(|_| vec![]).collect::>(); + for batch in std::mem::take(&mut self.staging_batches) { + for i in 0..all_cols.len() { + all_cols[i].push(batch.column(i).clone()); + } + } + + // coalesce each column + let mut coalesced_cols = vec![]; + for cols in all_cols { + let ref_cols = cols.iter().map(|col| col.as_ref()).collect::>(); + coalesced_cols.push(arrow::compute::concat(&ref_cols)?); + } + let coalesced_batch = RecordBatch::try_new_with_options( + schema, + coalesced_cols, + &RecordBatchOptions::new().with_row_count(Some(self.staging_rows)), )?; self.staging_rows = 0; self.staging_batches_mem_size = 0; - Ok(coalesced) + Ok(coalesced_batch) } fn should_flush(&self) -> bool { diff --git a/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs b/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs index 531d3fcaf..4efd6d503 100644 --- a/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs +++ b/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs @@ -24,11 +24,12 @@ use arrow::{array::*, compute::*, datatypes::*, record_batch::RecordBatch}; use datafusion::{ common::{ cast::{as_list_array, as_struct_array}, - DataFusionError, Result, ScalarValue, + Result, ScalarValue, }, logical_expr::ColumnarValue, physical_expr::PhysicalExpr, }; +use datafusion_ext_commons::df_execution_err; use crate::down_cast_any_ref; @@ -117,18 +118,18 @@ impl PhysicalExpr for GetIndexedFieldExpr { as_struct_array.column(*k as usize).clone(), )) } - (DataType::List(_), key) => Err(DataFusionError::Execution(format!( + (DataType::List(_), key) => df_execution_err!( "get indexed field is only possible on lists with int64 indexes. \ Tried with {key:?} index" - ))), - (DataType::Struct(_), key) => Err(DataFusionError::Execution(format!( + ), + (DataType::Struct(_), key) => df_execution_err!( "get indexed field is only possible on struct with int32 indexes. \ Tried with {key:?} index" - ))), - (dt, key) => Err(DataFusionError::Execution(format!( + ), + (dt, key) => df_execution_err!( "get indexed field is only possible on lists with int64 indexes or struct \ with utf8 indexes. Tried {dt:?} with {key:?} index" - ))), + ), } } @@ -169,22 +170,19 @@ fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { (DataType::Struct(fields), ScalarValue::Int32(Some(k))) => { let field = fields.get(*k as usize); match field { - None => Err(DataFusionError::Plan(format!( - "Field {k} not found in struct" - ))), + None => df_execution_err!("Field {k} not found in struct"), Some(f) => Ok(f.as_ref().clone()), } } - (DataType::Struct(_), _) => Err(DataFusionError::Plan( - "Only ints are valid as an indexed field in a struct".to_string(), - )), - (DataType::List(_), _) => Err(DataFusionError::Plan( - "Only ints are valid as an indexed field in a list".to_string(), - )), - _ => Err(DataFusionError::Plan( - "The expression to get an indexed field is only valid for List or Struct types" - .to_string(), - )), + (DataType::Struct(_), _) => { + df_execution_err!("Only ints are valid as an indexed field in a struct",) + } + (DataType::List(_), _) => { + df_execution_err!("Only ints are valid as an indexed field in a list",) + } + _ => df_execution_err!( + "The expression to get an indexed field is only valid for List or Struct types", + ), } } diff --git a/native-engine/datafusion-ext-exprs/src/get_map_value.rs b/native-engine/datafusion-ext-exprs/src/get_map_value.rs index 0ac5e29db..5a0913ee6 100644 --- a/native-engine/datafusion-ext-exprs/src/get_map_value.rs +++ b/native-engine/datafusion-ext-exprs/src/get_map_value.rs @@ -26,10 +26,11 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion::{ - common::{DataFusionError, Result, ScalarValue}, + common::{Result, ScalarValue}, logical_expr::ColumnarValue, physical_expr::PhysicalExpr, }; +use datafusion_ext_commons::{df_execution_err, df_unimplemented_err}; use crate::down_cast_any_ref; @@ -87,23 +88,34 @@ impl PhysicalExpr for GetMapValueExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let array = self.arg.evaluate(batch)?.into_array(1); match (array.data_type(), &self.key) { - (DataType::Map(_, _), _) if self.key.is_null() => { - Err(DataFusionError::NotImplemented("map key not support Null Type".to_string())) + (DataType::Map(..), _) if self.key.is_null() => { + df_unimplemented_err!("map key not support Null Type") } - (DataType::Map(_, _), _) => { + (DataType::Map(..), _) => { let as_map_array = array.as_any().downcast_ref::().unwrap(); - if !as_map_array.key_type().equals_datatype(&self.key.get_datatype()) { - return Err(DataFusionError::Execution("MapArray key type must equal to GetMapValue key type".to_string())) + if !as_map_array + .key_type() + .equals_datatype(&self.key.get_datatype()) + { + df_execution_err!("MapArray key type must equal to GetMapValue key type")?; } macro_rules! get_boolean_value { ($keyarrowty:ident, $scalar:expr) => {{ type A = paste::paste! {[< $keyarrowty Array >]}; - let key_array = as_map_array.keys().as_any().downcast_ref::().unwrap(); + let key_array = as_map_array.keys().as_any().downcast_ref::().unwrap(); let ans_boolean = eq_dyn_bool_scalar(key_array, $scalar)?; - let ans_index = ans_boolean.iter().enumerate() - .filter(|(_, ans)| if let Some(res) = ans { res.clone() } else { false }) - .map(|(idx, _)|idx as i32) + let ans_index = ans_boolean + .iter() + .enumerate() + .filter(|(_, ans)| { + if let Some(res) = ans { + res.clone() + } else { + false + } + }) + .map(|(idx, _)| idx as i32) .collect::>(); let mut indices = vec![]; if ans_index.len() == 0 { @@ -124,7 +136,8 @@ impl PhysicalExpr for GetMapValueExpr { } } let indice_array = UInt32Array::from(indices); - let ans_array = arrow::compute::take(as_map_array.values(), &indice_array, None)?; + let ans_array = + arrow::compute::take(as_map_array.values(), &indice_array, None)?; Ok(ColumnarValue::Array(ans_array)) }}; } @@ -132,11 +145,19 @@ impl PhysicalExpr for GetMapValueExpr { macro_rules! get_prim_value { ($keyarrowty:ident, $scalar:expr) => {{ type A = paste::paste! {[< $keyarrowty Array >]}; - let key_array = as_map_array.keys().as_any().downcast_ref::().unwrap(); + let key_array = as_map_array.keys().as_any().downcast_ref::().unwrap(); let ans_boolean = eq_dyn_scalar(key_array, $scalar)?; - let ans_index = ans_boolean.iter().enumerate() - .filter(|(_, ans)| if let Some(res) = ans { res.clone() } else { false }) - .map(|(idx, _)|idx as i32) + let ans_index = ans_boolean + .iter() + .enumerate() + .filter(|(_, ans)| { + if let Some(res) = ans { + res.clone() + } else { + false + } + }) + .map(|(idx, _)| idx as i32) .collect::>(); let mut indices = vec![]; if ans_index.len() == 0 { @@ -157,7 +178,8 @@ impl PhysicalExpr for GetMapValueExpr { } } let indice_array = UInt32Array::from(indices); - let ans_array = arrow::compute::take(as_map_array.values(), &indice_array, None)?; + let ans_array = + arrow::compute::take(as_map_array.values(), &indice_array, None)?; Ok(ColumnarValue::Array(ans_array)) }}; } @@ -165,11 +187,19 @@ impl PhysicalExpr for GetMapValueExpr { macro_rules! get_str_value { ($keyarrowty:ident, $scalar:expr) => {{ type A = paste::paste! {[< $keyarrowty Array >]}; - let key_array = as_map_array.keys().as_any().downcast_ref::().unwrap(); + let key_array = as_map_array.keys().as_any().downcast_ref::().unwrap(); let ans_boolean = eq_dyn_utf8_scalar(key_array, $scalar)?; - let ans_index = ans_boolean.iter().enumerate() - .filter(|(_, ans)| if let Some(res) = ans { res.clone() } else { false }) - .map(|(idx, _)|idx as i32) + let ans_index = ans_boolean + .iter() + .enumerate() + .filter(|(_, ans)| { + if let Some(res) = ans { + res.clone() + } else { + false + } + }) + .map(|(idx, _)| idx as i32) .collect::>(); let mut indices = vec![]; if ans_index.len() == 0 { @@ -190,7 +220,8 @@ impl PhysicalExpr for GetMapValueExpr { } } let indice_array = UInt32Array::from(indices); - let ans_array = arrow::compute::take(as_map_array.values(), &indice_array, None)?; + let ans_array = + arrow::compute::take(as_map_array.values(), &indice_array, None)?; Ok(ColumnarValue::Array(ans_array)) }}; } @@ -198,11 +229,19 @@ impl PhysicalExpr for GetMapValueExpr { macro_rules! get_binary_value { ($keyarrowty:ident, $scalar:expr) => {{ type A = paste::paste! {[< $keyarrowty Array >]}; - let key_array = as_map_array.keys().as_any().downcast_ref::().unwrap(); + let key_array = as_map_array.keys().as_any().downcast_ref::().unwrap(); let ans_boolean = eq_dyn_binary_scalar(key_array, $scalar)?; - let ans_index = ans_boolean.iter().enumerate() - .filter(|(_, ans)| if let Some(res) = ans { res.clone() } else { false }) - .map(|(idx, _)|idx as i32) + let ans_index = ans_boolean + .iter() + .enumerate() + .filter(|(_, ans)| { + if let Some(res) = ans { + res.clone() + } else { + false + } + }) + .map(|(idx, _)| idx as i32) .collect::>(); let mut indices = vec![]; if ans_index.len() == 0 { @@ -223,7 +262,8 @@ impl PhysicalExpr for GetMapValueExpr { } } let indice_array = UInt32Array::from(indices); - let ans_array = arrow::compute::take(as_map_array.values(), &indice_array, None)?; + let ans_array = + arrow::compute::take(as_map_array.values(), &indice_array, None)?; Ok(ColumnarValue::Array(ans_array)) }}; } @@ -243,16 +283,15 @@ impl PhysicalExpr for GetMapValueExpr { ScalarValue::Utf8(Some(i)) => get_str_value!(String, i.as_str()), ScalarValue::LargeUtf8(Some(i)) => get_str_value!(LargeString, i.as_str()), ScalarValue::Binary(Some(i)) => get_binary_value!(Binary, i.as_slice()), - ScalarValue::LargeBinary(Some(i)) => get_binary_value!(LargeBinary, i.as_slice()), - t => { - Err(DataFusionError::Execution( - format!("get map value (Map) not support {} as key type", t))) - }, + ScalarValue::LargeBinary(Some(i)) => { + get_binary_value!(LargeBinary, i.as_slice()) + } + t => df_execution_err!("get map value (Map) not support {t} as key type"), } } (dt, key) => { - Err(DataFusionError::Execution(format!("get map value (Map) is only possible on map with no-null key. Tried {:?} with {:?} key", dt, key))) - }, + df_execution_err!("get map value (Map) is only possible on map with no-null key. Tried {dt:?} with {key:?} key") + } } } @@ -279,14 +318,10 @@ fn get_data_type_field(data_type: &DataType) -> Result { if let DataType::Struct(fields) = field.data_type() { Ok(fields[1].as_ref().clone()) // values field } else { - Err(DataFusionError::NotImplemented( - "Map field only support Struct".to_string(), - )) + df_unimplemented_err!("Map field only support Struct") } } - _ => Err(DataFusionError::Plan( - "The expression to get map value is only valid for `Map` types".to_string(), - )), + _ => df_execution_err!("The expression to get map value is only valid for `Map` types"), } } diff --git a/native-engine/datafusion-ext-exprs/src/named_struct.rs b/native-engine/datafusion-ext-exprs/src/named_struct.rs index 3c3e3ccd5..537ae0e5d 100644 --- a/native-engine/datafusion-ext-exprs/src/named_struct.rs +++ b/native-engine/datafusion-ext-exprs/src/named_struct.rs @@ -30,11 +30,11 @@ use datafusion::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }, - common::{DataFusionError, Result}, + common::Result, logical_expr::ColumnarValue, physical_expr::{expr_list_eq_any_order, PhysicalExpr}, }; -use datafusion_ext_commons::io::name_batch; +use datafusion_ext_commons::{df_execution_err, io::name_batch}; use crate::down_cast_any_ref; @@ -51,9 +51,7 @@ impl NamedStructExpr { let return_schema = match &return_type { DataType::Struct(fields) => Arc::new(Schema::new(fields.clone())), other => { - return Err(DataFusionError::Execution(format!( - "NamedStruct expects returning struct type, but got {other}" - ))) + df_execution_err!("NamedStruct expects returning struct type, but got {other}")? } }; Ok(Self { diff --git a/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs b/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs index 9c4ad5f3d..2e9bbb5d1 100644 --- a/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs +++ b/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs @@ -30,10 +30,10 @@ use blaze_jni_bridge::{ jni_new_object, }; use datafusion::{ - common::DataFusionError, error::Result, logical_expr::ColumnarValue, - physical_expr::utils::expr_list_eq_any_order, physical_plan::PhysicalExpr, + error::Result, logical_expr::ColumnarValue, physical_expr::utils::expr_list_eq_any_order, + physical_plan::PhysicalExpr, }; -use datafusion_ext_commons::cast::cast; +use datafusion_ext_commons::{cast::cast, df_execution_err}; use jni::objects::GlobalRef; use once_cell::sync::OnceCell; @@ -123,9 +123,7 @@ impl PhysicalExpr for SparkUDFWrapperExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { if !is_task_running() { - return Err(DataFusionError::Execution( - "SparkUDFWrapper: is_task_running=false".to_string(), - )); + df_execution_err!("SparkUDFWrapper: is_task_running=false")?; } let batch_schema = batch.schema(); diff --git a/native-engine/datafusion-ext-exprs/src/string_contains.rs b/native-engine/datafusion-ext-exprs/src/string_contains.rs index 1ee33bb7c..255208b28 100644 --- a/native-engine/datafusion-ext-exprs/src/string_contains.rs +++ b/native-engine/datafusion-ext-exprs/src/string_contains.rs @@ -25,10 +25,11 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion::{ - common::{DataFusionError, Result, ScalarValue}, + common::{Result, ScalarValue}, logical_expr::ColumnarValue, physical_plan::PhysicalExpr, }; +use datafusion_ext_commons::df_execution_err; use crate::down_cast_any_ref; @@ -96,10 +97,7 @@ impl PhysicalExpr for StringContainsExpr { let ret = maybe_string.map(|string| string.contains(&self.infix)); Ok(ColumnarValue::Scalar(ScalarValue::Boolean(ret))) } - expr => Err(DataFusionError::Plan(format!( - "contains: invalid expr: {:?}", - expr - ))), + expr => df_execution_err!("contains: invalid expr: {expr:?}")?, } } diff --git a/native-engine/datafusion-ext-exprs/src/string_ends_with.rs b/native-engine/datafusion-ext-exprs/src/string_ends_with.rs index 0dc7bd5c9..5cfa5ea4b 100644 --- a/native-engine/datafusion-ext-exprs/src/string_ends_with.rs +++ b/native-engine/datafusion-ext-exprs/src/string_ends_with.rs @@ -25,10 +25,11 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion::{ - common::{DataFusionError, Result, ScalarValue}, + common::{Result, ScalarValue}, logical_expr::ColumnarValue, physical_plan::PhysicalExpr, }; +use datafusion_ext_commons::df_execution_err; use crate::down_cast_any_ref; @@ -95,10 +96,7 @@ impl PhysicalExpr for StringEndsWithExpr { let ret = maybe_string.map(|string| string.ends_with(&self.suffix)); Ok(ColumnarValue::Scalar(ScalarValue::Boolean(ret))) } - expr => Err(DataFusionError::Plan(format!( - "ends_with: invalid expr: {:?}", - expr - ))), + expr => df_execution_err!("ends_with: invalid expr: {expr:?}"), } } diff --git a/native-engine/datafusion-ext-exprs/src/string_starts_with.rs b/native-engine/datafusion-ext-exprs/src/string_starts_with.rs index d46b37a78..8b490d7bc 100644 --- a/native-engine/datafusion-ext-exprs/src/string_starts_with.rs +++ b/native-engine/datafusion-ext-exprs/src/string_starts_with.rs @@ -25,10 +25,11 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion::{ - common::{DataFusionError, Result, ScalarValue}, + common::{Result, ScalarValue}, logical_expr::ColumnarValue, physical_plan::PhysicalExpr, }; +use datafusion_ext_commons::df_execution_err; use crate::down_cast_any_ref; @@ -95,10 +96,7 @@ impl PhysicalExpr for StringStartsWithExpr { let ret = maybe_string.map(|string| string.starts_with(&self.prefix)); Ok(ColumnarValue::Scalar(ScalarValue::Boolean(ret))) } - expr => Err(DataFusionError::Plan(format!( - "starts_with: invalid expr: {:?}", - expr - ))), + expr => df_execution_err!("starts_with: invalid expr: {expr:?}"), } } diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index b3ad80b71..848eb4c0d 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -14,10 +14,8 @@ use std::sync::Arc; -use datafusion::{ - common::{DataFusionError, Result}, - logical_expr::ScalarFunctionImplementation, -}; +use datafusion::{common::Result, logical_expr::ScalarFunctionImplementation}; +use datafusion_ext_commons::df_unimplemented_err; mod spark_check_overflow; mod spark_get_json_object; @@ -47,10 +45,6 @@ pub fn create_spark_ext_function(name: &str) -> Result Arc::new(spark_strings::string_concat_ws), "StringLower" => Arc::new(spark_strings::string_lower), "StringUpper" => Arc::new(spark_strings::string_upper), - - _ => Err(DataFusionError::NotImplemented(format!( - "spark ext function not implemented: {}", - name - )))?, + _ => df_unimplemented_err!("spark ext function not implemented: {name}")?, }) } diff --git a/native-engine/datafusion-ext-functions/src/spark_make_array.rs b/native-engine/datafusion-ext-functions/src/spark_make_array.rs index e8b578035..72fa655e5 100644 --- a/native-engine/datafusion-ext-functions/src/spark_make_array.rs +++ b/native-engine/datafusion-ext-functions/src/spark_make_array.rs @@ -19,9 +19,9 @@ use std::sync::Arc; use arrow::{array::*, datatypes::DataType}; use datafusion::{ common::{Result, ScalarValue}, - error::DataFusionError, logical_expr::ColumnarValue, }; +use datafusion_ext_commons::df_execution_err; macro_rules! downcast_vec { ($ARGS:expr, $ARRAY_TYPE:ident) => {{ @@ -29,7 +29,7 @@ macro_rules! downcast_vec { .iter() .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { Some(array) => Ok(array), - _ => Err(DataFusionError::Internal("failed to downcast".to_string())), + _ => df_execution_err!("failed to downcast"), }) }}; } @@ -62,9 +62,7 @@ macro_rules! array { .iter() .any(|arg| arg.len() != 1 && arg.len() != num_rows) { - return Err(DataFusionError::Execution(format!( - "all columns of array must have the same length" - ))); + df_execution_err!("all columns of array must have the same length")?; } // downcast all arguments to their common format @@ -91,9 +89,7 @@ macro_rules! array { fn array_array(args: &[ArrayRef]) -> Result { // do not accept 0 arguments. if args.is_empty() { - return Err(DataFusionError::Internal( - "array requires at least one argument".to_string(), - )); + df_execution_err!("array requires at least one argument")?; } let res = match args[0].data_type() { diff --git a/native-engine/datafusion-ext-functions/src/spark_null_if_zero.rs b/native-engine/datafusion-ext-functions/src/spark_null_if_zero.rs index 94dcd0a73..d941ba503 100644 --- a/native-engine/datafusion-ext-functions/src/spark_null_if_zero.rs +++ b/native-engine/datafusion-ext-functions/src/spark_null_if_zero.rs @@ -16,9 +16,10 @@ use std::sync::Arc; use arrow::{array::*, compute::*, datatypes::*}; use datafusion::{ - common::{DataFusionError, Result, ScalarValue}, + common::{Result, ScalarValue}, physical_plan::ColumnarValue, }; +use datafusion_ext_commons::df_unimplemented_err; /// used to avoid DivideByZero error in divide/modulo pub fn spark_null_if_zero(args: &[ColumnarValue]) -> Result { @@ -74,10 +75,7 @@ pub fn spark_null_if_zero(args: &[ColumnarValue]) -> Result { handle_decimal!(Decimal256, *precision, *scale) } dt => { - return Err(DataFusionError::Execution(format!( - "Unsupported data type: {:?}", - dt - ))); + return df_unimplemented_err!("Unsupported data type: {dt:?}"); } }) } diff --git a/native-engine/datafusion-ext-functions/src/spark_strings.rs b/native-engine/datafusion-ext-functions/src/spark_strings.rs index 95efb8137..666a2aafc 100644 --- a/native-engine/datafusion-ext-functions/src/spark_strings.rs +++ b/native-engine/datafusion-ext-functions/src/spark_strings.rs @@ -21,10 +21,11 @@ use arrow::{ use datafusion::{ common::{ cast::{as_int32_array, as_list_array, as_string_array}, - DataFusionError, Result, ScalarValue, + Result, ScalarValue, }, physical_plan::ColumnarValue, }; +use datafusion_ext_commons::df_execution_err; pub fn string_lower(args: &[ColumnarValue]) -> Result { match &args[0] { @@ -36,11 +37,7 @@ pub fn string_lower(args: &[ColumnarValue]) -> Result { ColumnarValue::Scalar(ScalarValue::Utf8(Some(str))) => Ok(ColumnarValue::Scalar( ScalarValue::Utf8(Some(str.to_lowercase())), )), - _ => { - return Err(DataFusionError::Execution(format!( - "string_lower only supports literal utf8" - ))); - } + _ => df_execution_err!("string_lower only supports literal utf8"), } } @@ -54,11 +51,7 @@ pub fn string_upper(args: &[ColumnarValue]) -> Result { ColumnarValue::Scalar(ScalarValue::Utf8(Some(str))) => Ok(ColumnarValue::Scalar( ScalarValue::Utf8(Some(str.to_uppercase())), )), - _ => { - return Err(DataFusionError::Execution(format!( - "string_lower only supports literal utf8" - ))); - } + _ => df_execution_err!("string_lower only supports literal utf8"), } } @@ -79,11 +72,7 @@ pub fn string_repeat(args: &[ColumnarValue]) -> Result { ColumnarValue::Scalar(scalar) if scalar.is_null() => { return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); } - _ => { - return Err(DataFusionError::Execution(format!( - "string_repeat n only supports literal int32" - ))); - } + _ => df_execution_err!("string_repeat n only supports literal int32")?, }; let repeated_string_array: ArrayRef = Arc::new(StringArray::from_iter( @@ -98,11 +87,7 @@ pub fn string_split(args: &[ColumnarValue]) -> Result { let string_array = args[0].clone().into_array(1); let pat = match &args[1] { ColumnarValue::Scalar(ScalarValue::Utf8(Some(pat))) if !pat.is_empty() => pat, - _ => { - return Err(DataFusionError::Execution(format!( - "string_split pattern only supports non-empty literal string" - ))); - } + _ => df_execution_err!("string_split pattern only supports non-empty literal string")?, }; let mut splitted_builder = ListBuilder::new(StringBuilder::new()); @@ -128,10 +113,10 @@ pub fn string_split(args: &[ColumnarValue]) -> Result { pub fn string_concat(args: &[ColumnarValue]) -> Result { // do not accept 0 arguments. if args.is_empty() { - return Err(DataFusionError::Internal(format!( + df_execution_err!( "concat was called with {} arguments. It requires at least 1.", - args.len() - ))); + args.len(), + )?; } // first, decide whether to return a scalar or a vector. @@ -206,11 +191,7 @@ pub fn string_concat_ws(args: &[ColumnarValue]) -> Result { ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); } - _ => { - return Err(DataFusionError::Execution(format!( - "string_concat_ws separator only supports literal string" - ))); - } + _ => df_execution_err!("string_concat_ws separator only supports literal string")?, }; #[derive(Clone)] @@ -269,9 +250,7 @@ pub fn string_concat_ws(args: &[ColumnarValue]) -> Result { } } } - Err(DataFusionError::Execution(format!( - "concat_ws args must be string or array" - ))) + df_execution_err!("concat_ws args must be string or array") }) .collect::>>()?; diff --git a/native-engine/datafusion-ext-plans/src/agg/agg_tables.rs b/native-engine/datafusion-ext-plans/src/agg/agg_tables.rs index 7e6d3fce8..e947f6bb0 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg_tables.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg_tables.rs @@ -23,11 +23,11 @@ use ahash::RandomState; use arrow::{record_batch::RecordBatch, row::Rows}; use async_trait::async_trait; use datafusion::{ - common::Result, error::DataFusionError, execution::context::TaskContext, - physical_plan::metrics::BaselineMetrics, + common::Result, execution::context::TaskContext, physical_plan::metrics::BaselineMetrics, }; use datafusion_ext_commons::{ bytes_arena::BytesArena, + df_execution_err, io::{read_bytes_slice, read_len, write_len}, loser_tree::LoserTree, rdxsort, @@ -553,9 +553,7 @@ impl InMemTable { } write_len(65536, &mut writer)?; // EOF write_len(0, &mut writer)?; - writer - .finish() - .map_err(|err| DataFusionError::Execution(format!("{}", err)))?; + writer.finish().or_else(|err| df_execution_err!("{err}"))?; spill.complete()?; Ok(spill) } diff --git a/native-engine/datafusion-ext-plans/src/agg/avg.rs b/native-engine/datafusion-ext-plans/src/agg/avg.rs index fb541fba9..177cb05d1 100644 --- a/native-engine/datafusion-ext-plans/src/agg/avg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/avg.rs @@ -24,9 +24,9 @@ use datafusion::{ cast::{as_decimal128_array, as_int64_array}, Result, ScalarValue, }, - error::DataFusionError, physical_expr::PhysicalExpr, }; +use datafusion_ext_commons::df_unimplemented_err; use crate::agg::{ agg_buf::{AccumInitialValue, AggBuf}, @@ -265,9 +265,6 @@ fn get_final_merger(dt: &DataType) -> Result ScalarValue DataType::UInt32 => get_fn!(UInt32, f64), DataType::UInt64 => get_fn!(UInt64, f64), DataType::Decimal128(..) => get_fn!(Decimal128), - other => Err(DataFusionError::NotImplemented(format!( - "unsupported data type in avg(): {}", - other - ))), + other => df_unimplemented_err!("unsupported data type in avg(): {other}"), } } diff --git a/native-engine/datafusion-ext-plans/src/agg/maxmin.rs b/native-engine/datafusion-ext-plans/src/agg/maxmin.rs index d87b27805..692e0e2a0 100644 --- a/native-engine/datafusion-ext-plans/src/agg/maxmin.rs +++ b/native-engine/datafusion-ext-plans/src/agg/maxmin.rs @@ -23,9 +23,9 @@ use std::{ use arrow::{array::*, datatypes::*}; use datafusion::{ common::{Result, ScalarValue}, - error::DataFusionError, physical_expr::PhysicalExpr, }; +use datafusion_ext_commons::df_unimplemented_err; use paste::paste; use crate::agg::{ @@ -203,12 +203,7 @@ impl Agg for AggMaxMin

{ } } } - other => { - return Err(DataFusionError::NotImplemented(format!( - "unsupported data type in {}(): {other}", - P::NAME, - ))); - } + other => df_unimplemented_err!("unsupported data type in {}(): {other}", P::NAME)?, } Ok(()) } @@ -320,10 +315,7 @@ fn get_partial_updater( } } }), - other => Err(DataFusionError::NotImplemented(format!( - "unsupported data type in {}(): {other}", - P::NAME, - ))), + other => df_unimplemented_err!("unsupported data type in {}(): {other}", P::NAME), } } @@ -379,10 +371,7 @@ fn get_partial_batch_updater( } } }), - other => Err(DataFusionError::NotImplemented(format!( - "unsupported data type in {}(): {other}", - P::NAME, - ))), + other => df_unimplemented_err!("unsupported data type in {}(): {other}", P::NAME), } } @@ -436,10 +425,7 @@ fn get_partial_buf_merger( } } }), - other => Err(DataFusionError::NotImplemented(format!( - "unsupported data type in {}(): {other}", - P::NAME, - ))), + other => df_unimplemented_err!("unsupported data type in {}(): {other}", P::NAME), } } diff --git a/native-engine/datafusion-ext-plans/src/agg/mod.rs b/native-engine/datafusion-ext-plans/src/agg/mod.rs index c12752e06..92a1f6c57 100644 --- a/native-engine/datafusion-ext-plans/src/agg/mod.rs +++ b/native-engine/datafusion-ext-plans/src/agg/mod.rs @@ -28,10 +28,11 @@ use std::{any::Any, fmt::Debug, sync::Arc}; use arrow::{array::*, datatypes::*}; use datafusion::{ - common::{DataFusionError, Result, ScalarValue}, + common::{Result, ScalarValue}, logical_expr::aggregate_function, physical_expr::PhysicalExpr, }; +use datafusion_ext_commons::df_unimplemented_err; use datafusion_ext_exprs::cast::TryCastExpr; use crate::agg::agg_buf::{AccumInitialValue, AggBuf, AggDynBinary, AggDynScalar, AggDynStr}; @@ -244,9 +245,7 @@ pub trait Agg: Send + Sync + Debug { { s.value.clone() } else { - return Err(DataFusionError::NotImplemented(format!( - "unsupported data type: {other}" - ))); + return df_unimplemented_err!("unsupported data type: {other}"); } } }) diff --git a/native-engine/datafusion-ext-plans/src/agg/sum.rs b/native-engine/datafusion-ext-plans/src/agg/sum.rs index a9f71bbcd..27cb9a7ff 100644 --- a/native-engine/datafusion-ext-plans/src/agg/sum.rs +++ b/native-engine/datafusion-ext-plans/src/agg/sum.rs @@ -22,9 +22,9 @@ use std::{ use arrow::{array::*, datatypes::*}; use datafusion::{ common::{Result, ScalarValue}, - error::DataFusionError, physical_expr::PhysicalExpr, }; +use datafusion_ext_commons::df_unimplemented_err; use paste::paste; use crate::agg::{ @@ -157,12 +157,7 @@ impl Agg for AggSum { DataType::UInt32 => handle!(UInt32), DataType::UInt64 => handle!(UInt64), DataType::Decimal128(..) => handle!(Decimal128), - other => { - return Err(DataFusionError::NotImplemented(format!( - "unsupported data type in sum(): {}", - other - ))); - } + other => df_unimplemented_err!("unsupported data type in sum(): {other}")?, } Ok(()) } @@ -228,10 +223,7 @@ fn get_partial_updater(dt: &DataType) -> Result fn_fixed!(UInt32), DataType::UInt64 => fn_fixed!(UInt64), DataType::Decimal128(..) => fn_fixed!(Decimal128), - other => Err(DataFusionError::NotImplemented(format!( - "unsupported data type in sum(): {}", - other - ))), + other => df_unimplemented_err!("unsupported data type in sum(): {other}"), } } @@ -262,10 +254,7 @@ fn get_partial_batch_updater(dt: &DataType) -> Result fn_fixed!(UInt32), DataType::UInt64 => fn_fixed!(UInt64), DataType::Decimal128(..) => fn_fixed!(Decimal128), - other => Err(DataFusionError::NotImplemented(format!( - "unsupported data type in sum(): {}", - other - ))), + other => df_unimplemented_err!("unsupported data type in sum(): {other}"), } } @@ -295,9 +284,6 @@ fn get_partial_buf_merger(dt: &DataType) -> Result fn_fixed!(UInt32), DataType::UInt64 => fn_fixed!(UInt64), DataType::Decimal128(..) => fn_fixed!(Decimal128), - other => Err(DataFusionError::NotImplemented(format!( - "unsupported data type in sum(): {}", - other - ))), + other => df_unimplemented_err!("unsupported data type in sum(): {other}"), } } diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs index 9377dd6fc..e7648f0d7 100644 --- a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs @@ -26,7 +26,7 @@ use blaze_jni_bridge::{ conf::{BooleanConf, IntConf}, }; use datafusion::{ - common::{DataFusionError, Result, Statistics}, + common::{Result, Statistics}, execution::context::TaskContext, logical_expr::JoinType, physical_expr::PhysicalSortExpr, @@ -42,6 +42,7 @@ use datafusion::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, }, }; +use datafusion_ext_commons::df_execution_err; use futures::{stream::once, StreamExt, TryStreamExt}; use parking_lot::Mutex; @@ -78,9 +79,7 @@ impl BroadcastJoinExec { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti, ) { if join_filter.is_some() { - return Err(DataFusionError::Plan(format!( - "Semi/Anti join with filter is not supported yet" - ))); + df_execution_err!("Semi/Anti join with filter is not supported yet")?; } } diff --git a/native-engine/datafusion-ext-plans/src/common/output.rs b/native-engine/datafusion-ext-plans/src/common/output.rs index 95f3fdc05..97c4cf79f 100644 --- a/native-engine/datafusion-ext-plans/src/common/output.rs +++ b/native-engine/datafusion-ext-plans/src/common/output.rs @@ -22,14 +22,17 @@ use std::{ use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use blaze_jni_bridge::is_task_running; use datafusion::{ - common::{DataFusionError, Result}, + common::Result, execution::context::TaskContext, physical_plan::{ metrics::ScopedTimerGuard, stream::RecordBatchReceiverStream, SendableRecordBatchStream, }, }; -use datafusion_ext_commons::io::{read_one_batch, write_one_batch}; -use futures::{FutureExt, StreamExt, TryFutureExt}; +use datafusion_ext_commons::{ + df_execution_err, + io::{read_one_batch, write_one_batch}, +}; +use futures::{FutureExt, StreamExt}; use once_cell::sync::OnceCell; use parking_lot::Mutex; use tokio::sync::mpsc::Sender; @@ -63,9 +66,10 @@ impl WrappedRecordBatchSender { .into_iter() .filter(|wrapped| match wrapped.upgrade() { Some(wrapped) if Arc::ptr_eq(&wrapped.task_context, task_context) => { - let _ = wrapped.sender.send(Err(DataFusionError::Execution(format!( - "task completed/cancelled" - )))); + wrapped + .sender + .try_send(df_execution_err!("task completed/cancelled")) + .unwrap_or_default(); false } Some(_) => true, // do not modify senders from other tasks @@ -81,13 +85,13 @@ impl WrappedRecordBatchSender { ) { // panic if we meet an error let batch = batch_result - .unwrap_or_else(|err| panic!("output_with_sender: received an error: {}", err)); + .unwrap_or_else(|err| panic!("output_with_sender: received an error: {err}")); stop_timer.iter_mut().for_each(|timer| timer.stop()); self.sender .send(Ok(batch)) .await - .unwrap_or_else(|err| panic!("output_with_sender: send error: {}", err)); + .unwrap_or_else(|err| panic!("output_with_sender: send error: {err}")); stop_timer.iter_mut().for_each(|timer| timer.restart()); } } @@ -105,6 +109,8 @@ pub trait TaskOutputter { mem_consumer: Arc, stream: SendableRecordBatchStream, ) -> Result; + + fn cancel_task(&self); } impl TaskOutputter for Arc { @@ -115,27 +121,15 @@ impl TaskOutputter for Arc { output: impl FnOnce(Arc) -> Fut + Send + 'static, ) -> Result { let mut stream_builder = RecordBatchReceiverStream::builder(output_schema, 1); - let sender = stream_builder.tx().clone(); - let err_sender = sender.clone(); - let wrapped_sender = WrappedRecordBatchSender::new(self.clone(), sender); + let err_sender = stream_builder.tx().clone(); + let wrapped_sender = + WrappedRecordBatchSender::new(self.clone(), stream_builder.tx().clone()); stream_builder.spawn(async move { let result = AssertUnwindSafe(async move { - let task_running = is_task_running(); - if !task_running { - panic!( - "output_with_sender[{}] canceled due to task finished/killed", - desc - ); + if let Err(err) = output(wrapped_sender).await { + panic!("output_with_sender[{desc}]: output() returns error: {err}"); } - output(wrapped_sender) - .unwrap_or_else(|err| { - panic!( - "output_with_sender[{}]: output() returns error: {}", - desc, err - ); - }) - .await }) .catch_unwind() .await @@ -143,22 +137,22 @@ impl TaskOutputter for Arc { .unwrap_or_else(|err| { let panic_message = panic_message::get_panic_message(&err).unwrap_or("unknown error"); - Err(DataFusionError::Execution(panic_message.to_owned())) + df_execution_err!("{panic_message}") }); if let Err(err) = result { let err_message = err.to_string(); - let _ = err_sender.send(Err(err)).await; + err_sender + .send(df_execution_err!("{err}")) + .await + .unwrap_or_default(); // panic current spawn let task_running = is_task_running(); if !task_running { - panic!( - "output_with_sender[{}] canceled due to task finished/killed", - desc - ); + panic!("output_with_sender[{desc}] canceled due to task finished/killed"); } else { - panic!("output_with_sender[{}] error: {}", desc, err_message); + panic!("output_with_sender[{desc}] error: {err_message}"); } } }); @@ -206,4 +200,8 @@ impl TaskOutputter for Arc { Ok(()) }) } + + fn cancel_task(&self) { + WrappedRecordBatchSender::cancel_task(self); + } } diff --git a/native-engine/datafusion-ext-plans/src/debug_exec.rs b/native-engine/datafusion-ext-plans/src/debug_exec.rs index 729337a67..42ba51225 100644 --- a/native-engine/datafusion-ext-plans/src/debug_exec.rs +++ b/native-engine/datafusion-ext-plans/src/debug_exec.rs @@ -23,7 +23,7 @@ use std::{ use arrow::{datatypes::SchemaRef, record_batch::RecordBatch, util::pretty::pretty_format_batches}; use async_trait::async_trait; use datafusion::{ - error::{DataFusionError, Result}, + error::Result, execution::context::TaskContext, physical_expr::PhysicalSortExpr, physical_plan::{ @@ -81,13 +81,8 @@ impl ExecutionPlan for DebugExec { fn with_new_children( self: Arc, - children: Vec>, + _children: Vec>, ) -> Result> { - if children.len() != 1 { - return Err(DataFusionError::Plan( - "DebugExec expects one children".to_string(), - )); - } Ok(Arc::new(DebugExec::new( self.input.clone(), self.debug_id.clone(), diff --git a/native-engine/datafusion-ext-plans/src/empty_partitions_exec.rs b/native-engine/datafusion-ext-plans/src/empty_partitions_exec.rs index adb5544eb..eb1413732 100644 --- a/native-engine/datafusion-ext-plans/src/empty_partitions_exec.rs +++ b/native-engine/datafusion-ext-plans/src/empty_partitions_exec.rs @@ -23,7 +23,7 @@ use std::{ use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use async_trait::async_trait; use datafusion::{ - error::{DataFusionError, Result}, + error::Result, execution::context::TaskContext, physical_expr::PhysicalSortExpr, physical_plan::{ @@ -83,13 +83,8 @@ impl ExecutionPlan for EmptyPartitionsExec { fn with_new_children( self: Arc, - children: Vec>, + _children: Vec>, ) -> Result> { - if !children.is_empty() { - return Err(DataFusionError::Plan( - "EmptyPartitionsExec expects no children".to_string(), - )); - } Ok(self) } diff --git a/native-engine/datafusion-ext-plans/src/expand_exec.rs b/native-engine/datafusion-ext-plans/src/expand_exec.rs index dcf5f1475..36807ac40 100644 --- a/native-engine/datafusion-ext-plans/src/expand_exec.rs +++ b/native-engine/datafusion-ext-plans/src/expand_exec.rs @@ -19,7 +19,7 @@ use arrow::{ record_batch::{RecordBatch, RecordBatchOptions}, }; use datafusion::{ - common::{DataFusionError, Result, Statistics}, + common::{Result, Statistics}, execution::context::TaskContext, physical_expr::{PhysicalExpr, PhysicalSortExpr}, physical_plan::{ @@ -30,7 +30,7 @@ use datafusion::{ SendableRecordBatchStream, }, }; -use datafusion_ext_commons::cast::cast; +use datafusion_ext_commons::{cast::cast, df_execution_err}; use futures::{stream::once, StreamExt, TryStreamExt}; use crate::common::output::TaskOutputter; @@ -59,10 +59,7 @@ impl ExpandExec { .transpose()?; if projection_data_type.as_ref() != Some(schema_data_type) { - return Err(DataFusionError::Plan(format!( - "ExpandExec data type not matches: {:?} vs {:?}", - projection_data_type, schema_data_type - ))); + df_execution_err!("ExpandExec data type not matches: {projection_data_type:?} vs {schema_data_type:?}")?; } } } @@ -106,11 +103,6 @@ impl ExecutionPlan for ExpandExec { self: Arc, children: Vec>, ) -> Result> { - if children.len() != 1 { - return Err(DataFusionError::Plan( - "ExpandExec expects one children".to_string(), - )); - } Ok(Arc::new(Self { schema: self.schema(), projections: self.projections.clone(), diff --git a/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs b/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs index 6638a2e05..24c041dcf 100644 --- a/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs +++ b/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs @@ -21,7 +21,7 @@ use std::{ use arrow::datatypes::SchemaRef; use blaze_jni_bridge::{jni_call, jni_call_static, jni_new_global_ref, jni_new_string}; use datafusion::{ - error::{DataFusionError, Result}, + error::Result, execution::context::TaskContext, physical_expr::PhysicalSortExpr, physical_plan::{ @@ -91,14 +91,13 @@ impl ExecutionPlan for FFIReaderExec { fn with_new_children( self: Arc, - children: Vec>, + _children: Vec>, ) -> Result> { - if !children.is_empty() { - return Err(DataFusionError::Plan( - "Blaze FFIReaderExec expects 0 children".to_owned(), - )); - } - Ok(self) + Ok(Arc::new(Self::new( + self.num_partitions, + self.export_iter_provider_resource_id.clone(), + self.schema.clone(), + ))) } fn execute( diff --git a/native-engine/datafusion-ext-plans/src/filter_exec.rs b/native-engine/datafusion-ext-plans/src/filter_exec.rs index b3f51c88a..6972e0cef 100644 --- a/native-engine/datafusion-ext-plans/src/filter_exec.rs +++ b/native-engine/datafusion-ext-plans/src/filter_exec.rs @@ -16,7 +16,7 @@ use std::{any::Any, fmt::Formatter, sync::Arc}; use arrow::datatypes::{DataType, SchemaRef}; use datafusion::{ - common::{DataFusionError, Result, Statistics}, + common::{Result, Statistics}, execution::context::TaskContext, physical_expr::{expressions::Column, PhysicalExprRef, PhysicalSortExpr}, physical_plan::{ @@ -25,7 +25,7 @@ use datafusion::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, }, }; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceInput; +use datafusion_ext_commons::{df_execution_err, streams::coalesce_stream::CoalesceInput}; use futures::{stream::once, StreamExt, TryStreamExt}; use itertools::Itertools; @@ -54,17 +54,13 @@ impl FilterExec { let schema = input.schema(); if predicates.is_empty() { - return Err(DataFusionError::Plan(format!( - "Filter requires at least one predicate" - ))); + df_execution_err!("Filter requires at least one predicate")?; } if !predicates .iter() .all(|pred| matches!(pred.data_type(&schema), Ok(DataType::Boolean))) { - return Err(DataFusionError::Plan(format!( - "Filter predicate must return boolean values" - ))); + df_execution_err!("Filter predicate must return boolean values")?; } Ok(Self { input, diff --git a/native-engine/datafusion-ext-plans/src/generate/mod.rs b/native-engine/datafusion-ext-plans/src/generate/mod.rs index d7771b053..74fe99352 100644 --- a/native-engine/datafusion-ext-plans/src/generate/mod.rs +++ b/native-engine/datafusion-ext-plans/src/generate/mod.rs @@ -21,7 +21,8 @@ use arrow::{ datatypes::{DataType, SchemaRef}, record_batch::RecordBatch, }; -use datafusion::{common::Result, error::DataFusionError, physical_plan::PhysicalExpr}; +use datafusion::{common::Result, physical_plan::PhysicalExpr}; +use datafusion_ext_commons::df_unimplemented_err; use crate::generate::explode::{ExplodeArray, ExplodeMap}; @@ -54,18 +55,12 @@ pub fn create_generator( GenerateFunc::Explode => match children[0].data_type(input_schema)? { DataType::List(..) => Ok(Arc::new(ExplodeArray::new(children[0].clone(), false))), DataType::Map(..) => Ok(Arc::new(ExplodeMap::new(children[0].clone(), false))), - other => Err(DataFusionError::Plan(format!( - "unsupported explode type: {}", - other - ))), + other => df_unimplemented_err!("unsupported explode type: {other}"), }, GenerateFunc::PosExplode => match children[0].data_type(input_schema)? { DataType::List(..) => Ok(Arc::new(ExplodeArray::new(children[0].clone(), true))), DataType::Map(..) => Ok(Arc::new(ExplodeMap::new(children[0].clone(), true))), - other => Err(DataFusionError::Plan(format!( - "unsupported pos_explode type: {}", - other - ))), + other => df_unimplemented_err!("unsupported pos_explode type: {other}"), }, } } diff --git a/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs b/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs index 6c43df1bd..dc08fb193 100644 --- a/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs +++ b/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs @@ -22,7 +22,7 @@ use arrow::datatypes::SchemaRef; use async_trait::async_trait; use blaze_jni_bridge::{jni_call, jni_call_static, jni_new_global_ref, jni_new_string}; use datafusion::{ - error::{DataFusionError, Result}, + error::Result, execution::context::TaskContext, physical_plan::{ expressions::PhysicalSortExpr, @@ -95,9 +95,12 @@ impl ExecutionPlan for IpcReaderExec { self: Arc, _children: Vec>, ) -> Result> { - Err(DataFusionError::Plan( - "Blaze ShuffleReaderExec does not support with_new_children()".to_owned(), - )) + Ok(Arc::new(Self::new( + self.num_partitions, + self.ipc_provider_resource_id.clone(), + self.schema.clone(), + self.mode, + ))) } fn execute( diff --git a/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs b/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs index cd669df54..86324bfd6 100644 --- a/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs +++ b/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs @@ -20,7 +20,7 @@ use blaze_jni_bridge::{ jni_call, jni_call_static, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_string, }; use datafusion::{ - error::{DataFusionError, Result}, + error::Result, execution::context::TaskContext, physical_expr::PhysicalSortExpr, physical_plan::{ @@ -83,13 +83,8 @@ impl ExecutionPlan for IpcWriterExec { fn with_new_children( self: Arc, - children: Vec>, + _children: Vec>, ) -> Result> { - if children.len() != 1 { - return Err(DataFusionError::Plan( - "IpcWriterExec expects one children".to_string(), - )); - } Ok(Arc::new(IpcWriterExec::new( self.input.clone(), self.ipc_consumer_resource_id.clone(), diff --git a/native-engine/datafusion-ext-plans/src/limit_exec.rs b/native-engine/datafusion-ext-plans/src/limit_exec.rs index 8dc8048f0..5c59bf891 100644 --- a/native-engine/datafusion-ext-plans/src/limit_exec.rs +++ b/native-engine/datafusion-ext-plans/src/limit_exec.rs @@ -8,7 +8,7 @@ use std::{ use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use datafusion::{ - common::{DataFusionError, Result}, + common::Result, execution::context::TaskContext, physical_expr::PhysicalSortExpr, physical_plan::{ @@ -67,12 +67,7 @@ impl ExecutionPlan for LimitExec { self: Arc, children: Vec>, ) -> Result> { - match children.len() { - 1 => Ok(Arc::new(Self::new(children[0].clone(), self.limit))), - _ => Err(DataFusionError::Internal( - "LimitExec wrong number of children".to_string(), - )), - } + Ok(Arc::new(Self::new(children[0].clone(), self.limit))) } fn execute( diff --git a/native-engine/datafusion-ext-plans/src/parquet_exec.rs b/native-engine/datafusion-ext-plans/src/parquet_exec.rs index f70a5637c..840685631 100644 --- a/native-engine/datafusion-ext-plans/src/parquet_exec.rs +++ b/native-engine/datafusion-ext-plans/src/parquet_exec.rs @@ -53,7 +53,11 @@ use datafusion::{ RecordBatchStream, SendableRecordBatchStream, Statistics, }, }; -use datafusion_ext_commons::hadoop_fs::{FsDataInputStream, FsProvider}; +use datafusion_ext_commons::{ + df_execution_err, + hadoop_fs::{FsDataInputStream, FsProvider}, + streams::coalesce_stream::CoalesceInput, +}; use fmt::Debug; use futures::{future::BoxFuture, stream::once, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use object_store::ObjectMeta; @@ -225,10 +229,12 @@ impl ExecutionPlan for ParquetExec { None => (0..self.base_config.file_schema.fields().len()).collect(), }; + let batch_size = context.session_config().batch_size(); + let sub_batch_size = (1.0 + batch_size as f64).sqrt() as usize; let opener = ParquetOpener { partition_index, projection: Arc::from(projection), - batch_size: context.session_config().batch_size(), + batch_size: sub_batch_size, limit: self.base_config.limit, predicate: self.predicate.clone(), pruning_predicate: self.pruning_predicate.clone(), @@ -250,10 +256,11 @@ impl ExecutionPlan for ParquetExec { file_stream = file_stream.with_on_error(OnError::Skip); } let mut stream = Box::pin(file_stream); - Ok(Box::pin(RecordBatchStreamAdapter::new( + let context_cloned = context.clone(); + let timed_stream = Box::pin(RecordBatchStreamAdapter::new( self.schema(), once(async move { - context.output_with_sender( + context_cloned.output_with_sender( "ParquetScan", stream.schema(), move |sender| async move { @@ -266,7 +273,8 @@ impl ExecutionPlan for ParquetExec { ) }) .try_flatten(), - ))) + )); + context.coalesce_with_default_batch_size(timed_stream, &baseline_metrics) } fn metrics(&self) -> Option { @@ -339,11 +347,9 @@ impl ParquetFileReader { let path = BASE64_URL_SAFE_NO_PAD .decode(self.meta.location.filename().expect("missing filename")) .map(|bytes| String::from_utf8_lossy(&bytes).to_string()) - .map_err(|_| { - DataFusionError::Execution(format!( - "cannot decode filename: {:?}", - self.meta.location.filename() - )) + .or_else(|_| { + let filename = self.meta.location.filename(); + df_execution_err!("cannot decode filename: {filename:?}") })?; let fs = self.fs_provider.provide(&path)?; Ok(Arc::new(fs.open(&path)?)) diff --git a/native-engine/datafusion-ext-plans/src/parquet_sink_exec.rs b/native-engine/datafusion-ext-plans/src/parquet_sink_exec.rs index 3bf05e36d..1bef99f4d 100644 --- a/native-engine/datafusion-ext-plans/src/parquet_sink_exec.rs +++ b/native-engine/datafusion-ext-plans/src/parquet_sink_exec.rs @@ -41,6 +41,7 @@ use datafusion::{ }; use datafusion_ext_commons::{ cast::cast, + df_execution_err, hadoop_fs::{FsDataOutputStream, FsProvider}, }; use futures::{stream::once, StreamExt, TryStreamExt}; @@ -182,16 +183,17 @@ async fn execute_parquet_sink( let mut timer = metrics.elapsed_compute().timer(); // parse hive_schema from props - let hive_schema = props + let hive_schema = match props .iter() .find(|(key, _)| key == "parquet.hive.schema") .map(|(_, value)| value) .and_then(|value| parse_message_type(value.as_str()).ok()) .and_then(|tp| parquet_to_arrow_schema(&SchemaDescriptor::new(Arc::new(tp)), None).ok()) .map(Arc::new) - .ok_or(DataFusionError::Execution(format!( - "missing parquet.hive.schema" - )))?; + { + Some(hive_schema) => hive_schema, + _ => df_execution_err!("missing parquet.hive.schema")?, + }; // parse row group byte size from props let block_size = props @@ -239,8 +241,7 @@ async fn execute_parquet_sink( metrics.record_output(num_rows); Ok::<_, DataFusionError>(()) }); - fut.await - .map_err(|err| DataFusionError::Execution(format!("{err}")))??; + fut.await.or_else(|err| df_execution_err!("{err}"))??; timer.stop(); } @@ -251,8 +252,7 @@ async fn execute_parquet_sink( w.close()?; Ok::<_, DataFusionError>(()) }); - fut.await - .map_err(|err| DataFusionError::Execution(format!("{err}")))??; + fut.await.or_else(|err| df_execution_err!("{err}"))??; } // parquet sink does not provide any output records diff --git a/native-engine/datafusion-ext-plans/src/project_exec.rs b/native-engine/datafusion-ext-plans/src/project_exec.rs index a0d76452d..33d60cf7c 100644 --- a/native-engine/datafusion-ext-plans/src/project_exec.rs +++ b/native-engine/datafusion-ext-plans/src/project_exec.rs @@ -221,6 +221,8 @@ async fn execute_project_with_filtering( let mut timer = baseline_metrics.elapsed_compute().timer(); let output_batch = cached_expr_evaluator.filter_project(&batch, output_schema.clone())?; + drop(batch); + baseline_metrics.record_output(output_batch.num_rows()); sender.send(Ok(output_batch), Some(&mut timer)).await; } diff --git a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs index 5e56a9062..7d1574f19 100644 --- a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs +++ b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs @@ -26,7 +26,7 @@ use arrow::{ }; use async_trait::async_trait; use datafusion::{ - error::{DataFusionError, Result}, + error::Result, execution::context::TaskContext, physical_expr::PhysicalSortExpr, physical_plan::{ @@ -35,6 +35,7 @@ use datafusion::{ SendableRecordBatchStream, Statistics, }, }; +use datafusion_ext_commons::df_execution_err; use futures::{Stream, StreamExt}; use crate::agg::AGG_BUF_COLUMN_NAME; @@ -64,11 +65,10 @@ impl RenameColumnsExec { } } if new_names.len() != input_schema.fields().len() { - return Err(DataFusionError::Plan(format!( + df_execution_err!( "renamed_column_names length not matched with input schema, \ - renames: {:?}, input schema: {}", - renamed_column_names, input_schema, - ))); + renames: {renamed_column_names:?}, input schema: {input_schema}", + )?; } let renamed_column_names = new_names; let renamed_schema = Arc::new(Schema::new( @@ -122,11 +122,6 @@ impl ExecutionPlan for RenameColumnsExec { self: Arc, children: Vec>, ) -> Result> { - if children.len() != 1 { - return Err(DataFusionError::Plan( - "RenameColumnsExec expects one children".to_string(), - )); - } Ok(Arc::new(RenameColumnsExec::try_new( children[0].clone(), self.renamed_column_names.clone(), diff --git a/native-engine/datafusion-ext-plans/src/shuffle/bucket_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/bucket_repartitioner.rs index 77248ce9d..a64401155 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/bucket_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/bucket_repartitioner.rs @@ -38,6 +38,7 @@ use datafusion::{ }; use datafusion_ext_commons::{ array_builder::{builder_extend, make_batch, new_array_builders}, + df_execution_err, io::write_one_batch, }; use futures::lock::Mutex; @@ -244,7 +245,7 @@ impl ShuffleRepartitioner for BucketShuffleRepartitioner { Ok::<(), DataFusionError>(()) }) .await - .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {:?}", e)))??; + .or_else(|e| df_execution_err!("shuffle write error: {e:?}"))??; // update disk spill size let spill_disk_usage = raw_spills diff --git a/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs index 5356c89ef..30d2df549 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs @@ -28,7 +28,7 @@ use datafusion::{ Partitioning, }, }; -use datafusion_ext_commons::{io::write_one_batch, loser_tree::LoserTree}; +use datafusion_ext_commons::{df_execution_err, io::write_one_batch, loser_tree::LoserTree}; use derivative::Derivative; use futures::lock::Mutex; @@ -391,7 +391,7 @@ impl ShuffleRepartitioner for SortShuffleRepartitioner { Ok::<(), DataFusionError>(()) }) .await - .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {:?}", e)))??; + .or_else(|e| df_execution_err!("shuffle write error: {e:?}"))??; // update disk spill size let spill_disk_usage = raw_spills diff --git a/native-engine/datafusion-ext-plans/src/shuffle_writer_exec.rs b/native-engine/datafusion-ext-plans/src/shuffle_writer_exec.rs index 705e04f3e..a13b6bbff 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle_writer_exec.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle_writer_exec.rs @@ -19,7 +19,7 @@ use std::{any::Any, fmt::Debug, sync::Arc}; use arrow::datatypes::SchemaRef; use async_trait::async_trait; use datafusion::{ - error::{DataFusionError, Result}, + error::Result, execution::context::TaskContext, physical_plan::{ expressions::PhysicalSortExpr, @@ -29,6 +29,7 @@ use datafusion::{ Statistics, }, }; +use datafusion_ext_commons::df_execution_err; use futures::{stream::once, TryStreamExt}; use crate::{ @@ -99,9 +100,7 @@ impl ExecutionPlan for ShuffleWriterExec { self.output_data_file.clone(), self.output_index_file.clone(), )?)), - _ => Err(DataFusionError::Internal( - "ShuffleWriterExec wrong number of children".to_string(), - )), + _ => df_execution_err!("ShuffleWriterExec wrong number of children"), } } diff --git a/native-engine/datafusion-ext-plans/src/sort_exec.rs b/native-engine/datafusion-ext-plans/src/sort_exec.rs index 7aa93378a..05016e535 100644 --- a/native-engine/datafusion-ext-plans/src/sort_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_exec.rs @@ -31,7 +31,7 @@ use arrow::{ }; use async_trait::async_trait; use datafusion::{ - common::{DataFusionError, Result, Statistics}, + common::{Result, Statistics}, execution::context::TaskContext, physical_expr::PhysicalSortExpr, physical_plan::{ @@ -42,6 +42,7 @@ use datafusion::{ }; use datafusion_ext_commons::{ bytes_arena::BytesArena, + df_execution_err, io::{read_bytes_slice, read_len, read_one_batch, write_len, write_one_batch}, loser_tree::LoserTree, slim_bytes::SlimBytes, @@ -772,9 +773,7 @@ impl SortedBatches { writer.write_all(key)?; } } - writer - .finish() - .map_err(|err| DataFusionError::Execution(format!("{}", err)))?; + writer.finish().or_else(|err| df_execution_err!("{err}"))?; spill.complete()?; Ok(Some(spill)) } diff --git a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs index 6b17bb2e0..5c9f3c983 100644 --- a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs @@ -23,7 +23,7 @@ use arrow::{ row::{Row, RowConverter, Rows, SortField}, }; use datafusion::{ - error::{DataFusionError, Result}, + error::Result, execution::context::TaskContext, logical_expr::{JoinType, JoinType::*}, physical_expr::PhysicalSortExpr, @@ -37,7 +37,7 @@ use datafusion::{ Statistics, }, }; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceInput; +use datafusion_ext_commons::{df_execution_err, streams::coalesce_stream::CoalesceInput}; use futures::{StreamExt, TryStreamExt}; use parking_lot::Mutex as SyncMutex; @@ -82,19 +82,17 @@ impl SortMergeJoinExec { if matches!(join_type, LeftSemi | LeftAnti | RightSemi | RightAnti,) { if join_filter.is_some() { - return Err(DataFusionError::Plan(format!( - "Semi/Anti join with filter is not supported yet" - ))); + df_execution_err!("Semi/Anti join with filter is not supported yet")?; } } check_join_is_valid(&left_schema, &right_schema, &on)?; if sort_options.len() != on.len() { - return Err(DataFusionError::Plan(format!( + df_execution_err!( "Expected number of sort options: {}, actual: {}", on.len(), - sort_options.len() - ))); + sort_options.len(), + )?; } let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); @@ -176,19 +174,14 @@ impl ExecutionPlan for SortMergeJoinExec { self: Arc, children: Vec>, ) -> Result> { - match &children[..] { - [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new( - left.clone(), - right.clone(), - self.on.clone(), - self.join_type, - self.join_filter.clone(), - self.sort_options.clone(), - )?)), - _ => Err(DataFusionError::Internal( - "SortMergeJoin wrong number of children".to_string(), - )), - } + Ok(Arc::new(SortMergeJoinExec::try_new( + children[0].clone(), + children[1].clone(), + self.on.clone(), + self.join_type, + self.join_filter.clone(), + self.sort_options.clone(), + )?)) } fn execute( diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala index f6eaa7048..90fd93601 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala @@ -19,7 +19,6 @@ import java.io.File import java.io.IOException import java.nio.file.Files import java.nio.file.StandardCopyOption -import java.util import java.util.concurrent.atomic.AtomicReference import org.apache.arrow.c.ArrowArray @@ -35,6 +34,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.CompletionIterator import org.apache.spark.util.Utils import org.blaze.protobuf.PartitionId @@ -51,26 +52,48 @@ case class BlazeCallNativeWrapper( BlazeCallNativeWrapper.initNative() private val error: AtomicReference[Throwable] = new AtomicReference(null) + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator(this.getClass.getName, 0, Long.MaxValue) private val dictionaryProvider = new CDataDictionaryProvider() - private val recordsQueue = new util.ArrayDeque[InternalRow]() private var arrowSchema: Schema = _ + private var schema: StructType = _ + private var toUnsafe: UnsafeProjection = _ + private var root: VectorSchemaRoot = _ + private var batch: ColumnarBatch = _ + private var batchCurRowIdx = 0 logInfo(s"Start executing native plan") private var nativeRuntimePtr = JniBridge.callNative(NativeHelper.nativeMemory, this) private lazy val rowIterator = new Iterator[InternalRow] { - override def hasNext: Boolean = - !recordsQueue.isEmpty || JniBridge.nextBatch(nativeRuntimePtr) + override def hasNext: Boolean = { + checkError() + if (batch != null && batchCurRowIdx < batch.numRows()) { + return true + } + if (root != null) { + batch.close() + root.close() + } + if (nativeRuntimePtr != 0 && JniBridge.nextBatch(nativeRuntimePtr)) { // load next batch + return hasNext + } + false + } - override def next(): InternalRow = - recordsQueue.poll() + override def next(): InternalRow = { + checkError() + val row = toUnsafe(batch.getRow(batchCurRowIdx)).copy() + batchCurRowIdx += 1 + row + } } context.foreach(_.addTaskCompletionListener[Unit]((_: TaskContext) => close())) context.foreach(_.addTaskFailureListener((_, _) => close())) def getRowIterator: Iterator[InternalRow] = { - CompletionIterator[InternalRow, Iterator[InternalRow]](rowIterator, this.close()) + CompletionIterator[InternalRow, Iterator[InternalRow]](rowIterator, close()) } protected def getMetrics: MetricNode = @@ -78,33 +101,17 @@ case class BlazeCallNativeWrapper( protected def importSchema(ffiSchemaPtr: Long): Unit = { val ffiSchema = ArrowSchema.wrap(ffiSchemaPtr) - arrowSchema = Data.importSchema(ArrowUtils.rootAllocator, ffiSchema, dictionaryProvider) + arrowSchema = Data.importSchema(allocator, ffiSchema, dictionaryProvider) + schema = ArrowUtils.fromArrowSchema(arrowSchema) + toUnsafe = UnsafeProjection.create(schema) } protected def importBatch(ffiArrayPtr: Long): Unit = { - checkError() - - val root: VectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) val ffiArray = ArrowArray.wrap(ffiArrayPtr) - Utils.tryWithSafeFinally { - Data.importIntoVectorSchemaRoot( - ArrowUtils.rootAllocator, - ffiArray, - root, - dictionaryProvider) - - val toUnsafe = UnsafeProjection.create(ArrowUtils.fromArrowSchema(root.getSchema)) - toUnsafe.initialize(Option(TaskContext.get()).map(_.partitionId()).getOrElse(0)) - - val batch = ColumnarHelper.rootAsBatch(root) - for (row <- ColumnarHelper.batchAsRowIter(batch)) { - checkError() - recordsQueue.offer(toUnsafe(row).copy()) - } - } { - root.close() - ffiArray.close() - } + root = VectorSchemaRoot.create(arrowSchema, allocator) + Data.importIntoVectorSchemaRoot(allocator, ffiArray, root, dictionaryProvider) + batch = ColumnarHelper.rootAsBatch(root) + batchCurRowIdx = 0 } protected def setError(error: Throwable): Unit = { @@ -114,6 +121,7 @@ case class BlazeCallNativeWrapper( protected def checkError(): Unit = { val throwable = error.getAndSet(null) if (throwable != null) { + close() throw throwable } } @@ -139,8 +147,14 @@ case class BlazeCallNativeWrapper( if (nativeRuntimePtr != 0) { JniBridge.finalizeNative(nativeRuntimePtr) nativeRuntimePtr = 0 + if (root != null) { + batch.close() + root.close() + } + dictionaryProvider.close() + allocator.close() + checkError() } - checkError() } } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIStreamImportIterator.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIStreamImportIterator.scala deleted file mode 100644 index 7c8e76407..000000000 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIStreamImportIterator.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2022 The Blaze Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.blaze.arrowio - -import org.apache.arrow.c.ArrowArrayStream -import org.apache.arrow.c.Data -import org.apache.spark.TaskContext - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils - -class ArrowFFIStreamImportIterator( - taskContext: Option[TaskContext], - arrowFFIStreamPtr: Long, - checkError: () => Unit = () => Unit) - extends Iterator[InternalRow] { - - private var stream = ArrowArrayStream.wrap(arrowFFIStreamPtr) - private var reader = Data.importArrayStream(ArrowUtils.rootAllocator, stream) - private var currentRows: Iterator[InternalRow] = Iterator.empty - - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => close())) - - override def hasNext: Boolean = { - if (stream == null) { // closed? - return false - } - if (currentRows.hasNext) { - return true - } - - // load next batch - var hasNextBatch = false - try { - hasNextBatch = reader.loadNextBatch() - checkError() - } catch { - case _ if taskContext.exists(tc => tc.isCompleted() || tc.isInterrupted()) => - hasNextBatch = false - } - if (!hasNextBatch) { - close() - return false - } - - val currentBatch = ColumnarHelper.rootAsBatch(reader.getVectorSchemaRoot) - try { - // convert batch to persisted row iterator - val toUnsafe = this.toUnsafe - val rowIterator = currentBatch.rowIterator() - val currentRowsArray = new Array[UnsafeRow](currentBatch.numRows()) - var i = 0 - while (rowIterator.hasNext) { - currentRowsArray(i) = toUnsafe(rowIterator.next()).copy() - i += 1 - } - currentRows = currentRowsArray.iterator - - } finally { - // current batch can be closed after all rows converted to UnsafeRow - currentBatch.close() - // reader.getVectorSchemaRoot.clear() - } - hasNext - } - - override def next(): InternalRow = { - currentRows.next() - } - - private lazy val toUnsafe: UnsafeProjection = { - val localOutput = ArrowUtils - .fromArrowSchema(reader.getVectorSchemaRoot.getSchema) - .map(field => AttributeReference(field.name, field.dataType, field.nullable)()) - - val toUnsafe = UnsafeProjection.create(localOutput, localOutput) - toUnsafe.initialize(taskContext.map(_.partitionId()).getOrElse(0)) - toUnsafe - } - - def close(): Unit = { - if (stream != null) { - currentRows = Iterator.empty - reader.close() - reader = null - stream.close() - stream = null - checkError() - } - } -}