diff --git a/src/codemodder/codemods/api.py b/src/codemodder/codemods/api.py index 4c9c3c0b..8655ed4f 100644 --- a/src/codemodder/codemods/api.py +++ b/src/codemodder/codemods/api.py @@ -44,9 +44,11 @@ def __new__(cls, *args, **kwargs): return cls.codemod_base( metadata=cls.metadata, - detector=SemgrepRuleDetector(cls.detector_pattern) - if getattr(cls, "detector_pattern", None) - else None, + detector=( + SemgrepRuleDetector(cls.detector_pattern) + if getattr(cls, "detector_pattern", None) + else None + ), # This allows the transformer to inherit all the methods of the class itself transformer=LibcstTransformerPipeline(cls), ) diff --git a/src/codemodder/codemods/base_codemod.py b/src/codemodder/codemods/base_codemod.py index 9ec3cfbc..e3cc059e 100644 --- a/src/codemodder/codemods/base_codemod.py +++ b/src/codemodder/codemods/base_codemod.py @@ -82,13 +82,11 @@ def __init__( @property @abstractmethod - def origin(self) -> str: - ... + def origin(self) -> str: ... @property @abstractmethod - def docs_module_path(self) -> str: - ... + def docs_module_path(self) -> str: ... @property def name(self) -> str: diff --git a/src/codemodder/codemods/base_detector.py b/src/codemodder/codemods/base_detector.py index c1bae4d6..f3fb87d6 100644 --- a/src/codemodder/codemods/base_detector.py +++ b/src/codemodder/codemods/base_detector.py @@ -12,5 +12,4 @@ def apply( codemod_id: str, context: CodemodExecutionContext, files_to_analyze: list[Path], - ) -> ResultSet: - ... + ) -> ResultSet: ... diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index 2bbcbd27..108fdd39 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -26,12 +26,14 @@ def infer_expression_type(node: cst.BaseExpression) -> Optional[BaseType]: """ # The current implementation covers some common cases and is in no way complete match node: - case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call( - func=cst.Name("int") - ) | cst.Call(func=cst.Name("float")) | cst.Call( - func=cst.Name("abs") - ) | cst.Call( - func=cst.Name("len") + case ( + cst.Integer() + | cst.Imaginary() + | cst.Float() + | cst.Call(func=cst.Name("int")) + | cst.Call(func=cst.Name("float")) + | cst.Call(func=cst.Name("abs")) + | cst.Call(func=cst.Name("len")) ): return BaseType.NUMBER case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp(): diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 8f5cee5f..1fe27f08 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -293,10 +293,11 @@ def is_value_of_assignment( """ parent = self.get_metadata(ParentNodeProvider, expr) match parent: - case cst.AnnAssign(value=value) | cst.Assign(value=value) | cst.WithItem( - item=value - ) | cst.NamedExpr( - value=value + case ( + cst.AnnAssign(value=value) + | cst.Assign(value=value) + | cst.WithItem(item=value) + | cst.NamedExpr(value=value) ) if expr == value: # type: ignore return parent return None @@ -448,9 +449,12 @@ class NameAndAncestorResolutionMixin(NameResolutionMixin, AncestorPatternsMixin) def extract_value(self, node: cst.AnnAssign | cst.Assign | cst.WithItem): match node: - case cst.AnnAssign(value=value) | cst.Assign(value=value) | cst.WithItem( - item=value - ) | cst.NamedExpr(value=value): + case ( + cst.AnnAssign(value=value) + | cst.Assign(value=value) + | cst.WithItem(item=value) + | cst.NamedExpr(value=value) + ): return value return None diff --git a/src/core_codemods/fix_mutable_params.py b/src/core_codemods/fix_mutable_params.py index f903948d..838b4b40 100644 --- a/src/core_codemods/fix_mutable_params.py +++ b/src/core_codemods/fix_mutable_params.py @@ -89,12 +89,14 @@ def _gather_and_update_params( ) add_annotation = add_annotation or annotation is not None updated_params.append( - updated.with_changes( - default=cst.Name("None"), - annotation=annotation, - ) - if needs_update - else updated, + ( + updated.with_changes( + default=cst.Name("None"), + annotation=annotation, + ) + if needs_update + else updated + ), ) return updated_params, new_var_decls, add_annotation diff --git a/src/core_codemods/literal_or_new_object_identity.py b/src/core_codemods/literal_or_new_object_identity.py index 32f907aa..3e91b85f 100644 --- a/src/core_codemods/literal_or_new_object_identity.py +++ b/src/core_codemods/literal_or_new_object_identity.py @@ -24,7 +24,18 @@ class LiteralOrNewObjectIdentity(SimpleCodemod, NameAndAncestorResolutionMixin): def _is_object_creation_or_literal(self, node: cst.BaseExpression): match node: - case cst.List() | cst.Dict() | cst.Tuple() | cst.Set() | cst.Integer() | cst.Float() | cst.Imaginary() | cst.SimpleString() | cst.ConcatenatedString() | cst.FormattedString(): + case ( + cst.List() + | cst.Dict() + | cst.Tuple() + | cst.Set() + | cst.Integer() + | cst.Float() + | cst.Imaginary() + | cst.SimpleString() + | cst.ConcatenatedString() + | cst.FormattedString() + ): return True case cst.Call(func=cst.Name() as name): return self.is_builtin_function(node) and name.value in ( diff --git a/src/core_codemods/refactor/refactor_new_api.py b/src/core_codemods/refactor/refactor_new_api.py index 372d5305..df86aeff 100644 --- a/src/core_codemods/refactor/refactor_new_api.py +++ b/src/core_codemods/refactor/refactor_new_api.py @@ -172,13 +172,15 @@ def leave_ImportFrom(self, original: cst.ImportFrom, updated: cst.ImportFrom): def leave_ClassDef(self, original: cst.ClassDef, new: cst.ClassDef) -> cst.ClassDef: new_bases: list[cst.Arg] = [ - base.with_changes(value=cst.Name(self.new_api_class)) - if self.find_base_name(base.value) - in ( - "codemodder.codemods.api.BaseCodemod", - "codemodder.codemods.api.SemgrepCodemod", + ( + base.with_changes(value=cst.Name(self.new_api_class)) + if self.find_base_name(base.value) + in ( + "codemodder.codemods.api.BaseCodemod", + "codemodder.codemods.api.SemgrepCodemod", + ) + else base ) - else base for base in original.bases ] diff --git a/src/core_codemods/remove_module_global.py b/src/core_codemods/remove_module_global.py index eff23af5..95dc07cd 100644 --- a/src/core_codemods/remove_module_global.py +++ b/src/core_codemods/remove_module_global.py @@ -18,7 +18,10 @@ def leave_Global( self, original_node: cst.Global, updated_node: cst.Global, - ) -> Union[cst.Global, cst.RemovalSentinel,]: + ) -> Union[ + cst.Global, + cst.RemovalSentinel, + ]: if not self.filter_by_path_includes_or_excludes( self.node_position(original_node) ): diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index df115eb1..636abaf2 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -492,9 +492,9 @@ def leave_Call(self, original_node: cst.Call) -> None: first_arg.value.visit(query_visitor) for expr in query_visitor.leaves: match expr: - case cst.SimpleString() | cst.FormattedStringText() if self._has_keyword( - expr.value - ): + case ( + cst.SimpleString() | cst.FormattedStringText() + ) if self._has_keyword(expr.value): self.calls[original_node] = query_visitor.leaves diff --git a/tests/codemods/test_combine_startswith_endswith.py b/tests/codemods/test_combine_startswith_endswith.py index be8ec368..8363732d 100644 --- a/tests/codemods/test_combine_startswith_endswith.py +++ b/tests/codemods/test_combine_startswith_endswith.py @@ -38,7 +38,9 @@ def test_no_change(self, tmpdir, code): self.run_and_assert(tmpdir, code, code) def test_exclude_line(self, tmpdir): - input_code = expected = """\ + input_code = ( + expected + ) = """\ x = "foo" x.startswith("foo") or x.startswith("f") """ diff --git a/tests/codemods/test_enable_jinja2_autoescape.py b/tests/codemods/test_enable_jinja2_autoescape.py index 9d0eafe4..5998e5b3 100644 --- a/tests/codemods/test_enable_jinja2_autoescape.py +++ b/tests/codemods/test_enable_jinja2_autoescape.py @@ -133,7 +133,9 @@ def test_aiohttp_import_alias(self, tmpdir): self.run_and_assert(tmpdir, input_code, expected_output) def test_aiohttp_import_alias_no_change(self, tmpdir): - expected_output = input_code = """ + expected_output = ( + input_code + ) = """ from aiohttp_jinja2 import foo as setup setup_jinja2(app) """ diff --git a/tests/codemods/test_exception_without_raise.py b/tests/codemods/test_exception_without_raise.py index 653d062f..320c2562 100644 --- a/tests/codemods/test_exception_without_raise.py +++ b/tests/codemods/test_exception_without_raise.py @@ -51,7 +51,9 @@ def test_raised_exception(self, tmpdir): self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) def test_exclude_line(self, tmpdir): - input_code = expected = """\ + input_code = ( + expected + ) = """\ print(1) ValueError("Bad value!") """ diff --git a/tests/codemods/test_fix_assert_tuple.py b/tests/codemods/test_fix_assert_tuple.py index b4412fd4..dd361366 100644 --- a/tests/codemods/test_fix_assert_tuple.py +++ b/tests/codemods/test_fix_assert_tuple.py @@ -77,7 +77,9 @@ def test_no_change(self, tmpdir, code): self.run_and_assert(tmpdir, code, code) def test_exclude_line(self, tmpdir): - input_code = expected = """\ + input_code = ( + expected + ) = """\ assert (1, 2) """ lines_to_exclude = [1] diff --git a/tests/codemods/test_fix_deprecated_abstractproperty.py b/tests/codemods/test_fix_deprecated_abstractproperty.py index 4645b729..be0ead2a 100644 --- a/tests/codemods/test_fix_deprecated_abstractproperty.py +++ b/tests/codemods/test_fix_deprecated_abstractproperty.py @@ -66,7 +66,9 @@ def foo(self): self.run_and_assert(tmpdir, original_code, new_code) def test_different_abstractproperty(self, tmpdir): - new_code = original_code = """ + new_code = ( + original_code + ) = """ from xyz import abstractproperty class A: @@ -123,7 +125,9 @@ def foo(self): self.run_and_assert(tmpdir, original_code, new_code) def test_exclude_line(self, tmpdir): - input_code = expected = """\ + input_code = ( + expected + ) = """\ import abc class A: diff --git a/tests/codemods/test_fix_empty_sequence_comparison.py b/tests/codemods/test_fix_empty_sequence_comparison.py index e3019ef7..b8bcd7ae 100644 --- a/tests/codemods/test_fix_empty_sequence_comparison.py +++ b/tests/codemods/test_fix_empty_sequence_comparison.py @@ -228,7 +228,9 @@ def test_no_change(self, tmpdir, code): self.run_and_assert(tmpdir, code, code) def test_exclude_line(self, tmpdir): - input_code = expected = """\ + input_code = ( + expected + ) = """\ x = [1] if x != []: pass diff --git a/tests/codemods/test_harden_pyyaml.py b/tests/codemods/test_harden_pyyaml.py index cec378ba..6ec5f03e 100644 --- a/tests/codemods/test_harden_pyyaml.py +++ b/tests/codemods/test_harden_pyyaml.py @@ -63,7 +63,9 @@ def test_import_alias(self, tmpdir): self.run_and_assert(tmpdir, input_code, expected) def test_preserve_custom_loader(self, tmpdir): - expected = input_code = """ + expected = ( + input_code + ) = """ import yaml from custom import CustomLoader @@ -73,7 +75,9 @@ def test_preserve_custom_loader(self, tmpdir): self.run_and_assert(tmpdir, input_code, expected) def test_preserve_custom_loader_kwarg(self, tmpdir): - expected = input_code = """ + expected = ( + input_code + ) = """ import yaml from custom import CustomLoader diff --git a/tests/codemods/test_remove_debug_breakpoint.py b/tests/codemods/test_remove_debug_breakpoint.py index c97597ae..c7fbf102 100644 --- a/tests/codemods/test_remove_debug_breakpoint.py +++ b/tests/codemods/test_remove_debug_breakpoint.py @@ -82,7 +82,9 @@ def something(): self.run_and_assert(tmpdir, input_code, expected) def test_exclude_line(self, tmpdir): - input_code = expected = """\ + input_code = ( + expected + ) = """\ x = "foo" breakpoint() """ diff --git a/tests/codemods/test_remove_unnecessary_f_str.py b/tests/codemods/test_remove_unnecessary_f_str.py index e2b31c91..e6c0ed82 100644 --- a/tests/codemods/test_remove_unnecessary_f_str.py +++ b/tests/codemods/test_remove_unnecessary_f_str.py @@ -32,7 +32,9 @@ def test_change(self, tmpdir): self.run_and_assert(tmpdir, before, after, num_changes=3) def test_exclude_line(self, tmpdir): - input_code = expected = """\ + input_code = ( + expected + ) = """\ bad: str = f"bad" + "bad" """ lines_to_exclude = [1] diff --git a/tests/codemods/test_subprocess_shell_false.py b/tests/codemods/test_subprocess_shell_false.py index 71a4ff2f..f4e44c9b 100644 --- a/tests/codemods/test_subprocess_shell_false.py +++ b/tests/codemods/test_subprocess_shell_false.py @@ -54,7 +54,9 @@ def test_shell_False(self, tmpdir, func): self.run_and_assert(tmpdir, input_code, input_code) def test_exclude_line(self, tmpdir): - input_code = expected = """\ + input_code = ( + expected + ) = """\ import subprocess subprocess.run(args, shell=True) """ diff --git a/tests/codemods/test_url_sandbox.py b/tests/codemods/test_url_sandbox.py index 800649ce..df262547 100644 --- a/tests/codemods/test_url_sandbox.py +++ b/tests/codemods/test_url_sandbox.py @@ -188,7 +188,9 @@ def test_requests_with_alias(self, add_dependency, tmpdir): add_dependency.assert_called_once_with(Security) def test_ignore_hardcoded(self, _, tmpdir): - expected = input_code = """ + expected = ( + input_code + ) = """ import requests requests.get("www.google.com") @@ -197,7 +199,9 @@ def test_ignore_hardcoded(self, _, tmpdir): self.run_and_assert(tmpdir, input_code, expected) def test_ignore_hardcoded_from_global_variable(self, _, tmpdir): - expected = input_code = """ + expected = ( + input_code + ) = """ import requests URL = "www.google.com" @@ -207,7 +211,9 @@ def test_ignore_hardcoded_from_global_variable(self, _, tmpdir): self.run_and_assert(tmpdir, input_code, expected) def test_ignore_hardcoded_from_local_variable(self, _, tmpdir): - expected = input_code = """ + expected = ( + input_code + ) = """ import requests def foo(): @@ -218,7 +224,9 @@ def foo(): self.run_and_assert(tmpdir, input_code, expected) def test_ignore_hardcoded_from_local_variable_transitive(self, _, tmpdir): - expected = input_code = """ + expected = ( + input_code + ) = """ import requests def foo(): diff --git a/tests/codemods/test_use_generator.py b/tests/codemods/test_use_generator.py index 9b9947ea..9746a9ee 100644 --- a/tests/codemods/test_use_generator.py +++ b/tests/codemods/test_use_generator.py @@ -18,20 +18,26 @@ def test_list_comprehension(self, tmpdir, func): self.run_and_assert(tmpdir, original_code, new_code) def test_not_special_builtin(self, tmpdir): - expected = original_code = """ + expected = ( + original_code + ) = """ x = some([i for i in range(10)]) """ self.run_and_assert(tmpdir, original_code, expected) def test_not_global_function(self, tmpdir): - expected = original_code = """ + expected = ( + original_code + ) = """ from foo import any x = any([i for i in range(10)]) """ self.run_and_assert(tmpdir, original_code, expected) def test_exclude_line(self, tmpdir): - input_code = expected = """\ + input_code = ( + expected + ) = """\ x = any([i for i in range(10)]) """ lines_to_exclude = [1]