Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alias detection solution for codemods #106

Merged
merged 4 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import libcst as cst
from libcst import MetadataDependent, matchers
from libcst.helpers import get_full_name_for_node
from libcst.metadata import Assignment, BaseAssignment, ImportAssignment, ScopeProvider
from libcst.metadata import (
Assignment,
BaseAssignment,
ImportAssignment,
ScopeProvider,
)
from libcst.metadata.scope_provider import GlobalScope


Expand Down Expand Up @@ -58,6 +63,43 @@ def _is_direct_call_from_imported_module(
return (import_node, alias)
return None

def get_imported_prefix(
self, node
) -> Optional[tuple[Union[cst.Import, cst.ImportFrom], cst.ImportAlias]]:
"""
Given a node representing an access, finds if any part of its prefix is imported.
Returns a import and import alias pair.
"""
for nodo in iterate_left_expressions(node):
match nodo:
case cst.Name() | cst.Attribute():
maybe_assignment = self.find_single_assignment(nodo)
if maybe_assignment and isinstance(
maybe_assignment, ImportAssignment
):
import_node = maybe_assignment.node
for alias in import_node.names:
if maybe_assignment.name in (
alias.evaluated_alias,
alias.evaluated_name,
):
return (import_node, alias)
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

surprised pylint nor mypy didn't complain but all returns of a func should return the same len, so this should be (None, None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type is annotated as Optional[Tuple[...]], that's the reason.
Returning (None, None) would mess with None checks since bool((None,None)) == True.


def get_aliased_prefix_name(self, node: cst.CSTNode, name: str) -> Optional[str]:
"""
Returns the alias of name if name is imported and used as a prefix for this node.
"""
maybe_import = self.get_imported_prefix(node)
maybe_name = None
if maybe_import and matchers.matches(maybe_import[0], matchers.Import()):
_, ia = maybe_import
imp_name = get_full_name_for_node(ia.name)
if imp_name == name and ia.asname:
# AsName is always a Name for ImportAlias
maybe_name = ia.asname.name.value
return maybe_name

def find_assignments(
self,
node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator],
Expand Down Expand Up @@ -110,10 +152,13 @@ def find_single_assignment(

def iterate_left_expressions(node: cst.BaseExpression):
yield node
if matchers.matches(node, matchers.Attribute()):
yield from iterate_left_expressions(node.value)
if matchers.matches(node, matchers.Call()):
yield from iterate_left_expressions(node.func)
match node:
case cst.Attribute():
yield from iterate_left_expressions(node.value)
case cst.Call():
yield from iterate_left_expressions(node.func)
case cst.Subscript():
yield from iterate_left_expressions(node.value)


def get_leftmost_expression(node: cst.BaseExpression) -> cst.BaseExpression:
Expand Down
16 changes: 13 additions & 3 deletions src/core_codemods/harden_pyyaml.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.utils_mixin import NameResolutionMixin


class HardenPyyaml(SemgrepCodemod):
class HardenPyyaml(SemgrepCodemod, NameResolutionMixin):
NAME = "harden-pyyaml"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
SUMMARY = "Use SafeLoader in `yaml.load()` Calls"
Expand All @@ -14,6 +15,8 @@ class HardenPyyaml(SemgrepCodemod):
}
]

_module_name = "yaml"

@classmethod
def rule(cls):
return """
Expand Down Expand Up @@ -44,6 +47,13 @@ def rule(cls):

"""

def on_result_found(self, _, updated_node):
new_args = [*updated_node.args[:1], self.parse_expression("yaml.SafeLoader")]
def on_result_found(self, original_node, updated_node):
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
maybe_name = maybe_name or self._module_name
if maybe_name == self._module_name:
self.add_needed_import(self._module_name)
new_args = [
*updated_node.args[:1],
self.parse_expression(f"{maybe_name}.SafeLoader"),
]
return self.update_arg_target(updated_node, new_args)
12 changes: 9 additions & 3 deletions src/core_codemods/tempfile_mktemp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.utils_mixin import NameResolutionMixin


class TempfileMktemp(SemgrepCodemod):
class TempfileMktemp(SemgrepCodemod, NameResolutionMixin):
NAME = "secure-tempfile"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
SUMMARY = "Upgrade and Secure Temp File Creation"
Expand All @@ -14,6 +15,8 @@ class TempfileMktemp(SemgrepCodemod):
}
]

_module_name = "tempfile"

@classmethod
def rule(cls):
return """
Expand All @@ -26,6 +29,9 @@ def rule(cls):
"""

def on_result_found(self, original_node, updated_node):
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
maybe_name = maybe_name or self._module_name
if maybe_name == self._module_name:
self.add_needed_import(self._module_name)
self.remove_unused_import(original_node)
self.add_needed_import("tempfile")
return self.update_call_target(updated_node, "tempfile", "mkstemp")
return self.update_call_target(updated_node, maybe_name, "mkstemp")
14 changes: 11 additions & 3 deletions src/core_codemods/upgrade_sslcontext_minimum_version.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.utils_mixin import NameResolutionMixin


class UpgradeSSLContextMinimumVersion(SemgrepCodemod):
class UpgradeSSLContextMinimumVersion(SemgrepCodemod, NameResolutionMixin):
NAME = "upgrade-sslcontext-minimum-version"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
SUMMARY = "Upgrade SSLContext Minimum Version"
Expand All @@ -19,6 +20,8 @@ class UpgradeSSLContextMinimumVersion(SemgrepCodemod):
},
]

_module_name = "ssl"

@classmethod
def rule(cls):
return """
Expand All @@ -45,6 +48,11 @@ def rule(cls):
"""

def on_result_found(self, original_node, updated_node):
maybe_name = self.get_aliased_prefix_name(
original_node.value, self._module_name
)
maybe_name = maybe_name or self._module_name
if maybe_name == self._module_name:
self.add_needed_import(self._module_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these new green lines added are the same (or almost?) in every codemod added so maybe make this a separate method to call to dedupe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These address a very specific problem that are shared by those codemods. I don't want to add it to the api because the solution needs NameResolutionMixin which depends on ScopeProvider.

Metadata is calculated at the start of every transform on a as-need basis (libcst looks at METADATA_DEPENDENCIES for that). Integrating it to the API means every codemod would calculate the ScopeProvider metadata, making it a bit more expensive.

Maybe creating a class for those types of codemods would be a good idea, but I'd rather do that when we try to rewrite those in pure libcst.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree the duplication here is not ideal. I think we need to revisit the whole concept of utility mixins for this reason, although I understand the reasons it was done this way.

self.remove_unused_import(original_node)
self.add_needed_import("ssl")
return self.update_assign_rhs(updated_node, "ssl.TLSVersion.TLSv1_2")
return self.update_assign_rhs(updated_node, f"{maybe_name}.TLSVersion.TLSv1_2")
7 changes: 4 additions & 3 deletions tests/codemods/test_harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,17 @@ def test_all_unsafe_loaders_kwarg(self, tmpdir, loader):
"""
self.run_and_assert(tmpdir, input_code, expected)

@pytest.mark.skip()
def test_import_alias(self, tmpdir):
input_code = """import yaml as yam
from yaml import Loader

data = b'!!python/object/apply:subprocess.Popen \\n- ls'
deserialized_data = yam.load(data, Loader=Loader)
"""
expected = """import yaml
expected = """import yaml as yam
from yaml import Loader

data = b'!!python/object/apply:subprocess.Popen \\n- ls'
deserialized_data = yaml.load(data, yaml.SafeLoader)
deserialized_data = yam.load(data, yam.SafeLoader)
"""
self.run_and_assert(tmpdir, input_code, expected)
3 changes: 1 addition & 2 deletions tests/codemods/test_harden_ruamel.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,13 @@ def test_unsafe_import(self, tmpdir, loader):
"""
self.run_and_assert(tmpdir, input_code, expected)

@pytest.mark.skip()
@pytest.mark.parametrize("loader", ["YAML(typ='base')", "YAML(typ='unsafe')"])
def test_import_alias(self, tmpdir, loader):
input_code = f"""from ruamel import yaml as yam
serializer = yam.{loader}
"""

expected = """import ruamel
expected = """from ruamel import yaml as yam
serializer = yam.YAML(typ="safe")
"""

Expand Down
12 changes: 12 additions & 0 deletions tests/codemods/test_https_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ def test_simple(self, tmpdir):
after = r"""import urllib3

urllib3.HTTPSConnectionPool("localhost", "80")
"""
self.run_and_assert(tmpdir, before, after)
assert len(self.file_context.codemod_changes) == 1

def test_module_alias(self, tmpdir):
before = r"""import urllib3 as module

module.HTTPConnectionPool("localhost", "80")
"""
after = r"""import urllib3 as module

module.HTTPSConnectionPool("localhost", "80")
"""
self.run_and_assert(tmpdir, before, after)
assert len(self.file_context.codemod_changes) == 1
Expand Down
15 changes: 13 additions & 2 deletions tests/codemods/test_tempfile_mktemp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from core_codemods.tempfile_mktemp import TempfileMktemp
from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest

Expand Down Expand Up @@ -50,7 +49,6 @@ def test_from_import(self, tmpdir):
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.skip()
def test_import_alias(self, tmpdir):
input_code = """import tempfile as _tempfile

Expand All @@ -61,6 +59,19 @@ def test_import_alias(self, tmpdir):

_tempfile.mkstemp()
var = "hello"
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_import_method_alias(self, tmpdir):
input_code = """from tempfile import mktemp as get_temp_file

get_temp_file()
var = "hello"
"""
expected_output = """import tempfile

tempfile.mkstemp()
var = "hello"
"""
self.run_and_assert(tmpdir, input_code, expected_output)

Expand Down
5 changes: 2 additions & 3 deletions tests/codemods/test_upgrade_sslcontext_minimum_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
]


class TestUpgradeSSLContextMininumVersion(BaseSemgrepCodemodTest):
class TestUpgradeSSLContextMinimumVersion(BaseSemgrepCodemodTest):
codemod = UpgradeSSLContextMinimumVersion

@pytest.mark.parametrize("version", INSECURE_VERSIONS)
Expand Down Expand Up @@ -64,10 +64,9 @@ def test_import_with_alias(self, tmpdir):
context.minimum_version = whatever.TLSVersion.SSLv3
"""
expected_output = """import ssl as whatever
import ssl

context = whatever.SSLContext()
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.minimum_version = whatever.TLSVersion.TLSv1_2
"""
self.run_and_assert(tmpdir, input_code, expected_output)

Expand Down
20 changes: 20 additions & 0 deletions tests/codemods/test_use_defused_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,26 @@ def test_etree_simple_call(self, tmpdir, module, method):
self.run_and_assert(tmpdir, original_code, new_code)
self.assert_dependency(DefusedXML)

@pytest.mark.parametrize("method", ETREE_METHODS)
def test_etree_module_alias(self, tmpdir, method):
original_code = f"""
import xml.etree.ElementTree as alias
import xml.etree.cElementTree as calias

et = alias.{method}('some.xml')
cet = calias.{method}('some.xml')
"""

new_code = f"""
import defusedxml.ElementTree

et = defusedxml.ElementTree.{method}('some.xml')
cet = defusedxml.ElementTree.{method}('some.xml')
"""

self.run_and_assert(tmpdir, original_code, new_code)
self.assert_dependency(DefusedXML)

@pytest.mark.parametrize("method", ETREE_METHODS)
@pytest.mark.parametrize("module", ["ElementTree", "cElementTree"])
def test_etree_attribute_call(self, tmpdir, module, method):
Expand Down
Loading