Skip to content

Commit

Permalink
Change DataScan to accept Metadata and io (#581)
Browse files Browse the repository at this point in the history
* Change DataScan to accept Metadata and io

For the partial deletes I want to do a scan on in
memory metadata. Changing this API allows this.

* fix name-mapping issue

---------

Co-authored-by: HonahX <[email protected]>
  • Loading branch information
Fokko and HonahX authored Apr 8, 2024
1 parent 07442cc commit 5039b5d
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 110 deletions.
26 changes: 14 additions & 12 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string

if TYPE_CHECKING:
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table import FileScanTask

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1046,7 +1046,8 @@ def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dic

def project_table(
tasks: Iterable[FileScanTask],
table: Table,
table_metadata: TableMetadata,
io: FileIO,
row_filter: BooleanExpression,
projected_schema: Schema,
case_sensitive: bool = True,
Expand All @@ -1056,7 +1057,8 @@ def project_table(
Args:
tasks (Iterable[FileScanTask]): A URI or a path to a local file.
table (Table): The table that's being queried.
table_metadata (TableMetadata): The table metadata of the table that's being queried
io (FileIO): A FileIO to open streams to the object store
row_filter (BooleanExpression): The expression for filtering rows.
projected_schema (Schema): The output schema.
case_sensitive (bool): Case sensitivity when looking up column names.
Expand All @@ -1065,24 +1067,24 @@ def project_table(
Raises:
ResolveError: When an incompatible query is done.
"""
scheme, netloc, _ = PyArrowFileIO.parse_location(table.location())
if isinstance(table.io, PyArrowFileIO):
fs = table.io.fs_by_scheme(scheme, netloc)
scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location)
if isinstance(io, PyArrowFileIO):
fs = io.fs_by_scheme(scheme, netloc)
else:
try:
from pyiceberg.io.fsspec import FsspecFileIO

if isinstance(table.io, FsspecFileIO):
if isinstance(io, FsspecFileIO):
from pyarrow.fs import PyFileSystem

fs = PyFileSystem(FSSpecHandler(table.io.get_fs(scheme)))
fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
else:
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}")
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}")
except ModuleNotFoundError as e:
# When FsSpec is not installed
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") from e
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e

bound_row_filter = bind(table.schema(), row_filter, case_sensitive=case_sensitive)
bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)

projected_field_ids = {
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
Expand All @@ -1101,7 +1103,7 @@ def project_table(
deletes_per_file.get(task.file.file_path),
case_sensitive,
limit,
table.name_mapping(),
table_metadata.name_mapping(),
)
for task in tasks
]
Expand Down
70 changes: 29 additions & 41 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@
)
from pyiceberg.table.name_mapping import (
NameMapping,
parse_mapping_from_json,
update_mapping,
)
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
Expand Down Expand Up @@ -1215,7 +1214,8 @@ def scan(
limit: Optional[int] = None,
) -> DataScan:
return DataScan(
table=self,
table_metadata=self.metadata,
io=self.io,
row_filter=row_filter,
selected_fields=selected_fields,
case_sensitive=case_sensitive,
Expand Down Expand Up @@ -1312,10 +1312,7 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive

def name_mapping(self) -> Optional[NameMapping]:
"""Return the table's field-id NameMapping."""
if name_mapping_json := self.properties.get(TableProperties.DEFAULT_NAME_MAPPING):
return parse_mapping_from_json(name_mapping_json)
else:
return None
return self.metadata.name_mapping()

def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Expand Down Expand Up @@ -1468,7 +1465,8 @@ def _parse_row_filter(expr: Union[str, BooleanExpression]) -> BooleanExpression:


class TableScan(ABC):
table: Table
table_metadata: TableMetadata
io: FileIO
row_filter: BooleanExpression
selected_fields: Tuple[str, ...]
case_sensitive: bool
Expand All @@ -1478,15 +1476,17 @@ class TableScan(ABC):

def __init__(
self,
table: Table,
table_metadata: TableMetadata,
io: FileIO,
row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE,
selected_fields: Tuple[str, ...] = ("*",),
case_sensitive: bool = True,
snapshot_id: Optional[int] = None,
options: Properties = EMPTY_DICT,
limit: Optional[int] = None,
):
self.table = table
self.table_metadata = table_metadata
self.io = io
self.row_filter = _parse_row_filter(row_filter)
self.selected_fields = selected_fields
self.case_sensitive = case_sensitive
Expand All @@ -1496,19 +1496,20 @@ def __init__(

def snapshot(self) -> Optional[Snapshot]:
if self.snapshot_id:
return self.table.snapshot_by_id(self.snapshot_id)
return self.table.current_snapshot()
return self.table_metadata.snapshot_by_id(self.snapshot_id)
return self.table_metadata.current_snapshot()

def projection(self) -> Schema:
current_schema = self.table.schema()
current_schema = self.table_metadata.schema()
if self.snapshot_id is not None:
snapshot = self.table.snapshot_by_id(self.snapshot_id)
snapshot = self.table_metadata.snapshot_by_id(self.snapshot_id)
if snapshot is not None:
if snapshot.schema_id is not None:
snapshot_schema = self.table.schemas().get(snapshot.schema_id)
if snapshot_schema is not None:
current_schema = snapshot_schema
else:
try:
current_schema = next(
schema for schema in self.table_metadata.schemas if schema.schema_id == snapshot.schema_id
)
except StopIteration:
warnings.warn(f"Metadata does not contain schema with id: {snapshot.schema_id}")
else:
raise ValueError(f"Snapshot not found: {self.snapshot_id}")
Expand All @@ -1534,7 +1535,7 @@ def update(self: S, **overrides: Any) -> S:
def use_ref(self: S, name: str) -> S:
if self.snapshot_id:
raise ValueError(f"Cannot override ref, already set snapshot id={self.snapshot_id}")
if snapshot := self.table.snapshot_by_name(name):
if snapshot := self.table_metadata.snapshot_by_name(name):
return self.update(snapshot_id=snapshot.snapshot_id)

raise ValueError(f"Cannot scan unknown ref={name}")
Expand Down Expand Up @@ -1626,33 +1627,21 @@ def _match_deletes_to_data_file(data_entry: ManifestEntry, positional_delete_ent


class DataScan(TableScan):
def __init__(
self,
table: Table,
row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE,
selected_fields: Tuple[str, ...] = ("*",),
case_sensitive: bool = True,
snapshot_id: Optional[int] = None,
options: Properties = EMPTY_DICT,
limit: Optional[int] = None,
):
super().__init__(table, row_filter, selected_fields, case_sensitive, snapshot_id, options, limit)

def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
project = inclusive_projection(self.table.schema(), self.table.specs()[spec_id])
project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id])
return project(self.row_filter)

@cached_property
def partition_filters(self) -> KeyDefaultDict[int, BooleanExpression]:
return KeyDefaultDict(self._build_partition_projection)

def _build_manifest_evaluator(self, spec_id: int) -> Callable[[ManifestFile], bool]:
spec = self.table.specs()[spec_id]
return manifest_evaluator(spec, self.table.schema(), self.partition_filters[spec_id], self.case_sensitive)
spec = self.table_metadata.specs()[spec_id]
return manifest_evaluator(spec, self.table_metadata.schema(), self.partition_filters[spec_id], self.case_sensitive)

def _build_partition_evaluator(self, spec_id: int) -> Callable[[DataFile], bool]:
spec = self.table.specs()[spec_id]
partition_type = spec.partition_type(self.table.schema())
spec = self.table_metadata.specs()[spec_id]
partition_type = spec.partition_type(self.table_metadata.schema())
partition_schema = Schema(*partition_type.fields)
partition_expr = self.partition_filters[spec_id]

Expand Down Expand Up @@ -1687,16 +1676,14 @@ def plan_files(self) -> Iterable[FileScanTask]:
if not snapshot:
return iter([])

io = self.table.io

# step 1: filter manifests using partition summaries
# the filter depends on the partition spec used to write the manifest file, so create a cache of filters for each spec id

manifest_evaluators: Dict[int, Callable[[ManifestFile], bool]] = KeyDefaultDict(self._build_manifest_evaluator)

manifests = [
manifest_file
for manifest_file in snapshot.manifests(io)
for manifest_file in snapshot.manifests(self.io)
if manifest_evaluators[manifest_file.partition_spec_id](manifest_file)
]

Expand All @@ -1705,7 +1692,7 @@ def plan_files(self) -> Iterable[FileScanTask]:

partition_evaluators: Dict[int, Callable[[DataFile], bool]] = KeyDefaultDict(self._build_partition_evaluator)
metrics_evaluator = _InclusiveMetricsEvaluator(
self.table.schema(), self.row_filter, self.case_sensitive, self.options.get("include_empty_files") == "true"
self.table_metadata.schema(), self.row_filter, self.case_sensitive, self.options.get("include_empty_files") == "true"
).eval

min_data_sequence_number = _min_data_file_sequence_number(manifests)
Expand All @@ -1719,7 +1706,7 @@ def plan_files(self) -> Iterable[FileScanTask]:
lambda args: _open_manifest(*args),
[
(
io,
self.io,
manifest,
partition_evaluators[manifest.partition_spec_id],
metrics_evaluator,
Expand Down Expand Up @@ -1755,7 +1742,8 @@ def to_arrow(self) -> pa.Table:

return project_table(
self.plan_files(),
self.table,
self.table_metadata,
self.io,
self.row_filter,
self.projection(),
case_sensitive=self.case_sensitive,
Expand Down
14 changes: 14 additions & 0 deletions pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pyiceberg.exceptions import ValidationError
from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec, assign_fresh_partition_spec_ids
from pyiceberg.schema import Schema, assign_fresh_schema_ids
from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import MetadataLogEntry, Snapshot, SnapshotLogEntry
from pyiceberg.table.sorting import (
Expand Down Expand Up @@ -237,6 +238,13 @@ def schema(self) -> Schema:
"""Return the schema for this table."""
return next(schema for schema in self.schemas if schema.schema_id == self.current_schema_id)

def name_mapping(self) -> Optional[NameMapping]:
"""Return the table's field-id NameMapping."""
if name_mapping_json := self.properties.get("schema.name-mapping.default"):
return parse_mapping_from_json(name_mapping_json)
else:
return None

def spec(self) -> PartitionSpec:
"""Return the partition spec of this table."""
return next(spec for spec in self.partition_specs if spec.spec_id == self.default_spec_id)
Expand Down Expand Up @@ -278,6 +286,12 @@ def new_snapshot_id(self) -> int:

return snapshot_id

def snapshot_by_name(self, name: str) -> Optional[Snapshot]:
"""Return the snapshot referenced by the given name or null if no such reference exists."""
if ref := self.refs.get(name):
return self.snapshot_by_id(ref.snapshot_id)
return None

def current_snapshot(self) -> Optional[Snapshot]:
"""Get the current snapshot for this table, or None if there is no current snapshot."""
if self.current_snapshot_id is not None:
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog:
for col in df.columns:
assert df.filter(df[col].isNotNull()).count() == 5, "Expected all 5 rows to be non-null"

# check that the table can be read by pyiceberg
assert len(tbl.scan().to_arrow()) == 5, "Expected 5 rows"


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
Expand Down Expand Up @@ -255,6 +258,9 @@ def test_add_files_to_unpartitioned_table_with_schema_updates(
value_count = 1 if col == "quux" else 6
assert df.filter(df[col].isNotNull()).count() == value_count, f"Expected {value_count} rows to be non-null"

# check that the table can be read by pyiceberg
assert len(tbl.scan().to_arrow()) == 6, "Expected 6 rows"


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
Expand Down Expand Up @@ -324,6 +330,9 @@ def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Ca
assert [row.file_count for row in partition_rows] == [5]
assert [(row.partition.baz, row.partition.qux_month) for row in partition_rows] == [(123, 650)]

# check that the table can be read by pyiceberg
assert len(tbl.scan().to_arrow()) == 5, "Expected 5 rows"


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
Expand Down
Loading

0 comments on commit 5039b5d

Please sign in to comment.