Skip to content

Commit

Permalink
Export ArrowArrayStream from polars data frame (#5)
Browse files Browse the repository at this point in the history
* update polars version

* getting there!

* best attempt

* try will's solution

* apply Will's suggestion

* update wrapper

* closer

* fix cargo lock, document

* use s3_register() + Suggests for nanoarrow and arrow instead of data frame method

* update renv lock

* update nanoarrow in renv.lock

* try refresh

* minor

Co-authored-by: sorhawell <[email protected]>
  • Loading branch information
paleolimbot and sorhawell authored Jan 12, 2023
1 parent ffe7772 commit 1daf373
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 7 deletions.
10 changes: 9 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
URL: https://github.com/pola-rs/r-polars
Suggests:
devtools, testthat (>= 3.0.0), rextendr, data.table, pkgdown
devtools,
testthat (>= 3.0.0),
rextendr,
data.table,
pkgdown,
nanoarrow,
arrow
Remotes: apache/arrow-nanoarrow/r
Config/testthat/edition: 3
Collate:
'utils.R'
Expand All @@ -30,6 +37,7 @@ Collate:
'lazyframe__lazy.R'
'options.R'
'parquet.R'
'pkg-nanoarrow.R'
'rlang.R'
'series__series.R'
'translation.R'
Expand Down
7 changes: 2 additions & 5 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
# Generated by extendr: Do not edit by hand

# nolint start

#
# This file was created with the following call:
# .Call("wrap__make_rpolars_wrappers", use_symbols = TRUE, package_name = "rpolars")
Expand Down Expand Up @@ -77,6 +74,8 @@ DataFrame$to_struct <- function(name) .Call(wrap__DataFrame__to_struct, self, na

DataFrame$unnest <- function(names) .Call(wrap__DataFrame__unnest, self, names)

DataFrame$export_stream <- function(stream_ptr) invisible(.Call(wrap__DataFrame__export_stream, self, stream_ptr))

#' @export
`$.DataFrame` <- function (self, name) { func <- DataFrame[[name]]; environment(func) <- environment(); func }

Expand Down Expand Up @@ -669,5 +668,3 @@ LazyGroupBy$tail <- function(n) .Call(wrap__LazyGroupBy__tail, self, n)
#' @export
`[[.LazyGroupBy` <- `$.LazyGroupBy`


# nolint end
23 changes: 23 additions & 0 deletions R/pkg-nanoarrow.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

as_nanoarrow_array_stream.DataFrame <- function(x, ..., schema = NULL) {
# Don't support the schema argument yet
stopifnot(is.null(schema))

stream = nanoarrow::nanoarrow_allocate_array_stream()
.pr$DataFrame$export_stream(x, nanoarrow::nanoarrow_pointer_addr_chr(stream))

stream
}

infer_nanoarrow_schema.DataFrame <- function(x, ...) {
as_nanoarrow_array_stream.DataFrame(x)$get_schema()
}

as_record_batch_reader.DataFrame <- function(x, ..., schema = NULL) {
arrow::as_record_batch_reader(as_nanoarrow_array_stream.DataFrame(x, schema = schema))
}

as_arrow_table.DataFrame <- function(x, ...) {
reader = as_record_batch_reader.DataFrame(x)
reader$read_table()
}
70 changes: 69 additions & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ pl$mem_address = rpolars:::mem_address
#export dtypes directly into pl, because py-polars does that
move_env_elements(pl$dtypes,pl,names(pl$dtypes),remove = FALSE)


# register S3 methods for packages in Suggests
s3_register("nanoarrow::as_nanoarrow_array_stream", "DataFrame")
s3_register("nanoarrow::infer_nanoarrow_schema", "DataFrame")
s3_register("arrow::as_record_batch_reader", "DataFrame")
s3_register("arrow::as_arrow_table", "DataFrame")

pl$numeric_dtypes = pl$dtypes[substr(names(pl$dtypes),1,3) %in% c("Int","Flo")]

Expand All @@ -155,3 +159,67 @@ pl$mem_address = rpolars:::mem_address
print("")
print("done source")


# From the `vctrs` package (this function is intended to be copied
# without attribution or license requirements to avoid a hard dependency on
# vctrs:
# https://github.com/r-lib/vctrs/blob/c2a7710fe55e3a2249c4fdfe75bbccbafcf38804/R/register-s3.R#L25-L31
s3_register <- function(generic, class, method = NULL) {
stopifnot(is.character(generic), length(generic) == 1)
stopifnot(is.character(class), length(class) == 1)

pieces <- strsplit(generic, "::")[[1]]
stopifnot(length(pieces) == 2)
package <- pieces[[1]]
generic <- pieces[[2]]

caller <- parent.frame()

get_method_env <- function() {
top <- topenv(caller)
if (isNamespace(top)) {
asNamespace(environmentName(top))
} else {
caller
}
}
get_method <- function(method, env) {
if (is.null(method)) {
get(paste0(generic, ".", class), envir = get_method_env())
} else {
method
}
}

register <- function(...) {
envir <- asNamespace(package)

# Refresh the method each time, it might have been updated by
# `devtools::load_all()`
method_fn <- get_method(method)
stopifnot(is.function(method_fn))


# Only register if generic can be accessed
if (exists(generic, envir)) {
registerS3method(generic, class, method_fn, envir = envir)
} else if (identical(Sys.getenv("NOT_CRAN"), "true")) {
warning(sprintf(
"Can't find generic `%s` in package %s to register S3 method.",
generic,
package
))
}
}

# Always register hook in case package is later unloaded & reloaded
setHook(packageEvent(package, "onLoad"), register)

# Avoid registration failures during loading (pkgload or regular)
if (isNamespaceLoaded(package)) {
register()
}

invisible()
}
# nocov end
58 changes: 58 additions & 0 deletions renv.lock
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@
"Hash": "e9c08b94391e9f3f97355841229124f2",
"Requirements": []
},
"arrow": {
"Package": "arrow",
"Version": "10.0.1",
"Source": "Repository",
"Repository": "CRAN",
"Hash": "74b551e105beb7ee0fae1f7861668bac",
"Requirements": [
"R6",
"assertthat",
"bit64",
"cpp11",
"glue",
"purrr",
"rlang",
"tidyselect",
"vctrs"
]
},
"askpass": {
"Package": "askpass",
"Version": "1.1",
Expand All @@ -35,6 +53,14 @@
"sys"
]
},
"assertthat": {
"Package": "assertthat",
"Version": "0.2.1",
"Source": "Repository",
"Repository": "CRAN",
"Hash": "50c838a310445e954bc13f26f26a6ecf",
"Requirements": []
},
"base64enc": {
"Package": "base64enc",
"Version": "0.1-3",
Expand All @@ -43,6 +69,24 @@
"Hash": "543776ae6848fde2f48ff3816d0628bc",
"Requirements": []
},
"bit": {
"Package": "bit",
"Version": "4.0.5",
"Source": "Repository",
"Repository": "CRAN",
"Hash": "d242abec29412ce988848d0294b208fd",
"Requirements": []
},
"bit64": {
"Package": "bit64",
"Version": "4.0.5",
"Source": "Repository",
"Repository": "CRAN",
"Hash": "9fe98599ca456d6552421db0d6772d8f",
"Requirements": [
"bit"
]
},
"brew": {
"Package": "brew",
"Version": "1.0-8",
Expand Down Expand Up @@ -547,6 +591,20 @@
"shiny"
]
},
"nanoarrow": {
"Package": "nanoarrow",
"Version": "0.0.0.9000",
"Source": "GitHub",
"RemoteType": "github",
"RemoteHost": "api.github.com",
"RemoteUsername": "apache",
"RemoteRepo": "arrow-nanoarrow",
"RemoteSubdir": "r",
"RemoteRef": "main",
"RemoteSha": "da7b5ecc108848e538a2b2592315d0fcf0b12786",
"Hash": "be8adf716a3aaa52c31605f0e753f7d8",
"Requirements": []
},
"openssl": {
"Package": "openssl",
"Version": "2.0.5",
Expand Down
58 changes: 58 additions & 0 deletions src/rust/src/rdataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,52 @@ use rexpr::*;
pub use rseries::*;
use series_to_r::pl_series_to_list;

use polars_core::utils::arrow;
use polars::prelude::ArrowField;
use arrow::datatypes::DataType;

use crate::utils::r_result_list;

pub struct OwnedDataFrameIterator {
columns: Vec<polars::series::Series>,
data_type: arrow::datatypes::DataType,
idx: usize,
n_chunks: usize,
}

impl OwnedDataFrameIterator {
fn new(df: polars::frame::DataFrame ) -> Self {
let schema = df.schema().to_arrow();
let data_type = DataType::Struct(schema.fields);

Self {
columns: df.get_columns().clone(),
data_type,
idx: 0,
n_chunks: df.n_chunks()
}
}
}

impl Iterator for OwnedDataFrameIterator {
type Item = Result<Box<dyn arrow::array::Array>, arrow::error::Error>;

fn next(&mut self) -> Option<Self::Item> {
if self.idx >= self.n_chunks {
None
} else {
// create a batch of the columns with the same chunk no.
let batch_cols = self.columns.iter().map(|s| s.to_arrow(self.idx)).collect();
self.idx += 1;

let chunk = polars::frame::ArrowChunk::new(batch_cols);
let array = arrow::array::StructArray::new(self.data_type.clone(), chunk.into_arrays(), std::option::Option::None);
Some(std::result::Result::Ok(Box::new(array)))
}
}
}


#[extendr]
#[derive(Debug, Clone)]
pub struct DataFrame(pub pl::DataFrame);
Expand Down Expand Up @@ -204,6 +248,20 @@ impl DataFrame {

r_result_list(self.0.unnest(names).map(|s| DataFrame(s)))
}

pub fn export_stream(&self, stream_ptr: &str) {
let schema = self.0.schema().to_arrow();
let data_type = DataType::Struct(schema.fields);
let field = ArrowField::new("", data_type.clone(), false);

let iter_boxed = Box::new(OwnedDataFrameIterator::new(self.0.clone()));
let mut stream = arrow::ffi::export_iterator(iter_boxed, field);
let stream_out_ptr_addr: usize = stream_ptr.parse().unwrap();
let stream_out_ptr = stream_out_ptr_addr as *mut arrow::ffi::ArrowArrayStream;
unsafe {
std::ptr::swap_nonoverlapping(stream_out_ptr, &mut stream as *mut arrow::ffi::ArrowArrayStream, 1);
}
}
}
use crate::utils::wrappers::null_to_opt;
impl DataFrame {
Expand Down
49 changes: 49 additions & 0 deletions tests/testthat/test-pkg-nanoarrow.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

test_that("as_nanoarrow_array_stream() works for DataFrame", {
skip_if_not_installed("nanoarrow")

df = pl$DataFrame(a = 1L, b = "two")
stream <- nanoarrow::as_nanoarrow_array_stream(df)
expect_s3_class(stream, "nanoarrow_array_stream")
expect_identical(
as.data.frame(stream),
data.frame(a = 1L, b = "two")
)
})

test_that("infer_nanoarrow_schema() works for DataFrame", {
skip_if_not_installed("nanoarrow")

df = pl$DataFrame(a = 1L, b = "two")
stream_schema <- nanoarrow::as_nanoarrow_array_stream(df)$get_schema()
inferred_schema <- nanoarrow::infer_nanoarrow_schema(df)
expect_identical(format(stream_schema), format(inferred_schema))
})

test_that("as_record_batch_reader() works for DataFrame", {
skip_if_not_installed("arrow")

df = pl$DataFrame(a = 1L, b = "two")
reader <- arrow::as_record_batch_reader(df)
expect_s3_class(reader, "RecordBatchReader")

expect_identical(
# two as.data.frame()s because arrow sometimes returns a tibble here
as.data.frame(as.data.frame(reader)),
data.frame(a = 1L, b = "two")
)
})

test_that("as_arrow_table() works for DataFrame", {
skip_if_not_installed("arrow")

df = pl$DataFrame(a = 1L, b = "two")
table <- arrow::as_arrow_table(df)
expect_s3_class(table, "Table")

expect_identical(
# two as.data.frame()s because arrow sometimes returns a tibble here
as.data.frame(as.data.frame(table)),
data.frame(a = 1L, b = "two")
)
})

0 comments on commit 1daf373

Please sign in to comment.