Skip to content

Commit

Permalink
harden-pyyaml can detect inheriting unsafe loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Jan 9, 2024
1 parent 48a7831 commit 993c151
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 15 deletions.
5 changes: 5 additions & 0 deletions src/codemodder/codemods/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,8 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):

def leave_Assign(self, original_node, updated_node):
return self._new_or_updated_node(original_node, updated_node)

def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef:
return self._new_or_updated_node(original_node, updated_node)
80 changes: 67 additions & 13 deletions src/core_codemods/harden_pyyaml.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union
import libcst as cst
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.utils_mixin import NameResolutionMixin
Expand All @@ -6,8 +8,8 @@
class HardenPyyaml(SemgrepCodemod, NameResolutionMixin):
NAME = "harden-pyyaml"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
SUMMARY = "Use SafeLoader in `yaml.load()` Calls"
DESCRIPTION = "Ensures all calls to yaml.load use `SafeLoader`."
SUMMARY = "Replace unsafe `pyyaml` loader with `SafeLoader`"
DESCRIPTION = "Replace unsafe `pyyaml` loader with `SafeLoader` in calls to `yaml.load` or custom loader classes."
REFERENCES = [
{
"url": "https://owasp.org/www-community/vulnerabilities/Deserialization_of_untrusted_data",
Expand Down Expand Up @@ -50,18 +52,70 @@ def rule(cls):
- pattern: yaml.BaseLoader
- pattern: yaml.FullLoader
- pattern: yaml.UnsafeLoader
- patterns:
- pattern: |
class $X(...,$LOADER, ...):
...
- metavariable-pattern:
metavariable: $LOADER
patterns:
- pattern-either:
- pattern: yaml.Loader
- pattern: yaml.BaseLoader
- pattern: yaml.FullLoader
- pattern: yaml.UnsafeLoader
"""

def on_result_found(self, original_node, updated_node):
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
maybe_name = maybe_name or self._module_name
if maybe_name == self._module_name:
self.add_needed_import(self._module_name)
new_args = [
*updated_node.args[:1],
updated_node.args[1].with_changes(
value=self.parse_expression(f"{maybe_name}.SafeLoader")
),
def on_result_found(
self,
original_node: Union[cst.Call, cst.ClassDef],
updated_node: Union[cst.Call, cst.ClassDef],
):
# TODO: provide different change description for each case.
match original_node:
case cst.Call():
maybe_name = self.get_aliased_prefix_name(
original_node, self._module_name
)
maybe_name = maybe_name or self._module_name
if maybe_name == self._module_name:
self.add_needed_import(self._module_name)
new_args = [
*updated_node.args[:1],
updated_node.args[1].with_changes(
value=self.parse_expression(f"{maybe_name}.SafeLoader")
),
]
return self.update_arg_target(updated_node, new_args)
case cst.ClassDef():
return updated_node.with_changes(
bases=self._update_bases(original_node)
)
return updated_node

def _update_bases(self, original_node: cst.ClassDef) -> list[cst.Arg]:
new = []
base_names = [
f"yaml.{klas}"
for klas in ("UnsafeLoader", "Loader", "BaseLoader", "FullLoader")
]
return self.update_arg_target(updated_node, new_args)
for base_arg in original_node.bases:
base_name = self.find_base_name(base_arg.value)
if base_name not in base_names:
new.append(base_arg)
continue

match base_arg.value:
case cst.Name():
self.add_needed_import(self._module_name, "SafeLoader")
self.remove_unused_import(base_arg.value)
base_arg = base_arg.with_changes(
value=base_arg.value.with_changes(value="SafeLoader")
)
case cst.Attribute():
base_arg = base_arg.with_changes(
value=base_arg.value.with_changes(attr=cst.Name("SafeLoader"))
)
new.append(base_arg)
return new
123 changes: 121 additions & 2 deletions tests/codemods/test_harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

UNSAFE_LOADERS = yaml.loader.__all__.copy() # type: ignore
UNSAFE_LOADERS.remove("SafeLoader")
loaders = pytest.mark.parametrize("loader", UNSAFE_LOADERS)


class TestHardenPyyaml(BaseSemgrepCodemodTest):
Expand All @@ -19,8 +20,9 @@ def test_safe_loader(self, tmpdir):
deserialized_data = yaml.load(data, Loader=yaml.SafeLoader)
"""
self.run_and_assert(tmpdir, input_code, input_code)
assert len(self.file_context.codemod_changes) == 0

@pytest.mark.parametrize("loader", UNSAFE_LOADERS)
@loaders
def test_all_unsafe_loaders_arg(self, tmpdir, loader):
input_code = f"""import yaml
data = b'!!python/object/apply:subprocess.Popen \\n- ls'
Expand All @@ -32,8 +34,9 @@ def test_all_unsafe_loaders_arg(self, tmpdir, loader):
deserialized_data = yaml.load(data, yaml.SafeLoader)
"""
self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 1

@pytest.mark.parametrize("loader", UNSAFE_LOADERS)
@loaders
def test_all_unsafe_loaders_kwarg(self, tmpdir, loader):
input_code = f"""import yaml
data = b'!!python/object/apply:subprocess.Popen \\n- ls'
Expand All @@ -45,6 +48,7 @@ def test_all_unsafe_loaders_kwarg(self, tmpdir, loader):
deserialized_data = yaml.load(data, Loader=yaml.SafeLoader)
"""
self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 1

def test_import_alias(self, tmpdir):
input_code = """import yaml as yam
Expand All @@ -60,6 +64,7 @@ def test_import_alias(self, tmpdir):
deserialized_data = yam.load(data, Loader=yam.SafeLoader)
"""
self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 1

def test_preserve_custom_loader(self, tmpdir):
expected = input_code = """
Expand All @@ -70,6 +75,7 @@ def test_preserve_custom_loader(self, tmpdir):
"""

self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 0

def test_preserve_custom_loader_kwarg(self, tmpdir):
expected = input_code = """
Expand All @@ -80,3 +86,116 @@ def test_preserve_custom_loader_kwarg(self, tmpdir):
"""

self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 0


class TestHardenPyyamlClassInherit(BaseSemgrepCodemodTest):
codemod = HardenPyyaml

def test_safe_loader(self, tmpdir):
input_code = """\
import yaml
class MyCustomLoader(yaml.SafeLoader):
def __init__(self, *args, **kwargs):
super(MyCustomLoader, self).__init__(*args, **kwargs)
"""

self.run_and_assert(tmpdir, input_code, input_code)
assert len(self.file_context.codemod_changes) == 0

@loaders
def test_unsafe_loaders(self, tmpdir, loader):
input_code = f"""\
import yaml
class MyCustomLoader(yaml.{loader}):
def __init__(self, *args, **kwargs):
super(MyCustomLoader, self).__init__(*args, **kwargs)
"""
expected = """\
import yaml
class MyCustomLoader(yaml.SafeLoader):
def __init__(self, *args, **kwargs):
super(MyCustomLoader, self).__init__(*args, **kwargs)
"""
self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 1

def test_from_import(self, tmpdir):
input_code = """\
from yaml import UnsafeLoader
class MyCustomLoader(UnsafeLoader):
def __init__(self, *args, **kwargs):
super(MyCustomLoader, self).__init__(*args, **kwargs)
"""
expected = """\
from yaml import SafeLoader
class MyCustomLoader(SafeLoader):
def __init__(self, *args, **kwargs):
super(MyCustomLoader, self).__init__(*args, **kwargs)
"""
self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 1

def test_import_alias(self, tmpdir):
input_code = """\
import yaml as yam
class MyCustomLoader(yam.UnsafeLoader):
def __init__(self, *args, **kwargs):
super(MyCustomLoader, self).__init__(*args, **kwargs)
"""
expected = """\
import yaml as yam
class MyCustomLoader(yam.SafeLoader):
def __init__(self, *args, **kwargs):
super(MyCustomLoader, self).__init__(*args, **kwargs)
"""
self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 1

def test_multiple_bases(self, tmpdir):
input_code = """\
from abc import ABC
import yaml as yam
from whatever import Loader
class MyCustomLoader(ABC, yam.UnsafeLoader, Loader):
def __init__(self, *args, **kwargs):
super(MyCustomLoader, self).__init__(*args, **kwargs)
"""
expected = """\
from abc import ABC
import yaml as yam
from whatever import Loader
class MyCustomLoader(ABC, yam.SafeLoader, Loader):
def __init__(self, *args, **kwargs):
super(MyCustomLoader, self).__init__(*args, **kwargs)
"""
self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 1

def test_different_yaml(self, tmpdir):
input_code = """\
from yaml import UnsafeLoader
import whatever as yaml
class MyLoader(UnsafeLoader, yaml.Loader):
...
"""
expected = """\
from yaml import SafeLoader
import whatever as yaml
class MyLoader(SafeLoader, yaml.Loader):
...
"""
self.run_and_assert(tmpdir, input_code, expected)
assert len(self.file_context.codemod_changes) == 1

0 comments on commit 993c151

Please sign in to comment.