Skip to content

Commit

Permalink
fix test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangli20 committed Dec 25, 2024
1 parent a0cf56c commit cb087c7
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 44 deletions.
4 changes: 2 additions & 2 deletions native-engine/datafusion-ext-plans/src/agg/agg_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ impl HashingData {
// EOF
write_len(NUM_SPILL_BUCKETS, &mut writer)?;
write_len(0, &mut writer)?;
writer.flush()?;
writer.finish()?;
Ok(())
}
}
Expand Down Expand Up @@ -617,7 +617,7 @@ impl MergingData {
// EOF
write_len(NUM_SPILL_BUCKETS, &mut writer)?;
write_len(0, &mut writer)?;
writer.flush()?;
writer.finish()?;
Ok(())
}
}
Expand Down
7 changes: 3 additions & 4 deletions native-engine/datafusion-ext-plans/src/agg/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -707,12 +707,11 @@ mod tests {
acc_col.append_item(2, &ScalarValue::Int32(Some(7)));

let mut spill: Box<dyn Spill> = Box::new(vec![]);
let mut spill_writer = spill.get_compressed_writer();
acc_col
.spill(
IdxSelection::Range(0, 3),
&mut spill.get_compressed_writer(),
)
.spill(IdxSelection::Range(0, 3), &mut spill_writer)
.unwrap();
spill_writer.finish().unwrap();

let mut acc_col_unspill = AccSetColumn::empty(DataType::Int32);
acc_col_unspill
Expand Down
4 changes: 2 additions & 2 deletions native-engine/datafusion-ext-plans/src/agg_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ mod test {
Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_agg() -> Result<()> {
MemManager::init(10000);

Expand Down Expand Up @@ -647,7 +647,7 @@ mod fuzztest {
memmgr::MemManager,
};

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn fuzztest() -> Result<()> {
MemManager::init(1000); // small memory config to trigger spill
let session_ctx =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl<W: Write> IpcCompressionWriter<W> {
pub fn finish_current_buf(&mut self) -> Result<()> {
if !self.block_empty {
// finish current buf
self.block_writer.finish()?;
self.block_writer.finish_internal()?;

// write
let block_len = self.shared_buf.inner().len() - 4;
Expand Down Expand Up @@ -186,7 +186,11 @@ impl<W: Write> IoCompressionWriter<W> {
}
}

pub fn finish(&mut self) -> Result<()> {
pub fn finish(mut self) -> Result<()> {
self.finish_internal()
}

fn finish_internal(&mut self) -> Result<()> {
match self {
IoCompressionWriter::LZ4(w) => {
w.try_finish()
Expand Down
36 changes: 18 additions & 18 deletions native-engine/datafusion-ext-plans/src/joins/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ mod tests {
SHJRightProbed,
];

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_inner_one() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
Expand Down Expand Up @@ -304,7 +304,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_inner_two() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
Expand Down Expand Up @@ -344,7 +344,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_inner_two_two() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
Expand Down Expand Up @@ -385,7 +385,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_inner_with_nulls() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table_i32_nullable(
Expand Down Expand Up @@ -425,7 +425,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_left_one() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
Expand Down Expand Up @@ -459,7 +459,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_right_one() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
Expand Down Expand Up @@ -493,7 +493,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_full_one() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
Expand Down Expand Up @@ -530,7 +530,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_anti() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
Expand Down Expand Up @@ -563,7 +563,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_semi() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
Expand Down Expand Up @@ -597,7 +597,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_with_duplicated_column_names() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
Expand Down Expand Up @@ -631,7 +631,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_date32() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_date_table(
Expand Down Expand Up @@ -667,7 +667,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_date64() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_date64_table(
Expand Down Expand Up @@ -704,7 +704,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_left_sort_order() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
Expand Down Expand Up @@ -742,7 +742,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_right_sort_order() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left = build_table(
Expand Down Expand Up @@ -776,7 +776,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_left_multiple_batches() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left_batch_1 = build_table_i32(
Expand Down Expand Up @@ -827,7 +827,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_right_multiple_batches() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let right_batch_1 = build_table_i32(
Expand Down Expand Up @@ -878,7 +878,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_full_multiple_batches() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left_batch_1 = build_table_i32(
Expand Down Expand Up @@ -931,7 +931,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_existence_multiple_batches() -> Result<()> {
for test_type in ALL_TEST_TYPE {
let left_batch_1 = build_table_i32(
Expand Down
32 changes: 21 additions & 11 deletions native-engine/datafusion-ext-plans/src/orc_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use bytes::Bytes;
use datafusion::{
datasource::{
physical_plan::{FileMeta, FileOpenFuture, FileOpener, FileScanConfig, FileStream},
schema_adapter::SchemaMapper,
},
error::Result,
execution::context::TaskContext,
Expand All @@ -33,14 +34,14 @@ use datafusion::{
PlanProperties, SendableRecordBatchStream, Statistics,
},
};
use datafusion::datasource::schema_adapter::SchemaMapper;
use datafusion_ext_commons::{batch_size, df_execution_err, hadoop_fs::FsProvider};
use futures::{future::BoxFuture, FutureExt, StreamExt};
use futures_util::TryStreamExt;
use once_cell::sync::OnceCell;
use orc_rust::{
arrow_reader::ArrowReaderBuilder, projection::ProjectionMask, reader::AsyncChunkReader,
reader::metadata::FileMetadata,
arrow_reader::ArrowReaderBuilder,
projection::ProjectionMask,
reader::{metadata::FileMetadata, AsyncChunkReader},
};

use crate::{
Expand Down Expand Up @@ -221,7 +222,8 @@ impl FileOpener for OrcOpener {
builder = builder.with_file_byte_range(range);
}

let (schema_mapping, projection) = schema_adapter.map_schema(builder.file_metadata())?;
let (schema_mapping, projection) =
schema_adapter.map_schema(builder.file_metadata())?;

let projection_mask =
ProjectionMask::roots(builder.file_metadata().root_data_type(), projection);
Expand Down Expand Up @@ -268,26 +270,34 @@ struct SchemaAdapter {

impl SchemaAdapter {
pub fn new(table_schema: SchemaRef, projection: Vec<usize>) -> Self {
Self { table_schema, projection }
Self {
table_schema,
projection,
}
}

fn map_schema(&self, orc_file_meta: &FileMetadata) -> Result<(Arc<dyn SchemaMapper>, Vec<usize>)> {
fn map_schema(
&self,
orc_file_meta: &FileMetadata,
) -> Result<(Arc<dyn SchemaMapper>, Vec<usize>)> {
let projected_schema = SchemaRef::from(self.table_schema.project(&self.projection)?);

let mut projection = Vec::with_capacity(projected_schema.fields().len());
let mut field_mappings = vec![None; self.table_schema.fields().len()];

for nameColumn in orc_file_meta.root_data_type().children() {
for named_column in orc_file_meta.root_data_type().children() {
if let Some((table_idx, _table_field)) =
projected_schema.fields().find(nameColumn.name()) {
projected_schema.fields().find(named_column.name())
{
field_mappings[table_idx] = Some(projection.len());
projection.push(nameColumn.data_type().column_index());
projection.push(named_column.data_type().column_index());
}
}

Ok((
Arc::new(BlazeSchemaMapping::new(self.table_schema.clone(),
field_mappings,
Arc::new(BlazeSchemaMapping::new(
self.table_schema.clone(),
field_mappings,
)),
projection,
))
Expand Down
5 changes: 4 additions & 1 deletion native-engine/datafusion-ext-plans/src/scan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ pub struct BlazeSchemaMapping {

impl BlazeSchemaMapping {
pub fn new(table_schema: SchemaRef, field_mappings: Vec<Option<usize>>) -> Self {
Self { table_schema, field_mappings }
Self {
table_schema,
field_mappings,
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@ mod test {
"| 10 | 9 | 4 |",
"| 17 | 2 | 7 |",
"| 13 | 6 | 1 |",
"| 16 | 3 | 8 |",
"| 12 | 7 | 2 |",
"| 16 | 3 | 8 |",
"| 19 | 0 | 5 |",
"| 15 | 4 | 9 |",
"| 11 | 8 | 3 |",
Expand Down
5 changes: 2 additions & 3 deletions native-engine/datafusion-ext-plans/src/sort_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ impl BufferedData {
write_one_batch(batch.num_rows(), batch.columns(), &mut writer)?;
writer.write_all(&key_collector.store)?;
}
writer.flush()?;
writer.finish()?;
Ok(())
}

Expand Down Expand Up @@ -947,8 +947,7 @@ fn merge_spills(
)?;
output_writer.write_all(&key_collector.store)?;
}
output_writer.flush()?;
drop(output_writer);
output_writer.finish()?;
Ok(output_spill)
}

Expand Down

0 comments on commit cb087c7

Please sign in to comment.