Skip to content

Commit

Permalink
lxml parser defaults codemod
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Oct 3, 2023
1 parent dfdbba0 commit 3b35355
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 2 deletions.
34 changes: 34 additions & 0 deletions src/codemodder/codemods/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,29 @@ def update_assign_rhs(self, updated_node: cst.Assign, rhs: str):
def parse_expression(self, expression: str):
return cst.parse_expression(expression)

def replace_args(
self,
original_node,
args_info,
):
new_args = []

for arg in original_node.args:
arg_name, replacement_val, idx = _match_with_existing_arg(arg, args_info)
if arg_name is not None:
new = self.make_new_arg(arg_name, replacement_val, arg)
del args_info[idx]
else:
new = arg
new_args.append(new)

for arg_name, replacement_val, add_if_missing in args_info:
if add_if_missing:
new = self.make_new_arg(arg_name, replacement_val)
new_args.append(new)

return new_args

def replace_arg(
self,
original_node,
Expand Down Expand Up @@ -91,3 +114,14 @@ def make_new_arg(self, name, value, existing_arg=None):
value=cst.parse_expression(value),
equal=equal,
)


def _match_with_existing_arg(arg, args_info):
"""
Given an `arg` and a list of arg info, determine if
any of the names in arg_info match the arg.
"""
for idx, (arg_name, replacement_val, add_if_missing) in enumerate(args_info):
if matchers.matches(arg.keyword, matchers.Name(arg_name)):
return arg_name, replacement_val, idx
return None, None, None
42 changes: 42 additions & 0 deletions src/core_codemods/lxml_safe_parser_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod


class LxmlSafeParserDefaults(SemgrepCodemod):
NAME = "safe-lxml-parser-defaults"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW
SUMMARY = "Enable all security checks in `lxml.etree.XMLParser` call."
DESCRIPTION = "...........TODO"

@classmethod
def rule(cls):
return """
rules:
- patterns:
- pattern: lxml.etree.$CLASS(...)
- pattern-not: lxml.etree.$CLASS(..., resolve_entities=False, ...)
- pattern-not: lxml.etree.$CLASS(..., no_network=True, ..., resolve_entities=False, ...)
- pattern-not: lxml.etree.$CLASS(..., dtd_validation=False, ..., resolve_entities=False, ...)
- metavariable-pattern:
metavariable: $CLASS
patterns:
- pattern-either:
- pattern: XMLParser
- pattern: ETCompatXMLParser
- pattern: XMLTreeBuilder
- pattern: XMLPullParser
- pattern-inside: |
import lxml
...
"""

def on_result_found(self, original_node, updated_node):
new_args = self.replace_args(
original_node,
[
("resolve_entities", "False", True),
("no_network", "True", False),
("dtd_validation", "False", False),
],
)
return self.update_arg_target(updated_node, new_args)
4 changes: 2 additions & 2 deletions tests/codemods/test_enable_jinja2_autoescape.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
class TestEnableJinja2Autoescape(BaseSemgrepCodemodTest):
codemod = EnableJinja2Autoescape

def test_rule_ids(self):
assert self.codemod.NAME == "enable-jinja2-autoescape"
def test_name(self):
assert self.codemod.name() == "enable-jinja2-autoescape"

def test_import(self, tmpdir):
input_code = """import jinja2
Expand Down
100 changes: 100 additions & 0 deletions tests/codemods/test_lxml_safe_parameter_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
from core_codemods.lxml_safe_parser_defaults import LxmlSafeParserDefaults
from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest

each_class = pytest.mark.parametrize(
"klass", ["XMLParser", "ETCompatXMLParser", "XMLTreeBuilder", "XMLPullParser"]
)


class TestJwtDecodeVerify(BaseSemgrepCodemodTest):
codemod = LxmlSafeParserDefaults

def test_name(self):
assert self.codemod.name() == "safe-lxml-parser-defaults"

@each_class
def test_import(self, tmpdir, klass):
input_code = f"""import lxml
parser = lxml.etree.{klass}()
var = "hello"
"""
expexted_output = f"""import lxml
parser = lxml.etree.{klass}(resolve_entities=False)
var = "hello"
"""

self.run_and_assert(tmpdir, input_code, expexted_output)

@each_class
def test_from_import(self, tmpdir, klass):
input_code = f"""from lxml.etree import {klass}
parser = {klass}()
var = "hello"
"""
expexted_output = f"""from lxml.etree import {klass}
parser = {klass}(resolve_entities=False)
var = "hello"
"""

self.run_and_assert(tmpdir, input_code, expexted_output)

@each_class
def test_import_alias(self, tmpdir, klass):
input_code = f"""from lxml.etree import {klass} as xmlklass
parser = xmlklass()
var = "hello"
"""
expexted_output = f"""from lxml.etree import {klass} as xmlklass
parser = xmlklass(resolve_entities=False)
var = "hello"
"""

self.run_and_assert(tmpdir, input_code, expexted_output)

@pytest.mark.parametrize(
"input_args,expected_args",
[
(
"resolve_entities=True",
"resolve_entities=False",
),
(
"resolve_entities=False",
"resolve_entities=False",
),
(
"""resolve_entities=True, no_network=False, dtd_validation=True""",
"""resolve_entities=False, no_network=True, dtd_validation=False""",
),
(
"""dtd_validation=True""",
"""dtd_validation=False, resolve_entities=False""",
),
(
"""no_network=False""",
"""no_network=True, resolve_entities=False""",
),
(
"""no_network=True""",
"""no_network=True, resolve_entities=False""",
),
],
)
@each_class
def test_verify_variations(self, tmpdir, klass, input_args, expected_args):
input_code = f"""import lxml
parser = lxml.etree.{klass}({input_args})
var = "hello"
"""
expexted_output = f"""import lxml
parser = lxml.etree.{klass}({expected_args})
var = "hello"
"""
self.run_and_assert(tmpdir, input_code, expexted_output)

0 comments on commit 3b35355

Please sign in to comment.