diff --git a/src/core_codemods/harden_pyyaml.py b/src/core_codemods/harden_pyyaml.py index 32806559..00c7f2a6 100644 --- a/src/core_codemods/harden_pyyaml.py +++ b/src/core_codemods/harden_pyyaml.py @@ -89,20 +89,34 @@ def on_result_found( ] return self.update_arg_target(updated_node, new_args) case cst.ClassDef(): - # todo: maybe name above - # todo: generalize method to update_bases - new = [] - for base_arg in original_node.bases: - if matchers.matches( - base_arg.value, - matchers.Attribute(value=matchers.Name(value="yaml")), - ): - base_arg = base_arg.with_changes( - value=base_arg.value.with_changes( - attr=cst.Name("SafeLoader") - ) - ) - new.append(base_arg) - - return updated_node.with_changes(bases=new) + return updated_node.with_changes( + bases=self._update_bases(original_node) + ) return updated_node + + def _update_bases(self, original_node: cst.ClassDef) -> list[cst.Arg]: + new = [] + unsafe_name_matchers = ( + matchers.Name(value="UnsafeLoader") + | matchers.Name(value="Loader") + | matchers.Name(value="BaseLoader") + | matchers.Name(value="FullLoader") + ) + for base_arg in original_node.bases: + if matchers.matches(base_arg.value, unsafe_name_matchers): + self.add_needed_import(self._module_name, "SafeLoader") + self.remove_unused_import(base_arg.value) + base_arg = base_arg.with_changes( + value=base_arg.value.with_changes(value="SafeLoader") + ) + if matchers.matches( + base_arg.value, + matchers.Attribute(value=matchers.Name(value="yaml")), + ) or matchers.matches( + base_arg.value, matchers.Attribute(attr=unsafe_name_matchers) + ): + base_arg = base_arg.with_changes( + value=base_arg.value.with_changes(attr=cst.Name("SafeLoader")) + ) + new.append(base_arg) + return new diff --git a/tests/codemods/test_harden_pyyaml.py b/tests/codemods/test_harden_pyyaml.py index 4ba0d64d..da0a0bb6 100644 --- a/tests/codemods/test_harden_pyyaml.py +++ b/tests/codemods/test_harden_pyyaml.py @@ -1,4 +1,3 @@ -from textwrap import dedent import pytest import yaml from core_codemods.harden_pyyaml import HardenPyyaml @@ -21,6 +20,7 @@ def test_safe_loader(self, tmpdir): deserialized_data = yaml.load(data, Loader=yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, input_code) + assert len(self.file_context.codemod_changes) == 0 @loaders def test_all_unsafe_loaders_arg(self, tmpdir, loader): @@ -34,6 +34,7 @@ def test_all_unsafe_loaders_arg(self, tmpdir, loader): deserialized_data = yaml.load(data, yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 @loaders def test_all_unsafe_loaders_kwarg(self, tmpdir, loader): @@ -47,6 +48,7 @@ def test_all_unsafe_loaders_kwarg(self, tmpdir, loader): deserialized_data = yaml.load(data, Loader=yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 def test_import_alias(self, tmpdir): input_code = """import yaml as yam @@ -62,6 +64,7 @@ def test_import_alias(self, tmpdir): deserialized_data = yam.load(data, Loader=yam.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 def test_preserve_custom_loader(self, tmpdir): expected = input_code = """ @@ -72,6 +75,7 @@ def test_preserve_custom_loader(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 0 def test_preserve_custom_loader_kwarg(self, tmpdir): expected = input_code = """ @@ -82,6 +86,7 @@ def test_preserve_custom_loader_kwarg(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 0 class TestHardenPyyamlClassInherit(BaseSemgrepCodemodTest): @@ -97,7 +102,8 @@ def __init__(self, *args, **kwargs): """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + self.run_and_assert(tmpdir, input_code, input_code) + assert len(self.file_context.codemod_changes) == 0 @loaders def test_unsafe_loaders(self, tmpdir, loader): @@ -108,29 +114,68 @@ class MyCustomLoader(yaml.{loader}): def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ - expected = f"""\ + expected = """\ import yaml class MyCustomLoader(yaml.SafeLoader): def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 + + def test_from_import(self, tmpdir): + input_code = """\ + from yaml import UnsafeLoader + + class MyCustomLoader(UnsafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + expected = """\ + from yaml import SafeLoader + + class MyCustomLoader(SafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 def test_import_alias(self, tmpdir): - input_code = """import yaml as yam -from yaml import Loader + input_code = """\ + import yaml as yam -data = b'!!python/object/apply:subprocess.Popen \\n- ls' -deserialized_data = yam.load(data, Loader=Loader) -""" - expected = """import yaml as yam -from yaml import Loader + class MyCustomLoader(yam.UnsafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + expected = """\ + import yaml as yam -data = b'!!python/object/apply:subprocess.Popen \\n- ls' -deserialized_data = yam.load(data, Loader=yam.SafeLoader) -""" + class MyCustomLoader(yam.SafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 - # todo test_import_alias - # todo: multiple class inheritance + def test_multiple_bases(self, tmpdir): + input_code = """\ + from abc import ABC + import yaml as yam + + class MyCustomLoader(ABC, yam.UnsafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + expected = """\ + from abc import ABC + import yaml as yam + + class MyCustomLoader(ABC, yam.SafeLoader): + def __init__(self, *args, **kwargs): + super(MyCustomLoader, self).__init__(*args, **kwargs) + """ + self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1