Skip to content

Commit

Permalink
to_arrow_batches
Browse files Browse the repository at this point in the history
  • Loading branch information
sungwy committed Jun 3, 2024
1 parent 1629d28 commit f604b15
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 15 deletions.
118 changes: 103 additions & 15 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,24 +1009,39 @@ def _task_to_record_batches(
for batch in batches:
if positional_deletes:
# Create the mask of indices that we're interested in
indices = _combine_positional_deletes(positional_deletes, current_index, len(batch))

indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
print(f"DEBUG: {indices=} {current_index=} {len(batch)=}")
print(f"{batch=}")
batch = batch.take(indices)
print(f"{batch=}")
# Apply the user filter
if pyarrow_filter is not None:
# we need to switch back and forth between RecordBatch and Table
# as Expression filter isn't yet supported in RecordBatch
# https://github.com/apache/arrow/issues/39220
arrow_table = pa.Table.from_batches([batch])
arrow_table = arrow_table.filter(pyarrow_filter)
arrow_batches = arrow_table.to_batches()
for arrow_batch in arrow_batches:
yield to_requested_schema(projected_schema, file_project_schema, arrow_table)
else:
yield to_requested_schema(projected_schema, file_project_schema, arrow_table)
batch = arrow_table.to_batches()[0]
yield to_requested_schema(projected_schema, file_project_schema, batch)
current_index += len(batch)


def _task_to_table(
fs: FileSystem,
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
) -> pa.Table:
batches = _task_to_record_batches(
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
)
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema))


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
deletes_per_file: Dict[str, List[ChunkedArray]] = {}
unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks]))
Expand Down Expand Up @@ -1103,7 +1118,6 @@ def project_table(
projected_field_ids,
deletes_per_file.get(task.file.file_path),
case_sensitive,
limit,
table_metadata.name_mapping(),
)
for task in tasks
Expand Down Expand Up @@ -1137,8 +1151,78 @@ def project_table(
return result


def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.RecordBatch) -> pa.RecordBatch:
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
def project_batches(
tasks: Iterable[FileScanTask],
table_metadata: TableMetadata,
io: FileIO,
row_filter: BooleanExpression,
projected_schema: Schema,
case_sensitive: bool = True,
limit: Optional[int] = None,
) -> Iterator[pa.ReordBatch]:
"""Resolve the right columns based on the identifier.
Args:
tasks (Iterable[FileScanTask]): A URI or a path to a local file.
table_metadata (TableMetadata): The table metadata of the table that's being queried
io (FileIO): A FileIO to open streams to the object store
row_filter (BooleanExpression): The expression for filtering rows.
projected_schema (Schema): The output schema.
case_sensitive (bool): Case sensitivity when looking up column names.
limit (Optional[int]): Limit the number of records.
Raises:
ResolveError: When an incompatible query is done.
"""
scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location)
if isinstance(io, PyArrowFileIO):
fs = io.fs_by_scheme(scheme, netloc)
else:
try:
from pyiceberg.io.fsspec import FsspecFileIO

if isinstance(io, FsspecFileIO):
from pyarrow.fs import PyFileSystem

fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
else:
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}")
except ModuleNotFoundError as e:
# When FsSpec is not installed
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e

bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)

projected_field_ids = {
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
}.union(extract_field_ids(bound_row_filter))

deletes_per_file = _read_all_delete_files(fs, tasks)

total_row_count = 0

for task in tasks:
batches = _task_to_record_batches(
fs,
task,
bound_row_filter,
projected_schema,
projected_field_ids,
deletes_per_file.get(task.file.file_path),
case_sensitive,
table_metadata.name_mapping(),
)
for batch in batches:
if limit is not None:
if total_row_count + len(batch) >= limit:
yield batch.take(limit - total_row_count)
break
yield batch
total_row_count += len(batch)


def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch:
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))

arrays = []
fields = []
Expand Down Expand Up @@ -1247,8 +1331,8 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st

if isinstance(partner_struct, pa.StructArray):
return partner_struct.field(name)
elif isinstance(partner_struct, pa.Table):
return partner_struct.column(name).combine_chunks()
elif isinstance(partner_struct, pa.RecordBatch):
return partner_struct.column(name)

return None

Expand Down Expand Up @@ -1785,15 +1869,19 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT

def write_parquet(task: WriteTask) -> DataFile:
table_schema = task.schema
arrow_table = pa.Table.from_batches(task.record_batches)

# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
# otherwise use the original schema
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
file_schema = sanitized_schema
arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
batches = [
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch)
for batch in task.record_batches
]
else:
file_schema = table_schema

batches = task.record_batches
arrow_table = pa.Table.from_batches(batches)
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
Expand Down
13 changes: 13 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,19 @@ def to_arrow(self) -> pa.Table:
limit=self.limit,
)

def to_arrow_batches(self) -> pa.Table:
from pyiceberg.io.pyarrow import project_batches

return project_batches(
self.plan_files(),
self.table_metadata,
self.io,
self.row_filter,
self.projection(),
case_sensitive=self.case_sensitive,
limit=self.limit,
)

def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
return self.to_arrow().to_pandas(**kwargs)

Expand Down

0 comments on commit f604b15

Please sign in to comment.