diff --git a/src/core_codemods/harden_pyyaml.py b/src/core_codemods/harden_pyyaml.py index 00c7f2a68..84a450cd7 100644 --- a/src/core_codemods/harden_pyyaml.py +++ b/src/core_codemods/harden_pyyaml.py @@ -73,6 +73,7 @@ def on_result_found( original_node: Union[cst.Call, cst.ClassDef], updated_node: Union[cst.Call, cst.ClassDef], ): + # TODO: provide different change description for each case. match original_node: case cst.Call(): maybe_name = self.get_aliased_prefix_name( @@ -102,7 +103,16 @@ def _update_bases(self, original_node: cst.ClassDef) -> list[cst.Arg]: | matchers.Name(value="BaseLoader") | matchers.Name(value="FullLoader") ) + base_names = [ + f"yaml.{klas}" + for klas in ("UnsafeLoader", "Loader", "BaseLoader", "FullLoader") + ] for base_arg in original_node.bases: + base_name = self.find_base_name(base_arg.value) + if base_name not in base_names: + new.append(base_arg) + continue + if matchers.matches(base_arg.value, unsafe_name_matchers): self.add_needed_import(self._module_name, "SafeLoader") self.remove_unused_import(base_arg.value) diff --git a/tests/codemods/test_harden_pyyaml.py b/tests/codemods/test_harden_pyyaml.py index da0a0bb6d..1e30f2827 100644 --- a/tests/codemods/test_harden_pyyaml.py +++ b/tests/codemods/test_harden_pyyaml.py @@ -164,18 +164,38 @@ def test_multiple_bases(self, tmpdir): input_code = """\ from abc import ABC import yaml as yam + from whatever import Loader - class MyCustomLoader(ABC, yam.UnsafeLoader): + class MyCustomLoader(ABC, yam.UnsafeLoader, Loader): def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ expected = """\ from abc import ABC import yaml as yam + from whatever import Loader - class MyCustomLoader(ABC, yam.SafeLoader): + class MyCustomLoader(ABC, yam.SafeLoader, Loader): def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ self.run_and_assert(tmpdir, input_code, expected) assert len(self.file_context.codemod_changes) == 1 + + def test_different_yaml(self, tmpdir): + input_code = """\ + from yaml import UnsafeLoader + import whatever as yaml + + class MyLoader(UnsafeLoader, yaml.Loader): + ... + """ + expected = """\ + from yaml import SafeLoader + import whatever as yaml + + class MyLoader(SafeLoader, yaml.Loader): + ... + """ + self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1