Skip to content

Commit

Permalink
Implement codemod for replacing xml with defusedxml
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Oct 16, 2023
1 parent 8128ee6 commit 0d0cca6
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 101 deletions.
90 changes: 90 additions & 0 deletions src/codemodder/codemods/imported_call_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import abc
from typing import Sequence

import libcst as cst
from libcst import matchers
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import PositionProvider

from codemodder.change import Change
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.file_context import FileContext


class ImportedCallModifier(
VisitorBasedCodemodCommand, NameResolutionMixin, metaclass=abc.ABCMeta
):
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(
self,
codemod_context: CodemodContext,
file_context: FileContext,
matching_functions: dict[str, str],
change_description: str,
):
super().__init__(codemod_context)
self.line_exclude = file_context.line_exclude
self.line_include = file_context.line_include
self.matching_functions = matching_functions
self._matching_function_names = list(matching_functions.keys())
self.change_description = change_description
self.changes_in_file: list[Change] = []

def updated_args(self, original_args: Sequence[cst.Arg]):
return original_args

@abc.abstractmethod
def update_attribute(self, true_name, original_node, updated_node, new_args):
pass

@abc.abstractmethod
def update_simple_name(self, true_name, original_node, updated_node, new_args):
pass

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
pos_to_match = self.node_position(original_node)
line_number = pos_to_match.start.line
if self.filter_by_path_includes_or_excludes(pos_to_match):
true_name = self.find_base_name(original_node.func)
if (
self._is_direct_call_from_imported_module(original_node)
and true_name in self._matching_function_names
):
self.changes_in_file.append(
Change(str(line_number), self.change_description).to_json()
)

new_args = self.updated_args(updated_node.args)

# has a prefix, e.g. a.call() -> a.new_call()
if matchers.matches(original_node.func, matchers.Attribute()):
return self.update_attribute(
true_name, original_node, updated_node, new_args
)

# it is a simple name, e.g. call() -> module.new_call()
return self.update_simple_name(
true_name, original_node, updated_node, new_args
)

return updated_node

def filter_by_path_includes_or_excludes(self, pos_to_match):
"""
Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags.
"""
# excludes takes precedence if defined
if self.line_exclude:
return not any(match_line(pos_to_match, line) for line in self.line_exclude)
if self.line_include:
return any(match_line(pos_to_match, line) for line in self.line_include)
return True

def node_position(self, node):
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112
return self.get_metadata(PositionProvider, node)


def match_line(pos, line):
return pos.start.line == line and pos.end.line == line
147 changes: 54 additions & 93 deletions src/core_codemods/https_connection.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,55 @@
from typing import Sequence
from libcst import matchers

import libcst as cst
from libcst.codemod import Codemod, CodemodContext
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from libcst.metadata import (
PositionProvider,
)
from libcst.metadata import PositionProvider

from codemodder.codemods.base_codemod import (
BaseCodemod,
CodemodMetadata,
ReviewGuidance,
)
from codemodder.change import Change
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.file_context import FileContext
import libcst as cst
from libcst.codemod import (
Codemod,
CodemodContext,
VisitorBasedCodemodCommand,
)
from codemodder.codemods.imported_call_modifier import ImportedCallModifier


class HTTPSConnectionModifier(ImportedCallModifier):
def updated_args(self, original_args):
"""
Last argument _proxy_config does not match new method
We convert it to keyword
"""
new_args = list(original_args)
if self.count_positional_args(new_args) == 10:
new_args[9] = new_args[9].with_changes(
keyword=cst.parse_expression("_proxy_config")
)
return new_args

def update_attribute(self, true_name, original_node, updated_node, new_args):
del true_name, original_node
return updated_node.with_changes(
args=new_args,
func=updated_node.func.with_changes(
attr=cst.parse_expression("HTTPSConnectionPool")
),
)

def update_simple_name(self, true_name, original_node, updated_node, new_args):
del true_name
AddImportsVisitor.add_needed_import(self.context, "urllib3")
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node)
return updated_node.with_changes(
args=new_args,
func=cst.parse_expression("urllib3.HTTPSConnectionPool"),
)

def count_positional_args(self, arglist: Sequence[cst.Arg]) -> int:
for i, arg in enumerate(arglist):
if arg.keyword:
return i
return len(arglist)


class HTTPSConnection(BaseCodemod, Codemod):
Expand All @@ -43,93 +75,22 @@ class HTTPSConnection(BaseCodemod, Codemod):

METADATA_DEPENDENCIES = (PositionProvider,)

matching_functions = {
"urllib3.HTTPConnectionPool",
"urllib3.connectionpool.HTTPConnectionPool",
matching_functions: dict[str, str] = {
"urllib3.HTTPConnectionPool": "urllib3",
"urllib3.connectionpool.HTTPConnectionPool": "urllib3",
}

def __init__(self, codemod_context: CodemodContext, *codemod_args):
Codemod.__init__(self, codemod_context)
BaseCodemod.__init__(self, *codemod_args)

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
visitor = ConnectionPollVisitor(self.context, self.file_context)
visitor = HTTPSConnectionModifier(
self.context,
self.file_context,
self.matching_functions,
self.CHANGE_DESCRIPTION,
)
result_tree = visitor.transform_module(tree)
self.file_context.codemod_changes.extend(visitor.changes_in_file)
return result_tree


class ConnectionPollVisitor(VisitorBasedCodemodCommand, NameResolutionMixin):
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
super().__init__(codemod_context)
self.line_exclude = file_context.line_exclude
self.line_include = file_context.line_include
self.changes_in_file: list[Change] = []

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
pos_to_match = self.node_position(original_node)
line_number = pos_to_match.start.line
if self.filter_by_path_includes_or_excludes(pos_to_match):
true_name = self.find_base_name(original_node.func)
if (
self._is_direct_call_from_imported_module(original_node)
and true_name in HTTPSConnection.matching_functions
):
self.changes_in_file.append(
Change(
str(line_number), HTTPSConnection.CHANGE_DESCRIPTION
).to_json()
)
# last argument _proxy_config does not match new method
# we convert it to keyword
new_args = list(original_node.args)
if count_positional_args(original_node.args) == 10:
new_args[9] = new_args[9].with_changes(
keyword=cst.parse_expression("_proxy_config")
)
# has a prefix, e.g. a.call() -> a.new_call()
if matchers.matches(original_node.func, matchers.Attribute()):
return updated_node.with_changes(
args=new_args,
func=updated_node.func.with_changes(
attr=cst.parse_expression("HTTPSConnectionPool")
),
)
# it is a simple name, e.g. call() -> module.new_call()
AddImportsVisitor.add_needed_import(self.context, "urllib3")
RemoveImportsVisitor.remove_unused_import_by_node(
self.context, original_node
)
return updated_node.with_changes(
args=new_args,
func=cst.parse_expression("urllib3.HTTPSConnectionPool"),
)
return updated_node

def filter_by_path_includes_or_excludes(self, pos_to_match):
"""
Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags.
"""
# excludes takes precedence if defined
if self.line_exclude:
return not any(match_line(pos_to_match, line) for line in self.line_exclude)
if self.line_include:
return any(match_line(pos_to_match, line) for line in self.line_include)
return True

def node_position(self, node):
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112
return self.get_metadata(PositionProvider, node)


def match_line(pos, line):
return pos.start.line == line and pos.end.line == line


def count_positional_args(arglist: Sequence[cst.Arg]) -> int:
for i, arg in enumerate(arglist):
if arg.keyword:
return i
return len(arglist)
83 changes: 83 additions & 0 deletions src/core_codemods/use_defused_xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import libcst as cst
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import BaseCodemod
from codemodder.codemods.imported_call_modifier import ImportedCallModifier


class DefusedXmlModifier(ImportedCallModifier):
def update_attribute(self, true_name, _, updated_node, new_args):
import_name = self.matching_functions[true_name]
AddImportsVisitor.add_needed_import(self.context, import_name)
return updated_node.with_changes(
args=new_args,
func=cst.Attribute(
value=cst.parse_expression(import_name),
attr=cst.Name(value=true_name.split(".")[-1]),
),
)

def update_simple_name(self, true_name, _, updated_node, new_args):
import_name = self.matching_functions[true_name]
AddImportsVisitor.add_needed_import(self.context, import_name)
return updated_node.with_changes(
args=new_args,
func=cst.Attribute(
value=cst.parse_expression(import_name),
attr=cst.Name(value=true_name.split(".")[-1]),
),
)


ETREE_METHODS = ["parse", "fromstring", "iterparse", "XMLParser"]
SAX_METHODS = ["parse", "make_parser", "parseString"]
DOM_METHODS = ["parse", "parseString"]
# TODO: add expat methods?


class UseDefusedXml(BaseCodemod):
NAME = "use-defused-xml"
SUMMARY = "Use defusedxml for parsing XML"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_REVIEW
DESCRIPTION = "Use method from defusedxml"

# Mapping from original function to corresponding defusedxml module
matching_functions: dict[str, str] = {}
matching_functions.update(
{
f"xml.etree.ElementTree.{method}": "defusedxml.ElementTree"
for method in ETREE_METHODS
}
)
matching_functions.update(
{
# defusedxml.cElementTree is deprecated so we use defusedxml.ElementTree
f"xml.etree.cElementTree.{method}": "defusedxml.ElementTree"
for method in ETREE_METHODS
}
)
matching_functions.update(
{f"xml.sax.{method}": "defusedxml.sax" for method in SAX_METHODS}
)
matching_functions.update(
{f"xml.dom.minidom.{method}": "defusedxml.minidom" for method in DOM_METHODS}
)
matching_functions.update(
{f"xml.dom.pulldom.{method}": "defusedxml.pulldom" for method in DOM_METHODS}
)

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
visitor = DefusedXmlModifier(
self.context,
self.file_context,
self.matching_functions,
self.CHANGE_DESCRIPTION,
)
result_tree = visitor.transform_module(tree)
self.file_context.codemod_changes.extend(visitor.changes_in_file)
if visitor.changes_in_file:
RemoveImportsVisitor.remove_unused_import_by_node(self.context, tree)
# TODO: add dependency

return result_tree
18 changes: 10 additions & 8 deletions tests/codemods/base_codemod_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# pylint: disable=no-member,not-callable,attribute-defined-outside-init
from collections import defaultdict
import os
from pathlib import Path
from textwrap import dedent
from typing import ClassVar

import libcst as cst
from libcst.codemod import CodemodContext
from pathlib import Path
import os
from collections import defaultdict
import mock

from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext
from codemodder.registry import CodemodRegistry, CodemodCollection
from codemodder.semgrep import run as semgrep_run
from typing import ClassVar

import mock


class BaseCodemodTest:
Expand All @@ -24,7 +26,7 @@ def run_and_assert(self, tmpdir, input_code, expected):
self.run_and_assert_filepath(tmpdir, tmp_file_path, input_code, expected)

def run_and_assert_filepath(self, root, file_path, input_code, expected):
input_tree = cst.parse_module(input_code)
input_tree = cst.parse_module(dedent(input_code))
self.execution_context = CodemodExecutionContext(
directory=root,
dry_run=True,
Expand All @@ -44,7 +46,7 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected):
)
output_tree = command_instance.transform_module(input_tree)

assert output_tree.code == expected
assert output_tree.code == dedent(expected)


class BaseSemgrepCodemodTest(BaseCodemodTest):
Expand Down
1 change: 1 addition & 0 deletions tests/codemods/test_base_codemod.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import libcst as cst
from libcst.codemod import Codemod, CodemodContext
import mock

from codemodder.codemods.base_codemod import (
SemgrepCodemod,
CodemodMetadata,
Expand Down
Loading

0 comments on commit 0d0cca6

Please sign in to comment.