diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 061fcccb..ec66f6b4 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -17,11 +17,14 @@ CodemodMetadata, ReviewGuidance, ) +from codemodder.codemods.base_visitor import UtilsMixin from codemodder.codemods.transformations.remove_empty_string_concatenation import ( RemoveEmptyStringConcatenation, ) from codemodder.codemods.utils import Append, ReplaceNodes, get_function_name_node from codemodder.codemods.utils_mixin import NameResolutionMixin +from codemodder.context import CodemodExecutionContext +from codemodder.file_context import FileContext parameter_token = "?" @@ -30,7 +33,7 @@ literal = literal_number | literal_string -class SQLQueryParameterization(BaseCodemod, Codemod): +class SQLQueryParameterization(BaseCodemod, UtilsMixin, Codemod): SUMMARY = "Parameterize SQL queries." METADATA = CodemodMetadata( DESCRIPTION=SUMMARY, @@ -55,12 +58,19 @@ class SQLQueryParameterization(BaseCodemod, Codemod): ParentNodeProvider, ) - def __init__(self, context: CodemodContext, *codemod_args) -> None: + def __init__( + self, + context: CodemodContext, + execution_context: CodemodExecutionContext, + file_context: FileContext, + *codemod_args, + ) -> None: self.changed_nodes: dict[ cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel ] = {} + BaseCodemod.__init__(self, execution_context, file_context, *codemod_args) + UtilsMixin.__init__(self, context, {}) Codemod.__init__(self, context) - BaseCodemod.__init__(self, *codemod_args) def _build_param_element(self, middle, index: int) -> cst.BaseExpression: # TODO maybe a parameterized string would be better here @@ -83,6 +93,11 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: result = tree for call, query in find_queries.calls.items(): + # filter by line includes/excludes + call_pos = self.node_position(call) + if not self.filter_by_path_includes_or_excludes(call_pos): + break + ep = ExtractParameters(self.context, query) tree.visit(ep) @@ -342,6 +357,7 @@ def _is_injectable(self, expression: cst.CSTNode) -> bool: return True def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: + # TODO limited for now, won't include cases like "name = 'username_" + name + "_tail'" match node: case cst.SimpleString(): prefix = node.prefix.lower()