From 091c0afd96a774db6250fe9379a8e9cfcff74677 Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:46:22 +0530 Subject: [PATCH 1/4] implemeted residual_evaluator.py with tests --- pyiceberg/expressions/residual_evaluator.py | 220 +++++++++++++ tests/expressions/test_residual_evaluator.py | 306 +++++++++++++++++++ 2 files changed, 526 insertions(+) create mode 100644 pyiceberg/expressions/residual_evaluator.py create mode 100644 tests/expressions/test_residual_evaluator.py diff --git a/pyiceberg/expressions/residual_evaluator.py b/pyiceberg/expressions/residual_evaluator.py new file mode 100644 index 0000000000..4aac91e3a6 --- /dev/null +++ b/pyiceberg/expressions/residual_evaluator.py @@ -0,0 +1,220 @@ +from abc import ABC +from pyiceberg.expressions.visitors import ( + BoundBooleanExpressionVisitor, + BooleanExpression, + UnboundPredicate, + BoundPredicate, + visit, + BoundTerm, + AlwaysFalse, + AlwaysTrue +) +from pyiceberg.expressions.literals import Literal +from pyiceberg.expressions import ( + And, + Or +) +from pyiceberg.types import L +from pyiceberg.partitioning import PartitionSpec +from pyiceberg.schema import Schema +from typing import Any, List, Set +from pyiceberg.typedef import Record + + +class ResidualVisitor(BoundBooleanExpressionVisitor[BooleanExpression], ABC): + schema: Schema + spec: PartitionSpec + case_sensitive: bool + + def __init__(self, schema: Schema, spec: PartitionSpec, case_sensitive: bool, expr: BooleanExpression): + self.schema = schema + self.spec = spec + self.case_sensitive = case_sensitive + self.expr = expr + + + def eval(self, partition_data: Record): + self.struct = partition_data + return visit(self.expr, visitor=self) + + + def visit_true(self) -> BooleanExpression: + return AlwaysTrue() + + def visit_false(self) -> BooleanExpression: + return AlwaysFalse() + + def visit_not(self, child_result: BooleanExpression) -> BooleanExpression: + return Not(child_result) + def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return And(left_result, right_result) + + def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return Or(left_result, right_result) + + + def visit_is_null(self, term: BoundTerm[L]) -> bool: + return term.eval(self.struct) is None + + def visit_not_null(self, term: BoundTerm[L]) -> bool: + return term.eval(self.struct) is not None + + def visit_is_nan(self, term: BoundTerm[L]) -> bool: + val = term.eval(self.struct) + if val is None: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_nan(self, term: BoundTerm[L]) -> bool: + val = term.eval(self.struct) + if val is not None: + return self.visit_true() + else: + return self.visit_false() + + def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) < literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) <= literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) > literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) >= literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) == literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) != literal.value: + return self.visit_true() + else: + return self.visit_false() + + + def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: + if term.eval(self.struct) in literals: + return self.visit_true() + else: + return self.visit_false() + def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: + if term.eval(self.struct) not in literals: + return self.visit_true() + else: + return self.visit_false() + + def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + eval_res = term.eval(self.struct) + return eval_res is not None and str(eval_res).startswith(str(literal.value)) + + def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + return not self.visit_starts_with(term, literal) + + def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> List[str]: + """ + called from eval + input + """ + parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) + if parts == []: + return predicate + + from pyiceberg.types import StructType + def struct_to_schema(struct: StructType) -> Schema: + return Schema(*[f for f in struct.fields]) + + for part in parts: + + strict_projection = part.transform.strict_project(part.name, predicate) + strict_result = None + + if strict_projection is not None: + bound = strict_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) + assert isinstance(bound, BoundPredicate) + if isinstance(bound, BoundPredicate): + strict_result = super().visit_bound_predicate(bound) + else: + strict_result = bound + + if strict_result is not None and isinstance(strict_result, AlwaysTrue): + return AlwaysTrue() + + inclusive_projection = part.transform.project(part.name, predicate) + inclusive_result = None + if inclusive_projection is not None: + bound_inclusive = inclusive_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) + if isinstance(bound_inclusive, BoundPredicate): + # using predicate method specific to inclusive + inclusive_result = super().visit_bound_predicate(bound_inclusive) + else: + # if the result is not a predicate, then it must be a constant like alwaysTrue or + # alwaysFalse + inclusive_result = bound_inclusive + if inclusive_result is not None and isinstance(inclusive_result, AlwaysFalse): + return AlwaysFalse() + + return predicate + + def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + bound = predicate.bind(self.schema, case_sensitive=True) + + if isinstance(bound, BoundPredicate): + bound_residual = self.visit_bound_predicate(predicate=bound) + # if isinstance(bound_residual, BooleanExpression): + if bound_residual not in (AlwaysFalse(), AlwaysTrue()): + # replace inclusive original unbound predicate + return predicate + + # use the non-predicate residual (e.g. alwaysTrue) + return bound_residual + + # if binding didn't result in a Predicate, return the expression + return bound + + + + + +class ResidualEvaluator(ResidualVisitor): + def residual_for(self, partition_data): + return self.eval(partition_data) + +class UnpartitionedResidualEvaluator(ResidualEvaluator): + + def __init__(self, schema: Schema,expr: BooleanExpression): + from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC + super().__init__(schema=schema, spec=UNPARTITIONED_PARTITION_SPEC, expr=expr, case_sensitive=False) + self.expr = expr + + def residual_for(self, partition_data): + return self.expr + + +def residual_evaluator_of( + spec: PartitionSpec, + expr: BooleanExpression, + case_sensitive: bool, + schema: Schema +) -> ResidualEvaluator: + if len(spec.fields) != 0: + return ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive) + else: + return UnpartitionedResidualEvaluator(schema=schema,expr=expr) \ No newline at end of file diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py new file mode 100644 index 0000000000..f47cabac38 --- /dev/null +++ b/tests/expressions/test_residual_evaluator.py @@ -0,0 +1,306 @@ +import pytest +from pyiceberg.expressions import ( + AlwaysTrue, + EqualTo, + LessThan, + AlwaysFalse, + And, + Or, + GreaterThan, + GreaterThanOrEqual, + UnboundPredicate, + BoundPredicate, + BoundReference, + BooleanExpression, + BoundLessThan, + BoundGreaterThan, + NotNull, + IsNull, + In, + NotIn, + NotNaN, + IsNaN, + StartsWith, + NotStartsWith +) +from pyiceberg.expressions.residual_evaluator import residual_evaluator_of +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.transforms import IdentityTransform, DayTransform +from pyiceberg.typedef import Record +from pyiceberg.types import ( + IntegerType, + DoubleType, + FloatType, + NestedField, + StringType, + TimestampType +) +from pyiceberg.utils.datetime import timestamp_to_micros +from pyiceberg.expressions.literals import literal + + +def test_identity_transform_residual(): + + schema = Schema( + NestedField(50, "dateint", IntegerType()), + NestedField(51, "hour", IntegerType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "dateint_part") + ) + + predicate = Or( + Or( + And(EqualTo("dateint", 20170815), LessThan("hour", 12)), + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)) + ), + And(EqualTo("dateint", 20170801), GreaterThan("hour", 11)) + ) + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20170815)) + + # assert residual == True + assert isinstance(residual, UnboundPredicate) + assert residual.term.name == 'hour' + # assert residual.term.field.name == 'hour' + assert residual.literal.value == 12 + assert type(residual) == LessThan + + residual = res_eval.residual_for(Record(dateint=20170801)) + + assert isinstance(residual, UnboundPredicate) + assert residual.term.name == 'hour' + assert residual.literal.value == 11 + assert type(residual) == GreaterThan + + residual = res_eval.residual_for(Record(dateint=20170812)) + + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(dateint=20170817)) + + assert residual == AlwaysFalse() + + +def test_case_insensitive_identity_transform_residuals(): + + schema = Schema( + NestedField(50, "dateint", IntegerType()), + NestedField(51, "hour", IntegerType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "dateint_part") + ) + + predicate = Or( + Or( + And(EqualTo("DATEINT", 20170815), LessThan("HOUR", 12)), + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)) + ), + And(EqualTo("Dateint", 20170801), GreaterThan("hOUr", 11)) + ) + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + + with pytest.raises(ValueError) as e: + residual = res_eval.residual_for(Record(dateint=20170815)) + assert "Could not find field with name DATEINT, case_sensitive=True" in str(e.value) + + +def test_unpartitioned_residuals(): + + + expressions = [ + AlwaysTrue(), + AlwaysFalse(), + LessThan("a", 5), + GreaterThanOrEqual("b", 16), + NotNull("c"), + IsNull("d"), + In("e",[1, 2, 3]), + NotIn("f", [1, 2, 3]), + NotNaN("g"), + IsNaN("h"), + StartsWith("data", "abcd"), + NotStartsWith("data", "abcd") + ] + + schema = Schema( + NestedField(50, "dateint", IntegerType()), + NestedField(51, "hour", IntegerType()), + NestedField(52, "a", IntegerType()) + ) + for expr in expressions: + from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC + residual_evaluator = residual_evaluator_of( + UNPARTITIONED_PARTITION_SPEC, expr, True, schema=schema + ) + assert residual_evaluator.residual_for(Record()) == expr + + +def test_in(): + + schema = Schema( + NestedField(50, "dateint", IntegerType()), + NestedField(51, "hour", IntegerType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "dateint_part") + ) + + predicate = In("dateint", [20170815, 20170816, 20170817]) + + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20170815)) + + assert residual == AlwaysTrue() + + +def test_in_timestamp(): + + schema = Schema( + NestedField(50, "ts", TimestampType()), + NestedField(51, "hour", IntegerType()) + ) + + + spec = PartitionSpec( + PartitionField(50, 1000, DayTransform(), "ts_part") + ) + + date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value + date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value + + day = DayTransform().transform(TimestampType()) + # assert date_20191201 == True + ts_day = day(date_20191201) + + # assert ts_day == True + + pred = In("ts", [ date_20191202, date_20191201]) + + res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(ts_day)) + assert residual == pred + + residual = res_eval.residual_for(Record(ts_day+3)) + assert residual == AlwaysFalse() + + +def test_not_in(): + + schema = Schema( + NestedField(50, "dateint", IntegerType()), + NestedField(51, "hour", IntegerType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "dateint_part") + ) + + predicate = NotIn("dateint", [20170815, 20170816, 20170817]) + + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20180815)) + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(dateint=20170815)) + assert residual == AlwaysFalse() + + +def test_is_nan(): + schema = Schema( + NestedField(50, "double", DoubleType()), + NestedField(51, "hour", IntegerType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "double_part") + ) + + predicate = IsNaN("double") + + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysFalse() + + +def test_is_not_nan(): + schema = Schema( + NestedField(50, "double", DoubleType()), + NestedField(51, "float", FloatType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "double_part") + ) + + predicate = NotNaN("double") + + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysFalse() + + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysTrue() + + + spec = PartitionSpec( + PartitionField(51, 1051, IdentityTransform(), "float_part") + ) + + predicate = NotNaN("float") + + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysFalse() + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysTrue() + + +def test_not_in_timestamp(): + + schema = Schema( + NestedField(50, "ts", TimestampType()), + NestedField(51, "dateint", IntegerType()) + ) + + + spec = PartitionSpec( + PartitionField(50, 1000, DayTransform(), "ts_part") + ) + + date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value + date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value + + day = DayTransform().transform(TimestampType()) + # assert date_20191201 == True + ts_day = day(date_20191201) + + # assert ts_day == True + + pred = NotIn("ts", [ date_20191202, date_20191201]) + + res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(ts_day)) + assert residual == pred + + residual = res_eval.residual_for(Record(ts_day+3)) + assert residual == AlwaysTrue() \ No newline at end of file From 3cd797deb3809570bc17698cbcd504c29473dda6 Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:49:00 +0530 Subject: [PATCH 2/4] added license --- pyiceberg/expressions/residual_evaluator.py | 16 ++++++++++++++++ tests/expressions/test_residual_evaluator.py | 17 +++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/pyiceberg/expressions/residual_evaluator.py b/pyiceberg/expressions/residual_evaluator.py index 4aac91e3a6..bbd954cb07 100644 --- a/pyiceberg/expressions/residual_evaluator.py +++ b/pyiceberg/expressions/residual_evaluator.py @@ -1,3 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. from abc import ABC from pyiceberg.expressions.visitors import ( BoundBooleanExpressionVisitor, diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py index f47cabac38..4af30f9827 100644 --- a/tests/expressions/test_residual_evaluator.py +++ b/tests/expressions/test_residual_evaluator.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name import pytest from pyiceberg.expressions import ( AlwaysTrue, From 6b0924e89863950f81a7c97d41746ecb42d6d2ba Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:59:36 +0530 Subject: [PATCH 3/4] fixed lint --- pyiceberg/expressions/residual_evaluator.py | 58 +++--- tests/expressions/test_residual_evaluator.py | 176 ++++++------------- 2 files changed, 82 insertions(+), 152 deletions(-) diff --git a/pyiceberg/expressions/residual_evaluator.py b/pyiceberg/expressions/residual_evaluator.py index bbd954cb07..4d382bf24d 100644 --- a/pyiceberg/expressions/residual_evaluator.py +++ b/pyiceberg/expressions/residual_evaluator.py @@ -15,26 +15,25 @@ # specific language governing permissions and limitations # under the License. from abc import ABC +from typing import Any, List, Set + +from pyiceberg.expressions import And, Or +from pyiceberg.expressions.literals import Literal from pyiceberg.expressions.visitors import ( - BoundBooleanExpressionVisitor, + AlwaysFalse, + AlwaysTrue, BooleanExpression, - UnboundPredicate, + BoundBooleanExpressionVisitor, BoundPredicate, - visit, BoundTerm, - AlwaysFalse, - AlwaysTrue -) -from pyiceberg.expressions.literals import Literal -from pyiceberg.expressions import ( - And, - Or + Not, + UnboundPredicate, + visit, ) -from pyiceberg.types import L from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema -from typing import Any, List, Set from pyiceberg.typedef import Record +from pyiceberg.types import L class ResidualVisitor(BoundBooleanExpressionVisitor[BooleanExpression], ABC): @@ -48,12 +47,10 @@ def __init__(self, schema: Schema, spec: PartitionSpec, case_sensitive: bool, ex self.case_sensitive = case_sensitive self.expr = expr - def eval(self, partition_data: Record): self.struct = partition_data return visit(self.expr, visitor=self) - def visit_true(self) -> BooleanExpression: return AlwaysTrue() @@ -62,13 +59,13 @@ def visit_false(self) -> BooleanExpression: def visit_not(self, child_result: BooleanExpression) -> BooleanExpression: return Not(child_result) + def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return And(left_result, right_result) def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left_result, right_result) - def visit_is_null(self, term: BoundTerm[L]) -> bool: return term.eval(self.struct) is None @@ -125,12 +122,12 @@ def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: else: return self.visit_false() - def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: if term.eval(self.struct) in literals: return self.visit_true() else: return self.visit_false() + def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: if term.eval(self.struct) not in literals: return self.visit_true() @@ -146,19 +143,26 @@ def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> List[str]: """ - called from eval - input + If there is no strict projection or if it evaluates to false, then return the predicate. + + Get the strict projection and inclusive projection of this predicate in partition data, + then use them to determine whether to return the original predicate. The strict projection + returns true iff the original predicate would have returned true, so the predicate can be + eliminated if the strict projection evaluates to true. Similarly the inclusive projection + returns false iff the original predicate would have returned false, so the predicate can + also be eliminated if the inclusive projection evaluates to false. + """ parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) if parts == []: return predicate from pyiceberg.types import StructType + def struct_to_schema(struct: StructType) -> Schema: return Schema(*[f for f in struct.fields]) for part in parts: - strict_projection = part.transform.strict_project(part.name, predicate) strict_result = None @@ -206,17 +210,16 @@ def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpr return bound - - - class ResidualEvaluator(ResidualVisitor): def residual_for(self, partition_data): return self.eval(partition_data) -class UnpartitionedResidualEvaluator(ResidualEvaluator): - def __init__(self, schema: Schema,expr: BooleanExpression): +class UnpartitionedResidualEvaluator(ResidualEvaluator): + # Finds the residuals for an Expression the partitions in the given PartitionSpec + def __init__(self, schema: Schema, expr: BooleanExpression): from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC + super().__init__(schema=schema, spec=UNPARTITIONED_PARTITION_SPEC, expr=expr, case_sensitive=False) self.expr = expr @@ -225,12 +228,9 @@ def residual_for(self, partition_data): def residual_evaluator_of( - spec: PartitionSpec, - expr: BooleanExpression, - case_sensitive: bool, - schema: Schema + spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema ) -> ResidualEvaluator: if len(spec.fields) != 0: return ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive) else: - return UnpartitionedResidualEvaluator(schema=schema,expr=expr) \ No newline at end of file + return UnpartitionedResidualEvaluator(schema=schema, expr=expr) diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py index 4af30f9827..d49d18eb8d 100644 --- a/tests/expressions/test_residual_evaluator.py +++ b/tests/expressions/test_residual_evaluator.py @@ -16,72 +16,54 @@ # under the License. # pylint:disable=redefined-outer-name import pytest + from pyiceberg.expressions import ( - AlwaysTrue, - EqualTo, - LessThan, AlwaysFalse, + AlwaysTrue, And, - Or, + EqualTo, GreaterThan, GreaterThanOrEqual, - UnboundPredicate, - BoundPredicate, - BoundReference, - BooleanExpression, - BoundLessThan, - BoundGreaterThan, - NotNull, - IsNull, In, + IsNaN, + IsNull, + LessThan, NotIn, NotNaN, - IsNaN, + NotNull, + NotStartsWith, + Or, StartsWith, - NotStartsWith + UnboundPredicate, ) +from pyiceberg.expressions.literals import literal from pyiceberg.expressions.residual_evaluator import residual_evaluator_of from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.transforms import IdentityTransform, DayTransform +from pyiceberg.transforms import DayTransform, IdentityTransform from pyiceberg.typedef import Record -from pyiceberg.types import ( - IntegerType, - DoubleType, - FloatType, - NestedField, - StringType, - TimestampType -) -from pyiceberg.utils.datetime import timestamp_to_micros -from pyiceberg.expressions.literals import literal +from pyiceberg.types import DoubleType, FloatType, IntegerType, NestedField, TimestampType def test_identity_transform_residual(): + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) - schema = Schema( - NestedField(50, "dateint", IntegerType()), - NestedField(51, "hour", IntegerType()) - ) - - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "dateint_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) predicate = Or( Or( And(EqualTo("dateint", 20170815), LessThan("hour", 12)), - And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)) + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)), ), - And(EqualTo("dateint", 20170801), GreaterThan("hour", 11)) + And(EqualTo("dateint", 20170801), GreaterThan("hour", 11)), ) - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(dateint=20170815)) # assert residual == True assert isinstance(residual, UnboundPredicate) - assert residual.term.name == 'hour' + assert residual.term.name == "hour" # assert residual.term.field.name == 'hour' assert residual.literal.value == 12 assert type(residual) == LessThan @@ -89,7 +71,7 @@ def test_identity_transform_residual(): residual = res_eval.residual_for(Record(dateint=20170801)) assert isinstance(residual, UnboundPredicate) - assert residual.term.name == 'hour' + assert residual.term.name == "hour" assert residual.literal.value == 11 assert type(residual) == GreaterThan @@ -103,25 +85,18 @@ def test_identity_transform_residual(): def test_case_insensitive_identity_transform_residuals(): + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) - schema = Schema( - NestedField(50, "dateint", IntegerType()), - NestedField(51, "hour", IntegerType()) - ) - - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "dateint_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) predicate = Or( Or( And(EqualTo("DATEINT", 20170815), LessThan("HOUR", 12)), - And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)) + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)), ), - And(EqualTo("Dateint", 20170801), GreaterThan("hOUr", 11)) + And(EqualTo("Dateint", 20170801), GreaterThan("hOUr", 11)), ) - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) - + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) with pytest.raises(ValueError) as e: residual = res_eval.residual_for(Record(dateint=20170815)) @@ -129,8 +104,6 @@ def test_case_insensitive_identity_transform_residuals(): def test_unpartitioned_residuals(): - - expressions = [ AlwaysTrue(), AlwaysFalse(), @@ -138,41 +111,32 @@ def test_unpartitioned_residuals(): GreaterThanOrEqual("b", 16), NotNull("c"), IsNull("d"), - In("e",[1, 2, 3]), + In("e", [1, 2, 3]), NotIn("f", [1, 2, 3]), NotNaN("g"), IsNaN("h"), StartsWith("data", "abcd"), - NotStartsWith("data", "abcd") + NotStartsWith("data", "abcd"), ] schema = Schema( - NestedField(50, "dateint", IntegerType()), - NestedField(51, "hour", IntegerType()), - NestedField(52, "a", IntegerType()) + NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType()), NestedField(52, "a", IntegerType()) ) for expr in expressions: from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC - residual_evaluator = residual_evaluator_of( - UNPARTITIONED_PARTITION_SPEC, expr, True, schema=schema - ) + + residual_evaluator = residual_evaluator_of(UNPARTITIONED_PARTITION_SPEC, expr, True, schema=schema) assert residual_evaluator.residual_for(Record()) == expr def test_in(): + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) - schema = Schema( - NestedField(50, "dateint", IntegerType()), - NestedField(51, "hour", IntegerType()) - ) - - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "dateint_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) predicate = In("dateint", [20170815, 20170816, 20170817]) - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(dateint=20170815)) @@ -180,16 +144,9 @@ def test_in(): def test_in_timestamp(): + schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "hour", IntegerType())) - schema = Schema( - NestedField(50, "ts", TimestampType()), - NestedField(51, "hour", IntegerType()) - ) - - - spec = PartitionSpec( - PartitionField(50, 1000, DayTransform(), "ts_part") - ) + spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part")) date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value @@ -200,31 +157,25 @@ def test_in_timestamp(): # assert ts_day == True - pred = In("ts", [ date_20191202, date_20191201]) + pred = In("ts", [date_20191202, date_20191201]) res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(ts_day)) assert residual == pred - residual = res_eval.residual_for(Record(ts_day+3)) + residual = res_eval.residual_for(Record(ts_day + 3)) assert residual == AlwaysFalse() def test_not_in(): + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) - schema = Schema( - NestedField(50, "dateint", IntegerType()), - NestedField(51, "hour", IntegerType()) - ) - - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "dateint_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) predicate = NotIn("dateint", [20170815, 20170816, 20170817]) - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(dateint=20180815)) assert residual == AlwaysTrue() @@ -234,18 +185,13 @@ def test_not_in(): def test_is_nan(): - schema = Schema( - NestedField(50, "double", DoubleType()), - NestedField(51, "hour", IntegerType()) - ) + schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "hour", IntegerType())) - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "double_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) predicate = IsNaN("double") - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(double=None)) assert residual == AlwaysTrue() @@ -255,34 +201,25 @@ def test_is_nan(): def test_is_not_nan(): - schema = Schema( - NestedField(50, "double", DoubleType()), - NestedField(51, "float", FloatType()) - ) + schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "float", FloatType())) - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "double_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) predicate = NotNaN("double") - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(double=None)) assert residual == AlwaysFalse() - residual = res_eval.residual_for(Record(double=2)) assert residual == AlwaysTrue() - - spec = PartitionSpec( - PartitionField(51, 1051, IdentityTransform(), "float_part") - ) + spec = PartitionSpec(PartitionField(51, 1051, IdentityTransform(), "float_part")) predicate = NotNaN("float") - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(double=None)) assert residual == AlwaysFalse() @@ -292,16 +229,9 @@ def test_is_not_nan(): def test_not_in_timestamp(): + schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "dateint", IntegerType())) - schema = Schema( - NestedField(50, "ts", TimestampType()), - NestedField(51, "dateint", IntegerType()) - ) - - - spec = PartitionSpec( - PartitionField(50, 1000, DayTransform(), "ts_part") - ) + spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part")) date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value @@ -312,12 +242,12 @@ def test_not_in_timestamp(): # assert ts_day == True - pred = NotIn("ts", [ date_20191202, date_20191201]) + pred = NotIn("ts", [date_20191202, date_20191201]) res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(ts_day)) assert residual == pred - residual = res_eval.residual_for(Record(ts_day+3)) - assert residual == AlwaysTrue() \ No newline at end of file + residual = res_eval.residual_for(Record(ts_day + 3)) + assert residual == AlwaysTrue() From 96cb4e9977513e5ecd95722fa644ef84e61bc00f Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Tue, 24 Dec 2024 13:42:32 +0530 Subject: [PATCH 4/4] fixed lint errors --- pyiceberg/expressions/residual_evaluator.py | 59 ++++++++++++-------- tests/expressions/test_residual_evaluator.py | 18 +++--- 2 files changed, 44 insertions(+), 33 deletions(-) diff --git a/pyiceberg/expressions/residual_evaluator.py b/pyiceberg/expressions/residual_evaluator.py index 4d382bf24d..025772f627 100644 --- a/pyiceberg/expressions/residual_evaluator.py +++ b/pyiceberg/expressions/residual_evaluator.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from abc import ABC -from typing import Any, List, Set +from typing import Any, Set from pyiceberg.expressions import And, Or from pyiceberg.expressions.literals import Literal @@ -47,7 +47,7 @@ def __init__(self, schema: Schema, spec: PartitionSpec, case_sensitive: bool, ex self.case_sensitive = case_sensitive self.expr = expr - def eval(self, partition_data: Record): + def eval(self, partition_data: Record) -> BooleanExpression: self.struct = partition_data return visit(self.expr, visitor=self) @@ -66,82 +66,94 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left_result, right_result) - def visit_is_null(self, term: BoundTerm[L]) -> bool: - return term.eval(self.struct) is None + def visit_is_null(self, term: BoundTerm[L]) -> BooleanExpression: + if term.eval(self.struct) is None: + return AlwaysTrue() + else: + return AlwaysFalse() - def visit_not_null(self, term: BoundTerm[L]) -> bool: - return term.eval(self.struct) is not None + def visit_not_null(self, term: BoundTerm[L]) -> BooleanExpression: + if term.eval(self.struct) is not None: + return AlwaysTrue() + else: + return AlwaysFalse() - def visit_is_nan(self, term: BoundTerm[L]) -> bool: + def visit_is_nan(self, term: BoundTerm[L]) -> BooleanExpression: val = term.eval(self.struct) if val is None: return self.visit_true() else: return self.visit_false() - def visit_not_nan(self, term: BoundTerm[L]) -> bool: + def visit_not_nan(self, term: BoundTerm[L]) -> BooleanExpression: val = term.eval(self.struct) if val is not None: return self.visit_true() else: return self.visit_false() - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) < literal.value: return self.visit_true() else: return self.visit_false() - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) <= literal.value: return self.visit_true() else: return self.visit_false() - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) > literal.value: return self.visit_true() else: return self.visit_false() - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) >= literal.value: return self.visit_true() else: return self.visit_false() - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) == literal.value: return self.visit_true() else: return self.visit_false() - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) != literal.value: return self.visit_true() else: return self.visit_false() - def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: + def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: if term.eval(self.struct) in literals: return self.visit_true() else: return self.visit_false() - def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: + def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: if term.eval(self.struct) not in literals: return self.visit_true() else: return self.visit_false() - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: eval_res = term.eval(self.struct) - return eval_res is not None and str(eval_res).startswith(str(literal.value)) + if eval_res is not None and str(eval_res).startswith(str(literal.value)): + return AlwaysTrue() + else: + return AlwaysFalse() - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: - return not self.visit_starts_with(term, literal) + def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if not self.visit_starts_with(term, literal): + return AlwaysTrue() + else: + return AlwaysFalse() - def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> List[str]: + def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: """ If there is no strict projection or if it evaluates to false, then return the predicate. @@ -168,7 +180,6 @@ def struct_to_schema(struct: StructType) -> Schema: if strict_projection is not None: bound = strict_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) - assert isinstance(bound, BoundPredicate) if isinstance(bound, BoundPredicate): strict_result = super().visit_bound_predicate(bound) else: @@ -211,7 +222,7 @@ def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpr class ResidualEvaluator(ResidualVisitor): - def residual_for(self, partition_data): + def residual_for(self, partition_data: Record) -> BooleanExpression: return self.eval(partition_data) @@ -223,7 +234,7 @@ def __init__(self, schema: Schema, expr: BooleanExpression): super().__init__(schema=schema, spec=UNPARTITIONED_PARTITION_SPEC, expr=expr, case_sensitive=False) self.expr = expr - def residual_for(self, partition_data): + def residual_for(self, partition_data: Record) -> BooleanExpression: return self.expr diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py index d49d18eb8d..c7210eaf01 100644 --- a/tests/expressions/test_residual_evaluator.py +++ b/tests/expressions/test_residual_evaluator.py @@ -45,7 +45,7 @@ from pyiceberg.types import DoubleType, FloatType, IntegerType, NestedField, TimestampType -def test_identity_transform_residual(): +def test_identity_transform_residual() -> None: schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) @@ -84,7 +84,7 @@ def test_identity_transform_residual(): assert residual == AlwaysFalse() -def test_case_insensitive_identity_transform_residuals(): +def test_case_insensitive_identity_transform_residuals() -> None: schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) @@ -103,7 +103,7 @@ def test_case_insensitive_identity_transform_residuals(): assert "Could not find field with name DATEINT, case_sensitive=True" in str(e.value) -def test_unpartitioned_residuals(): +def test_unpartitioned_residuals() -> None: expressions = [ AlwaysTrue(), AlwaysFalse(), @@ -129,7 +129,7 @@ def test_unpartitioned_residuals(): assert residual_evaluator.residual_for(Record()) == expr -def test_in(): +def test_in() -> None: schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) @@ -143,7 +143,7 @@ def test_in(): assert residual == AlwaysTrue() -def test_in_timestamp(): +def test_in_timestamp() -> None: schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part")) @@ -168,7 +168,7 @@ def test_in_timestamp(): assert residual == AlwaysFalse() -def test_not_in(): +def test_not_in() -> None: schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) @@ -184,7 +184,7 @@ def test_not_in(): assert residual == AlwaysFalse() -def test_is_nan(): +def test_is_nan() -> None: schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) @@ -200,7 +200,7 @@ def test_is_nan(): assert residual == AlwaysFalse() -def test_is_not_nan(): +def test_is_not_nan() -> None: schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "float", FloatType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) @@ -228,7 +228,7 @@ def test_is_not_nan(): assert residual == AlwaysTrue() -def test_not_in_timestamp(): +def test_not_in_timestamp() -> None: schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "dateint", IntegerType())) spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part"))