From a8e36d0895a8be338609814310613f3a5e2db3a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20C=2E=20Silva?= <12188364+andrecsilva@users.noreply.github.com> Date: Mon, 19 Feb 2024 12:18:16 -0300 Subject: [PATCH] SQLQueryParameterization will now correctly parameterize names (#279) * SQLQueryParameterization will now correctly parameterize names Also refactored to use mixins functions. * Hardening suggestions for codemodder-python / sqlp-fix (#280) Use Assignment Expression (Walrus) In Conditional Co-authored-by: pixeebot[bot] <23113631+pixeebot@users.noreply.github.com> --------- Co-authored-by: pixeebot[bot] <104101892+pixeebot[bot]@users.noreply.github.com> Co-authored-by: pixeebot[bot] <23113631+pixeebot@users.noreply.github.com> --- .../test_sql_parameterization.py | 7 +- src/codemodder/codemods/base_visitor.py | 6 +- src/codemodder/codemods/utils_mixin.py | 73 ++++++- src/core_codemods/file_resource_leak.py | 70 +------ .../secure_flask_session_config.py | 7 +- src/core_codemods/sql_parameterization.py | 193 ++++++++++++------ src/core_codemods/subprocess_shell_false.py | 5 +- src/core_codemods/use_walrus_if.py | 3 +- tests/codemods/test_sql_parameterization.py | 25 +++ 9 files changed, 244 insertions(+), 145 deletions(-) diff --git a/integration_tests/test_sql_parameterization.py b/integration_tests/test_sql_parameterization.py index 99242d33..bd3b844e 100644 --- a/integration_tests/test_sql_parameterization.py +++ b/integration_tests/test_sql_parameterization.py @@ -1,4 +1,7 @@ -from core_codemods.sql_parameterization import SQLQueryParameterization +from core_codemods.sql_parameterization import ( + SQLQueryParameterization, + SQLQueryParameterizationTransformer, +) from integration_tests.base_test import ( BaseIntegrationTest, original_and_expected_from_code_path, @@ -37,5 +40,5 @@ class TestSQLQueryParameterization(BaseIntegrationTest): # fmt: on expected_line_change = "12" - change_description = SQLQueryParameterization.change_description + change_description = SQLQueryParameterizationTransformer.change_description num_changed_files = 1 diff --git a/src/codemodder/codemods/base_visitor.py b/src/codemodder/codemods/base_visitor.py index cb0bbb26..430b4cdb 100644 --- a/src/codemodder/codemods/base_visitor.py +++ b/src/codemodder/codemods/base_visitor.py @@ -1,14 +1,14 @@ -from typing import Any, Tuple +from typing import ClassVar, Collection from libcst import MetadataDependent from libcst.codemod import ContextAwareVisitor, VisitorBasedCodemodCommand -from libcst.metadata import PositionProvider +from libcst.metadata import PositionProvider, ProviderT from codemodder.result import Result # TODO: this should just be part of BaseTransformer and BaseVisitor? class UtilsMixin(MetadataDependent): - METADATA_DEPENDENCIES: Tuple[Any, ...] = (PositionProvider,) + METADATA_DEPENDENCIES: ClassVar[Collection[ProviderT]] = (PositionProvider,) def __init__( self, diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 729fa761..ee64e699 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -1,5 +1,5 @@ import itertools -from typing import Any, Collection, Optional, Tuple, Union +from typing import ClassVar, Collection, Optional, Union import libcst as cst from libcst import MetadataDependent, matchers from libcst.helpers import get_full_name_for_node @@ -10,14 +10,17 @@ BuiltinAssignment, ImportAssignment, ParentNodeProvider, + ProviderT, Scope, ScopeProvider, ) from libcst.metadata.scope_provider import GlobalScope +from codemodder.utils.utils import extract_targets_of_assignment + class NameResolutionMixin(MetadataDependent): - METADATA_DEPENDENCIES: Tuple[Any, ...] = (ScopeProvider,) + METADATA_DEPENDENCIES: ClassVar[Collection[ProviderT]] = (ScopeProvider,) def _find_imported_name(self, node: cst.Name) -> Optional[str]: match self.find_single_assignment(node): @@ -280,7 +283,7 @@ def find_accesses(self, node) -> Collection[Access]: class AncestorPatternsMixin(MetadataDependent): - METADATA_DEPENDENCIES: Tuple[Any, ...] = (ParentNodeProvider,) + METADATA_DEPENDENCIES: ClassVar[Collection[ProviderT]] = (ParentNodeProvider,) def is_value_of_assignment( self, expr @@ -449,12 +452,10 @@ def get_parent(self, node: cst.CSTNode) -> Optional[cst.CSTNode]: class NameAndAncestorResolutionMixin(NameResolutionMixin, AncestorPatternsMixin): - METADATA_DEPENDENCIES: Tuple[Any, ...] = ( - ScopeProvider, - ParentNodeProvider, - ) - def extract_value(self, node: cst.AnnAssign | cst.Assign | cst.WithItem): + def extract_value( + self, node: cst.AnnAssign | cst.Assign | cst.WithItem | cst.NamedExpr + ): match node: case ( cst.AnnAssign(value=value) @@ -467,7 +468,7 @@ def extract_value(self, node: cst.AnnAssign | cst.Assign | cst.WithItem): def resolve_expression(self, node: cst.BaseExpression) -> cst.BaseExpression: """ - If the expression is a Name, transitively resolves the name to another type of expression. Otherwise returns self. + If the expression is a Name, transitively resolves the name to another expression through single assignments. Otherwise returns self. """ maybe_expr = None match node: @@ -490,6 +491,60 @@ def _resolve_name_transitive(self, node: cst.Name) -> Optional[cst.BaseExpressio return value return None + def _find_direct_name_assignment_targets( + self, name: cst.Name + ) -> list[cst.BaseAssignTargetExpression]: + name_targets = [] + accesses = self.find_accesses(name) + for node in (access.node for access in accesses): + if maybe_assigned := self.is_value_of_assignment(node): + targets = extract_targets_of_assignment(maybe_assigned) + name_targets.extend(targets) + return name_targets + + def _find_name_assignment_targets( + self, name: cst.Name + ) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]: + named_targets, other_targets = self._sieve_targets( + self._find_direct_name_assignment_targets(name) + ) + + for child in named_targets: + c_named_targets, c_other_targets = self._find_name_assignment_targets(child) + named_targets.extend(c_named_targets) + other_targets.extend(c_other_targets) + return named_targets, other_targets + + def _sieve_targets( + self, targets + ) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]: + named_targets = [] + other_targets = [] + for t in targets: + # TODO maybe consider subscript here for named_targets + if isinstance(t, cst.Name): + named_targets.append(t) + else: + other_targets.append(t) + return named_targets, other_targets + + def find_transitive_assignment_targets( + self, expr + ) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]: + """ + Returns all the targets that an expression can reach. It returns a pair of lists, where the first list contains all targets that are Name, and the second contains all others. + """ + if maybe_assigned := self.is_value_of_assignment(expr): + named_targets, other_targets = self._sieve_targets( + extract_targets_of_assignment(maybe_assigned) + ) + for n in named_targets: + n_named_targets, n_other_targets = self._find_name_assignment_targets(n) + named_targets.extend(n_named_targets) + other_targets.extend(n_other_targets) + return named_targets, other_targets + return ([], []) + def iterate_left_expressions(node: cst.BaseExpression): yield node diff --git a/src/core_codemods/file_resource_leak.py b/src/core_codemods/file_resource_leak.py index 5b251d10..7bf189a7 100644 --- a/src/core_codemods/file_resource_leak.py +++ b/src/core_codemods/file_resource_leak.py @@ -4,7 +4,6 @@ LibcstTransformerPipeline, ) from codemodder.result import Result -from codemodder.utils.utils import extract_targets_of_assignment import libcst as cst from libcst import SimpleStatementLine, ensure_type, matchers from libcst.codemod import ( @@ -19,7 +18,11 @@ ) from codemodder.change import Change from codemodder.codemods.utils import MetadataPreservingTransformer -from codemodder.codemods.utils_mixin import AncestorPatternsMixin, NameResolutionMixin +from codemodder.codemods.utils_mixin import ( + AncestorPatternsMixin, + NameAndAncestorResolutionMixin, + NameResolutionMixin, +) from codemodder.file_context import FileContext from functools import partial from core_codemods.api import ( @@ -161,14 +164,8 @@ def _is_resource(self, call: cst.Call) -> bool: return False -class ResourceLeakFixer( - MetadataPreservingTransformer, NameResolutionMixin, AncestorPatternsMixin -): - METADATA_DEPENDENCIES = ( - PositionProvider, - ScopeProvider, - ParentNodeProvider, - ) +class ResourceLeakFixer(MetadataPreservingTransformer, NameAndAncestorResolutionMixin): + METADATA_DEPENDENCIES = (PositionProvider,) def __init__( self, @@ -222,7 +219,7 @@ def _handle_block( # 1 would point to 0 since f.read() would be included in the with statement of 0 new_index_of_original_stmt = list(range(len(new_stmts))) for stmt, assignment, resource in reversed(leak): - named_targets, other_targets = self._find_transitive_assignment_targets( + named_targets, other_targets = self.find_transitive_assignment_targets( resource ) index = original_block.body.index(stmt) @@ -309,57 +306,6 @@ def _last_ancestor_index(self, node, node_sequence) -> Optional[int]: last = i return last - def _find_direct_name_assignment_targets( - self, name: cst.Name - ) -> list[cst.BaseAssignTargetExpression]: - name_targets = [] - accesses = self.find_accesses(name) - for node in (access.node for access in accesses): - if maybe_assigned := self.is_value_of_assignment(node): - targets = extract_targets_of_assignment(maybe_assigned) - name_targets.extend(targets) - return name_targets - - def _find_name_assignment_targets( - self, name: cst.Name - ) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]: - named_targets, other_targets = self._sieve_targets( - self._find_direct_name_assignment_targets(name) - ) - - for child in named_targets: - c_named_targets, c_other_targets = self._find_name_assignment_targets(child) - named_targets.extend(c_named_targets) - other_targets.extend(c_other_targets) - return named_targets, other_targets - - def _sieve_targets( - self, targets - ) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]: - named_targets = [] - other_targets = [] - for t in targets: - # TODO maybe consider subscript here for named_targets - if isinstance(t, cst.Name): - named_targets.append(t) - else: - other_targets.append(t) - return named_targets, other_targets - - def _find_transitive_assignment_targets( - self, expr - ) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]: - if maybe_assigned := self.is_value_of_assignment(expr): - named_targets, other_targets = self._sieve_targets( - extract_targets_of_assignment(maybe_assigned) - ) - for n in named_targets: - n_named_targets, n_other_targets = self._find_name_assignment_targets(n) - named_targets.extend(n_named_targets) - other_targets.extend(n_other_targets) - return named_targets, other_targets - return ([], []) - # pylint: disable-next=too-many-arguments def _wrap_in_with_statement( self, diff --git a/src/core_codemods/secure_flask_session_config.py b/src/core_codemods/secure_flask_session_config.py index 6e01a5dd..44d80680 100644 --- a/src/core_codemods/secure_flask_session_config.py +++ b/src/core_codemods/secure_flask_session_config.py @@ -1,6 +1,6 @@ import libcst as cst from libcst.codemod import Codemod, CodemodContext -from libcst.metadata import ParentNodeProvider, PositionProvider +from libcst.metadata import ParentNodeProvider from libcst import matchers from codemodder.codemods.utils_mixin import NameResolutionMixin @@ -87,7 +87,10 @@ class FixFlaskConfig(BaseTransformer, NameResolutionMixin): Visitor to find calls to flask.Flask and related `.config` accesses. """ - METADATA_DEPENDENCIES = (PositionProvider, ParentNodeProvider) + METADATA_DEPENDENCIES = ( + *BaseTransformer.METADATA_DEPENDENCIES, + ParentNodeProvider, + ) SECURE_SESSION_CONFIGS = { # None value indicates unassigned, using default is safe # values in order of precedence diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 1bbcdfd9..3b845543 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -4,9 +4,6 @@ import libcst as cst from libcst import ( - FormattedString, - SimpleString, - SimpleWhitespace, ensure_type, matchers, ) @@ -22,9 +19,12 @@ PositionProvider, ScopeProvider, ) +from codemodder.codemods.libcst_transformer import ( + LibcstResultTransformer, + LibcstTransformerPipeline, +) from core_codemods.api import ( - SimpleCodemod, Metadata, Reference, ReviewGuidance, @@ -41,7 +41,10 @@ get_function_name_node, infer_expression_type, ) -from codemodder.codemods.utils_mixin import NameResolutionMixin +from codemodder.codemods.utils_mixin import ( + NameAndAncestorResolutionMixin, +) +from core_codemods.api.core_codemod import CoreCodemod parameter_token = "?" @@ -49,16 +52,7 @@ raw_quote_pattern = re.compile(r"(? cst.Module: break # Step (3) - ep = ExtractParameters(self.context, query) + ep = ExtractParameters(self.context, query, find_queries.aliased) tree.visit(ep) # Step (4) - build tuple elements and fix injection params_elements: list[cst.Element] = [] for start, middle, end in ep.injection_patterns: - prepend, append = self._fix_injection(start, middle, end) - expr = self._build_param_element(prepend, middle, append) + prepend, append = self._fix_injection( + start, middle, end, find_queries.aliased + ) + expr = self._build_param_element( + prepend, middle, append, find_queries.aliased + ) params_elements.append( cst.Element( value=expr, - comma=cst.Comma(whitespace_after=SimpleWhitespace(" ")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ) ) @@ -162,7 +161,10 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.changed_nodes = {} line_number = self.get_metadata(PositionProvider, call).start.line self.file_context.codemod_changes.append( - Change(line_number, SQLQueryParameterization.change_description) + Change( + line_number, + SQLQueryParameterizationTransformer.change_description, + ) ) # Normalization and cleanup result = result.visit(RemoveEmptyStringConcatenation()) @@ -174,10 +176,17 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: return result def _fix_injection( - self, start: cst.CSTNode, middle: list[cst.CSTNode], end: cst.CSTNode + self, + start: cst.CSTNode, + middle: list[cst.CSTNode], + end: cst.CSTNode, + aliased_expr: dict[cst.CSTNode, cst.CSTNode], ): for expr in middle: - if isinstance( + # TODO aliased + if expr in aliased_expr: + self.changed_nodes[aliased_expr[expr]] = cst.parse_expression('""') + elif isinstance( expr, cst.FormattedStringText | cst.FormattedStringExpression ): self.changed_nodes[expr] = cst.RemovalSentinel.REMOVE @@ -225,7 +234,7 @@ def _fix_injection( # pylint: disable-next=too-many-arguments def _remove_literal_and_gather_extra( self, original_node, updated_node, prefix, new_raw_value, extra_raw_value - ) -> Optional[SimpleString]: + ) -> Optional[cst.SimpleString]: extra = None match updated_node: case cst.SimpleString(): @@ -263,16 +272,30 @@ def _remove_literal_and_gather_extra( return extra -class LinearizeQuery(ContextAwareVisitor, NameResolutionMixin): +SQLQueryParameterization = CoreCodemod( + metadata=Metadata( + name="sql-parameterization", + summary="Parameterize SQL Queries", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference(url="https://cwe.mitre.org/data/definitions/89.html"), + Reference(url="https://owasp.org/www-community/attacks/SQL_Injection"), + ], + ), + transformer=LibcstTransformerPipeline(SQLQueryParameterizationTransformer), + detector=None, +) + + +class LinearizeQuery(ContextAwareVisitor, NameAndAncestorResolutionMixin): """ Gather all the expressions that are concatenated to build the query. """ - METADATA_DEPENDENCIES = (ParentNodeProvider,) - def __init__(self, context) -> None: - super().__init__(context) self.leaves: list[cst.CSTNode] = [] + self.aliased: dict[cst.CSTNode, cst.CSTNode] = {} + super().__init__(context) def on_visit(self, node: cst.CSTNode): # We only care about expressions, ignore everything else @@ -316,29 +339,15 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]: return False def recurse_Name(self, node: cst.Name) -> list[cst.CSTNode]: - if assignment := self.find_single_assignment(node): - base_scope = assignment.scope - # TODO make this check in detect injection, to be more precise - - # Ensure that this variable is not used anywhere else - # variables used in the global scope / class scope may be referenced in other files - if ( - not isinstance(base_scope, GlobalScope) - and not isinstance(base_scope, ClassScope) - and len(assignment.references) == 1 - ): - maybe_gparent = self._find_gparent(assignment.node) - if gparent := maybe_gparent: - match gparent: - case cst.AnnAssign() | cst.Assign(): - if gparent.value: - gparent_scope = self.get_metadata( - ScopeProvider, gparent - ) - if gparent_scope and gparent_scope == base_scope: - visitor = LinearizeQuery(self.context) - gparent.value.visit(visitor) - return visitor.leaves + # if the expression is a name, try to find its single assignment + if (resolved := self.resolve_expression(node)) != node: + visitor = LinearizeQuery(self.context) + resolved.visit(visitor) + if len(visitor.leaves) == 1: + self.aliased[resolved] = node + return [resolved] + self.aliased |= visitor.aliased + return visitor.leaves return [node] def recurse_Attribute(self, node: cst.Attribute) -> list[cst.CSTNode]: @@ -356,21 +365,26 @@ def _find_gparent(self, n: cst.CSTNode) -> Optional[cst.CSTNode]: return gparent -class ExtractParameters(ContextAwareVisitor): +class ExtractParameters(ContextAwareVisitor, NameAndAncestorResolutionMixin): """ Detects injections and gather the expressions that are injectable. """ - METADATA_DEPENDENCIES = ( - ScopeProvider, - ParentNodeProvider, - ) - - def __init__(self, context: CodemodContext, query: list[cst.CSTNode]) -> None: + def __init__( + self, + context: CodemodContext, + query: list[cst.CSTNode], + aliased: dict[cst.CSTNode, cst.CSTNode], + ) -> None: self.query: list[cst.CSTNode] = query self.injection_patterns: list[ - Tuple[cst.CSTNode, list[cst.CSTNode], cst.CSTNode] + tuple[ + cst.CSTNode, + list[cst.CSTNode], + cst.CSTNode, + ] ] = [] + self.aliased: dict[cst.CSTNode, cst.CSTNode] = aliased super().__init__(context) def leave_Module(self, original_node: cst.Module): @@ -378,6 +392,7 @@ def leave_Module(self, original_node: cst.Module): modulo_2 = 1 # treat it as a stack while leaves: + # TODO check if we can change values here any expression in middle should not be from GlobalScope or ClassScope # search for the literal start, we detect the single quote start = leaves.pop() if not self._is_literal_start(start, modulo_2): @@ -391,7 +406,12 @@ def leave_Module(self, original_node: cst.Module): break end = leaves.pop() if any(map(self._is_injectable, middle)): - self.injection_patterns.append((start, middle, end)) + if ( + self._can_be_changed(start) + and self._can_be_changed(end) + and all(map(self._can_be_changed_middle, middle)) + ): + self.injection_patterns.append((start, middle, end)) # end may contain the start of another literal, put it back # should not be a single quote @@ -402,7 +422,8 @@ def leave_Module(self, original_node: cst.Module): modulo_2 = 1 def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool: - t = _extract_prefix_raw_value(self, expression) + value = expression + t = _extract_prefix_raw_value(self, value) if not t: return True prefix, raw_value = t @@ -412,10 +433,48 @@ def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool: return raw_quote_pattern.fullmatch(raw_value) is None return quote_pattern.fullmatch(raw_value) is None + def _is_assigned_to_exposed_scope(self, expression): + named, other = self.find_transitive_assignment_targets(expression) + for t in itertools.chain(named, other): + scope = self.get_metadata(ScopeProvider, t, None) + match scope: + case GlobalScope() | ClassScope() | None: + return True + return False + + def _is_target_in_expose_scope(self, expression): + assignments = self.find_assignments(expression) + for assignment in assignments: + match assignment.scope: + case GlobalScope() | ClassScope() | None: + return True + return False + + def _can_be_changed_middle(self, expression): + # is it assigned to a variable with global/class scope? + # is itself a target in global/class scope? + # if the expression is aliased, it is just a reference and we can always change + if expression in self.aliased: + return True + return not ( + self._is_target_in_expose_scope(expression) + or self._is_assigned_to_exposed_scope(expression) + ) + + def _can_be_changed(self, expression): + # is it assigned to a variable with global/class scope? + # is itself a target in global/class scope? + return not ( + self._is_target_in_expose_scope(expression) + or self._is_assigned_to_exposed_scope(expression) + ) + def _is_injectable(self, expression: cst.BaseExpression) -> bool: return not bool(infer_expression_type(expression)) - def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: + def _is_literal_start( + self, node: cst.CSTNode | tuple[cst.CSTNode, cst.CSTNode], modulo_2: int + ) -> bool: t = _extract_prefix_raw_value(self, node) if not t: return False @@ -430,7 +489,9 @@ def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: # don't count \\' as these are escaped in string literals return (matches != None) and len(matches) % 2 == modulo_2 - def _is_literal_end(self, node: cst.CSTNode) -> bool: + def _is_literal_end( + self, node: cst.CSTNode | tuple[cst.CSTNode, cst.CSTNode] + ) -> bool: t = _extract_prefix_raw_value(self, node) if not t: return False @@ -476,6 +537,7 @@ class FindQueryCalls(ContextAwareVisitor): def __init__(self, context: CodemodContext) -> None: self.calls: dict = {} + self.aliased: dict[cst.CSTNode, cst.CSTNode] = {} super().__init__(context) def _has_keyword(self, string: str) -> bool: @@ -500,6 +562,7 @@ def leave_Call(self, original_node: cst.Call) -> None: cst.SimpleString() | cst.FormattedStringText() ) if self._has_keyword(expr.value): self.calls[original_node] = query_visitor.leaves + self.aliased |= query_visitor.aliased def _extract_prefix_raw_value(self, node) -> Optional[Tuple[str, str]]: @@ -509,7 +572,7 @@ def _extract_prefix_raw_value(self, node) -> Optional[Tuple[str, str]]: case cst.FormattedStringText(): try: parent = self.get_metadata(ParentNodeProvider, node) - parent = ensure_type(parent, FormattedString) + parent = ensure_type(parent, cst.FormattedString) except Exception: return None return parent.start.lower(), node.value diff --git a/src/core_codemods/subprocess_shell_false.py b/src/core_codemods/subprocess_shell_false.py index 2cf7b408..33c900c0 100644 --- a/src/core_codemods/subprocess_shell_false.py +++ b/src/core_codemods/subprocess_shell_false.py @@ -34,7 +34,10 @@ class SubprocessShellFalse(SimpleCodemod, NameResolutionMixin): for func in {"run", "call", "check_output", "check_call", "Popen"} ] - METADATA_DEPENDENCIES = SimpleCodemod.METADATA_DEPENDENCIES + (ParentNodeProvider,) + METADATA_DEPENDENCIES = ( + *SimpleCodemod.METADATA_DEPENDENCIES, + ParentNodeProvider, + ) IGNORE_ANNOTATIONS = ["S603"] def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): diff --git a/src/core_codemods/use_walrus_if.py b/src/core_codemods/use_walrus_if.py index 7c5e4465..ffc1a313 100644 --- a/src/core_codemods/use_walrus_if.py +++ b/src/core_codemods/use_walrus_if.py @@ -37,7 +37,8 @@ class UseWalrusIf(SimpleCodemod): change_description = ( "Replaces multiple expressions involving `if` operator with 'walrus' operator." ) - METADATA_DEPENDENCIES = SimpleCodemod.METADATA_DEPENDENCIES + ( + METADATA_DEPENDENCIES = ( + *SimpleCodemod.METADATA_DEPENDENCIES, ParentNodeProvider, ScopeProvider, ) diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py index 6b9ec10c..a4446621 100644 --- a/tests/codemods/test_sql_parameterization.py +++ b/tests/codemods/test_sql_parameterization.py @@ -173,6 +173,31 @@ def test_formatted_string_simple(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) + def test_formatted_string_simple_aliased(self, tmpdir): + input_code = """ + import sqlite3 + + def foo(): + table_name = "TABLE" + search_vector = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + query = f"DELETE FROM {table_name} WHERE embeddings <=> '{search_vector}'" + cursor.execute(query) + """ + expected = """ + import sqlite3 + + def foo(): + table_name = "TABLE" + search_vector = input() + connection = sqlite3.connect("my_db.db") + cursor = connection.cursor() + query = f"DELETE FROM {table_name} WHERE embeddings <=> ?" + cursor.execute(query, (search_vector, )) + """ + self.run_and_assert(tmpdir, input_code, expected) + def test_formatted_string_quote_in_middle(self, tmpdir): input_code = """ import sqlite3