diff --git a/integration_tests/base_test.py b/integration_tests/base_test.py index 7de37135..c3922f1b 100644 --- a/integration_tests/base_test.py +++ b/integration_tests/base_test.py @@ -57,11 +57,16 @@ def setup_class(cls): cls.codemod_registry = registry.load_registered_codemods() def setup_method(self): - self.codemod_wrapper = [ - cmod - for cmod in self.codemod_registry._codemods - if cmod.codemod == self.codemod - ][0] + try: + self.codemod_wrapper = [ + cmod + for cmod in self.codemod_registry._codemods + if cmod.codemod == self.codemod + ][0] + except IndexError as exc: + raise IndexError( + "You must register the codemod to a CodemodCollection." + ) from exc def _assert_run_fields(self, run, output_path): assert run["vendor"] == "pixee" diff --git a/integration_tests/test_lxml_safe_parser_defaults.py b/integration_tests/test_lxml_safe_parser_defaults.py new file mode 100644 index 00000000..10ad552a --- /dev/null +++ b/integration_tests/test_lxml_safe_parser_defaults.py @@ -0,0 +1,16 @@ +from core_codemods.lxml_safe_parser_defaults import LxmlSafeParserDefaults +from integration_tests.base_test import ( + BaseIntegrationTest, + original_and_expected_from_code_path, +) + + +class TestLxmlSafeParserDefaults(BaseIntegrationTest): + codemod = LxmlSafeParserDefaults + code_path = "tests/samples/lxml_parser.py" + original_code, expected_new_code = original_and_expected_from_code_path( + code_path, [(1, "parser = lxml.etree.XMLParser(resolve_entities=False)\n")] + ) + expected_diff = "--- \n+++ \n@@ -1,2 +1,2 @@\n import lxml\n-parser = lxml.etree.XMLParser()\n+parser = lxml.etree.XMLParser(resolve_entities=False)\n" + expected_line_change = "2" + change_description = LxmlSafeParserDefaults.CHANGE_DESCRIPTION diff --git a/src/codemodder/codemods/api/helpers.py b/src/codemodder/codemods/api/helpers.py index 15cab470..93073810 100644 --- a/src/codemodder/codemods/api/helpers.py +++ b/src/codemodder/codemods/api/helpers.py @@ -1,8 +1,11 @@ +from collections import namedtuple import libcst as cst from libcst import matchers from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor from codemodder.codemods.utils import get_call_name +NewArg = namedtuple("NewArg", ["name", "value", "add_if_missing"]) + class Helpers: def remove_unused_import(self, original_node): @@ -47,34 +50,34 @@ def update_assign_rhs(self, updated_node: cst.Assign, rhs: str): def parse_expression(self, expression: str): return cst.parse_expression(expression) - def replace_arg( - self, - original_node, - target_arg_name, - target_arg_replacement_val, - add_if_missing=False, - ): - """Given a node, return its args with one arg's value changed. + def replace_args(self, original_node, args_info): + """ + Iterate over the args in original_node and replace each arg + with any matching arg in `args_info`. - If add_if_missing is True, then if target arg is not present, add it. + :param original_node: libcst node with args attribute. + :param list args_info: List of NewArg """ assert hasattr(original_node, "args") + assert all( + isinstance(arg, NewArg) for arg in args_info + ), "`args_info` must contain `NewArg` types." new_args = [] - arg_added = False for arg in original_node.args: - if matchers.matches(arg.keyword, matchers.Name(target_arg_name)): - new = self.make_new_arg( - target_arg_name, target_arg_replacement_val, arg - ) - arg_added = True + arg_name, replacement_val, idx = _match_with_existing_arg(arg, args_info) + if arg_name is not None: + new = self.make_new_arg(arg_name, replacement_val, arg) + del args_info[idx] else: new = arg new_args.append(new) - if add_if_missing and not arg_added: - new = self.make_new_arg(target_arg_name, target_arg_replacement_val) - new_args.append(new) + for arg_name, replacement_val, add_if_missing in args_info: + if add_if_missing: + new = self.make_new_arg(arg_name, replacement_val) + new_args.append(new) + return new_args def make_new_arg(self, name, value, existing_arg=None): @@ -91,3 +94,14 @@ def make_new_arg(self, name, value, existing_arg=None): value=cst.parse_expression(value), equal=equal, ) + + +def _match_with_existing_arg(arg, args_info): + """ + Given an `arg` and a list of arg info, determine if + any of the names in arg_info match the arg. + """ + for idx, (arg_name, replacement_val, _) in enumerate(args_info): + if matchers.matches(arg.keyword, matchers.Name(arg_name)): + return arg_name, replacement_val, idx + return None, None, None diff --git a/src/core_codemods/__init__.py b/src/core_codemods/__init__.py index 5b550662..eb1cff26 100644 --- a/src/core_codemods/__init__.py +++ b/src/core_codemods/__init__.py @@ -9,6 +9,7 @@ from .https_connection import HTTPSConnection from .jwt_decode_verify import JwtDecodeVerify from .limit_readline import LimitReadline +from .lxml_safe_parser_defaults import LxmlSafeParserDefaults from .order_imports import OrderImports from .process_creation_sandbox import ProcessSandbox from .remove_unnecessary_f_str import RemoveUnnecessaryFStr @@ -35,6 +36,7 @@ HTTPSConnection, JwtDecodeVerify, LimitReadline, + LxmlSafeParserDefaults, OrderImports, ProcessSandbox, RemoveUnnecessaryFStr, diff --git a/src/core_codemods/docs/pixee_python_safe-lxml-parser-defaults.md b/src/core_codemods/docs/pixee_python_safe-lxml-parser-defaults.md new file mode 100644 index 00000000..541145a2 --- /dev/null +++ b/src/core_codemods/docs/pixee_python_safe-lxml-parser-defaults.md @@ -0,0 +1,22 @@ +This codemod configures safe parameter values when initializing `lxml.etree.XMLParser`, `lxml.etree.ETCompatXMLParser`, +`lxml.etree.XMLTreeBuilder`, or `lxml.etree.XMLPullParser`. If parameters `resolve_entities`, `no_network`, +and `dtd_validation` are not set to safe values, your code may be vulnerable to entity expansion +attacks and external entity (XXE) attacks. + +Parameters `no_network` and `dtd_validation` have safe default values of `True` and `False`, respectively, so this +codemod will set each to the default safe value if your code has assigned either to an unsafe value. + +Parameter `resolve_entities` has an unsafe default value of `True`. This codemod will set `resolve_entities=False` if set to `True` or omitted. + +The changes look as follows: + +```diff + import lxml + +- parser = lxml.etree.XMLParser() +- parser = lxml.etree.XMLParser(resolve_entities=True) +- parser = lxml.etree.XMLParser(resolve_entities=True, no_network=False, dtd_validation=True) ++ parser = lxml.etree.XMLParser(resolve_entities=False) ++ parser = lxml.etree.XMLParser(resolve_entities=False) ++ parser = lxml.etree.XMLParser(resolve_entities=False, no_network=True, dtd_validation=False) +``` diff --git a/src/core_codemods/enable_jinja2_autoescape.py b/src/core_codemods/enable_jinja2_autoescape.py index 36e5f14a..8e8fe798 100644 --- a/src/core_codemods/enable_jinja2_autoescape.py +++ b/src/core_codemods/enable_jinja2_autoescape.py @@ -1,5 +1,6 @@ from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.api.helpers import NewArg class EnableJinja2Autoescape(SemgrepCodemod): @@ -21,7 +22,8 @@ def rule(cls): """ def on_result_found(self, original_node, updated_node): - new_args = self.replace_arg( - original_node, "autoescape", "True", add_if_missing=True + new_args = self.replace_args( + original_node, + [NewArg(name="autoescape", value="True", add_if_missing=True)], ) return self.update_arg_target(updated_node, new_args) diff --git a/src/core_codemods/harden_ruamel.py b/src/core_codemods/harden_ruamel.py index 2b12b552..86c9a7c3 100644 --- a/src/core_codemods/harden_ruamel.py +++ b/src/core_codemods/harden_ruamel.py @@ -1,5 +1,6 @@ from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.api.helpers import NewArg class HardenRuamel(SemgrepCodemod): @@ -27,5 +28,7 @@ def rule(cls): """ def on_result_found(self, original_node, updated_node): - new_args = self.replace_arg(original_node, "typ", '"safe"') + new_args = self.replace_args( + original_node, [NewArg(name="typ", value='"safe"', add_if_missing=False)] + ) return self.update_arg_target(updated_node, new_args) diff --git a/src/core_codemods/jwt_decode_verify.py b/src/core_codemods/jwt_decode_verify.py index ae328008..f1e4f145 100644 --- a/src/core_codemods/jwt_decode_verify.py +++ b/src/core_codemods/jwt_decode_verify.py @@ -2,6 +2,7 @@ from libcst import matchers from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.api.helpers import NewArg class JwtDecodeVerify(SemgrepCodemod): @@ -64,18 +65,14 @@ def replace_options_arg(self, node_args): new_args.append(new) return new_args - def replace_arg( - self, - original_node, - target_arg_name, - target_arg_replacement_val, - add_if_missing=False, - ): - new_args = super().replace_arg(original_node, "verify", "True") + def replace_args(self, original_node, args_info): + new_args = super().replace_args(original_node, args_info) return self.replace_options_arg(new_args) def on_result_found(self, original_node, updated_node): - new_args = self.replace_arg(original_node, "verify", "True") + new_args = self.replace_args( + original_node, [NewArg(name="verify", value="True", add_if_missing=False)] + ) return self.update_arg_target(updated_node, new_args) diff --git a/src/core_codemods/lxml_safe_parser_defaults.py b/src/core_codemods/lxml_safe_parser_defaults.py new file mode 100644 index 00000000..e7b4dc2e --- /dev/null +++ b/src/core_codemods/lxml_safe_parser_defaults.py @@ -0,0 +1,43 @@ +from codemodder.codemods.base_codemod import ReviewGuidance +from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.api.helpers import NewArg + + +class LxmlSafeParserDefaults(SemgrepCodemod): + NAME = "safe-lxml-parser-defaults" + REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW + SUMMARY = "Use safe defaults for lxml parsers" + DESCRIPTION = "Replace lxml parser parameters with safe defaults" + + @classmethod + def rule(cls): + return """ + rules: + - patterns: + - pattern: lxml.etree.$CLASS(...) + - pattern-not: lxml.etree.$CLASS(..., resolve_entities=False, ...) + - pattern-not: lxml.etree.$CLASS(..., no_network=True, ..., resolve_entities=False, ...) + - pattern-not: lxml.etree.$CLASS(..., dtd_validation=False, ..., resolve_entities=False, ...) + - metavariable-pattern: + metavariable: $CLASS + patterns: + - pattern-either: + - pattern: XMLParser + - pattern: ETCompatXMLParser + - pattern: XMLTreeBuilder + - pattern: XMLPullParser + - pattern-inside: | + import lxml + ... + """ + + def on_result_found(self, original_node, updated_node): + new_args = self.replace_args( + original_node, + [ + NewArg(name="resolve_entities", value="False", add_if_missing=True), + NewArg(name="no_network", value="True", add_if_missing=False), + NewArg(name="dtd_validation", value="False", add_if_missing=False), + ], + ) + return self.update_arg_target(updated_node, new_args) diff --git a/src/core_codemods/requests_verify.py b/src/core_codemods/requests_verify.py index c8226f54..c01d4594 100644 --- a/src/core_codemods/requests_verify.py +++ b/src/core_codemods/requests_verify.py @@ -1,5 +1,6 @@ from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.api.helpers import NewArg class RequestsVerify(SemgrepCodemod): @@ -22,5 +23,7 @@ def rule(cls): """ def on_result_found(self, original_node, updated_node): - new_args = self.replace_arg(original_node, "verify", "True") + new_args = self.replace_args( + original_node, [NewArg(name="verify", value="True", add_if_missing=False)] + ) return self.update_arg_target(updated_node, new_args) diff --git a/tests/codemods/test_enable_jinja2_autoescape.py b/tests/codemods/test_enable_jinja2_autoescape.py index f821d919..055f537e 100644 --- a/tests/codemods/test_enable_jinja2_autoescape.py +++ b/tests/codemods/test_enable_jinja2_autoescape.py @@ -5,8 +5,8 @@ class TestEnableJinja2Autoescape(BaseSemgrepCodemodTest): codemod = EnableJinja2Autoescape - def test_rule_ids(self): - assert self.codemod.NAME == "enable-jinja2-autoescape" + def test_name(self): + assert self.codemod.name() == "enable-jinja2-autoescape" def test_import(self, tmpdir): input_code = """import jinja2 diff --git a/tests/codemods/test_lxml_safe_parameter_defaults.py b/tests/codemods/test_lxml_safe_parameter_defaults.py new file mode 100644 index 00000000..fd6dfba2 --- /dev/null +++ b/tests/codemods/test_lxml_safe_parameter_defaults.py @@ -0,0 +1,115 @@ +import pytest +from core_codemods.lxml_safe_parser_defaults import LxmlSafeParserDefaults +from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest + +each_class = pytest.mark.parametrize( + "klass", ["XMLParser", "ETCompatXMLParser", "XMLTreeBuilder", "XMLPullParser"] +) + + +class TestLxmlSafeParserDefaults(BaseSemgrepCodemodTest): + codemod = LxmlSafeParserDefaults + + def test_name(self): + assert self.codemod.name() == "safe-lxml-parser-defaults" + + @each_class + def test_import(self, tmpdir, klass): + input_code = f"""import lxml + +parser = lxml.etree.{klass}() +var = "hello" +""" + expexted_output = f"""import lxml + +parser = lxml.etree.{klass}(resolve_entities=False) +var = "hello" +""" + + self.run_and_assert(tmpdir, input_code, expexted_output) + + @each_class + def test_from_import(self, tmpdir, klass): + input_code = f"""from lxml.etree import {klass} + +parser = {klass}() +var = "hello" +""" + expexted_output = f"""from lxml.etree import {klass} + +parser = {klass}(resolve_entities=False) +var = "hello" +""" + + self.run_and_assert(tmpdir, input_code, expexted_output) + + @each_class + def test_from_import_module(self, tmpdir, klass): + input_code = f"""from lxml import etree + +parser = etree.{klass}() +var = "hello" +""" + expexted_output = f"""from lxml import etree + +parser = etree.{klass}(resolve_entities=False) +var = "hello" +""" + + self.run_and_assert(tmpdir, input_code, expexted_output) + + @each_class + def test_import_alias(self, tmpdir, klass): + input_code = f"""from lxml.etree import {klass} as xmlklass + +parser = xmlklass() +var = "hello" +""" + expexted_output = f"""from lxml.etree import {klass} as xmlklass + +parser = xmlklass(resolve_entities=False) +var = "hello" +""" + + self.run_and_assert(tmpdir, input_code, expexted_output) + + @pytest.mark.parametrize( + "input_args,expected_args", + [ + ( + "resolve_entities=True", + "resolve_entities=False", + ), + ( + "resolve_entities=False", + "resolve_entities=False", + ), + ( + """resolve_entities=True, no_network=False, dtd_validation=True""", + """resolve_entities=False, no_network=True, dtd_validation=False""", + ), + ( + """dtd_validation=True""", + """dtd_validation=False, resolve_entities=False""", + ), + ( + """no_network=False""", + """no_network=True, resolve_entities=False""", + ), + ( + """no_network=True""", + """no_network=True, resolve_entities=False""", + ), + ], + ) + @each_class + def test_verify_variations(self, tmpdir, klass, input_args, expected_args): + input_code = f"""import lxml +parser = lxml.etree.{klass}({input_args}) +var = "hello" +""" + expexted_output = f"""import lxml +parser = lxml.etree.{klass}({expected_args}) +var = "hello" +""" + self.run_and_assert(tmpdir, input_code, expexted_output) diff --git a/tests/samples/lxml_parser.py b/tests/samples/lxml_parser.py new file mode 100644 index 00000000..4458e9c0 --- /dev/null +++ b/tests/samples/lxml_parser.py @@ -0,0 +1,2 @@ +import lxml +parser = lxml.etree.XMLParser()