diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 3a4fa1ab611a7..fd50215cee9ae 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -3716,10 +3716,13 @@ cdef class Scanner(_Weakrefable): Parameters ---------- - source : Iterator - The iterator of Batches. + source : Iterator or Arrow-compatible stream object + The iterator of Batches. This can be a pyarrow RecordBatchReader, + any object that implements the Arrow PyCapsule Protocol for + streams, or an actual Python iterator of RecordBatches. schema : Schema - The schema of the batches. + The schema of the batches (required when passing a Python + iterator). columns : list[str] or dict[str, Expression], default None The columns to project. This can be a list of column names to include (order and duplicates will be preserved), or a dictionary @@ -3775,6 +3778,12 @@ cdef class Scanner(_Weakrefable): raise ValueError('Cannot specify a schema when providing ' 'a RecordBatchReader') reader = source + elif hasattr(source, "__arrow_c_stream__"): + if schema: + raise ValueError( + 'Cannot specify a schema when providing an object ' + 'implementing the Arrow PyCapsule Protocol') + reader = pa.ipc.RecordBatchReader.from_stream(source) elif _is_iterable(source): if schema is None: raise ValueError('Must provide schema to construct scanner ' diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py index 1efbfe1665a75..c61e13ee75801 100644 --- a/python/pyarrow/dataset.py +++ b/python/pyarrow/dataset.py @@ -964,7 +964,11 @@ def file_visitor(written_file): elif isinstance(data, (pa.RecordBatch, pa.Table)): schema = schema or data.schema data = InMemoryDataset(data, schema=schema) - elif isinstance(data, pa.ipc.RecordBatchReader) or _is_iterable(data): + elif ( + isinstance(data, pa.ipc.RecordBatchReader) + or hasattr(data, "__arrow_c_stream__") + or _is_iterable(data) + ): data = Scanner.from_batches(data, schema=schema) schema = None elif not isinstance(data, (Dataset, Scanner)): diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 772670ad79fd3..b6aaa2840d83c 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -66,6 +66,14 @@ pytestmark = pytest.mark.dataset +class TableStreamWrapper: + def __init__(self, table): + self.table = table + + def __arrow_c_stream__(self, requested_schema=None): + return self.table.__arrow_c_stream__(requested_schema) + + def _generate_data(n): import datetime import itertools @@ -2543,6 +2551,7 @@ def test_scan_iterator(use_threads): for factory, schema in ( (lambda: pa.RecordBatchReader.from_batches( batch.schema, [batch]), None), + (lambda: TableStreamWrapper(table), None), (lambda: (batch for _ in range(1)), batch.schema), ): # Scanning the fragment consumes the underlying iterator @@ -4674,15 +4683,20 @@ def test_write_iterable(tempdir): base_dir = tempdir / 'inmemory_iterable' ds.write_dataset((batch for batch in table.to_batches()), base_dir, schema=table.schema, - basename_template='dat_{i}.arrow', format="feather") + basename_template='dat_{i}.arrow', format="ipc") result = ds.dataset(base_dir, format="ipc").to_table() assert result.equals(table) base_dir = tempdir / 'inmemory_reader' reader = pa.RecordBatchReader.from_batches(table.schema, table.to_batches()) - ds.write_dataset(reader, base_dir, - basename_template='dat_{i}.arrow', format="feather") + ds.write_dataset(reader, base_dir, basename_template='dat_{i}.arrow', format="ipc") + result = ds.dataset(base_dir, format="ipc").to_table() + assert result.equals(table) + + base_dir = tempdir / 'inmemory_pycapsule' + stream = TableStreamWrapper(table) + ds.write_dataset(stream, base_dir, basename_template='dat_{i}.arrow', format="ipc") result = ds.dataset(base_dir, format="ipc").to_table() assert result.equals(table)