Skip to content

Commit

Permalink
Adds support for format operators in SQLQueryParameterization (#361)
Browse files Browse the repository at this point in the history
* Format expressions initial implementation

* Format expressions initial implementation

* Transform to remove empty string formatting

* Refactoring and documentation

* Refactoring and documentation

* Tests for printf style string parser

* LinearizeStringExpression tests

* Tests for SQL parameterization with printf format strings

* Refactored and moved cleaning transformations

* Refactoring and more tests

* Linting

* fixup! Refactoring and more tests

* Hardening suggestions for codemodder-python / sqlp-formatop (#362)

Use Assignment Expression (Walrus) In Conditional

Co-authored-by: pixeebot[bot] <104101892+pixeebot[bot]@users.noreply.github.com>

* fixup! Hardening suggestions for codemodder-python / sqlp-formatop (#362)

* Small refactoring

* fixup! Small refactoring

* Better documentation

* Disables RemoveUnnecessarFStr and test

---------

Co-authored-by: pixeebot[bot] <104101892+pixeebot[bot]@users.noreply.github.com>
  • Loading branch information
andrecsilva and pixeebot[bot] authored Mar 14, 2024
1 parent 133ec48 commit ef74258
Show file tree
Hide file tree
Showing 17 changed files with 1,348 additions and 270 deletions.
12 changes: 10 additions & 2 deletions integration_tests/test_unnecessary_f_str.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import pytest

from codemodder.codemods.test import (
BaseIntegrationTest,
original_and_expected_from_code_path,
)
from core_codemods.remove_unnecessary_f_str import RemoveUnnecessaryFStr
from core_codemods.remove_unnecessary_f_str import (
RemoveUnnecessaryFStr,
RemoveUnnecessaryFStrTransform,
)


@pytest.mark.skipif(
True, reason="May fail if it runs after test_sql_parameterization. See Issue #378."
)
class TestFStr(BaseIntegrationTest):
codemod = RemoveUnnecessaryFStr
code_path = "tests/samples/unnecessary_f_str.py"
Expand All @@ -13,4 +21,4 @@ class TestFStr(BaseIntegrationTest):
)
expected_diff = '--- \n+++ \n@@ -1,2 +1,2 @@\n-bad = f"hello"\n+bad = "hello"\n good = f"{2+3}"\n'
expected_line_change = "1"
change_description = RemoveUnnecessaryFStr.change_description
change_description = RemoveUnnecessaryFStrTransform.change_description
4 changes: 2 additions & 2 deletions src/codemodder/codemods/libcst_transformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from collections import namedtuple
from typing import cast

import libcst as cst
from libcst import matchers
from libcst._position import CodeRange
from libcst.codemod import CodemodContext
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from libcst.metadata import PositionProvider

from codemodder.codemods.base_transformer import BaseTransformerPipeline
from codemodder.codemods.base_visitor import BaseTransformer
Expand Down Expand Up @@ -100,7 +100,7 @@ def leave_ClassDef(

def node_position(self, node):
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112
return cast(CodeRange, self.get_metadata(self.METADATA_DEPENDENCIES[0], node))
return self.get_metadata(PositionProvider, node)

def add_change(self, node, description: str, start: bool = True):
position = self.node_position(node)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import Union

import libcst as cst
from libcst import CSTTransformer, RemovalSentinel, SimpleString
from libcst import RemovalSentinel, SimpleString
from libcst.codemod import ContextAwareTransformer

from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin
from codemodder.utils.utils import is_empty_string_literal

class RemoveEmptyStringConcatenation(CSTTransformer):

class RemoveEmptyStringConcatenation(
ContextAwareTransformer, NameAndAncestorResolutionMixin
):
"""
Removes concatenation with empty strings (e.g. "hello " + "") or "hello" ""
"""
Expand All @@ -19,15 +25,19 @@ def leave_FormattedStringExpression(
RemovalSentinel,
]:
expr = original_node.expression
match expr:
case SimpleString() if expr.raw_value == "": # type: ignore
resolved = self.resolve_expression(expr)
match resolved:
case SimpleString() if resolved.raw_value == "": # type: ignore
return RemovalSentinel.REMOVE
return updated_node

def leave_BinaryOperation(
self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation
) -> cst.BaseExpression:
return self.handle_node(updated_node)
match original_node.operator:
case cst.Add():
return self.handle_node(updated_node)
return updated_node

def leave_ConcatenatedString(
self,
Expand All @@ -41,20 +51,12 @@ def handle_node(
) -> cst.BaseExpression:
left = updated_node.left
right = updated_node.right
if self._is_empty_string_literal(left):
if self._is_empty_string_literal(right):
if is_empty_string_literal(left):
if is_empty_string_literal(right):
return cst.SimpleString(value='""')
return right
if self._is_empty_string_literal(right):
if self._is_empty_string_literal(left):
if is_empty_string_literal(right):
if is_empty_string_literal(left):
return cst.SimpleString(value='""')
return left
return updated_node

def _is_empty_string_literal(self, node):
match node:
case cst.SimpleString() if node.raw_value == "":
return True
case cst.FormattedString() if not node.parts:
return True
return False
9 changes: 7 additions & 2 deletions src/codemodder/codemods/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, TypeAlias

import libcst as cst
from libcst import MetadataDependent, matchers
Expand Down Expand Up @@ -79,6 +79,11 @@ class Prepend(SequenceExtension):
pass


ReplacementNodeType: TypeAlias = (
cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel | dict[str, Any]
)


class ReplaceNodes(cst.CSTTransformer):
"""
Replace nodes with their corresponding values in a given dict. The replacements dictionary should either contain a mapping from a node to another node, RemovalSentinel, or FlattenSentinel to be replaced, or a dict mapping each attribute, by name, to a new value. Additionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectively.
Expand All @@ -88,7 +93,7 @@ def __init__(
self,
replacements: dict[
cst.CSTNode,
cst.CSTNode | cst.FlattenSentinel | cst.RemovalSentinel | dict[str, Any],
ReplacementNodeType | dict[str, Any],
],
):
self.replacements = replacements
Expand Down
253 changes: 253 additions & 0 deletions src/codemodder/utils/clean_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import itertools
from typing import Union

import libcst as cst
from libcst.codemod import (
Codemod,
CodemodContext,
ContextAwareTransformer,
ContextAwareVisitor,
VisitorBasedCodemodCommand,
)
from libcst.metadata import ClassScope, GlobalScope, ParentNodeProvider, ScopeProvider

from codemodder.codemods.utils import ReplacementNodeType, ReplaceNodes
from codemodder.codemods.utils_mixin import (
NameAndAncestorResolutionMixin,
NameResolutionMixin,
)
from codemodder.utils.format_string_parser import (
PrintfStringExpression,
PrintfStringText,
dict_to_values_dict,
expressions_from_replacements,
parse_formatted_string,
)
from codemodder.utils.linearize_string_expression import LinearizeStringMixin
from codemodder.utils.utils import is_empty_sequence_literal, is_empty_string_literal


class RemoveEmptyExpressionsFormatting(Codemod):
"""
Cleans and removes string format operator (i.e. `%`) expressions that formats empty expressions or strings. For example, `"abc%s123" % ""` -> `"abc123"`, or `"abc" % {}` -> `"abc"`.
"""

METADATA_DEPENDENCIES = (
ParentNodeProvider,
ScopeProvider,
)

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
result = tree
visitor = RemoveEmptyExpressionsFormattingVisitor(self.context)
result.visit(visitor)
if visitor.node_replacements:
result = result.visit(ReplaceNodes(visitor.node_replacements))
return result

def should_allow_multiple_passes(self) -> bool:
return True


class RemoveEmptyExpressionsFormattingVisitor(
ContextAwareVisitor, NameAndAncestorResolutionMixin, LinearizeStringMixin
):

def __init__(self, context: CodemodContext) -> None:
self.node_replacements: dict[cst.CSTNode, ReplacementNodeType] = {}
super().__init__(context)

def _resolve_dict(
self, dict_node: cst.Dict
) -> dict[cst.BaseExpression, cst.BaseExpression]:
returned: dict[cst.BaseExpression, cst.BaseExpression] = {}
for element in dict_node.elements:
match element:
case cst.DictElement():
returned |= {element.key: element.value}
case cst.StarredDictElement():
resolved = self.resolve_expression(element.value)
if isinstance(resolved, cst.Dict):
returned |= self._resolve_dict(resolved)
return returned

def _build_replacements(self, node, node_parts, parts_to_remove):
new_raw_value = ""
change = False
for part in node_parts:
if part in parts_to_remove:
change = True
else:
new_raw_value += part.value
if change:
match node:
case cst.SimpleString():
self.node_replacements[node] = node.with_changes(
value=node.prefix + node.quote + new_raw_value + node.quote
)
case cst.FormattedStringText():
self.node_replacements[node] = node.with_changes(
value=new_raw_value
)

def _record_node_pieces(self, parts) -> dict:
node_pieces: dict[
cst.CSTNode,
list[PrintfStringExpression | PrintfStringText],
] = {}
for part in parts:
match part:
case PrintfStringText() | PrintfStringExpression():
if part.origin in node_pieces:
node_pieces[part.origin].append(part)
else:
node_pieces[part.origin] = [part]
return node_pieces

def leave_BinaryOperation(self, original_node: cst.BinaryOperation):
if not isinstance(original_node.operator, cst.Modulo):
return

# is left or right an empty literal?
if is_empty_string_literal(self.resolve_expression(original_node.left)):
self.node_replacements[original_node] = cst.SimpleString("''")
return
right = self.resolve_expression(right := original_node.right)
if is_empty_sequence_literal(right):
self.node_replacements[original_node] = original_node.left
return

# gather all the parts of the format operator
resolved_dict = {}
match right:
case cst.Dict():
resolved_dict = self._resolve_dict(right)
keys: dict | list = dict_to_values_dict(resolved_dict)
case _:
keys = expressions_from_replacements(right)
linearized_string_expr = self.linearize_string_expression(original_node.left)
parsed = parse_formatted_string(
linearized_string_expr.parts if linearized_string_expr else [], keys
)
node_pieces = self._record_node_pieces(parsed)

# failed parsing of expression, aborting
if not parsed:
return

# is there any expressions to replace? if not, remove the operator
if all(not isinstance(p, PrintfStringExpression) for p in parsed):
self.node_replacements[original_node] = original_node.left
return

# gather all the expressions parts that resolves to an empty string and remove them
to_remove = set()
for part in parsed:
match part:
case PrintfStringExpression():
resolved_part_expression = self.resolve_expression(part.expression)
if is_empty_string_literal(resolved_part_expression):
to_remove.add(part)
keys_to_remove = {part.key or 0 for part in to_remove}
for part in to_remove:
self._build_replacements(part.origin, node_pieces[part.origin], to_remove)

# remove all the elements on the right that resolves to an empty string
match right:
case cst.Dict():
for v in resolved_dict.values():
resolved_v = self.resolve_expression(v)
if is_empty_string_literal(resolved_v):
parent = self.get_parent(v)
if parent:
self.node_replacements[parent] = cst.RemovalSentinel.REMOVE

case cst.Tuple():
new_tuple_elements = []
# outright remove
if len(keys_to_remove) != len(keys):
for i, element in enumerate(right.elements):
if i not in keys_to_remove:
new_tuple_elements.append(element)
if len(new_tuple_elements) != len(right.elements):
if len(new_tuple_elements) == 1:
self.node_replacements[right] = new_tuple_elements[0].value
else:
self.node_replacements[right] = right.with_changes(
elements=new_tuple_elements
)
case _:
if keys_to_remove:
self.node_replacements[original_node] = self.node_replacements.get(
original_node.left, original_node.left
)


class RemoveUnusedVariables(VisitorBasedCodemodCommand, NameResolutionMixin):
"""
Removes assinments that aren't referenced anywhere else. It will preseve assignments that are in exposed scopes, that is, module or class scope.
"""

def _handle_target(self, node):
# TODO starred elements
# TODO list/tuple case, remove assignment values
match node:
# case cst.Tuple() | cst.List():
# new_elements = []
# for e in node.elements:
# new_expr = self._handle_target(e.value)
# if new_expr:
# new_elements.append(e.with_changes(value = new_expr))
# if new_elements:
# if len(new_elements) ==1:
# return new_elements[0]
# return node.with_changes(elements = new_elements)
# return None
case cst.Name():
if self.find_accesses(node):
return node
else:
return None
case _:
return node

def leave_Assign(
self, original_node: cst.Assign, updated_node: cst.Assign
) -> Union[
cst.BaseSmallStatement,
cst.FlattenSentinel[cst.BaseSmallStatement],
cst.RemovalSentinel,
]:
if scope := self.get_metadata(ScopeProvider, original_node, None):
if isinstance(scope, GlobalScope | ClassScope):
return updated_node

new_targets = []
for target in original_node.targets:
if new_target := self._handle_target(target.target):
new_targets.append(target.with_changes(target=new_target))
# remove everything
if not new_targets:
return cst.RemovalSentinel.REMOVE
return updated_node.with_changes(targets=new_targets)


class NormalizeFStrings(ContextAwareTransformer):
"""
Finds all the f-strings whose parts are only composed of FormattedStringText and concats all of them in a single part.
"""

def leave_FormattedString(
self, original_node: cst.FormattedString, updated_node: cst.FormattedString
) -> cst.BaseExpression:
all_parts = list(
itertools.takewhile(
lambda x: isinstance(x, cst.FormattedStringText), original_node.parts
)
)
if len(all_parts) != len(updated_node.parts):
return updated_node
new_part = cst.FormattedStringText(
value="".join(map(lambda x: x.value, all_parts))
)
return updated_node.with_changes(parts=[new_part])
Loading

0 comments on commit ef74258

Please sign in to comment.