Skip to content

Commit

Permalink
Codemod Lxml parser defaults (#61)
Browse files Browse the repository at this point in the history
* lxml parser defaults codemod

* generalize add arg to add args

* add integration test

* add docs

* add namedtuple NewArg

* fix summary, desc, and add test
  • Loading branch information
clavedeluna authored Oct 4, 2023
1 parent 61beb53 commit 0790cc2
Show file tree
Hide file tree
Showing 13 changed files with 262 additions and 38 deletions.
15 changes: 10 additions & 5 deletions integration_tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,16 @@ def setup_class(cls):
cls.codemod_registry = registry.load_registered_codemods()

def setup_method(self):
self.codemod_wrapper = [
cmod
for cmod in self.codemod_registry._codemods
if cmod.codemod == self.codemod
][0]
try:
self.codemod_wrapper = [
cmod
for cmod in self.codemod_registry._codemods
if cmod.codemod == self.codemod
][0]
except IndexError as exc:
raise IndexError(
"You must register the codemod to a CodemodCollection."
) from exc

def _assert_run_fields(self, run, output_path):
assert run["vendor"] == "pixee"
Expand Down
16 changes: 16 additions & 0 deletions integration_tests/test_lxml_safe_parser_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from core_codemods.lxml_safe_parser_defaults import LxmlSafeParserDefaults
from integration_tests.base_test import (
BaseIntegrationTest,
original_and_expected_from_code_path,
)


class TestLxmlSafeParserDefaults(BaseIntegrationTest):
codemod = LxmlSafeParserDefaults
code_path = "tests/samples/lxml_parser.py"
original_code, expected_new_code = original_and_expected_from_code_path(
code_path, [(1, "parser = lxml.etree.XMLParser(resolve_entities=False)\n")]
)
expected_diff = "--- \n+++ \n@@ -1,2 +1,2 @@\n import lxml\n-parser = lxml.etree.XMLParser()\n+parser = lxml.etree.XMLParser(resolve_entities=False)\n"
expected_line_change = "2"
change_description = LxmlSafeParserDefaults.CHANGE_DESCRIPTION
50 changes: 32 additions & 18 deletions src/codemodder/codemods/api/helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from collections import namedtuple
import libcst as cst
from libcst import matchers
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from codemodder.codemods.utils import get_call_name

NewArg = namedtuple("NewArg", ["name", "value", "add_if_missing"])


class Helpers:
def remove_unused_import(self, original_node):
Expand Down Expand Up @@ -47,34 +50,34 @@ def update_assign_rhs(self, updated_node: cst.Assign, rhs: str):
def parse_expression(self, expression: str):
return cst.parse_expression(expression)

def replace_arg(
self,
original_node,
target_arg_name,
target_arg_replacement_val,
add_if_missing=False,
):
"""Given a node, return its args with one arg's value changed.
def replace_args(self, original_node, args_info):
"""
Iterate over the args in original_node and replace each arg
with any matching arg in `args_info`.
If add_if_missing is True, then if target arg is not present, add it.
:param original_node: libcst node with args attribute.
:param list args_info: List of NewArg
"""
assert hasattr(original_node, "args")
assert all(
isinstance(arg, NewArg) for arg in args_info
), "`args_info` must contain `NewArg` types."
new_args = []
arg_added = False

for arg in original_node.args:
if matchers.matches(arg.keyword, matchers.Name(target_arg_name)):
new = self.make_new_arg(
target_arg_name, target_arg_replacement_val, arg
)
arg_added = True
arg_name, replacement_val, idx = _match_with_existing_arg(arg, args_info)
if arg_name is not None:
new = self.make_new_arg(arg_name, replacement_val, arg)
del args_info[idx]
else:
new = arg
new_args.append(new)

if add_if_missing and not arg_added:
new = self.make_new_arg(target_arg_name, target_arg_replacement_val)
new_args.append(new)
for arg_name, replacement_val, add_if_missing in args_info:
if add_if_missing:
new = self.make_new_arg(arg_name, replacement_val)
new_args.append(new)

return new_args

def make_new_arg(self, name, value, existing_arg=None):
Expand All @@ -91,3 +94,14 @@ def make_new_arg(self, name, value, existing_arg=None):
value=cst.parse_expression(value),
equal=equal,
)


def _match_with_existing_arg(arg, args_info):
"""
Given an `arg` and a list of arg info, determine if
any of the names in arg_info match the arg.
"""
for idx, (arg_name, replacement_val, _) in enumerate(args_info):
if matchers.matches(arg.keyword, matchers.Name(arg_name)):
return arg_name, replacement_val, idx
return None, None, None
2 changes: 2 additions & 0 deletions src/core_codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .https_connection import HTTPSConnection
from .jwt_decode_verify import JwtDecodeVerify
from .limit_readline import LimitReadline
from .lxml_safe_parser_defaults import LxmlSafeParserDefaults
from .order_imports import OrderImports
from .process_creation_sandbox import ProcessSandbox
from .remove_unnecessary_f_str import RemoveUnnecessaryFStr
Expand All @@ -35,6 +36,7 @@
HTTPSConnection,
JwtDecodeVerify,
LimitReadline,
LxmlSafeParserDefaults,
OrderImports,
ProcessSandbox,
RemoveUnnecessaryFStr,
Expand Down
22 changes: 22 additions & 0 deletions src/core_codemods/docs/pixee_python_safe-lxml-parser-defaults.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
This codemod configures safe parameter values when initializing `lxml.etree.XMLParser`, `lxml.etree.ETCompatXMLParser`,
`lxml.etree.XMLTreeBuilder`, or `lxml.etree.XMLPullParser`. If parameters `resolve_entities`, `no_network`,
and `dtd_validation` are not set to safe values, your code may be vulnerable to entity expansion
attacks and external entity (XXE) attacks.

Parameters `no_network` and `dtd_validation` have safe default values of `True` and `False`, respectively, so this
codemod will set each to the default safe value if your code has assigned either to an unsafe value.

Parameter `resolve_entities` has an unsafe default value of `True`. This codemod will set `resolve_entities=False` if set to `True` or omitted.

The changes look as follows:

```diff
import lxml

- parser = lxml.etree.XMLParser()
- parser = lxml.etree.XMLParser(resolve_entities=True)
- parser = lxml.etree.XMLParser(resolve_entities=True, no_network=False, dtd_validation=True)
+ parser = lxml.etree.XMLParser(resolve_entities=False)
+ parser = lxml.etree.XMLParser(resolve_entities=False)
+ parser = lxml.etree.XMLParser(resolve_entities=False, no_network=True, dtd_validation=False)
```
6 changes: 4 additions & 2 deletions src/core_codemods/enable_jinja2_autoescape.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.api.helpers import NewArg


class EnableJinja2Autoescape(SemgrepCodemod):
Expand All @@ -21,7 +22,8 @@ def rule(cls):
"""

def on_result_found(self, original_node, updated_node):
new_args = self.replace_arg(
original_node, "autoescape", "True", add_if_missing=True
new_args = self.replace_args(
original_node,
[NewArg(name="autoescape", value="True", add_if_missing=True)],
)
return self.update_arg_target(updated_node, new_args)
5 changes: 4 additions & 1 deletion src/core_codemods/harden_ruamel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.api.helpers import NewArg


class HardenRuamel(SemgrepCodemod):
Expand Down Expand Up @@ -27,5 +28,7 @@ def rule(cls):
"""

def on_result_found(self, original_node, updated_node):
new_args = self.replace_arg(original_node, "typ", '"safe"')
new_args = self.replace_args(
original_node, [NewArg(name="typ", value='"safe"', add_if_missing=False)]
)
return self.update_arg_target(updated_node, new_args)
15 changes: 6 additions & 9 deletions src/core_codemods/jwt_decode_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from libcst import matchers
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.api.helpers import NewArg


class JwtDecodeVerify(SemgrepCodemod):
Expand Down Expand Up @@ -64,18 +65,14 @@ def replace_options_arg(self, node_args):
new_args.append(new)
return new_args

def replace_arg(
self,
original_node,
target_arg_name,
target_arg_replacement_val,
add_if_missing=False,
):
new_args = super().replace_arg(original_node, "verify", "True")
def replace_args(self, original_node, args_info):
new_args = super().replace_args(original_node, args_info)
return self.replace_options_arg(new_args)

def on_result_found(self, original_node, updated_node):
new_args = self.replace_arg(original_node, "verify", "True")
new_args = self.replace_args(
original_node, [NewArg(name="verify", value="True", add_if_missing=False)]
)
return self.update_arg_target(updated_node, new_args)


Expand Down
43 changes: 43 additions & 0 deletions src/core_codemods/lxml_safe_parser_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.api.helpers import NewArg


class LxmlSafeParserDefaults(SemgrepCodemod):
NAME = "safe-lxml-parser-defaults"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
SUMMARY = "Use safe defaults for lxml parsers"
DESCRIPTION = "Replace lxml parser parameters with safe defaults"

@classmethod
def rule(cls):
return """
rules:
- patterns:
- pattern: lxml.etree.$CLASS(...)
- pattern-not: lxml.etree.$CLASS(..., resolve_entities=False, ...)
- pattern-not: lxml.etree.$CLASS(..., no_network=True, ..., resolve_entities=False, ...)
- pattern-not: lxml.etree.$CLASS(..., dtd_validation=False, ..., resolve_entities=False, ...)
- metavariable-pattern:
metavariable: $CLASS
patterns:
- pattern-either:
- pattern: XMLParser
- pattern: ETCompatXMLParser
- pattern: XMLTreeBuilder
- pattern: XMLPullParser
- pattern-inside: |
import lxml
...
"""

def on_result_found(self, original_node, updated_node):
new_args = self.replace_args(
original_node,
[
NewArg(name="resolve_entities", value="False", add_if_missing=True),
NewArg(name="no_network", value="True", add_if_missing=False),
NewArg(name="dtd_validation", value="False", add_if_missing=False),
],
)
return self.update_arg_target(updated_node, new_args)
5 changes: 4 additions & 1 deletion src/core_codemods/requests_verify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.api.helpers import NewArg


class RequestsVerify(SemgrepCodemod):
Expand All @@ -22,5 +23,7 @@ def rule(cls):
"""

def on_result_found(self, original_node, updated_node):
new_args = self.replace_arg(original_node, "verify", "True")
new_args = self.replace_args(
original_node, [NewArg(name="verify", value="True", add_if_missing=False)]
)
return self.update_arg_target(updated_node, new_args)
4 changes: 2 additions & 2 deletions tests/codemods/test_enable_jinja2_autoescape.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
class TestEnableJinja2Autoescape(BaseSemgrepCodemodTest):
codemod = EnableJinja2Autoescape

def test_rule_ids(self):
assert self.codemod.NAME == "enable-jinja2-autoescape"
def test_name(self):
assert self.codemod.name() == "enable-jinja2-autoescape"

def test_import(self, tmpdir):
input_code = """import jinja2
Expand Down
Loading

0 comments on commit 0790cc2

Please sign in to comment.