-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement codemod for replacing xml with defusedxml
- Loading branch information
Showing
6 changed files
with
365 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import abc | ||
from typing import Sequence | ||
|
||
import libcst as cst | ||
from libcst import matchers | ||
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand | ||
from libcst.metadata import PositionProvider | ||
|
||
from codemodder.change import Change | ||
from codemodder.codemods.utils_mixin import NameResolutionMixin | ||
from codemodder.file_context import FileContext | ||
|
||
|
||
class ImportedCallModifier( | ||
VisitorBasedCodemodCommand, NameResolutionMixin, metaclass=abc.ABCMeta | ||
): | ||
METADATA_DEPENDENCIES = (PositionProvider,) | ||
|
||
def __init__( | ||
self, | ||
codemod_context: CodemodContext, | ||
file_context: FileContext, | ||
matching_functions: dict[str, str], | ||
change_description: str, | ||
): | ||
super().__init__(codemod_context) | ||
self.line_exclude = file_context.line_exclude | ||
self.line_include = file_context.line_include | ||
self.matching_functions = matching_functions | ||
self._matching_function_names = list(matching_functions.keys()) | ||
self.change_description = change_description | ||
self.changes_in_file: list[Change] = [] | ||
|
||
def updated_args(self, original_args: Sequence[cst.Arg]): | ||
return original_args | ||
|
||
@abc.abstractmethod | ||
def update_attribute(self, true_name, original_node, updated_node, new_args): | ||
pass | ||
|
||
@abc.abstractmethod | ||
def update_simple_name(self, true_name, original_node, updated_node, new_args): | ||
pass | ||
|
||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): | ||
pos_to_match = self.node_position(original_node) | ||
line_number = pos_to_match.start.line | ||
if self.filter_by_path_includes_or_excludes(pos_to_match): | ||
true_name = self.find_base_name(original_node.func) | ||
if ( | ||
self._is_direct_call_from_imported_module(original_node) | ||
and true_name in self._matching_function_names | ||
): | ||
self.changes_in_file.append( | ||
Change(str(line_number), self.change_description).to_json() | ||
) | ||
|
||
new_args = self.updated_args(updated_node.args) | ||
|
||
# has a prefix, e.g. a.call() -> a.new_call() | ||
if matchers.matches(original_node.func, matchers.Attribute()): | ||
return self.update_attribute( | ||
true_name, original_node, updated_node, new_args | ||
) | ||
|
||
# it is a simple name, e.g. call() -> module.new_call() | ||
return self.update_simple_name( | ||
true_name, original_node, updated_node, new_args | ||
) | ||
|
||
return updated_node | ||
|
||
def filter_by_path_includes_or_excludes(self, pos_to_match): | ||
""" | ||
Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags. | ||
""" | ||
# excludes takes precedence if defined | ||
if self.line_exclude: | ||
return not any(match_line(pos_to_match, line) for line in self.line_exclude) | ||
if self.line_include: | ||
return any(match_line(pos_to_match, line) for line in self.line_include) | ||
return True | ||
|
||
def node_position(self, node): | ||
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112 | ||
return self.get_metadata(PositionProvider, node) | ||
|
||
|
||
def match_line(pos, line): | ||
return pos.start.line == line and pos.end.line == line |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import libcst as cst | ||
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor | ||
|
||
from codemodder.codemods.base_codemod import ReviewGuidance | ||
from codemodder.codemods.api import BaseCodemod | ||
from codemodder.codemods.imported_call_modifier import ImportedCallModifier | ||
|
||
|
||
class DefusedXmlModifier(ImportedCallModifier): | ||
def update_attribute(self, true_name, _, updated_node, new_args): | ||
import_name = self.matching_functions[true_name] | ||
AddImportsVisitor.add_needed_import(self.context, import_name) | ||
return updated_node.with_changes( | ||
args=new_args, | ||
func=cst.Attribute( | ||
value=cst.parse_expression(import_name), | ||
attr=cst.Name(value=true_name.split(".")[-1]), | ||
), | ||
) | ||
|
||
def update_simple_name(self, true_name, _, updated_node, new_args): | ||
import_name = self.matching_functions[true_name] | ||
AddImportsVisitor.add_needed_import(self.context, import_name) | ||
return updated_node.with_changes( | ||
args=new_args, | ||
func=cst.Attribute( | ||
value=cst.parse_expression(import_name), | ||
attr=cst.Name(value=true_name.split(".")[-1]), | ||
), | ||
) | ||
|
||
|
||
ETREE_METHODS = ["parse", "fromstring", "iterparse", "XMLParser"] | ||
SAX_METHODS = ["parse", "make_parser", "parseString"] | ||
DOM_METHODS = ["parse", "parseString"] | ||
# TODO: add expat methods? | ||
|
||
|
||
class UseDefusedXml(BaseCodemod): | ||
NAME = "use-defused-xml" | ||
SUMMARY = "Use defusedxml for parsing XML" | ||
REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_REVIEW | ||
DESCRIPTION = "Use method from defusedxml" | ||
|
||
# Mapping from original function to corresponding defusedxml module | ||
matching_functions: dict[str, str] = {} | ||
matching_functions.update( | ||
{ | ||
f"xml.etree.ElementTree.{method}": "defusedxml.ElementTree" | ||
for method in ETREE_METHODS | ||
} | ||
) | ||
matching_functions.update( | ||
{ | ||
# defusedxml.cElementTree is deprecated so we use defusedxml.ElementTree | ||
f"xml.etree.cElementTree.{method}": "defusedxml.ElementTree" | ||
for method in ETREE_METHODS | ||
} | ||
) | ||
matching_functions.update( | ||
{f"xml.sax.{method}": "defusedxml.sax" for method in SAX_METHODS} | ||
) | ||
matching_functions.update( | ||
{f"xml.dom.minidom.{method}": "defusedxml.minidom" for method in DOM_METHODS} | ||
) | ||
matching_functions.update( | ||
{f"xml.dom.pulldom.{method}": "defusedxml.pulldom" for method in DOM_METHODS} | ||
) | ||
|
||
def transform_module_impl(self, tree: cst.Module) -> cst.Module: | ||
visitor = DefusedXmlModifier( | ||
self.context, | ||
self.file_context, | ||
self.matching_functions, | ||
self.CHANGE_DESCRIPTION, | ||
) | ||
result_tree = visitor.transform_module(tree) | ||
self.file_context.codemod_changes.extend(visitor.changes_in_file) | ||
if visitor.changes_in_file: | ||
RemoveImportsVisitor.remove_unused_import_by_node(self.context, tree) | ||
# TODO: add dependency | ||
|
||
return result_tree |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.