diff --git a/pyiceberg/expressions/residual_evaluator.py b/pyiceberg/expressions/residual_evaluator.py new file mode 100644 index 0000000000..025772f627 --- /dev/null +++ b/pyiceberg/expressions/residual_evaluator.py @@ -0,0 +1,247 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from abc import ABC +from typing import Any, Set + +from pyiceberg.expressions import And, Or +from pyiceberg.expressions.literals import Literal +from pyiceberg.expressions.visitors import ( + AlwaysFalse, + AlwaysTrue, + BooleanExpression, + BoundBooleanExpressionVisitor, + BoundPredicate, + BoundTerm, + Not, + UnboundPredicate, + visit, +) +from pyiceberg.partitioning import PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.typedef import Record +from pyiceberg.types import L + + +class ResidualVisitor(BoundBooleanExpressionVisitor[BooleanExpression], ABC): + schema: Schema + spec: PartitionSpec + case_sensitive: bool + + def __init__(self, schema: Schema, spec: PartitionSpec, case_sensitive: bool, expr: BooleanExpression): + self.schema = schema + self.spec = spec + self.case_sensitive = case_sensitive + self.expr = expr + + def eval(self, partition_data: Record) -> BooleanExpression: + self.struct = partition_data + return visit(self.expr, visitor=self) + + def visit_true(self) -> BooleanExpression: + return AlwaysTrue() + + def visit_false(self) -> BooleanExpression: + return AlwaysFalse() + + def visit_not(self, child_result: BooleanExpression) -> BooleanExpression: + return Not(child_result) + + def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return And(left_result, right_result) + + def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return Or(left_result, right_result) + + def visit_is_null(self, term: BoundTerm[L]) -> BooleanExpression: + if term.eval(self.struct) is None: + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_not_null(self, term: BoundTerm[L]) -> BooleanExpression: + if term.eval(self.struct) is not None: + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_is_nan(self, term: BoundTerm[L]) -> BooleanExpression: + val = term.eval(self.struct) + if val is None: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_nan(self, term: BoundTerm[L]) -> BooleanExpression: + val = term.eval(self.struct) + if val is not None: + return self.visit_true() + else: + return self.visit_false() + + def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) < literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) <= literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) > literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) >= literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) == literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) != literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: + if term.eval(self.struct) in literals: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: + if term.eval(self.struct) not in literals: + return self.visit_true() + else: + return self.visit_false() + + def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + eval_res = term.eval(self.struct) + if eval_res is not None and str(eval_res).startswith(str(literal.value)): + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if not self.visit_starts_with(term, literal): + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: + """ + If there is no strict projection or if it evaluates to false, then return the predicate. + + Get the strict projection and inclusive projection of this predicate in partition data, + then use them to determine whether to return the original predicate. The strict projection + returns true iff the original predicate would have returned true, so the predicate can be + eliminated if the strict projection evaluates to true. Similarly the inclusive projection + returns false iff the original predicate would have returned false, so the predicate can + also be eliminated if the inclusive projection evaluates to false. + + """ + parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) + if parts == []: + return predicate + + from pyiceberg.types import StructType + + def struct_to_schema(struct: StructType) -> Schema: + return Schema(*[f for f in struct.fields]) + + for part in parts: + strict_projection = part.transform.strict_project(part.name, predicate) + strict_result = None + + if strict_projection is not None: + bound = strict_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) + if isinstance(bound, BoundPredicate): + strict_result = super().visit_bound_predicate(bound) + else: + strict_result = bound + + if strict_result is not None and isinstance(strict_result, AlwaysTrue): + return AlwaysTrue() + + inclusive_projection = part.transform.project(part.name, predicate) + inclusive_result = None + if inclusive_projection is not None: + bound_inclusive = inclusive_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) + if isinstance(bound_inclusive, BoundPredicate): + # using predicate method specific to inclusive + inclusive_result = super().visit_bound_predicate(bound_inclusive) + else: + # if the result is not a predicate, then it must be a constant like alwaysTrue or + # alwaysFalse + inclusive_result = bound_inclusive + if inclusive_result is not None and isinstance(inclusive_result, AlwaysFalse): + return AlwaysFalse() + + return predicate + + def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + bound = predicate.bind(self.schema, case_sensitive=True) + + if isinstance(bound, BoundPredicate): + bound_residual = self.visit_bound_predicate(predicate=bound) + # if isinstance(bound_residual, BooleanExpression): + if bound_residual not in (AlwaysFalse(), AlwaysTrue()): + # replace inclusive original unbound predicate + return predicate + + # use the non-predicate residual (e.g. alwaysTrue) + return bound_residual + + # if binding didn't result in a Predicate, return the expression + return bound + + +class ResidualEvaluator(ResidualVisitor): + def residual_for(self, partition_data: Record) -> BooleanExpression: + return self.eval(partition_data) + + +class UnpartitionedResidualEvaluator(ResidualEvaluator): + # Finds the residuals for an Expression the partitions in the given PartitionSpec + def __init__(self, schema: Schema, expr: BooleanExpression): + from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC + + super().__init__(schema=schema, spec=UNPARTITIONED_PARTITION_SPEC, expr=expr, case_sensitive=False) + self.expr = expr + + def residual_for(self, partition_data: Record) -> BooleanExpression: + return self.expr + + +def residual_evaluator_of( + spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema +) -> ResidualEvaluator: + if len(spec.fields) != 0: + return ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive) + else: + return UnpartitionedResidualEvaluator(schema=schema, expr=expr) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 4ec3403bb3..1bf023cd50 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1328,6 +1328,9 @@ def filter(self: S, expr: Union[str, BooleanExpression]) -> S: def with_case_sensitive(self: S, case_sensitive: bool = True) -> S: return self.update(case_sensitive=case_sensitive) + @abstractmethod + def count(self) -> int: ... + class ScanTask(ABC): pass @@ -1594,6 +1597,13 @@ def to_ray(self) -> ray.data.dataset.Dataset: return ray.data.from_arrow(self.to_arrow()) + def count(self) -> int: + res = 0 + tasks = self.plan_files() + for task in tasks: + res += task.file.record_count + return res + @dataclass(frozen=True) class WriteTask: diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 7f72568b41..d5daea900e 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -1360,6 +1360,58 @@ def test_append_table(catalog: SqlCatalog, table_schema_simple: Schema, table_id assert df == table.scan().to_arrow() +@pytest.mark.parametrize( + "catalog", + [ + lazy_fixture("catalog_memory"), + lazy_fixture("catalog_sqlite"), + lazy_fixture("catalog_sqlite_without_rowcount"), + lazy_fixture("catalog_sqlite_fsspec"), + ], +) +@pytest.mark.parametrize( + "table_identifier", + [ + lazy_fixture("random_table_identifier"), + lazy_fixture("random_hierarchical_identifier"), + lazy_fixture("random_table_identifier_with_catalog"), + ], +) +def test_count_table(catalog: SqlCatalog, table_schema_simple: Schema, table_identifier: Identifier) -> None: + table_identifier_nocatalog = catalog._identifier_to_tuple_without_catalog(table_identifier) + namespace = Catalog.namespace_from(table_identifier_nocatalog) + catalog.create_namespace(namespace) + table = catalog.create_table(table_identifier, table_schema_simple) + + df = pa.Table.from_pydict( + { + "foo": ["a"], + "bar": [1], + "baz": [True], + }, + schema=schema_to_pyarrow(table_schema_simple), + ) + + table.append(df) + + # new snapshot is written in APPEND mode + assert len(table.metadata.snapshots) == 1 + assert table.metadata.snapshots[0].snapshot_id == table.metadata.current_snapshot_id + assert table.metadata.snapshots[0].parent_snapshot_id is None + assert table.metadata.snapshots[0].sequence_number == 1 + assert table.metadata.snapshots[0].summary is not None + assert table.metadata.snapshots[0].summary.operation == Operation.APPEND + assert table.metadata.snapshots[0].summary["added-data-files"] == "1" + assert table.metadata.snapshots[0].summary["added-records"] == "1" + assert table.metadata.snapshots[0].summary["total-data-files"] == "1" + assert table.metadata.snapshots[0].summary["total-records"] == "1" + assert len(table.metadata.metadata_log) == 1 + + # read back the data + assert df == table.scan().to_arrow() + assert len(table.scan().to_arrow()) == table.scan().count() + + @pytest.mark.parametrize( "catalog", [ diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py new file mode 100644 index 0000000000..c7210eaf01 --- /dev/null +++ b/tests/expressions/test_residual_evaluator.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name +import pytest + +from pyiceberg.expressions import ( + AlwaysFalse, + AlwaysTrue, + And, + EqualTo, + GreaterThan, + GreaterThanOrEqual, + In, + IsNaN, + IsNull, + LessThan, + NotIn, + NotNaN, + NotNull, + NotStartsWith, + Or, + StartsWith, + UnboundPredicate, +) +from pyiceberg.expressions.literals import literal +from pyiceberg.expressions.residual_evaluator import residual_evaluator_of +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.transforms import DayTransform, IdentityTransform +from pyiceberg.typedef import Record +from pyiceberg.types import DoubleType, FloatType, IntegerType, NestedField, TimestampType + + +def test_identity_transform_residual() -> None: + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) + + predicate = Or( + Or( + And(EqualTo("dateint", 20170815), LessThan("hour", 12)), + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)), + ), + And(EqualTo("dateint", 20170801), GreaterThan("hour", 11)), + ) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20170815)) + + # assert residual == True + assert isinstance(residual, UnboundPredicate) + assert residual.term.name == "hour" + # assert residual.term.field.name == 'hour' + assert residual.literal.value == 12 + assert type(residual) == LessThan + + residual = res_eval.residual_for(Record(dateint=20170801)) + + assert isinstance(residual, UnboundPredicate) + assert residual.term.name == "hour" + assert residual.literal.value == 11 + assert type(residual) == GreaterThan + + residual = res_eval.residual_for(Record(dateint=20170812)) + + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(dateint=20170817)) + + assert residual == AlwaysFalse() + + +def test_case_insensitive_identity_transform_residuals() -> None: + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) + + predicate = Or( + Or( + And(EqualTo("DATEINT", 20170815), LessThan("HOUR", 12)), + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)), + ), + And(EqualTo("Dateint", 20170801), GreaterThan("hOUr", 11)), + ) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + with pytest.raises(ValueError) as e: + residual = res_eval.residual_for(Record(dateint=20170815)) + assert "Could not find field with name DATEINT, case_sensitive=True" in str(e.value) + + +def test_unpartitioned_residuals() -> None: + expressions = [ + AlwaysTrue(), + AlwaysFalse(), + LessThan("a", 5), + GreaterThanOrEqual("b", 16), + NotNull("c"), + IsNull("d"), + In("e", [1, 2, 3]), + NotIn("f", [1, 2, 3]), + NotNaN("g"), + IsNaN("h"), + StartsWith("data", "abcd"), + NotStartsWith("data", "abcd"), + ] + + schema = Schema( + NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType()), NestedField(52, "a", IntegerType()) + ) + for expr in expressions: + from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC + + residual_evaluator = residual_evaluator_of(UNPARTITIONED_PARTITION_SPEC, expr, True, schema=schema) + assert residual_evaluator.residual_for(Record()) == expr + + +def test_in() -> None: + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) + + predicate = In("dateint", [20170815, 20170816, 20170817]) + + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20170815)) + + assert residual == AlwaysTrue() + + +def test_in_timestamp() -> None: + schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part")) + + date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value + date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value + + day = DayTransform().transform(TimestampType()) + # assert date_20191201 == True + ts_day = day(date_20191201) + + # assert ts_day == True + + pred = In("ts", [date_20191202, date_20191201]) + + res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(ts_day)) + assert residual == pred + + residual = res_eval.residual_for(Record(ts_day + 3)) + assert residual == AlwaysFalse() + + +def test_not_in() -> None: + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) + + predicate = NotIn("dateint", [20170815, 20170816, 20170817]) + + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20180815)) + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(dateint=20170815)) + assert residual == AlwaysFalse() + + +def test_is_nan() -> None: + schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) + + predicate = IsNaN("double") + + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysFalse() + + +def test_is_not_nan() -> None: + schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "float", FloatType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) + + predicate = NotNaN("double") + + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysFalse() + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysTrue() + + spec = PartitionSpec(PartitionField(51, 1051, IdentityTransform(), "float_part")) + + predicate = NotNaN("float") + + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysFalse() + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysTrue() + + +def test_not_in_timestamp() -> None: + schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "dateint", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part")) + + date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value + date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value + + day = DayTransform().transform(TimestampType()) + # assert date_20191201 == True + ts_day = day(date_20191201) + + # assert ts_day == True + + pred = NotIn("ts", [date_20191202, date_20191201]) + + res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(ts_day)) + assert residual == pred + + residual = res_eval.residual_for(Record(ts_day + 3)) + assert residual == AlwaysTrue()