Skip to content

Commit

Permalink
update range partition
Browse files Browse the repository at this point in the history
  • Loading branch information
guoying06 committed Dec 31, 2024
1 parent 5a942c1 commit 7e59a18
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 28 deletions.
22 changes: 19 additions & 3 deletions native-engine/blaze-serde/src/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1190,10 +1190,26 @@ pub fn parse_protobuf_partitioning(
let bound_cols: Vec<ArrayRef> = value_list
.iter()
.map(|x| {
let key_values = convert_required!(x.value);
ScalarValue::iter_to_array(key_values)
let xx = x.clone().value.unwrap();
let values_ref = match xx {
protobuf::scalar_value::Value::ListValue(scalar_list) => {
let protobuf::ScalarListValue {
values,
datatype: _opt_scalar_type,
} = scalar_list;
let value_vec: Vec<ScalarValue> = values
.iter()
.map(|val| val.try_into())
.collect::<Result<Vec<_>, _>>()
.map_err(|_| proto_error("partition::from_proto() error"))?;
ScalarValue::iter_to_array(value_vec)
.map_err(|_| proto_error("partition::from_proto() error"))
}
_ => Err(proto_error("partition::from_proto() bound_list type error")),
};
values_ref
})
.collect::<Result<_, DataFusionError>>()?;
.collect::<Result<Vec<ArrayRef>, _>>()?;

let bound_rows = sort_row_converter.lock().convert_columns(&bound_cols)?;
Ok(Some(RePartitioning::RangePartitioning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use datafusion_ext_commons::{
compute_suggested_batch_size_for_output, df_execution_err,
};
use jni::objects::GlobalRef;
#[cfg(test)]
use parking_lot::Mutex;

use crate::{
common::{ipc_compression::IpcCompressionWriter, timer_helper::TimerHelper},
Expand All @@ -39,7 +41,6 @@ use crate::{
evaluate_robin_partition_ids, rss::RssWriter, RePartitioning,
},
};
use parking_lot::Mutex;

pub struct BufferedData {
partition_id: usize,
Expand Down Expand Up @@ -336,7 +337,7 @@ fn sort_batches_by_partition_id(
round_robin_start_rows += batch.num_rows();
round_robin_start_rows %= partitioning.partition_count();
}
RePartitioning::RangePartitioning(sort_expr, partition_num, bounds) => {
RePartitioning::RangePartitioning(sort_expr, _, bounds) => {
part_ids = evaluate_range_partition_ids(&batch, sort_expr, bounds).unwrap();
}
_ => unreachable!("unsupported partitioning: {:?}", partitioning),
Expand Down
17 changes: 7 additions & 10 deletions native-engine/datafusion-ext-plans/src/shuffle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,8 @@ fn get_partition(key_row: Row, bound_rows: &Arc<Rows>, ascending: bool) -> u32 {
}
} else {
// Determine which binary search method to use only once.
partition = binary_search(bound_rows, key_row, 0, num_rows);
partition = binary_search(bound_rows, key_row, 0, num_rows as isize);
// binarySearch either returns the match location or -[insertion point]-1
if partition < 0 {
partition = 0 - partition - 1
}
if partition > num_rows {
partition = num_rows
}
Expand All @@ -265,21 +262,21 @@ fn get_partition(key_row: Row, bound_rows: &Arc<Rows>, ascending: bool) -> u32 {
partition as u32
}

fn binary_search(rows: &Arc<Rows>, target: Row, from_index: usize, to_index: usize) -> usize {
let mut low = from_index;
let mut high = to_index - 1;
fn binary_search(rows: &Arc<Rows>, target: Row, from_index: isize, to_index: isize) -> usize {
let mut low: isize = from_index;
let mut high: isize = to_index - 1;

while low <= high {
let mid = (low + high) >> 1;
let mid_val = rows.row(mid);
let mid_val = rows.row(mid as usize);

if mid_val < target {
low = mid + 1;
} else if mid_val > target {
high = mid - 1;
} else {
return mid; // key found
return mid as usize; // key found
}
}
return low; // key not found.
return low as usize; // key not found.
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ abstract class NativeShuffleExchangeBase(
new SQLShuffleWriteMetricsReporter(shuffleWriteMetrics, metrics).incWriteTime(v)
case _ =>
}))
var numPartitionsRest = numPartitions;
// if RangePartitioning => sample and find bounds
val nativeBounds = outputPartitioning match {
case RangePartitioning(sortingExpressions, numPartitions) =>
Expand All @@ -223,18 +224,7 @@ abstract class NativeShuffleExchangeBase(
numPartitions,
rddForSampling,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)

// bounds.map { internal_row =>
// {
// val valueSeq = sortingExpressions.zipWithIndex.map { case (field, index) =>
// NativeConverters.convertValue(
// internal_row.get(index, field.dataType),
// field.dataType)
// }
// NativeConverters.convertValueSeq(valueSeq)
// }
// }.toList

numPartitionsRest = bounds.length + 1
sortingExpressions.zipWithIndex.map { case (field, index) =>
val valueList = bounds.map { internal_row =>
internal_row.get(index, field.dataType)
Expand Down Expand Up @@ -280,7 +270,7 @@ abstract class NativeShuffleExchangeBase(
.setRangeRepartition(
PhysicalRangeRepartition
.newBuilder()
.setPartitionCount(nativeBounds.length + 1) // reset partition num
.setPartitionCount(numPartitionsRest) // reset partition num
.addAllListValue(nativeBounds.asJava)
.setSortExpr(nativeSortExecNode))
case p =>
Expand Down

0 comments on commit 7e59a18

Please sign in to comment.