diff --git a/ci_tests/test_webgoat_findings.py b/ci_tests/test_webgoat_findings.py index 5b3d893e..8ce1d1ad 100644 --- a/ci_tests/test_webgoat_findings.py +++ b/ci_tests/test_webgoat_findings.py @@ -12,6 +12,7 @@ "pixee:python/harden-pyyaml", "pixee:python/django-debug-flag-on", "pixee:python/url-sandbox", + "pixee:python/use-walrus-if", ] diff --git a/integration_tests/test_use_walrus_if.py b/integration_tests/test_use_walrus_if.py new file mode 100644 index 00000000..702c8c26 --- /dev/null +++ b/integration_tests/test_use_walrus_if.py @@ -0,0 +1,32 @@ +from codemodder.codemods.use_walrus_if import UseWalrusIf +from integration_tests.base_test import ( + BaseIntegrationTest, + original_and_expected_from_code_path, +) + + +class TestUseWalrusIf(BaseIntegrationTest): + codemod = UseWalrusIf + code_path = "tests/samples/use_walrus_if.py" + original_code, _ = original_and_expected_from_code_path(code_path, []) + expected_new_code = """ +if (x := foo()) is not None: + print(x) + +if y := bar(): + print(y) + +z = baz() +print(z) + + +def whatever(): + if (b := biz()) == 10: + print(b) +""".lstrip() + + expected_diff = "--- \n+++ \n@@ -1,9 +1,7 @@\n-x = foo()\n-if x is not None:\n+if (x := foo()) is not None:\n print(x)\n \n-y = bar()\n-if y:\n+if y := bar():\n print(y)\n \n z = baz()\n@@ -11,6 +9,5 @@\n \n \n def whatever():\n- b = biz()\n- if b == 10:\n+ if (b := biz()) == 10:\n print(b)\n" + + num_changes = 3 + expected_line_change = 1 + change_description = UseWalrusIf.CHANGE_DESCRIPTION diff --git a/src/codemodder/codemods/__init__.py b/src/codemodder/codemods/__init__.py index 1ce4261c..36761be7 100644 --- a/src/codemodder/codemods/__init__.py +++ b/src/codemodder/codemods/__init__.py @@ -18,6 +18,7 @@ from codemodder.codemods.remove_unnecessary_f_str import RemoveUnnecessaryFStr from codemodder.codemods.tempfile_mktemp import TempfileMktemp from codemodder.codemods.requests_verify import RequestsVerify +from codemodder.codemods.use_walrus_if import UseWalrusIf DEFAULT_CODEMODS = { DjangoDebugFlagOn, @@ -36,6 +37,7 @@ UrlSandbox, TempfileMktemp, RequestsVerify, + UseWalrusIf, } ALL_CODEMODS = DEFAULT_CODEMODS diff --git a/src/codemodder/codemods/api/__init__.py b/src/codemodder/codemods/api/__init__.py index c4944d5e..39eb1fa5 100644 --- a/src/codemodder/codemods/api/__init__.py +++ b/src/codemodder/codemods/api/__init__.py @@ -119,10 +119,7 @@ def __init__( # similar when they define their `on_result_found` method. # Right now this is just to demonstrate a particular use case. def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): - pos_to_match = self.node_position(original_node) - if self.filter_by_result( - pos_to_match - ) and self.filter_by_path_includes_or_excludes(pos_to_match): + if self.node_is_selected(original_node): self.report_change(original_node) if (attr := getattr(self, "on_result_found", None)) is not None: # pylint: disable=not-callable diff --git a/src/codemodder/codemods/base_visitor.py b/src/codemodder/codemods/base_visitor.py index 81b36c63..cac304ae 100644 --- a/src/codemodder/codemods/base_visitor.py +++ b/src/codemodder/codemods/base_visitor.py @@ -25,6 +25,12 @@ def filter_by_path_includes_or_excludes(self, pos_to_match): return any(match_line(pos_to_match, line) for line in self.line_include) return True + def node_is_selected(self, node) -> bool: + pos_to_match = self.node_position(node) + return self.filter_by_result( + pos_to_match + ) and self.filter_by_path_includes_or_excludes(pos_to_match) + def node_position(self, node): # See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112 return self.get_metadata(self.METADATA_DEPENDENCIES[0], node) diff --git a/src/codemodder/codemods/use_walrus_if.py b/src/codemodder/codemods/use_walrus_if.py new file mode 100644 index 00000000..0a264747 --- /dev/null +++ b/src/codemodder/codemods/use_walrus_if.py @@ -0,0 +1,141 @@ +from typing import List, Tuple + +import libcst as cst +from libcst._position import CodeRange +from libcst import matchers as m +from libcst.metadata import ParentNodeProvider, ScopeProvider + +from codemodder.change import Change +from codemodder.codemods.base_codemod import ReviewGuidance +from codemodder.codemods.utils_mixin import NameResolutionMixin +from codemodder.codemods.api import SemgrepCodemod + + +class UseWalrusIf(SemgrepCodemod, NameResolutionMixin): + METADATA_DEPENDENCIES = SemgrepCodemod.METADATA_DEPENDENCIES + ( + ParentNodeProvider, + ScopeProvider, + ) + NAME = "use-walrus-if" + SUMMARY = ( + "Replaces multiple expressions involving `if` operator with 'walrus' operator" + ) + REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW + DESCRIPTION = ( + "Replaces multiple expressions involving `if` operator with 'walrus' operator" + ) + + @classmethod + def rule(cls): + return """ + rules: + - patterns: + - pattern: | + $ASSIGN + if $COND: + $BODY + - metavariable-pattern: + metavariable: $ASSIGN + patterns: + - pattern: $VAR = $RHS + - metavariable-pattern: + metavariable: $COND + patterns: + - pattern: $VAR + - metavariable-pattern: + metavariable: $BODY + pattern: $VAR + - focus-metavariable: $ASSIGN + """ + + _modify_next_if: List[Tuple[CodeRange, cst.Assign]] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modify_next_if = [] + + def add_change(self, position: CodeRange): + self.file_context.codemod_changes.append( + Change( + lineNumber=position.start.line, + description=self.CHANGE_DESCRIPTION, + ).to_json() + ) + + def leave_If(self, original_node, updated_node): + if self._modify_next_if: + position, if_node = self._modify_next_if.pop() + is_name = m.matches(updated_node.test, m.Name()) + named_expr = cst.NamedExpr( + target=if_node.targets[0].target, + value=if_node.value, + lpar=[] if is_name else [cst.LeftParen()], + rpar=[] if is_name else [cst.RightParen()], + ) + self.add_change(position) + return ( + updated_node.with_changes(test=named_expr) + if is_name + else updated_node.with_changes( + test=updated_node.test.with_changes(left=named_expr) + ) + ) + + return original_node + + def _is_valid_modification(self, node): + """ + Restricts the kind of modifications we can make to the AST. + + This is necessary since the semgrep rule can't fully encode this restriction. + """ + if parent := self.get_metadata(ParentNodeProvider, node): + if gparent := self.get_metadata(ParentNodeProvider, parent): + if (idx := gparent.children.index(parent)) >= 0: + return m.matches( + gparent.children[idx + 1], + m.If(test=(m.Name() | m.Comparison(left=m.Name()))), + ) + return False + + def leave_Assign(self, original_node, updated_node): + if self.node_is_selected(original_node): + if self._is_valid_modification(original_node): + position = self.node_position(original_node) + self._modify_next_if.append((position, updated_node)) + return cst.RemoveFromParent() + + return original_node + + def leave_SimpleStatementLine(self, original_node, updated_node): + """ + Preserves the whitespace and comments in the line when all children are removed. + + This feels like a bug in libCST but we'll work around it for now. + """ + if not updated_node.body: + trailing_whitespace = ( + ( + original_node.trailing_whitespace.with_changes( + whitespace=cst.SimpleWhitespace(""), + ), + ) + if original_node.trailing_whitespace.comment + else () + ) + # NOTE: The effect of this is to preserve the + # whitespace and comments. However, the type expected by + # cst.Module.body is Sequence[Union[SimpleStatementLine, BaseCompoundStatement]]. + # So technically this violates the expected return type since we + # are not adding a new SimpleStatementLine but instead just bare + # EmptyLine and Comment nodes. + # A more correct solution would involve transferring any whitespace + # and comments to the subsequent SimpleStatementLine (which + # contains the If statement), but this would require a lot more + # state management to fit within the visitor pattern. We should + # revisit this at some point later. + return cst.FlattenSentinel( + original_node.leading_lines + trailing_whitespace + ) + + return updated_node diff --git a/tests/codemods/test_walrus_if.py b/tests/codemods/test_walrus_if.py new file mode 100644 index 00000000..65367c59 --- /dev/null +++ b/tests/codemods/test_walrus_if.py @@ -0,0 +1,150 @@ +import pytest + +from codemodder.codemods.use_walrus_if import UseWalrusIf +from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest + + +class TestUseWalrusIf(BaseSemgrepCodemodTest): + codemod = UseWalrusIf + + @pytest.mark.parametrize( + "condition", + [ + "is None", + "is not None", + "== 42", + '!= "bar"', + ], + ) + def test_simple_use_walrus_if(self, tmpdir, condition): + input_code = f""" +val = do_something() +if val {condition}: + do_something_else(val) +""" + expected_output = f""" +if (val := do_something()) {condition}: + do_something_else(val) +""" + self.run_and_assert(tmpdir, input_code, expected_output) + + def test_walrus_if_name_only(self, tmpdir): + input_code = """ +val = do_something() +if val: + do_something_else(val) +""" + expected_output = """ +if val := do_something(): + do_something_else(val) +""" + self.run_and_assert(tmpdir, input_code, expected_output) + + def test_walrus_if_preserve_comments(self, tmpdir): + input_code = """ +val = do_something() # comment +if val is not None: # another comment + do_something_else(val) +""" + expected_output = """ +# comment +if (val := do_something()) is not None: # another comment + do_something_else(val) +""" + self.run_and_assert(tmpdir, input_code, expected_output) + + def test_walrus_if_multiple(self, tmpdir): + input_code = """ +val = do_something() +if val is not None: + do_something_else(val) + +foo = hello() +if foo == "bar": + whatever(foo) +""" + expected_output = """ +if (val := do_something()) is not None: + do_something_else(val) + +if (foo := hello()) == "bar": + whatever(foo) +""" + self.run_and_assert(tmpdir, input_code, expected_output) + + def test_walrus_if_in_function(self, tmpdir): + """Make sure this works inside more complex code""" + input_code = """ +def foo(): + val = do_something() + if val is not None: + do_something_else(val) +""" + expected_output = """ +def foo(): + if (val := do_something()) is not None: + do_something_else(val) +""" + self.run_and_assert(tmpdir, input_code, expected_output) + + def test_walrus_if_nested(self, tmpdir): + """Make sure this works inside more complex code""" + input_code = """ +x = do_something() +if x is not None: + y = do_something_else(x) + if y is not None: + bizbaz(x, y) +""" + expected_output = """ +if (x := do_something()) is not None: + if (y := do_something_else(x)) is not None: + bizbaz(x, y) +""" + self.run_and_assert(tmpdir, input_code, expected_output) + + @pytest.mark.parametrize("space", ["", "\n"]) + def test_with_whitespace(self, tmpdir, space): + input_code = f""" +val = do_something(){space} +if val is not None: + do_something_else(val) +""" + expected_output = f""" +{space}if (val := do_something()) is not None: + do_something_else(val) +""" + self.run_and_assert(tmpdir, input_code, expected_output) + + @pytest.mark.parametrize("variable", ["foo", "value", "oval"]) + def test_dont_apply_walrus_different_variable(self, tmpdir, variable): + input_code = f""" +val = do_something() +if {variable} is not None: + do_something_else(val) +""" + self.run_and_assert(tmpdir, input_code, input_code) + + def test_dont_apply_walrus_method_call(self, tmpdir): + input_code = """ +val = do_something() +if val.method() is not None: + do_something_else(val) +""" + self.run_and_assert(tmpdir, input_code, input_code) + + def test_dont_apply_walrus_call_with_func(self, tmpdir): + input_code = """ +val = do_something() +if woot(val) is not None: + do_something_else(val) +""" + self.run_and_assert(tmpdir, input_code, input_code) + + def test_dont_apply_walrus_expr(self, tmpdir): + input_code = """ +val = do_something() +if val + 42 is not None: + do_something_else(val) +""" + self.run_and_assert(tmpdir, input_code, input_code) diff --git a/tests/samples/use_walrus_if.py b/tests/samples/use_walrus_if.py new file mode 100644 index 00000000..df591dfd --- /dev/null +++ b/tests/samples/use_walrus_if.py @@ -0,0 +1,16 @@ +x = foo() +if x is not None: + print(x) + +y = bar() +if y: + print(y) + +z = baz() +print(z) + + +def whatever(): + b = biz() + if b == 10: + print(b)