diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index ce7b5ac5..1b29298c 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -16,42 +16,40 @@ class BaseType(Enum): STRING = 3 BYTES = 4 - @classmethod - # pylint: disable-next=R0911 - def infer_expression_type(cls, node: cst.BaseExpression) -> Optional["BaseType"]: - """ - Tries to infer if the type of a given expression is one of the base literal types. - """ - # The current implementation could be enhanced with a few more cases - match node: - case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call( - func=cst.Name("int") - ) | cst.Call(func=cst.Name("float")) | cst.Call( - func=cst.Name("abs") - ) | cst.Call( - func=cst.Name("len") - ): - return BaseType.NUMBER - case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp(): - return BaseType.LIST - case cst.Call(func=cst.Name("str")) | cst.FormattedString(): - return BaseType.STRING - case cst.SimpleString(): - if "b" in node.prefix.lower(): - return BaseType.BYTES - return BaseType.STRING - case cst.ConcatenatedString(): - return cls.infer_expression_type(node.left) - case cst.BinaryOperation(operator=cst.Add()): - return cls.infer_expression_type( - node.left - ) or cls.infer_expression_type(node.right) - case cst.IfExp(): - if_true = cls.infer_expression_type(node.body) - or_else = cls.infer_expression_type(node.orelse) - if if_true == or_else: - return if_true - return None + +# pylint: disable-next=R0911 +def infer_expression_type(node: cst.BaseExpression) -> Optional[BaseType]: + """ + Tries to infer if the resulting type of a given expression is one of the base literal types. + """ + # The current implementation covers some common cases and is in no way complete + match node: + case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call( + func=cst.Name("int") + ) | cst.Call(func=cst.Name("float")) | cst.Call( + func=cst.Name("abs") + ) | cst.Call( + func=cst.Name("len") + ): + return BaseType.NUMBER + case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp(): + return BaseType.LIST + case cst.Call(func=cst.Name("str")) | cst.FormattedString(): + return BaseType.STRING + case cst.SimpleString(): + if "b" in node.prefix.lower(): + return BaseType.BYTES + return BaseType.STRING + case cst.ConcatenatedString(): + return infer_expression_type(node.left) + case cst.BinaryOperation(operator=cst.Add()): + return infer_expression_type(node.left) or infer_expression_type(node.right) + case cst.IfExp(): + if_true = infer_expression_type(node.body) + or_else = infer_expression_type(node.orelse) + if if_true == or_else: + return if_true + return None class SequenceExtension: diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index f5d85cf6..3619111f 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -2,7 +2,13 @@ from typing import Any, Optional, Tuple import itertools import libcst as cst -from libcst import FormattedString, SimpleWhitespace, ensure_type, matchers +from libcst import ( + FormattedString, + SimpleString, + SimpleWhitespace, + ensure_type, + matchers, +) from libcst.codemod import ( Codemod, CodemodContext, @@ -32,6 +38,7 @@ BaseType, ReplaceNodes, get_function_name_node, + infer_expression_type, ) from codemodder.codemods.utils_mixin import NameResolutionMixin from codemodder.file_context import FileContext @@ -113,6 +120,13 @@ def _build_param_element(self, prepend, middle, append): ) def transform_module_impl(self, tree: cst.Module) -> cst.Module: + # The transformation has four major steps: + # (1) FindQueryCalls - Find and gather all the SQL query execution calls. The result is a dict of call nodes and their associated list of nodes composing the query (i.e. step (2)). + # (2) LinearizeQuery - For each call, it gather all the string literals and expressions that composes the query. The result is a list of nodes whose concatenation is the query. + # (3) ExtractParameters - Detects which expressions are part of SQL string literals in the query. The result is a list of triples (a,b,c) such that a is the node that contains the start of the string literal, b is a list of expressions that composes that literal, and c is the node containing the end of the string literal. At least one node in b must be "injectable" (see). + # (4) SQLQueryParameterization - Executes steps (1)-(3) and gather a list of injection triples. For each triple (a,b,c) it makes the associated changes to insert the query parameter token. All the expressions in b are then concatenated in an expression and passed as a sequence of parameters to the execute call. + + # Steps (1) and (2) find_queries = FindQueryCalls(self.context) tree.visit(find_queries) @@ -123,10 +137,11 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: if not self.filter_by_path_includes_or_excludes(call_pos): break + # Step (3) ep = ExtractParameters(self.context, query) tree.visit(ep) - # build tuple elements and fix injection + # Step (4) - build tuple elements and fix injection params_elements: list[cst.Element] = [] for start, middle, end in ep.injection_patterns: prepend, append = self._fix_injection(start, middle, end) @@ -173,13 +188,10 @@ def _fix_injection( else: self.changed_nodes[expr] = cst.parse_expression('""') - prepend = append = None - # remove quote literal from start - current_start = self.changed_nodes.get(start) or start - prepend_raw_value = None + updated_start = self.changed_nodes.get(start) or start - t = _extract_prefix_raw_value(self, current_start) + t = _extract_prefix_raw_value(self, updated_start) prefix, raw_value = t if t else ("", "") # gather string after the quote @@ -187,45 +199,18 @@ def _fix_injection( quote_span = list(raw_quote_pattern.finditer(raw_value))[-1] else: quote_span = list(quote_pattern.finditer(raw_value))[-1] + new_raw_value = raw_value[: quote_span.start()] + parameter_token prepend_raw_value = raw_value[quote_span.end() :] - match current_start: - case cst.SimpleString(): - # uses the same quote and prefixes to guarantee it will be correctly interpreted - if prepend_raw_value: - prepend = cst.SimpleString( - value=current_start.prefix - + current_start.quote - + prepend_raw_value - + current_start.quote - ) - - new_value = ( - current_start.prefix - + current_start.quote - + new_raw_value - + current_start.quote - ) - self.changed_nodes[start] = current_start.with_changes(value=new_value) - - case cst.FormattedStringText(): - if prepend_raw_value: - prepend = cst.SimpleString( - value=("r" if "r" in prefix else "") - + "'" - + prepend_raw_value - + "'" - ) - - new_value = new_raw_value - self.changed_nodes[start] = current_start.with_changes(value=new_value) + prepend = self._remove_literal_and_gather_extra( + start, updated_start, prefix, new_raw_value, prepend_raw_value + ) # remove quote literal from end - current_end = self.changed_nodes.get(end) or end - append_raw_value = None + updated_end = self.changed_nodes.get(end) or end - t = _extract_prefix_raw_value(self, current_end) + t = _extract_prefix_raw_value(self, updated_end) prefix, raw_value = t if t else ("", "") if "r" in prefix: quote_span = list(raw_quote_pattern.finditer(raw_value))[0] @@ -234,36 +219,52 @@ def _fix_injection( new_raw_value = raw_value[quote_span.end() :] append_raw_value = raw_value[: quote_span.start()] - match current_end: + + append = self._remove_literal_and_gather_extra( + end, updated_end, prefix, new_raw_value, append_raw_value + ) + + return (prepend, append) + + # pylint: disable-next=too-many-arguments + def _remove_literal_and_gather_extra( + self, original_node, updated_node, prefix, new_raw_value, extra_raw_value + ) -> Optional[SimpleString]: + extra = None + match updated_node: case cst.SimpleString(): - # gather string up to quote to parameter - if append_raw_value: - append = cst.SimpleString( - value=current_end.prefix - + current_end.quote - + append_raw_value - + current_end.quote + # gather string after or before the quote + if extra_raw_value: + extra = cst.SimpleString( + value=updated_node.prefix + + updated_node.quote + + extra_raw_value + + updated_node.quote ) new_value = ( - current_end.prefix - + current_end.quote + updated_node.prefix + + updated_node.quote + new_raw_value - + current_end.quote + + updated_node.quote + ) + self.changed_nodes[original_node] = updated_node.with_changes( + value=new_value ) - self.changed_nodes[end] = current_end.with_changes(value=new_value) case cst.FormattedStringText(): - if append_raw_value: - append = cst.SimpleString( + if extra_raw_value: + extra = cst.SimpleString( value=("r" if "r" in prefix else "") + "'" - + append_raw_value + + extra_raw_value + "'" ) new_value = new_raw_value - self.changed_nodes[end] = current_end.with_changes(value=new_value) - return (prepend, append) + self.changed_nodes[original_node] = updated_node.with_changes( + value=new_value + ) + return extra class LinearizeQuery(ContextAwareVisitor, NameResolutionMixin): @@ -303,7 +304,7 @@ def on_visit(self, node: cst.CSTNode): return False def visit_BinaryOperation(self, node: cst.BinaryOperation) -> Optional[bool]: - maybe_type = BaseType.infer_expression_type(node) + maybe_type = infer_expression_type(node) if not maybe_type or maybe_type == BaseType.STRING: return True self.leaves.append(node) @@ -417,7 +418,7 @@ def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool: return quote_pattern.fullmatch(raw_value) is None def _is_injectable(self, expression: cst.BaseExpression) -> bool: - return not bool(BaseType.infer_expression_type(expression)) + return not bool(infer_expression_type(expression)) def _is_literal_start(self, node: cst.CSTNode, modulo_2: int) -> bool: t = _extract_prefix_raw_value(self, node) diff --git a/tests/test_basetype.py b/tests/test_basetype.py index e72907ec..2096a31e 100644 --- a/tests/test_basetype.py +++ b/tests/test_basetype.py @@ -1,36 +1,36 @@ import libcst as cst -from codemodder.codemods.utils import BaseType +from codemodder.codemods.utils import BaseType, infer_expression_type class TestBaseType: def test_binary_op_number(self): e = cst.parse_expression("1 + float(2)") - assert BaseType.infer_expression_type(e) == BaseType.NUMBER + assert infer_expression_type(e) == BaseType.NUMBER def test_binary_op_string_mixed(self): e = cst.parse_expression('f"a"+foo()') - assert BaseType.infer_expression_type(e) == BaseType.STRING + assert infer_expression_type(e) == BaseType.STRING def test_binary_op_list(self): e = cst.parse_expression("[1,2] + [x for x in [3]] + list((4,5))") - assert BaseType.infer_expression_type(e) == BaseType.LIST + assert infer_expression_type(e) == BaseType.LIST def test_binary_op_none(self): e = cst.parse_expression("foo() + bar()") - assert BaseType.infer_expression_type(e) == None + assert infer_expression_type(e) == None def test_bytes(self): e = cst.parse_expression('b"123"') - assert BaseType.infer_expression_type(e) == BaseType.BYTES + assert infer_expression_type(e) == BaseType.BYTES def test_if_mixed(self): e = cst.parse_expression('1 if True else "a"') - assert BaseType.infer_expression_type(e) == None + assert infer_expression_type(e) == None def test_if_numbers(self): e = cst.parse_expression("abs(1) if True else 2") - assert BaseType.infer_expression_type(e) == BaseType.NUMBER + assert infer_expression_type(e) == BaseType.NUMBER def test_if_numbers2(self): e = cst.parse_expression("float(1) if True else len([1,2])") - assert BaseType.infer_expression_type(e) == BaseType.NUMBER + assert infer_expression_type(e) == BaseType.NUMBER