Skip to content

Commit

Permalink
Fixing failing cases
Browse files Browse the repository at this point in the history
  • Loading branch information
prsabahrami committed May 18, 2024
1 parent 2af055c commit 2850d44
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 16 deletions.
45 changes: 31 additions & 14 deletions polarify/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ def __iter__(self):
def build_polars_when_then_otherwise(body: Sequence[ResolvedCase], orelse: ast.expr) -> ast.Call:
nodes: list[ast.Call] = []

assert body, "No when-then cases provided."
assert body or orelse, "No when-then cases provided."

if not body:
"""
When a match statement has no valid cases (i.e., all cases except catch-all pattern are ignored),
we return the orelse expression but the test setup does not work with literal expressions.
"""
raise ValueError("No valid cases provided.")

for test, then in body:
when_node = ast.Call(
Expand Down Expand Up @@ -195,9 +202,9 @@ def translate_match(
"""
TODO: Explain the purpose and goal of this method, it's quite complex
"""
if isinstance(pattern, ast.MatchValue) and isinstance(subj, ast.Name):
if isinstance(pattern, ast.MatchValue):
equality_ast = ast.Compare(
left=ast.Name(id=subj.id, ctx=ast.Load()),
left=subj,
ops=[ast.Eq()],
comparators=[pattern.value],
)
Expand All @@ -210,14 +217,12 @@ def translate_match(
)

return equality_ast
elif isinstance(pattern, ast.MatchValue) and isinstance(subj, ast.Tuple):
return self.translate_match(subj, ast.MatchSequence(patterns=[pattern]))
elif isinstance(pattern, ast.MatchAs) and isinstance(subj, ast.Name):
elif isinstance(pattern, ast.MatchAs):
if pattern.name is not None:
self.handle_assign(
ast.Assign(
targets=[ast.Name(id=pattern.name, ctx=ast.Store())],
value=ast.Name(id=subj.id, ctx=ast.Load()),
value=subj,
)
)
return guard
Expand All @@ -234,10 +239,8 @@ def translate_match(
elif isinstance(pattern, ast.MatchSequence):
if isinstance(pattern.patterns[-1], ast.MatchStar):
raise ValueError("starred patterns are not supported.")
if isinstance(subj, ast.Tuple):
while len(subj.elts) > len(pattern.patterns):
pattern.patterns.append(ast.MatchValue(value=ast.Constant(value=None)))

if isinstance(subj, ast.Tuple):
# TODO: Use polars list operations in the future
left = self.translate_match(subj.elts[0], pattern.patterns[0], guard)
right = (
Expand Down Expand Up @@ -298,13 +301,27 @@ def handle_return(self, value: ast.expr):
self.node.orelse.handle_return(value)

def handle_match(self, stmt: ast.Match):
def is_catch_all(pattern: ast.pattern) -> bool:
return isinstance(pattern, ast.MatchAs) and pattern.name is None
def is_catch_all(case: ast.match_case) -> bool:
# We check if the case is a catch-all pattern without a guard
# If it has a guard, we treat it as a regular case
return (
isinstance(case.pattern, ast.MatchAs)
and case.pattern.name is None
and case.guard is None
)

def ignore_case(case: ast.match_case) -> bool:
# if the length of the pattern is not equal to the length of the subject, python ignores the case
return (
isinstance(case.pattern, ast.MatchSequence)
and isinstance(stmt.subject, ast.Tuple)
and len(stmt.subject.elts) != len(case.pattern.patterns)
) or (isinstance(case.pattern, ast.MatchValue) and isinstance(stmt.subject, ast.Tuple))

if isinstance(self.node, UnresolvedState):
# We can always rewrite catch-all patterns to orelse since python throws a SyntaxError if the catch-all pattern is not the last case.
orelse = next(
iter([case.body for case in stmt.cases if is_catch_all(case.pattern)]),
iter([case.body for case in stmt.cases if is_catch_all(case)]),
[],
)
self.node = ConditionalState(
Expand All @@ -319,7 +336,7 @@ def is_catch_all(pattern: ast.pattern) -> bool:
parse_body(case.body, copy(self.node.assignments)),
)
for case in stmt.cases
if not is_catch_all(case.pattern)
if not is_catch_all(case) and not ignore_case(case)
],
orelse=parse_body(
orelse,
Expand Down
39 changes: 37 additions & 2 deletions tests/functions_310.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def match_guarded_match_as(x):
return 3


def match_sequence_padded_length(x):
def match_sequence_padded_length_no_case(x):
y = 2
z = None

Expand All @@ -244,6 +244,38 @@ def match_sequence_padded_length(x):
return -1


def match_sequence_padded_length_return(x):
y = 1
z = 2

match x, y, z:
case 1, 2:
return 1
return -1


def match_sequence_padded_length(x):
y = 1
z = 2

match x, y, z:
case 1, 2:
return 1
case 3, 4:
return -1
case 1, 2, 3:
return 2
return -2


def match_guard_no_assignation(x):
match x:
case _ if x > 1:
return 0
case _:
return 2


functions_310 = [
nested_match,
match_assignments_inside_branch,
Expand All @@ -263,15 +295,18 @@ def match_sequence_padded_length(x):
match_complex_subject,
match_guarded_match_as,
match_sequence_padded_length,
match_guard_no_assignation,
]

xfail_functions_310 = [
match_mapping,
match_sequence_padded_length_no_case,
match_sequence_padded_length_return,
]

unsupported_functions_310 = [
(match_sequence_star, "starred patterns are not supported."),
(match_sequence, "Matching lists is not supported."),
(match_sequence_with_brackets, "Matching lists is not supported."),
(match_guarded_match_as_no_return, "return"),
(match_guarded_match_as_no_return, "Not all branches return"),
]

0 comments on commit 2850d44

Please sign in to comment.