-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement codemod for replacing unsafe xml methods with defusedxml #76
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import abc | ||
from typing import Generic, Mapping, Sequence, Set, TypeVar, Union | ||
|
||
import libcst as cst | ||
from libcst import matchers | ||
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand | ||
from libcst.metadata import PositionProvider | ||
|
||
from codemodder.change import Change | ||
from codemodder.codemods.utils_mixin import NameResolutionMixin | ||
from codemodder.file_context import FileContext | ||
|
||
|
||
# It seems to me like we actually want two separate bounds instead of a Union but this is what mypy wants | ||
FunctionMatchType = TypeVar("FunctionMatchType", bound=Union[Mapping, Set]) | ||
|
||
|
||
class ImportedCallModifier( | ||
Generic[FunctionMatchType], | ||
VisitorBasedCodemodCommand, | ||
NameResolutionMixin, | ||
metaclass=abc.ABCMeta, | ||
): | ||
METADATA_DEPENDENCIES = (PositionProvider,) | ||
|
||
def __init__( | ||
self, | ||
codemod_context: CodemodContext, | ||
file_context: FileContext, | ||
matching_functions: FunctionMatchType, | ||
change_description: str, | ||
): | ||
super().__init__(codemod_context) | ||
self.line_exclude = file_context.line_exclude | ||
self.line_include = file_context.line_include | ||
self.matching_functions: FunctionMatchType = matching_functions | ||
self.change_description = change_description | ||
self.changes_in_file: list[Mapping] = [] | ||
|
||
def updated_args(self, original_args: Sequence[cst.Arg]): | ||
return original_args | ||
|
||
@abc.abstractmethod | ||
def update_attribute( | ||
self, | ||
true_name: str, | ||
original_node: cst.Call, | ||
updated_node: cst.Call, | ||
new_args: Sequence[cst.Arg], | ||
): | ||
"""Callback to modify tree when the detected call is of the form a.call()""" | ||
|
||
@abc.abstractmethod | ||
def update_simple_name( | ||
self, | ||
true_name: str, | ||
original_node: cst.Call, | ||
updated_node: cst.Call, | ||
new_args: Sequence[cst.Arg], | ||
): | ||
"""Callback to modify tree when the detected call is of the form call()""" | ||
|
||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): | ||
pos_to_match = self.node_position(original_node) | ||
line_number = pos_to_match.start.line | ||
if self.filter_by_path_includes_or_excludes(pos_to_match): | ||
drdavella marked this conversation as resolved.
Show resolved
Hide resolved
|
||
true_name = self.find_base_name(original_node.func) | ||
if ( | ||
self._is_direct_call_from_imported_module(original_node) | ||
and true_name | ||
and true_name in self.matching_functions | ||
): | ||
self.changes_in_file.append( | ||
Change(str(line_number), self.change_description).to_json() | ||
) | ||
|
||
new_args = self.updated_args(updated_node.args) | ||
|
||
# has a prefix, e.g. a.call() -> a.new_call() | ||
if matchers.matches(original_node.func, matchers.Attribute()): | ||
return self.update_attribute( | ||
true_name, original_node, updated_node, new_args | ||
) | ||
|
||
# it is a simple name, e.g. call() -> module.new_call() | ||
return self.update_simple_name( | ||
true_name, original_node, updated_node, new_args | ||
) | ||
Comment on lines
+85
to
+88
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a small clarification om this bit of code here. from lib import call
call() We have the option to simply change the call... from lib import call
import lib2
lib2.call2() or changing the call and import: from lib2 import call2
call2() The second has an advantage of not having a possibly useless import (not much of a problem since we have a codemod to amend that). Then again, then reason why I've favored the first one for the https codemod was because the change became more explicit. The second transformation is a bit harder to do but within the realm of possibility. Just tell me if you want me to support that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @andrecsilva that makes sense. It might be good to revisit at some point but right now I like that the updated call becomes completely explicit. |
||
|
||
return updated_node | ||
|
||
def filter_by_path_includes_or_excludes(self, pos_to_match): | ||
""" | ||
Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags. | ||
""" | ||
# excludes takes precedence if defined | ||
if self.line_exclude: | ||
return not any(match_line(pos_to_match, line) for line in self.line_exclude) | ||
if self.line_include: | ||
return any(match_line(pos_to_match, line) for line in self.line_include) | ||
return True | ||
|
||
def node_position(self, node): | ||
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112 | ||
return self.get_metadata(PositionProvider, node) | ||
|
||
|
||
def match_line(pos, line): | ||
return pos.start.line == line and pos.end.line == line | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,135 +1,82 @@ | ||
from typing import Sequence | ||
from libcst import matchers | ||
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor | ||
from libcst.metadata import ( | ||
PositionProvider, | ||
) | ||
from codemodder.codemods.base_codemod import ( | ||
BaseCodemod, | ||
CodemodMetadata, | ||
ReviewGuidance, | ||
) | ||
from codemodder.change import Change | ||
from codemodder.codemods.utils_mixin import NameResolutionMixin | ||
from codemodder.file_context import FileContext | ||
from typing import Sequence, Set | ||
|
||
import libcst as cst | ||
from libcst.codemod import ( | ||
Codemod, | ||
CodemodContext, | ||
VisitorBasedCodemodCommand, | ||
) | ||
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor | ||
from libcst.metadata import PositionProvider | ||
|
||
from codemodder.codemods.base_codemod import ReviewGuidance | ||
from codemodder.codemods.api import BaseCodemod | ||
from codemodder.codemods.imported_call_modifier import ImportedCallModifier | ||
|
||
class HTTPSConnection(BaseCodemod, Codemod): | ||
METADATA = CodemodMetadata( | ||
DESCRIPTION="Enforce HTTPS connection for `urllib3`.", | ||
NAME="https-connection", | ||
REVIEW_GUIDANCE=ReviewGuidance.MERGE_WITHOUT_REVIEW, | ||
REFERENCES=[ | ||
{ | ||
"url": "https://owasp.org/www-community/vulnerabilities/Insecure_Transport", | ||
"description": "", | ||
}, | ||
{ | ||
"url": "https://urllib3.readthedocs.io/en/stable/reference/urllib3.connectionpool.html#urllib3.HTTPConnectionPool", | ||
"description": "", | ||
}, | ||
], | ||
) | ||
CHANGE_DESCRIPTION = METADATA.DESCRIPTION | ||
SUMMARY = ( | ||
"Changes HTTPConnectionPool to HTTPSConnectionPool to Enforce Secure Connection" | ||
) | ||
|
||
class HTTPSConnectionModifier(ImportedCallModifier[Set[str]]): | ||
def updated_args(self, original_args): | ||
""" | ||
Last argument _proxy_config does not match new method | ||
|
||
We convert it to keyword | ||
""" | ||
new_args = list(original_args) | ||
if self.count_positional_args(new_args) == 10: | ||
new_args[9] = new_args[9].with_changes( | ||
keyword=cst.parse_expression("_proxy_config") | ||
) | ||
return new_args | ||
|
||
def update_attribute(self, true_name, original_node, updated_node, new_args): | ||
del true_name, original_node | ||
return updated_node.with_changes( | ||
args=new_args, | ||
func=updated_node.func.with_changes( | ||
attr=cst.Name(value="HTTPSConnectionPool") | ||
), | ||
) | ||
|
||
def update_simple_name(self, true_name, original_node, updated_node, new_args): | ||
del true_name | ||
AddImportsVisitor.add_needed_import(self.context, "urllib3") | ||
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) | ||
return updated_node.with_changes( | ||
args=new_args, | ||
func=cst.parse_expression("urllib3.HTTPSConnectionPool"), | ||
) | ||
|
||
def count_positional_args(self, arglist: Sequence[cst.Arg]) -> int: | ||
for idx, arg in enumerate(arglist): | ||
if arg.keyword: | ||
return idx | ||
return len(arglist) | ||
|
||
|
||
class HTTPSConnection(BaseCodemod): | ||
SUMMARY = "Enforce HTTPS Connection for `urllib3`" | ||
NAME = "https-connection" | ||
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW | ||
REFERENCES = [ | ||
{ | ||
"url": "https://owasp.org/www-community/vulnerabilities/Insecure_Transport", | ||
"description": "", | ||
}, | ||
{ | ||
"url": "https://urllib3.readthedocs.io/en/stable/reference/urllib3.connectionpool.html#urllib3.HTTPConnectionPool", | ||
"description": "", | ||
}, | ||
] | ||
|
||
METADATA_DEPENDENCIES = (PositionProvider,) | ||
|
||
matching_functions = { | ||
matching_functions: set[str] = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could have this be a keys-only dict. It would feel more consistent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not aware of any such construct; do you just mean a dict where the values are |
||
"urllib3.HTTPConnectionPool", | ||
"urllib3.connectionpool.HTTPConnectionPool", | ||
} | ||
|
||
def __init__(self, codemod_context: CodemodContext, *codemod_args): | ||
Codemod.__init__(self, codemod_context) | ||
BaseCodemod.__init__(self, *codemod_args) | ||
|
||
def transform_module_impl(self, tree: cst.Module) -> cst.Module: | ||
visitor = ConnectionPollVisitor(self.context, self.file_context) | ||
visitor = HTTPSConnectionModifier( | ||
self.context, | ||
self.file_context, | ||
self.matching_functions, | ||
self.CHANGE_DESCRIPTION, | ||
) | ||
result_tree = visitor.transform_module(tree) | ||
self.file_context.codemod_changes.extend(visitor.changes_in_file) | ||
return result_tree | ||
|
||
|
||
class ConnectionPollVisitor(VisitorBasedCodemodCommand, NameResolutionMixin): | ||
METADATA_DEPENDENCIES = (PositionProvider,) | ||
|
||
def __init__(self, codemod_context: CodemodContext, file_context: FileContext): | ||
super().__init__(codemod_context) | ||
self.line_exclude = file_context.line_exclude | ||
self.line_include = file_context.line_include | ||
self.changes_in_file: list[Change] = [] | ||
|
||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): | ||
pos_to_match = self.node_position(original_node) | ||
line_number = pos_to_match.start.line | ||
if self.filter_by_path_includes_or_excludes(pos_to_match): | ||
true_name = self.find_base_name(original_node.func) | ||
if ( | ||
self._is_direct_call_from_imported_module(original_node) | ||
and true_name in HTTPSConnection.matching_functions | ||
): | ||
self.changes_in_file.append( | ||
Change( | ||
str(line_number), HTTPSConnection.CHANGE_DESCRIPTION | ||
).to_json() | ||
) | ||
# last argument _proxy_config does not match new method | ||
# we convert it to keyword | ||
new_args = list(original_node.args) | ||
if count_positional_args(original_node.args) == 10: | ||
new_args[9] = new_args[9].with_changes( | ||
keyword=cst.parse_expression("_proxy_config") | ||
) | ||
# has a prefix, e.g. a.call() -> a.new_call() | ||
if matchers.matches(original_node.func, matchers.Attribute()): | ||
return updated_node.with_changes( | ||
args=new_args, | ||
func=updated_node.func.with_changes( | ||
attr=cst.parse_expression("HTTPSConnectionPool") | ||
), | ||
) | ||
# it is a simple name, e.g. call() -> module.new_call() | ||
AddImportsVisitor.add_needed_import(self.context, "urllib3") | ||
RemoveImportsVisitor.remove_unused_import_by_node( | ||
self.context, original_node | ||
) | ||
return updated_node.with_changes( | ||
args=new_args, | ||
func=cst.parse_expression("urllib3.HTTPSConnectionPool"), | ||
) | ||
return updated_node | ||
|
||
def filter_by_path_includes_or_excludes(self, pos_to_match): | ||
""" | ||
Returns False if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags. | ||
""" | ||
# excludes takes precedence if defined | ||
if self.line_exclude: | ||
return not any(match_line(pos_to_match, line) for line in self.line_exclude) | ||
if self.line_include: | ||
return any(match_line(pos_to_match, line) for line in self.line_include) | ||
return True | ||
|
||
def node_position(self, node): | ||
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112 | ||
return self.get_metadata(PositionProvider, node) | ||
|
||
|
||
def match_line(pos, line): | ||
return pos.start.line == line and pos.end.line == line | ||
|
||
|
||
def count_positional_args(arglist: Sequence[cst.Arg]) -> int: | ||
for i, arg in enumerate(arglist): | ||
if arg.keyword: | ||
return i | ||
return len(arglist) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was mypy, right? Isn't it weird that we fix a typing problem by ... removing typing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I (wrongly) introduced this on a previous PR when
mypy
wasn't working for me. The underlying code itself is not typed correctly and I thought adding this would fix it (but it does not).