-
Notifications
You must be signed in to change notification settings - Fork 199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Table.to_arrow_batch_reader
to return RecordBatchReader instead of a fully materialized Arrow Table
#786
Changes from 2 commits
1629d28
f604b15
83e09d6
7a4d7d2
63d2b78
6b95390
2570455
79cf181
905cc7a
39a99c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -655,12 +655,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedAr | |||
} | ||||
|
||||
|
||||
def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: int) -> pa.Array: | ||||
def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start_index: int, end_index: int) -> pa.Array: | ||||
if len(positional_deletes) == 1: | ||||
all_chunks = positional_deletes[0] | ||||
else: | ||||
all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes])) | ||||
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False) | ||||
return np.subtract(np.setdiff1d(np.arange(start_index, end_index), all_chunks, assume_unique=False), start_index) | ||||
|
||||
|
||||
def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema: | ||||
|
@@ -967,17 +967,16 @@ def _field_id(self, field: pa.Field) -> int: | |||
return -1 | ||||
|
||||
|
||||
def _task_to_table( | ||||
def _task_to_record_batches( | ||||
fs: FileSystem, | ||||
task: FileScanTask, | ||||
bound_row_filter: BooleanExpression, | ||||
projected_schema: Schema, | ||||
projected_field_ids: Set[int], | ||||
positional_deletes: Optional[List[ChunkedArray]], | ||||
case_sensitive: bool, | ||||
limit: Optional[int] = None, | ||||
name_mapping: Optional[NameMapping] = None, | ||||
) -> Optional[pa.Table]: | ||||
) -> Iterator[pa.RecordBatch]: | ||||
_, _, path = PyArrowFileIO.parse_location(task.file.file_path) | ||||
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) | ||||
with fs.open_input_file(path) as fin: | ||||
|
@@ -1005,36 +1004,42 @@ def _task_to_table( | |||
columns=[col.name for col in file_project_schema.columns], | ||||
) | ||||
|
||||
if positional_deletes: | ||||
# Create the mask of indices that we're interested in | ||||
indices = _combine_positional_deletes(positional_deletes, fragment.count_rows()) | ||||
|
||||
if limit: | ||||
if pyarrow_filter is not None: | ||||
# In case of the filter, we don't exactly know how many rows | ||||
# we need to fetch upfront, can be optimized in the future: | ||||
# https://github.com/apache/arrow/issues/35301 | ||||
arrow_table = fragment_scanner.take(indices) | ||||
arrow_table = arrow_table.filter(pyarrow_filter) | ||||
arrow_table = arrow_table.slice(0, limit) | ||||
else: | ||||
arrow_table = fragment_scanner.take(indices[0:limit]) | ||||
else: | ||||
arrow_table = fragment_scanner.take(indices) | ||||
current_index = 0 | ||||
batches = fragment_scanner.to_batches() | ||||
for batch in batches: | ||||
if positional_deletes: | ||||
# Create the mask of indices that we're interested in | ||||
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch)) | ||||
print(f"DEBUG: {indices=} {current_index=} {len(batch)=}") | ||||
print(f"{batch=}") | ||||
batch = batch.take(indices) | ||||
print(f"{batch=}") | ||||
# Apply the user filter | ||||
if pyarrow_filter is not None: | ||||
# we need to switch back and forth between RecordBatch and Table | ||||
# as Expression filter isn't yet supported in RecordBatch | ||||
# https://github.com/apache/arrow/issues/39220 | ||||
arrow_table = pa.Table.from_batches([batch]) | ||||
arrow_table = arrow_table.filter(pyarrow_filter) | ||||
else: | ||||
# If there are no deletes, we can just take the head | ||||
# and the user-filter is already applied | ||||
if limit: | ||||
arrow_table = fragment_scanner.head(limit) | ||||
else: | ||||
arrow_table = fragment_scanner.to_table() | ||||
batch = arrow_table.to_batches()[0] | ||||
yield to_requested_schema(projected_schema, file_project_schema, batch) | ||||
current_index += len(batch) | ||||
|
||||
if len(arrow_table) < 1: | ||||
return None | ||||
return to_requested_schema(projected_schema, file_project_schema, arrow_table) | ||||
|
||||
def _task_to_table( | ||||
fs: FileSystem, | ||||
task: FileScanTask, | ||||
bound_row_filter: BooleanExpression, | ||||
projected_schema: Schema, | ||||
projected_field_ids: Set[int], | ||||
positional_deletes: Optional[List[ChunkedArray]], | ||||
case_sensitive: bool, | ||||
name_mapping: Optional[NameMapping] = None, | ||||
) -> pa.Table: | ||||
batches = _task_to_record_batches( | ||||
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping | ||||
) | ||||
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema)) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was exactly what I had in mind, looking good 👍 |
||||
|
||||
|
||||
def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: | ||||
|
@@ -1113,7 +1118,6 @@ def project_table( | |||
projected_field_ids, | ||||
deletes_per_file.get(task.file.file_path), | ||||
case_sensitive, | ||||
limit, | ||||
table_metadata.name_mapping(), | ||||
) | ||||
for task in tasks | ||||
|
@@ -1147,16 +1151,86 @@ def project_table( | |||
return result | ||||
|
||||
|
||||
def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table: | ||||
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) | ||||
def project_batches( | ||||
tasks: Iterable[FileScanTask], | ||||
table_metadata: TableMetadata, | ||||
io: FileIO, | ||||
row_filter: BooleanExpression, | ||||
projected_schema: Schema, | ||||
case_sensitive: bool = True, | ||||
limit: Optional[int] = None, | ||||
) -> Iterator[pa.ReordBatch]: | ||||
sungwy marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
"""Resolve the right columns based on the identifier. | ||||
|
||||
Args: | ||||
tasks (Iterable[FileScanTask]): A URI or a path to a local file. | ||||
table_metadata (TableMetadata): The table metadata of the table that's being queried | ||||
io (FileIO): A FileIO to open streams to the object store | ||||
row_filter (BooleanExpression): The expression for filtering rows. | ||||
projected_schema (Schema): The output schema. | ||||
case_sensitive (bool): Case sensitivity when looking up column names. | ||||
limit (Optional[int]): Limit the number of records. | ||||
|
||||
Raises: | ||||
ResolveError: When an incompatible query is done. | ||||
""" | ||||
scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location) | ||||
if isinstance(io, PyArrowFileIO): | ||||
fs = io.fs_by_scheme(scheme, netloc) | ||||
else: | ||||
try: | ||||
from pyiceberg.io.fsspec import FsspecFileIO | ||||
|
||||
if isinstance(io, FsspecFileIO): | ||||
from pyarrow.fs import PyFileSystem | ||||
|
||||
fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme))) | ||||
else: | ||||
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") | ||||
except ModuleNotFoundError as e: | ||||
# When FsSpec is not installed | ||||
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e | ||||
|
||||
bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive) | ||||
|
||||
projected_field_ids = { | ||||
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType)) | ||||
}.union(extract_field_ids(bound_row_filter)) | ||||
|
||||
deletes_per_file = _read_all_delete_files(fs, tasks) | ||||
|
||||
total_row_count = 0 | ||||
|
||||
for task in tasks: | ||||
batches = _task_to_record_batches( | ||||
fs, | ||||
task, | ||||
bound_row_filter, | ||||
projected_schema, | ||||
projected_field_ids, | ||||
deletes_per_file.get(task.file.file_path), | ||||
case_sensitive, | ||||
table_metadata.name_mapping(), | ||||
) | ||||
for batch in batches: | ||||
if limit is not None: | ||||
if total_row_count + len(batch) >= limit: | ||||
yield batch.take(limit - total_row_count) | ||||
break | ||||
yield batch | ||||
total_row_count += len(batch) | ||||
|
||||
|
||||
def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch: | ||||
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) | ||||
|
||||
arrays = [] | ||||
fields = [] | ||||
for pos, field in enumerate(requested_schema.fields): | ||||
array = struct_array.field(pos) | ||||
arrays.append(array) | ||||
fields.append(pa.field(field.name, array.type, field.optional)) | ||||
return pa.Table.from_arrays(arrays, schema=pa.schema(fields)) | ||||
return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields)) | ||||
|
||||
|
||||
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]): | ||||
|
@@ -1257,8 +1331,8 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st | |||
|
||||
if isinstance(partner_struct, pa.StructArray): | ||||
return partner_struct.field(name) | ||||
elif isinstance(partner_struct, pa.Table): | ||||
return partner_struct.column(name).combine_chunks() | ||||
elif isinstance(partner_struct, pa.RecordBatch): | ||||
return partner_struct.column(name) | ||||
sungwy marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
return None | ||||
|
||||
|
@@ -1795,15 +1869,19 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT | |||
|
||||
def write_parquet(task: WriteTask) -> DataFile: | ||||
table_schema = task.schema | ||||
arrow_table = pa.Table.from_batches(task.record_batches) | ||||
|
||||
# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly | ||||
# otherwise use the original schema | ||||
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema: | ||||
file_schema = sanitized_schema | ||||
arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table) | ||||
batches = [ | ||||
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch) | ||||
for batch in task.record_batches | ||||
] | ||||
else: | ||||
file_schema = table_schema | ||||
|
||||
batches = task.record_batches | ||||
arrow_table = pa.Table.from_batches(batches) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, looking here, this forced materialization seems to preclude streaming writes, which would you may want if e.g. upserting large amounts of data. ParquetWriter can be used for streaming writes, so this seems unnecessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i.e., maybe we could do something like the following?: def sanitize_batches(batches: Iterator[RecordBatch], table_schema: Schema, sanitized_schema: Schema) -> Iterator[RecordBatch]:
if sanitized_schema != table_schema:
for batch in batches:
yield to_requested_schema(requested_schema=sanitized_schema, file_schema=table_schema, batch=batch)
else:
yield from batches
def write_parquet(task: WriteTask) -> DataFile:
table_schema = task.schema
# Check if schema needs to be transformed
sanitized_schema = sanitize_column_names(table_schema)
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=sanitized_schema.as_arrow(), **parquet_writer_kwargs) as writer:
for sanitized_batch in sanitize_batches(task.record_batches, table_schema, sanitized_schema):
writer.write_table(pa.Table.from_batches([sanitized_batch]), row_group_size=row_group_size) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep I totally agree. I wanted to focus this PR on introducing the reader first, and then work on a subsequent PR to incorporate batches into writes. This just maintains the existing functionality while making use of the refactored There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah so the change here is the order of operations. We want to call I wonder if we can push iceberg-python/pyiceberg/table/__init__.py Line 2920 in a29491a
Also in this #829, I wanted to introduce schema projection We can keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That’s a great suggestion @kevinjqliu - good to see that our work is going to be converging naturally here. I was hoping to focus on the new read API here and the necessary refactoring in the utility functions, and keep the changes to the write functions to a minimum. I could incorporate these changes and continue the discussion on updating the write functions in a follow up PR. I think there’s much discussion that are worth continuing on that topic (can we avoid materializing an arrow table and write with record batches)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sgtm, ty! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for the review! |
||||
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}' | ||||
fo = io.new_output(file_path) | ||||
with fo.create(overwrite=True) as fos: | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I review this, it occurs to me that it might be useful to expose options related to batching/read ahead, etc, that pyarrow accepts when constructing the scanner. See the pyarrow docs for more details.
Specifically, I think setting batch_size is probably something that ought to be tunable, since the memory pressure will be a function of batch size and the number and types of columns in the table.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great suggestion @corleyma I'll adopt this feedback when I make the next round of changes