Skip to content

Commit

Permalink
Add test and fix edge case in use-walrus-if
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Sep 25, 2023
1 parent 92c2b8a commit 03a1d7e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/codemodder/codemods/use_walrus_if.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Optional

import libcst as cst
from libcst._position import CodeRange
Expand Down Expand Up @@ -47,10 +47,12 @@ def rule(cls):
"""

_modify_next_if: List[Tuple[CodeRange, cst.Assign]]
_if_stack: List[Optional[Tuple[CodeRange, cst.Assign]]]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modify_next_if = []
self._if_stack = []

def add_change(self, position: CodeRange):
self.file_context.codemod_changes.append(
Expand All @@ -60,9 +62,14 @@ def add_change(self, position: CodeRange):
).to_json()
)

def visit_If(self, node):
self._if_stack.append(
self._modify_next_if.pop() if len(self._modify_next_if) else None
)

def leave_If(self, original_node, updated_node):
if self._modify_next_if:
position, if_node = self._modify_next_if.pop()
if (result := self._if_stack.pop()) is not None:
position, if_node = result
is_name = m.matches(updated_node.test, m.Name())
named_expr = cst.NamedExpr(
target=if_node.targets[0].target,
Expand Down
15 changes: 15 additions & 0 deletions tests/codemods/test_walrus_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,21 @@ def test_walrus_if_nested(self, tmpdir):
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_walrus_if_used_inner(self, tmpdir):
"""Make sure this works inside more complex code"""
input_code = """
result = foo()
if result is not None:
if something_else():
print(result)
"""
expected_output = """
if (result := foo()) is not None:
if something_else():
print(result)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.parametrize("space", ["", "\n"])
def test_with_whitespace(self, tmpdir, space):
input_code = f"""
Expand Down

0 comments on commit 03a1d7e

Please sign in to comment.