-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
255 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from typing import List, Tuple | ||
|
||
import libcst as cst | ||
from libcst._position import CodeRange | ||
from libcst import matchers as m | ||
from libcst.metadata import ParentNodeProvider, ScopeProvider | ||
|
||
from codemodder.change import Change | ||
from codemodder.codemods.base_codemod import ReviewGuidance | ||
from codemodder.codemods.utils_mixin import NameResolutionMixin | ||
from codemodder.codemods.api import SemgrepCodemod | ||
|
||
|
||
class UseWalrusIf(SemgrepCodemod, NameResolutionMixin): | ||
METADATA_DEPENDENCIES = SemgrepCodemod.METADATA_DEPENDENCIES + ( | ||
ParentNodeProvider, | ||
ScopeProvider, | ||
) | ||
NAME = "use-walrus-if" | ||
SUMMARY = ( | ||
"Replaces multiple expressions involving `if` operator with 'walrus' operator" | ||
) | ||
REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW | ||
DESCRIPTION = ( | ||
"Replaces multiple expressions involving `if` operator with 'walrus' operator" | ||
) | ||
|
||
@classmethod | ||
def rule(cls): | ||
return """ | ||
rules: | ||
- patterns: | ||
- pattern: | | ||
$ASSIGN | ||
if $COND: | ||
$BODY | ||
- metavariable-pattern: | ||
metavariable: $ASSIGN | ||
patterns: | ||
- pattern: $VAR = $RHS | ||
- metavariable-pattern: | ||
metavariable: $COND | ||
patterns: | ||
- pattern: $VAR | ||
- metavariable-pattern: | ||
metavariable: $BODY | ||
pattern: $VAR | ||
- focus-metavariable: $ASSIGN | ||
""" | ||
|
||
_modify_next_if: List[Tuple[CodeRange, cst.Assign]] | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self._modify_next_if = [] | ||
|
||
def add_change(self, position: CodeRange): | ||
self.file_context.codemod_changes.append( | ||
Change( | ||
lineNumber=position.start.line, | ||
description=self.CHANGE_DESCRIPTION, | ||
).to_json() | ||
) | ||
|
||
def leave_If(self, original_node, updated_node): | ||
if self._modify_next_if: | ||
position, if_node = self._modify_next_if.pop() | ||
is_name = m.matches(updated_node.test, m.Name()) | ||
named_expr = cst.NamedExpr( | ||
target=if_node.targets[0].target, | ||
value=if_node.value, | ||
lpar=[] if is_name else [cst.LeftParen()], | ||
rpar=[] if is_name else [cst.RightParen()], | ||
) | ||
self.add_change(position) | ||
return ( | ||
updated_node.with_changes(test=named_expr) | ||
if is_name | ||
else updated_node.with_changes( | ||
test=updated_node.test.with_changes(left=named_expr) | ||
) | ||
) | ||
|
||
return original_node | ||
|
||
def _is_valid_modification(self, node): | ||
""" | ||
Restricts the kind of modifications we can make to the AST. | ||
This is necessary since the semgrep rule can't fully encode this restriction. | ||
""" | ||
if parent := self.get_metadata(ParentNodeProvider, node): | ||
if gparent := self.get_metadata(ParentNodeProvider, parent): | ||
if (idx := gparent.children.index(parent)) >= 0: | ||
return m.matches( | ||
gparent.children[idx + 1], | ||
m.If(test=(m.Name() | m.Comparison(left=m.Name()))), | ||
) | ||
return False | ||
|
||
def leave_Assign(self, original_node, updated_node): | ||
if self.node_is_selected(original_node): | ||
if self._is_valid_modification(original_node): | ||
position = self.node_position(original_node) | ||
self._modify_next_if.append((position, updated_node)) | ||
return cst.RemoveFromParent() | ||
|
||
return original_node |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import pytest | ||
|
||
from codemodder.codemods.use_walrus_if import UseWalrusIf | ||
from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest | ||
|
||
|
||
class TestUseWalrusIf(BaseSemgrepCodemodTest): | ||
codemod = UseWalrusIf | ||
|
||
@pytest.mark.parametrize( | ||
"condition", | ||
[ | ||
"is None", | ||
"is not None", | ||
"== 42", | ||
'!= "bar"', | ||
], | ||
) | ||
def test_simple_use_walrus_if(self, tmpdir, condition): | ||
input_code = f""" | ||
val = do_something() | ||
if val {condition}: | ||
do_something_else(val) | ||
""" | ||
expected_output = f""" | ||
if (val := do_something()) {condition}: | ||
do_something_else(val) | ||
""" | ||
self.run_and_assert(tmpdir, input_code, expected_output) | ||
|
||
def test_walrus_if_name_only(self, tmpdir): | ||
input_code = """ | ||
val = do_something() | ||
if val: | ||
do_something_else(val) | ||
""" | ||
expected_output = """ | ||
if val := do_something(): | ||
do_something_else(val) | ||
""" | ||
self.run_and_assert(tmpdir, input_code, expected_output) | ||
|
||
def test_walrus_if_multiple(self, tmpdir): | ||
input_code = """ | ||
val = do_something() | ||
if val is not None: | ||
do_something_else(val) | ||
foo = hello() | ||
if foo == "bar": | ||
whatever(foo) | ||
""" | ||
# TODO: not sure why libcst isn't preserving empty lines | ||
expected_output = """ | ||
if (val := do_something()) is not None: | ||
do_something_else(val) | ||
if (foo := hello()) == "bar": | ||
whatever(foo) | ||
""" | ||
self.run_and_assert(tmpdir, input_code, expected_output) | ||
|
||
def test_walrus_if_in_function(self, tmpdir): | ||
"""Make sure this works inside more complex code""" | ||
input_code = """ | ||
def foo(): | ||
val = do_something() | ||
if val is not None: | ||
do_something_else(val) | ||
""" | ||
expected_output = """ | ||
def foo(): | ||
if (val := do_something()) is not None: | ||
do_something_else(val) | ||
""" | ||
self.run_and_assert(tmpdir, input_code, expected_output) | ||
|
||
def test_walrus_if_nested(self, tmpdir): | ||
"""Make sure this works inside more complex code""" | ||
input_code = """ | ||
x = do_something() | ||
if x is not None: | ||
y = do_something_else(x) | ||
if y is not None: | ||
bizbaz(x, y) | ||
""" | ||
expected_output = """ | ||
if (x := do_something()) is not None: | ||
if (y := do_something_else(x)) is not None: | ||
bizbaz(x, y) | ||
""" | ||
self.run_and_assert(tmpdir, input_code, expected_output) | ||
|
||
@pytest.mark.parametrize("space", ["", "\n"]) | ||
def test_with_whitespace(self, tmpdir, space): | ||
input_code = f""" | ||
val = do_something(){space} | ||
if val is not None: | ||
do_something_else(val) | ||
""" | ||
expected_output = f""" | ||
{space}if (val := do_something()) is not None: | ||
do_something_else(val) | ||
""" | ||
self.run_and_assert(tmpdir, input_code, expected_output) | ||
|
||
def test_dont_apply_walrus_different_variable(self, tmpdir): | ||
input_code = """ | ||
val = do_something() | ||
if foo is not None: | ||
do_something_else(val) | ||
""" | ||
self.run_and_assert(tmpdir, input_code, input_code) | ||
|
||
def test_dont_apply_walrus_method_call(self, tmpdir): | ||
input_code = """ | ||
val = do_something() | ||
if val.method() is not None: | ||
do_something_else(val) | ||
""" | ||
self.run_and_assert(tmpdir, input_code, input_code) | ||
|
||
def test_dont_apply_walrus_call_with_func(self, tmpdir): | ||
input_code = """ | ||
val = do_something() | ||
if woot(val) is not None: | ||
do_something_else(val) | ||
""" | ||
self.run_and_assert(tmpdir, input_code, input_code) | ||
|
||
def test_dont_apply_walrus_expr(self, tmpdir): | ||
input_code = """ | ||
val = do_something() | ||
if val + 42 is not None: | ||
do_something_else(val) | ||
""" | ||
self.run_and_assert(tmpdir, input_code, input_code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
bad = f"hello" | ||
bad = "hello" | ||
good = f"{2+3}" |