Skip to content

Commit

Permalink
Implement equivalet transform fuc for arrow1 (with tests)
Browse files Browse the repository at this point in the history
  • Loading branch information
H-Plus-Time committed Sep 5, 2023
1 parent 7a0f1df commit 9c43488
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 32 deletions.
3 changes: 3 additions & 0 deletions src/arrow1/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ pub mod writer;
#[cfg(feature = "writer")]
pub mod writer_properties;

#[cfg(all(feature = "writer", feature = "async"))]
pub mod writer_async;

pub mod error;

#[cfg(all(feature = "reader", feature = "async"))]
Expand Down
21 changes: 21 additions & 0 deletions src/arrow1/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,24 @@ pub async fn read_parquet_stream(
});
Ok(wasm_streams::ReadableStream::from_stream(stream).into_raw())
}

#[wasm_bindgen(js_name = "transformParquetStream")]
#[cfg(all(feature = "writer", feature = "async"))]
pub fn transform_parquet_stream(
stream: wasm_streams::readable::sys::ReadableStream,
writer_properties: Option<crate::arrow1::writer_properties::WriterProperties>,
) -> WasmResult<wasm_streams::readable::sys::ReadableStream> {
use futures::StreamExt;
let batches = wasm_streams::ReadableStream::from_raw(stream)
.into_stream()
.map(|maybe_chunk| {
let chunk = maybe_chunk.unwrap();
let transformed: arrow_wasm::arrow1::RecordBatch = chunk.try_into().unwrap();
transformed
});
let output_stream = super::writer_async::transform_parquet_stream(
batches,
writer_properties.unwrap_or_default(),
);
Ok(output_stream.unwrap())
}
38 changes: 38 additions & 0 deletions src/arrow1/writer_async.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use crate::arrow1::error::Result;
use crate::common::stream::WrappedWritableStream;
use async_compat::CompatExt;
use futures::StreamExt;
use parquet::arrow::async_writer::AsyncArrowWriter;
use wasm_bindgen_futures::spawn_local;

pub fn transform_parquet_stream(
batches: impl futures::Stream<Item = arrow_wasm::arrow1::RecordBatch> + 'static,
writer_properties: crate::arrow1::writer_properties::WriterProperties,
) -> Result<wasm_streams::readable::sys::ReadableStream> {
let options = Some(writer_properties.into());
// let encoding = writer_properties.get_encoding();

let (writable_stream, output_stream) = {
let raw_stream = wasm_streams::transform::sys::TransformStream::new();
let raw_writable = raw_stream.writable();
let inner_writer = wasm_streams::WritableStream::from_raw(raw_writable).into_async_write();
let writable_stream = WrappedWritableStream {
stream: inner_writer,
};
(writable_stream, raw_stream.readable())
};
spawn_local::<_>(async move {
let mut adapted_stream = batches.peekable();
let mut pinned_stream = std::pin::pin!(adapted_stream);
let first_batch = pinned_stream.as_mut().peek().await.unwrap();
let schema = first_batch.schema().into_inner();
// Need to create an encoding for each column
let mut writer =
AsyncArrowWriter::try_new(writable_stream.compat(), schema, 1024, options).unwrap();
while let Some(batch) = pinned_stream.next().await {
let _ = writer.write(&batch.into()).await;
}
let _ = writer.close().await;
});
Ok(output_stream)
}
33 changes: 2 additions & 31 deletions src/arrow2/writer_async.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,9 @@
use crate::arrow2::error::Result;
use crate::common::stream::WrappedWritableStream;
use arrow2::io::parquet::write::FileSink;
use futures::{AsyncWrite, SinkExt, StreamExt};
use futures::{SinkExt, StreamExt};
use wasm_bindgen_futures::spawn_local;

struct WrappedWritableStream<'writer> {
stream: wasm_streams::writable::IntoAsyncWrite<'writer>,
}

impl<'writer> AsyncWrite for WrappedWritableStream<'writer> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
AsyncWrite::poll_write(std::pin::Pin::new(&mut self.get_mut().stream), cx, buf)
}

fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
AsyncWrite::poll_flush(std::pin::Pin::new(&mut self.get_mut().stream), cx)
}

fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
AsyncWrite::poll_close(std::pin::Pin::new(&mut self.get_mut().stream), cx)
}
}

unsafe impl<'writer> Send for WrappedWritableStream<'writer> {}

pub fn transform_parquet_stream(
batches: impl futures::Stream<Item = arrow_wasm::arrow2::RecordBatch> + 'static,
writer_properties: crate::arrow2::writer_properties::WriterProperties,
Expand Down
3 changes: 3 additions & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ pub mod writer_properties;

#[cfg(feature = "async")]
pub mod fetch;

#[cfg(feature = "async")]
pub mod stream;
31 changes: 31 additions & 0 deletions src/common/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use futures::AsyncWrite;

pub struct WrappedWritableStream<'writer> {
pub stream: wasm_streams::writable::IntoAsyncWrite<'writer>,
}

impl<'writer> AsyncWrite for WrappedWritableStream<'writer> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
AsyncWrite::poll_write(std::pin::Pin::new(&mut self.get_mut().stream), cx, buf)
}

fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
AsyncWrite::poll_flush(std::pin::Pin::new(&mut self.get_mut().stream), cx)
}

fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
AsyncWrite::poll_close(std::pin::Pin::new(&mut self.get_mut().stream), cx)
}
}

unsafe impl<'writer> Send for WrappedWritableStream<'writer> {}
21 changes: 20 additions & 1 deletion tests/js/arrow1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import * as test from "tape";
import * as wasm from "../../pkg/node/arrow1";
import { readFileSync } from "fs";
import { tableFromIPC, tableToIPC } from "apache-arrow";
import { testArrowTablesEqual, readExpectedArrowData } from "./utils";
import { testArrowTablesEqual, readExpectedArrowData, temporaryServer } from "./utils";

// Path from repo root
const dataDir = "tests/data";
Expand Down Expand Up @@ -83,3 +83,22 @@ test("error produced trying to read file with arrayBuffer", (t) => {

t.end();
});

test("read stream-write stream-read stream round trip (no writer properties provided)", async (t) => {
const server = await temporaryServer();
const listeningPort = server.addresses()[0].port;
const rootUrl = `http://localhost:${listeningPort}`;

const expectedTable = readExpectedArrowData();

const url = `${rootUrl}/1-partition-brotli.parquet`;
const originalStream = await wasm.readParquetStream(url);

const stream = await wasm.transformParquetStream(originalStream);
const accumulatedBuffer = new Uint8Array(await new Response(stream).arrayBuffer());
const roundtripTable = tableFromIPC(wasm.readParquet(accumulatedBuffer).intoIPC());

testArrowTablesEqual(t, expectedTable, roundtripTable);
await server.close();
t.end();
})

0 comments on commit 9c43488

Please sign in to comment.