Skip to content

Commit

Permalink
do not change random yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Jan 5, 2024
1 parent f1a9f37 commit afa3e4a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/core_codemods/harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def _update_bases(self, original_node: cst.ClassDef) -> list[cst.Arg]:
| matchers.Name(value="FullLoader")
)
for base_arg in original_node.bases:
base_name = self.find_base_name(base_arg.value)
if "yaml" not in base_name:
new.append(base_arg)
continue

if matchers.matches(base_arg.value, unsafe_name_matchers):
self.add_needed_import(self._module_name, "SafeLoader")
self.remove_unused_import(base_arg.value)
Expand Down
24 changes: 22 additions & 2 deletions tests/codemods/test_harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,38 @@ def test_multiple_bases(self, tmpdir):
input_code = """\
from abc import ABC
import yaml as yam
from whatever import Loader
class MyCustomLoader(ABC, yam.UnsafeLoader):
class MyCustomLoader(ABC, yam.UnsafeLoader, Loader):
def __init__(self, *args, **kwargs):
super(MyCustomLoader, self).__init__(*args, **kwargs)
"""
expected = """\
from abc import ABC
import yaml as yam
from whatever import Loader
class MyCustomLoader(ABC, yam.SafeLoader):
class MyCustomLoader(ABC, yam.SafeLoader, Loader):
def __init__(self, *args, **kwargs):
super(MyCustomLoader, self).__init__(*args, **kwargs)
"""
self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 1

def test_different_yaml(self, tmpdir):
input_code = """\
from yaml import UnsafeLoader
import whatever as yaml
class MyLoader(UnsafeLoader, yaml.Loader):
...
"""
expected = """\
from yaml import SafeLoader
import whatever as yaml
class MyLoader(SafeLoader, yaml.Loader):
...
"""
self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 1

0 comments on commit afa3e4a

Please sign in to comment.