Skip to content

Commit

Permalink
GH-43410: [Python] Support Arrow PyCapsule stream objects in write_da…
Browse files Browse the repository at this point in the history
…taset (#43771)

### Rationale for this change

Expanding the support internally in pyarrow where we accept objects implementing the Arrow PyCapsule interface. This PR adds support in `ds.write_dataset()` since we already accept a RecordBatchReader as well.

### What changes are included in this PR?

`ds.write_dataset()` and `ds.Scanner.from_baches()` now accept any object implementing the Arrow PyCapsule interface for streams.

### Are these changes tested?

Yes

### Are there any user-facing changes?

No
* GitHub Issue: #43410

Authored-by: Joris Van den Bossche <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
jorisvandenbossche authored Nov 18, 2024
1 parent 4dc0492 commit ad75248
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
15 changes: 12 additions & 3 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3716,10 +3716,13 @@ cdef class Scanner(_Weakrefable):
Parameters
----------
source : Iterator
The iterator of Batches.
source : Iterator or Arrow-compatible stream object
The iterator of Batches. This can be a pyarrow RecordBatchReader,
any object that implements the Arrow PyCapsule Protocol for
streams, or an actual Python iterator of RecordBatches.
schema : Schema
The schema of the batches.
The schema of the batches (required when passing a Python
iterator).
columns : list[str] or dict[str, Expression], default None
The columns to project. This can be a list of column names to
include (order and duplicates will be preserved), or a dictionary
Expand Down Expand Up @@ -3775,6 +3778,12 @@ cdef class Scanner(_Weakrefable):
raise ValueError('Cannot specify a schema when providing '
'a RecordBatchReader')
reader = source
elif hasattr(source, "__arrow_c_stream__"):
if schema:
raise ValueError(
'Cannot specify a schema when providing an object '
'implementing the Arrow PyCapsule Protocol')
reader = pa.ipc.RecordBatchReader.from_stream(source)
elif _is_iterable(source):
if schema is None:
raise ValueError('Must provide schema to construct scanner '
Expand Down
6 changes: 5 additions & 1 deletion python/pyarrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,11 @@ def file_visitor(written_file):
elif isinstance(data, (pa.RecordBatch, pa.Table)):
schema = schema or data.schema
data = InMemoryDataset(data, schema=schema)
elif isinstance(data, pa.ipc.RecordBatchReader) or _is_iterable(data):
elif (
isinstance(data, pa.ipc.RecordBatchReader)
or hasattr(data, "__arrow_c_stream__")
or _is_iterable(data)
):
data = Scanner.from_batches(data, schema=schema)
schema = None
elif not isinstance(data, (Dataset, Scanner)):
Expand Down
20 changes: 17 additions & 3 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@
pytestmark = pytest.mark.dataset


class TableStreamWrapper:
def __init__(self, table):
self.table = table

def __arrow_c_stream__(self, requested_schema=None):
return self.table.__arrow_c_stream__(requested_schema)


def _generate_data(n):
import datetime
import itertools
Expand Down Expand Up @@ -2543,6 +2551,7 @@ def test_scan_iterator(use_threads):
for factory, schema in (
(lambda: pa.RecordBatchReader.from_batches(
batch.schema, [batch]), None),
(lambda: TableStreamWrapper(table), None),
(lambda: (batch for _ in range(1)), batch.schema),
):
# Scanning the fragment consumes the underlying iterator
Expand Down Expand Up @@ -4674,15 +4683,20 @@ def test_write_iterable(tempdir):
base_dir = tempdir / 'inmemory_iterable'
ds.write_dataset((batch for batch in table.to_batches()), base_dir,
schema=table.schema,
basename_template='dat_{i}.arrow', format="feather")
basename_template='dat_{i}.arrow', format="ipc")
result = ds.dataset(base_dir, format="ipc").to_table()
assert result.equals(table)

base_dir = tempdir / 'inmemory_reader'
reader = pa.RecordBatchReader.from_batches(table.schema,
table.to_batches())
ds.write_dataset(reader, base_dir,
basename_template='dat_{i}.arrow', format="feather")
ds.write_dataset(reader, base_dir, basename_template='dat_{i}.arrow', format="ipc")
result = ds.dataset(base_dir, format="ipc").to_table()
assert result.equals(table)

base_dir = tempdir / 'inmemory_pycapsule'
stream = TableStreamWrapper(table)
ds.write_dataset(stream, base_dir, basename_template='dat_{i}.arrow', format="ipc")
result = ds.dataset(base_dir, format="ipc").to_table()
assert result.equals(table)

Expand Down

0 comments on commit ad75248

Please sign in to comment.