diff --git a/src/codemodder/codemods/use_walrus_if.py b/src/codemodder/codemods/use_walrus_if.py index e4f1d295..f8754c3a 100644 --- a/src/codemodder/codemods/use_walrus_if.py +++ b/src/codemodder/codemods/use_walrus_if.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, Optional import libcst as cst from libcst._position import CodeRange @@ -47,10 +47,12 @@ def rule(cls): """ _modify_next_if: List[Tuple[CodeRange, cst.Assign]] + _if_stack: List[Optional[Tuple[CodeRange, cst.Assign]]] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._modify_next_if = [] + self._if_stack = [] def add_change(self, position: CodeRange): self.file_context.codemod_changes.append( @@ -60,9 +62,14 @@ def add_change(self, position: CodeRange): ).to_json() ) + def visit_If(self, node): + self._if_stack.append( + self._modify_next_if.pop() if len(self._modify_next_if) else None + ) + def leave_If(self, original_node, updated_node): - if self._modify_next_if: - position, if_node = self._modify_next_if.pop() + if (result := self._if_stack.pop()) is not None: + position, if_node = result is_name = m.matches(updated_node.test, m.Name()) named_expr = cst.NamedExpr( target=if_node.targets[0].target, diff --git a/tests/codemods/test_walrus_if.py b/tests/codemods/test_walrus_if.py index 65367c59..b533039a 100644 --- a/tests/codemods/test_walrus_if.py +++ b/tests/codemods/test_walrus_if.py @@ -103,6 +103,21 @@ def test_walrus_if_nested(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected_output) + def test_walrus_if_used_inner(self, tmpdir): + """Make sure this works inside more complex code""" + input_code = """ +result = foo() +if result is not None: + if something_else(): + print(result) +""" + expected_output = """ +if (result := foo()) is not None: + if something_else(): + print(result) +""" + self.run_and_assert(tmpdir, input_code, expected_output) + @pytest.mark.parametrize("space", ["", "\n"]) def test_with_whitespace(self, tmpdir, space): input_code = f"""