Skip to content

Commit

Permalink
SQLQueryParameterization will now correctly parameterize names (#279)
Browse files Browse the repository at this point in the history
* SQLQueryParameterization will now correctly parameterize names

Also refactored to use mixins functions.

* Hardening suggestions for codemodder-python / sqlp-fix (#280)

Use Assignment Expression (Walrus) In Conditional

Co-authored-by: pixeebot[bot] <[email protected]>

---------

Co-authored-by: pixeebot[bot] <104101892+pixeebot[bot]@users.noreply.github.com>
Co-authored-by: pixeebot[bot] <[email protected]>
  • Loading branch information
3 people authored Feb 19, 2024
1 parent 782311c commit a8e36d0
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 145 deletions.
7 changes: 5 additions & 2 deletions integration_tests/test_sql_parameterization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from core_codemods.sql_parameterization import SQLQueryParameterization
from core_codemods.sql_parameterization import (
SQLQueryParameterization,
SQLQueryParameterizationTransformer,
)
from integration_tests.base_test import (
BaseIntegrationTest,
original_and_expected_from_code_path,
Expand Down Expand Up @@ -37,5 +40,5 @@ class TestSQLQueryParameterization(BaseIntegrationTest):
# fmt: on

expected_line_change = "12"
change_description = SQLQueryParameterization.change_description
change_description = SQLQueryParameterizationTransformer.change_description
num_changed_files = 1
6 changes: 3 additions & 3 deletions src/codemodder/codemods/base_visitor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Any, Tuple
from typing import ClassVar, Collection
from libcst import MetadataDependent
from libcst.codemod import ContextAwareVisitor, VisitorBasedCodemodCommand
from libcst.metadata import PositionProvider
from libcst.metadata import PositionProvider, ProviderT

from codemodder.result import Result


# TODO: this should just be part of BaseTransformer and BaseVisitor?
class UtilsMixin(MetadataDependent):
METADATA_DEPENDENCIES: Tuple[Any, ...] = (PositionProvider,)
METADATA_DEPENDENCIES: ClassVar[Collection[ProviderT]] = (PositionProvider,)

def __init__(
self,
Expand Down
73 changes: 64 additions & 9 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Any, Collection, Optional, Tuple, Union
from typing import ClassVar, Collection, Optional, Union
import libcst as cst
from libcst import MetadataDependent, matchers
from libcst.helpers import get_full_name_for_node
Expand All @@ -10,14 +10,17 @@
BuiltinAssignment,
ImportAssignment,
ParentNodeProvider,
ProviderT,
Scope,
ScopeProvider,
)
from libcst.metadata.scope_provider import GlobalScope

from codemodder.utils.utils import extract_targets_of_assignment


class NameResolutionMixin(MetadataDependent):
METADATA_DEPENDENCIES: Tuple[Any, ...] = (ScopeProvider,)
METADATA_DEPENDENCIES: ClassVar[Collection[ProviderT]] = (ScopeProvider,)

def _find_imported_name(self, node: cst.Name) -> Optional[str]:
match self.find_single_assignment(node):
Expand Down Expand Up @@ -280,7 +283,7 @@ def find_accesses(self, node) -> Collection[Access]:


class AncestorPatternsMixin(MetadataDependent):
METADATA_DEPENDENCIES: Tuple[Any, ...] = (ParentNodeProvider,)
METADATA_DEPENDENCIES: ClassVar[Collection[ProviderT]] = (ParentNodeProvider,)

def is_value_of_assignment(
self, expr
Expand Down Expand Up @@ -449,12 +452,10 @@ def get_parent(self, node: cst.CSTNode) -> Optional[cst.CSTNode]:


class NameAndAncestorResolutionMixin(NameResolutionMixin, AncestorPatternsMixin):
METADATA_DEPENDENCIES: Tuple[Any, ...] = (
ScopeProvider,
ParentNodeProvider,
)

def extract_value(self, node: cst.AnnAssign | cst.Assign | cst.WithItem):
def extract_value(
self, node: cst.AnnAssign | cst.Assign | cst.WithItem | cst.NamedExpr
):
match node:
case (
cst.AnnAssign(value=value)
Expand All @@ -467,7 +468,7 @@ def extract_value(self, node: cst.AnnAssign | cst.Assign | cst.WithItem):

def resolve_expression(self, node: cst.BaseExpression) -> cst.BaseExpression:
"""
If the expression is a Name, transitively resolves the name to another type of expression. Otherwise returns self.
If the expression is a Name, transitively resolves the name to another expression through single assignments. Otherwise returns self.
"""
maybe_expr = None
match node:
Expand All @@ -490,6 +491,60 @@ def _resolve_name_transitive(self, node: cst.Name) -> Optional[cst.BaseExpressio
return value
return None

def _find_direct_name_assignment_targets(
self, name: cst.Name
) -> list[cst.BaseAssignTargetExpression]:
name_targets = []
accesses = self.find_accesses(name)
for node in (access.node for access in accesses):
if maybe_assigned := self.is_value_of_assignment(node):
targets = extract_targets_of_assignment(maybe_assigned)
name_targets.extend(targets)
return name_targets

def _find_name_assignment_targets(
self, name: cst.Name
) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]:
named_targets, other_targets = self._sieve_targets(
self._find_direct_name_assignment_targets(name)
)

for child in named_targets:
c_named_targets, c_other_targets = self._find_name_assignment_targets(child)
named_targets.extend(c_named_targets)
other_targets.extend(c_other_targets)
return named_targets, other_targets

def _sieve_targets(
self, targets
) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]:
named_targets = []
other_targets = []
for t in targets:
# TODO maybe consider subscript here for named_targets
if isinstance(t, cst.Name):
named_targets.append(t)
else:
other_targets.append(t)
return named_targets, other_targets

def find_transitive_assignment_targets(
self, expr
) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]:
"""
Returns all the targets that an expression can reach. It returns a pair of lists, where the first list contains all targets that are Name, and the second contains all others.
"""
if maybe_assigned := self.is_value_of_assignment(expr):
named_targets, other_targets = self._sieve_targets(
extract_targets_of_assignment(maybe_assigned)
)
for n in named_targets:
n_named_targets, n_other_targets = self._find_name_assignment_targets(n)
named_targets.extend(n_named_targets)
other_targets.extend(n_other_targets)
return named_targets, other_targets
return ([], [])


def iterate_left_expressions(node: cst.BaseExpression):
yield node
Expand Down
70 changes: 8 additions & 62 deletions src/core_codemods/file_resource_leak.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
LibcstTransformerPipeline,
)
from codemodder.result import Result
from codemodder.utils.utils import extract_targets_of_assignment
import libcst as cst
from libcst import SimpleStatementLine, ensure_type, matchers
from libcst.codemod import (
Expand All @@ -19,7 +18,11 @@
)
from codemodder.change import Change
from codemodder.codemods.utils import MetadataPreservingTransformer
from codemodder.codemods.utils_mixin import AncestorPatternsMixin, NameResolutionMixin
from codemodder.codemods.utils_mixin import (
AncestorPatternsMixin,
NameAndAncestorResolutionMixin,
NameResolutionMixin,
)
from codemodder.file_context import FileContext
from functools import partial
from core_codemods.api import (
Expand Down Expand Up @@ -161,14 +164,8 @@ def _is_resource(self, call: cst.Call) -> bool:
return False


class ResourceLeakFixer(
MetadataPreservingTransformer, NameResolutionMixin, AncestorPatternsMixin
):
METADATA_DEPENDENCIES = (
PositionProvider,
ScopeProvider,
ParentNodeProvider,
)
class ResourceLeakFixer(MetadataPreservingTransformer, NameAndAncestorResolutionMixin):
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(
self,
Expand Down Expand Up @@ -222,7 +219,7 @@ def _handle_block(
# 1 would point to 0 since f.read() would be included in the with statement of 0
new_index_of_original_stmt = list(range(len(new_stmts)))
for stmt, assignment, resource in reversed(leak):
named_targets, other_targets = self._find_transitive_assignment_targets(
named_targets, other_targets = self.find_transitive_assignment_targets(
resource
)
index = original_block.body.index(stmt)
Expand Down Expand Up @@ -309,57 +306,6 @@ def _last_ancestor_index(self, node, node_sequence) -> Optional[int]:
last = i
return last

def _find_direct_name_assignment_targets(
self, name: cst.Name
) -> list[cst.BaseAssignTargetExpression]:
name_targets = []
accesses = self.find_accesses(name)
for node in (access.node for access in accesses):
if maybe_assigned := self.is_value_of_assignment(node):
targets = extract_targets_of_assignment(maybe_assigned)
name_targets.extend(targets)
return name_targets

def _find_name_assignment_targets(
self, name: cst.Name
) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]:
named_targets, other_targets = self._sieve_targets(
self._find_direct_name_assignment_targets(name)
)

for child in named_targets:
c_named_targets, c_other_targets = self._find_name_assignment_targets(child)
named_targets.extend(c_named_targets)
other_targets.extend(c_other_targets)
return named_targets, other_targets

def _sieve_targets(
self, targets
) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]:
named_targets = []
other_targets = []
for t in targets:
# TODO maybe consider subscript here for named_targets
if isinstance(t, cst.Name):
named_targets.append(t)
else:
other_targets.append(t)
return named_targets, other_targets

def _find_transitive_assignment_targets(
self, expr
) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]:
if maybe_assigned := self.is_value_of_assignment(expr):
named_targets, other_targets = self._sieve_targets(
extract_targets_of_assignment(maybe_assigned)
)
for n in named_targets:
n_named_targets, n_other_targets = self._find_name_assignment_targets(n)
named_targets.extend(n_named_targets)
other_targets.extend(n_other_targets)
return named_targets, other_targets
return ([], [])

# pylint: disable-next=too-many-arguments
def _wrap_in_with_statement(
self,
Expand Down
7 changes: 5 additions & 2 deletions src/core_codemods/secure_flask_session_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import libcst as cst
from libcst.codemod import Codemod, CodemodContext
from libcst.metadata import ParentNodeProvider, PositionProvider
from libcst.metadata import ParentNodeProvider

from libcst import matchers
from codemodder.codemods.utils_mixin import NameResolutionMixin
Expand Down Expand Up @@ -87,7 +87,10 @@ class FixFlaskConfig(BaseTransformer, NameResolutionMixin):
Visitor to find calls to flask.Flask and related `.config` accesses.
"""

METADATA_DEPENDENCIES = (PositionProvider, ParentNodeProvider)
METADATA_DEPENDENCIES = (
*BaseTransformer.METADATA_DEPENDENCIES,
ParentNodeProvider,
)
SECURE_SESSION_CONFIGS = {
# None value indicates unassigned, using default is safe
# values in order of precedence
Expand Down
Loading

0 comments on commit a8e36d0

Please sign in to comment.