Skip to content

Commit

Permalink
use number of physical cores as tokio parallelism, removing blaze.tok…
Browse files Browse the repository at this point in the history
…io.num.worker.threads

sort multiple batches with interleaver in repartitioner

use in-place radix sort

use sync channel in ipc reading

fix collect_set/collect_list incorrect spill
  • Loading branch information
zhangli20 committed Dec 19, 2024
1 parent 7f8319c commit 9061f4b
Show file tree
Hide file tree
Showing 19 changed files with 563 additions and 358 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion native-engine/blaze-jni-bridge/src/conf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ macro_rules! define_conf {

define_conf!(IntConf, BATCH_SIZE);
define_conf!(DoubleConf, MEMORY_FRACTION);
define_conf!(IntConf, TOKIO_NUM_WORKER_THREADS);
define_conf!(BooleanConf, SMJ_INEQUALITY_JOIN_ENABLE);
define_conf!(BooleanConf, CASE_CONVERT_FUNCTIONS_ENABLE);
define_conf!(BooleanConf, INPUT_BATCH_STATISTICS_ENABLE);
Expand Down
1 change: 1 addition & 0 deletions native-engine/blaze/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ once_cell = "1.20.2"
panic-message = "0.3.0"
paste = "1.0.15"
prost = "0.13.3"
raw-cpuid = "11.2.0"
tokio = "=1.42.0"

[target.'cfg(not(windows))'.dependencies]
Expand Down
27 changes: 21 additions & 6 deletions native-engine/blaze/src/rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ use arrow::{
record_batch::RecordBatch,
};
use blaze_jni_bridge::{
conf::{IntConf, TOKIO_NUM_WORKER_THREADS},
is_task_running,
jni_bridge::JavaClasses,
jni_call, jni_call_static, jni_convert_byte_array, jni_exception_check, jni_exception_occurred,
jni_new_global_ref, jni_new_object, jni_new_string,
is_task_running, jni_bridge::JavaClasses, jni_call, jni_call_static, jni_convert_byte_array,
jni_exception_check, jni_exception_occurred, jni_new_global_ref, jni_new_object,
jni_new_string,
};
use blaze_serde::protobuf::TaskDefinition;
use datafusion::{
Expand All @@ -49,6 +47,7 @@ use datafusion_ext_plans::{
use futures::{FutureExt, StreamExt};
use jni::objects::{GlobalRef, JObject};
use prost::Message;
use raw_cpuid::CpuId;
use tokio::{runtime::Runtime, task::JoinHandle};

use crate::{
Expand Down Expand Up @@ -95,13 +94,29 @@ impl NativeExecutionRuntime {
&ExecutionPlanMetricsSet::new(),
);

// determine number of tokio worker threads
// use the real number of available physical cores
let default_parallelism = std::thread::available_parallelism()
.map(|v| v.get())
.unwrap_or(1);
let has_htt = CpuId::new()
.get_feature_info()
.map(|info| info.has_htt())
.unwrap_or(false);
let mut num_worker_threads = if has_htt {
default_parallelism / 2
} else {
default_parallelism
};
num_worker_threads = num_worker_threads.max(1);

// create tokio runtime
// propagate classloader and task context to spawned children threads
let spark_task_context = jni_call_static!(JniBridge.getTaskContext() -> JObject)?;
let spark_task_context_global = jni_new_global_ref!(spark_task_context.as_obj())?;
let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
.thread_name(format!("blaze-native-stage-{stage_id}-part-{partition_id}"))
.worker_threads(TOKIO_NUM_WORKER_THREADS.value()? as usize)
.worker_threads(num_worker_threads)
.on_thread_start(move || {
let classloader = JavaClasses::get().classloader;
let _ = jni_call_static!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct RadixTournamentTree<T> {
#[allow(clippy::len_without_is_empty)]
impl<T: KeyForRadixTournamentTree> RadixTournamentTree<T> {
pub fn new(values: Vec<T>, num_keys: usize) -> Self {
let num_keys = num_keys + 1; // avoid overflow
let num_values = values.len();
let mut tree = unsafe {
// safety:
Expand Down
121 changes: 89 additions & 32 deletions native-engine/datafusion-ext-commons/src/algorithm/rdxsort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,103 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::vec::IntoIter;
use crate::unchecked;

use radsort::Key;
/// Perform radix sort on a single array
///
/// - array: the array to be sorted
/// - counts: the counters to be used for counting, must be initialized to 0.
/// will be filled with the number of elements in each bucket after sorting.
/// - key: a function to extract the key from the array element
pub fn radix_sort_by_key<T>(array: &mut [T], counts: &mut [usize], key: impl Fn(&T) -> usize) {
#[derive(Default, Clone, Copy)]
struct Part {
cur: usize,
end: usize,
}

const STD_SORT_LIMIT: usize = 4096;
let num_keys = counts.len();
let mut counts = unchecked!(counts);
let mut parts = unchecked!(vec![Part::default(); num_keys]);

pub fn radix_sort_unstable(array: &mut [impl Key + Ord]) {
radix_sort_unstable_by_key(array, |v| *v);
}
// count
array.iter().for_each(|item| counts[key(item)] += 1);

// construct parts
let mut beg = 0;
for (idx, count) in counts.iter().enumerate() {
if *count > 0 {
parts[idx] = Part {
cur: beg,
end: beg + count,
};
beg += count;
}
}

pub fn radix_sort_unstable_by_key<T, K: Key + Ord>(array: &mut [T], key: impl Fn(&T) -> K) {
if array.len() < STD_SORT_LIMIT {
array.sort_unstable_by_key(key);
} else {
radsort::sort_by_key(array, key);
// reorganize each partition
let mut inexhausted_part_indices = unchecked!(vec![0; num_keys]);
for i in 0..num_keys {
inexhausted_part_indices[i] = i;
}
while {
inexhausted_part_indices.retain(|&i| parts[i].cur < parts[i].end);
inexhausted_part_indices.len() > 1
} {
for &part_idx in inexhausted_part_indices.iter() {
let cur_part = &parts[part_idx];
let cur = cur_part.cur;
let end = cur_part.end;
for item_idx in cur..end {
let target_part_idx = key(&array[item_idx]);
let target_part = &mut parts[target_part_idx];
unsafe {
// safety: skip bound check
array.swap_unchecked(item_idx, target_part.cur);
}
target_part.cur += 1;
}
}
}
}

pub trait RadixSortIterExt: Iterator {
fn radix_sorted_unstable(self) -> IntoIter<Self::Item>
where
Self: Sized,
Self::Item: Key + Ord,
{
let mut vec: Vec<Self::Item> = self.collect();
radix_sort_unstable(&mut vec);
vec.into_iter()
#[cfg(test)]
mod test {
use rand::Rng;

use super::*;

#[test]
fn fuzzytest_u16_small() {
for n in 0..1000 {
let mut array = vec![];
for _ in 0..n {
array.push(rand::thread_rng().gen::<u16>());
}

let mut array1 = array.clone();
radix_sort_by_key(&mut array1, &mut [0; 65536], |key| *key as usize);

let mut array2 = array.clone();
array2.sort_unstable();

assert_eq!(array1, array2);
}
}

fn radix_sorted_unstable_by_key<K: Key + Ord>(
self,
key: impl Fn(&Self::Item) -> K,
) -> IntoIter<Self::Item>
where
Self: Sized,
{
let mut vec: Vec<Self::Item> = self.collect();
radix_sort_unstable_by_key(&mut vec, key);
vec.into_iter()
#[test]
fn fuzzytest_u16_1m() {
let mut array = vec![];
for _ in 0..1000000 {
array.push(rand::thread_rng().gen::<u16>());
}

let mut array1 = array.clone();
radix_sort_by_key(&mut array1, &mut [0; 65536], |key| *key as usize);

let mut array2 = array.clone();
array2.sort_unstable();

assert_eq!(array1, array2);
}
}

impl<T, I: Iterator<Item = T>> RadixSortIterExt for I {}
1 change: 1 addition & 0 deletions native-engine/datafusion-ext-commons/src/arrow/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ pub fn cast_impl(
)
}
(&DataType::List(_), DataType::List(to_field)) => {
log::info!("XXX cast list to_field={to_field:?}");
let list = as_list_array(array);
let items = cast_impl(list.values(), to_field.data_type(), match_struct_fields)?;
make_array(
Expand Down
23 changes: 17 additions & 6 deletions native-engine/datafusion-ext-commons/src/arrow/selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,13 @@ pub fn create_array_interleaver(
});
}
}
let v = interleaver.arrays[*a].value(*b);
values.push(v)
let array = &interleaver.arrays[*a];
if array.is_valid(*b) {
let v = interleaver.arrays[*a].value(*b);
values.push(v)
} else {
values.push(Default::default());
}
}

let array = PrimitiveArray::<T>::new(values.into(), nulls);
Expand Down Expand Up @@ -172,9 +177,12 @@ pub fn create_array_interleaver(
});
}
}
let o = interleaver.arrays[*a].value_offsets();
let element_len = o[*b + 1].as_usize() - o[*b].as_usize();
capacity += element_len;
let array = &interleaver.arrays[*a];
if array.is_valid(*b) {
let o = array.value_offsets();
let element_len = o[*b + 1].as_usize() - o[*b].as_usize();
capacity += element_len;
}
offsets.append(T::Offset::from_usize(capacity).expect("overflow"));
}

Expand All @@ -192,7 +200,10 @@ pub fn create_array_interleaver(
});
}
}
values.extend_from_slice(interleaver.arrays[*a].value(*b).as_ref());
let array = &interleaver.arrays[*a];
if array.is_valid(*b) {
values.extend_from_slice(interleaver.arrays[*a].value(*b).as_ref());
}
}

// Safety: safe by construction
Expand Down
6 changes: 4 additions & 2 deletions native-engine/datafusion-ext-plans/src/agg/agg_hash_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,10 @@ impl AggHashMap {
}

pub fn upsert_records(&mut self, keys: Vec<impl AggHashMapKey>) -> Vec<u32> {
self.map.reserve(keys.len());
self.map.upsert_many(keys)
tokio::task::block_in_place(|| {
self.map.reserve(keys.len());
self.map.upsert_many(keys)
})
}

pub fn take_keys(&mut self) -> Vec<OwnedKey> {
Expand Down
Loading

0 comments on commit 9061f4b

Please sign in to comment.