diff --git a/src/codemodder/codemods/api/__init__.py b/src/codemodder/codemods/api/__init__.py index 36b1e706c..0a05313bc 100644 --- a/src/codemodder/codemods/api/__init__.py +++ b/src/codemodder/codemods/api/__init__.py @@ -135,3 +135,8 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): def leave_Assign(self, original_node, updated_node): return self._new_or_updated_node(original_node, updated_node) + + def leave_ClassDef( + self, original_node: cst.ClassDef, updated_node: cst.ClassDef + ) -> cst.ClassDef: + return self._new_or_updated_node(original_node, updated_node) diff --git a/src/core_codemods/harden_pyyaml.py b/src/core_codemods/harden_pyyaml.py index d35236102..b1001ca4d 100644 --- a/src/core_codemods/harden_pyyaml.py +++ b/src/core_codemods/harden_pyyaml.py @@ -1,3 +1,6 @@ +from typing import Union +import libcst as cst +from libcst import matchers from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import SemgrepCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin @@ -6,8 +9,8 @@ class HardenPyyaml(SemgrepCodemod, NameResolutionMixin): NAME = "harden-pyyaml" REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Use SafeLoader in `yaml.load()` Calls" - DESCRIPTION = "Ensures all calls to yaml.load use `SafeLoader`." + SUMMARY = "Replace unsafe `pyyaml` loader with `SafeLoader`" + DESCRIPTION = "Replace unsafe `pyyaml` loader with `SafeLoader` in calls to `yaml.load` or custom loader classes." REFERENCES = [ { "url": "https://owasp.org/www-community/vulnerabilities/Deserialization_of_untrusted_data", @@ -50,18 +53,76 @@ def rule(cls): - pattern: yaml.BaseLoader - pattern: yaml.FullLoader - pattern: yaml.UnsafeLoader + - patterns: + - pattern: | + class $X(...,$LOADER, ...): + ... + - metavariable-pattern: + metavariable: $LOADER + patterns: + - pattern-either: + - pattern: yaml.Loader + - pattern: yaml.BaseLoader + - pattern: yaml.FullLoader + - pattern: yaml.UnsafeLoader """ - def on_result_found(self, original_node, updated_node): - maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) - maybe_name = maybe_name or self._module_name - if maybe_name == self._module_name: - self.add_needed_import(self._module_name) - new_args = [ - *updated_node.args[:1], - updated_node.args[1].with_changes( - value=self.parse_expression(f"{maybe_name}.SafeLoader") - ), + def on_result_found( + self, + original_node: Union[cst.Call, cst.ClassDef], + updated_node: Union[cst.Call, cst.ClassDef], + ): + # TODO: provide different change description for each case. + match original_node: + case cst.Call(): + maybe_name = self.get_aliased_prefix_name( + original_node, self._module_name + ) + maybe_name = maybe_name or self._module_name + if maybe_name == self._module_name: + self.add_needed_import(self._module_name) + new_args = [ + *updated_node.args[:1], + updated_node.args[1].with_changes( + value=self.parse_expression(f"{maybe_name}.SafeLoader") + ), + ] + return self.update_arg_target(updated_node, new_args) + case cst.ClassDef(): + return updated_node.with_changes( + bases=self._update_bases(original_node) + ) + return updated_node + + def _update_bases(self, original_node: cst.ClassDef) -> list[cst.Arg]: + new = [] + unsafe_name_matchers = ( + matchers.Name(value="UnsafeLoader") + | matchers.Name(value="Loader") + | matchers.Name(value="BaseLoader") + | matchers.Name(value="FullLoader") + ) + base_names = [ + f"yaml.{klas}" + for klas in ("UnsafeLoader", "Loader", "BaseLoader", "FullLoader") ] - return self.update_arg_target(updated_node, new_args) + for base_arg in original_node.bases: + base_name = self.find_base_name(base_arg.value) + if base_name not in base_names: + new.append(base_arg) + continue + + match base_arg.value: + case cst.Name(): + self.add_needed_import(self._module_name, "SafeLoader") + self.remove_unused_import(base_arg.value) + base_arg = base_arg.with_changes( + value=base_arg.value.with_changes(value="SafeLoader") + ) + case cst.Attribute(): + base_arg = base_arg.with_changes( + value=base_arg.value.with_changes(attr=cst.Name("SafeLoader")) + ) + new.append(base_arg) + return new diff --git a/tests/codemods/test_harden_pyyaml.py b/tests/codemods/test_harden_pyyaml.py index d5619b46e..1e30f2827 100644 --- a/tests/codemods/test_harden_pyyaml.py +++ b/tests/codemods/test_harden_pyyaml.py @@ -5,6 +5,7 @@ UNSAFE_LOADERS = yaml.loader.__all__.copy() # type: ignore UNSAFE_LOADERS.remove("SafeLoader") +loaders = pytest.mark.parametrize("loader", UNSAFE_LOADERS) class TestHardenPyyaml(BaseSemgrepCodemodTest): @@ -19,8 +20,9 @@ def test_safe_loader(self, tmpdir): deserialized_data = yaml.load(data, Loader=yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, input_code) + assert len(self.file_context.codemod_changes) == 0 - @pytest.mark.parametrize("loader", UNSAFE_LOADERS) + @loaders def test_all_unsafe_loaders_arg(self, tmpdir, loader): input_code = f"""import yaml data = b'!!python/object/apply:subprocess.Popen \\n- ls' @@ -32,8 +34,9 @@ def test_all_unsafe_loaders_arg(self, tmpdir, loader): deserialized_data = yaml.load(data, yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 - @pytest.mark.parametrize("loader", UNSAFE_LOADERS) + @loaders def test_all_unsafe_loaders_kwarg(self, tmpdir, loader): input_code = f"""import yaml data = b'!!python/object/apply:subprocess.Popen \\n- ls' @@ -45,6 +48,7 @@ def test_all_unsafe_loaders_kwarg(self, tmpdir, loader): deserialized_data = yaml.load(data, Loader=yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 def test_import_alias(self, tmpdir): input_code = """import yaml as yam @@ -60,6 +64,7 @@ def test_import_alias(self, tmpdir): deserialized_data = yam.load(data, Loader=yam.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 def test_preserve_custom_loader(self, tmpdir): expected = input_code = """ @@ -70,6 +75,7 @@ def test_preserve_custom_loader(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 0 def test_preserve_custom_loader_kwarg(self, tmpdir): expected = input_code = """ @@ -80,3 +86,116 @@ def test_preserve_custom_loader_kwarg(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 0 + + +class TestHardenPyyamlClassInherit(BaseSemgrepCodemodTest): + codemod = HardenPyyaml + + def test_safe_loader(self, tmpdir): + input_code = """\ + import yaml + + class MyCustomLoader(yaml.SafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + + """ + + self.run_and_assert(tmpdir, input_code, input_code) + assert len(self.file_context.codemod_changes) == 0 + + @loaders + def test_unsafe_loaders(self, tmpdir, loader): + input_code = f"""\ + import yaml + + class MyCustomLoader(yaml.{loader}): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + expected = """\ + import yaml + + class MyCustomLoader(yaml.SafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 + + def test_from_import(self, tmpdir): + input_code = """\ + from yaml import UnsafeLoader + + class MyCustomLoader(UnsafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + expected = """\ + from yaml import SafeLoader + + class MyCustomLoader(SafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 + + def test_import_alias(self, tmpdir): + input_code = """\ + import yaml as yam + + class MyCustomLoader(yam.UnsafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + expected = """\ + import yaml as yam + + class MyCustomLoader(yam.SafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 + + def test_multiple_bases(self, tmpdir): + input_code = """\ + from abc import ABC + import yaml as yam + from whatever import Loader + + class MyCustomLoader(ABC, yam.UnsafeLoader, Loader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + expected = """\ + from abc import ABC + import yaml as yam + from whatever import Loader + + class MyCustomLoader(ABC, yam.SafeLoader, Loader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 + + def test_different_yaml(self, tmpdir): + input_code = """\ + from yaml import UnsafeLoader + import whatever as yaml + + class MyLoader(UnsafeLoader, yaml.Loader): + ... + """ + expected = """\ + from yaml import SafeLoader + import whatever as yaml + + class MyLoader(SafeLoader, yaml.Loader): + ... + """ + self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1