Skip to content

Commit

Permalink
Refactor to share code for import modifier codemods
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Mar 5, 2024
1 parent 1d40db1 commit d38b12b
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 116 deletions.
70 changes: 70 additions & 0 deletions src/codemodder/codemods/import_modifier_codemod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from abc import ABCMeta, abstractmethod
from typing import Mapping

import libcst as cst
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

from codemodder.codemods.api import SimpleCodemod
from codemodder.codemods.imported_call_modifier import ImportedCallModifier
from codemodder.dependency import Dependency


class MappingImportedCallModifier(ImportedCallModifier[Mapping[str, str]]):
def update_attribute(self, true_name, original_node, updated_node, new_args):
if not self.node_is_selected(original_node):
return updated_node

import_name = self.matching_functions[true_name]
AddImportsVisitor.add_needed_import(self.context, import_name)
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node)
return updated_node.with_changes(
args=new_args,
func=cst.Attribute(
value=cst.parse_expression(import_name),
attr=cst.Name(value=true_name.split(".")[-1]),
),
)

def update_simple_name(self, true_name, original_node, updated_node, new_args):
if not self.node_is_selected(original_node):
return updated_node

import_name = self.matching_functions[true_name]
AddImportsVisitor.add_needed_import(self.context, import_name)
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node)
return updated_node.with_changes(
args=new_args,
func=cst.Attribute(
value=cst.parse_expression(import_name),
attr=cst.Name(value=true_name.split(".")[-1]),
),
)


class ImportModifierCodemod(SimpleCodemod, metaclass=ABCMeta):
@property
def dependency(self) -> Dependency | None:
return None

@property
@abstractmethod
def mapping(self) -> Mapping[str, str]:
pass

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
if not self.node_is_selected(tree):
return tree

visitor = MappingImportedCallModifier(
self.context,
self.file_context,
self.mapping,
self.change_description,
self.results,
)
result_tree = visitor.transform_module(tree)
self.file_context.codemod_changes.extend(visitor.changes_in_file)
if visitor.changes_in_file and (dependency := self.dependency):
self.add_dependency(dependency)

return result_tree
2 changes: 1 addition & 1 deletion src/core_codemods/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# ruff: noqa: F401
from codemodder.codemods.api import Metadata, Reference, ReviewGuidance

from .core_codemod import CoreCodemod, SimpleCodemod
from .core_codemod import CoreCodemod, ImportModifierCodemod, SimpleCodemod
7 changes: 7 additions & 0 deletions src/core_codemods/api/core_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from codemodder.codemods.base_codemod import Metadata
from codemodder.codemods.base_detector import BaseDetector
from codemodder.codemods.base_transformer import BaseTransformerPipeline
from codemodder.codemods.import_modifier_codemod import (
ImportModifierCodemod as _ImportModifierCodemod,
)
from codemodder.context import CodemodExecutionContext


Expand Down Expand Up @@ -52,3 +55,7 @@ class SimpleCodemod(_SimpleCodemod):
"""

codemod_base = CoreCodemod


class ImportModifierCodemod(SimpleCodemod, _ImportModifierCodemod):
pass
75 changes: 12 additions & 63 deletions src/core_codemods/harden_pickle_load.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,10 @@
from typing import Mapping

import libcst as cst
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from codemodder.dependency import Dependency, Fickling
from core_codemods.api import ImportModifierCodemod, Metadata, Reference, ReviewGuidance

from core_codemods.api import (
SimpleCodemod,
Metadata,
Reference,
ReviewGuidance,
)
from codemodder.codemods.imported_call_modifier import ImportedCallModifier
from codemodder.dependency import Fickling


class HardenPickleModifier(ImportedCallModifier[Mapping[str, str]]):
def update_attribute(self, true_name, original_node, updated_node, new_args):
if not self.node_is_selected(original_node):
return updated_node

import_name = self.matching_functions[true_name]
AddImportsVisitor.add_needed_import(self.context, import_name)
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node)
return updated_node.with_changes(
args=new_args,
func=cst.Attribute(
value=cst.parse_expression(import_name),
attr=cst.Name(value=true_name.split(".")[-1]),
),
)

def update_simple_name(self, true_name, original_node, updated_node, new_args):
if not self.node_is_selected(original_node):
return updated_node

import_name = self.matching_functions[true_name]
AddImportsVisitor.add_needed_import(self.context, import_name)
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node)
return updated_node.with_changes(
args=new_args,
func=cst.Attribute(
value=cst.parse_expression(import_name),
attr=cst.Name(value=true_name.split(".")[-1]),
),
)


class HardenPickleLoad(SimpleCodemod):
class HardenPickleLoad(ImportModifierCodemod):
metadata = Metadata(
name="harden-pickle-load",
summary="Harden `pickle.load()` against deserialization attacks",
Expand All @@ -66,23 +25,13 @@ class HardenPickleLoad(SimpleCodemod):

change_description = "Harden `pickle.load()` against deserialization attacks"

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
if not self.node_is_selected(tree):
return tree

visitor = HardenPickleModifier(
self.context,
self.file_context,
# NOTE: the fickling api doesn't seem to support `loads` yet
{
"pickle.load": "fickling",
},
self.change_description,
self.results,
)
result_tree = visitor.transform_module(tree)
self.file_context.codemod_changes.extend(visitor.changes_in_file)
if visitor.changes_in_file:
self.add_dependency(Fickling)
@property
def dependency(self) -> Dependency:
return Fickling

return result_tree
@property
def mapping(self) -> Mapping[str, str]:
# NOTE: the fickling api doesn't seem to support `loads` yet
return {
"pickle.load": "fickling",
}
59 changes: 7 additions & 52 deletions src/core_codemods/use_defused_xml.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,15 @@
from functools import cached_property
from typing import Mapping

import libcst as cst
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

from codemodder.codemods.imported_call_modifier import ImportedCallModifier
from codemodder.dependency import DefusedXML
from core_codemods.api import Metadata, Reference, ReviewGuidance, SimpleCodemod


class DefusedXmlModifier(ImportedCallModifier[Mapping[str, str]]):
def update_attribute(self, true_name, original_node, updated_node, new_args):
import_name = self.matching_functions[true_name]
AddImportsVisitor.add_needed_import(self.context, import_name)
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node)
return updated_node.with_changes(
args=new_args,
func=cst.Attribute(
value=cst.parse_expression(import_name),
attr=cst.Name(value=true_name.split(".")[-1]),
),
)

def update_simple_name(self, true_name, original_node, updated_node, new_args):
import_name = self.matching_functions[true_name]
AddImportsVisitor.add_needed_import(self.context, import_name)
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node)
return updated_node.with_changes(
args=new_args,
func=cst.Attribute(
value=cst.parse_expression(import_name),
attr=cst.Name(value=true_name.split(".")[-1]),
),
)

from codemodder.dependency import DefusedXML, Dependency
from core_codemods.api import ImportModifierCodemod, Metadata, Reference, ReviewGuidance

ETREE_METHODS = ["parse", "fromstring", "iterparse", "XMLParser"]
SAX_METHODS = ["parse", "make_parser", "parseString"]
DOM_METHODS = ["parse", "parseString"]
# TODO: add expat methods?


class UseDefusedXml(SimpleCodemod):
class UseDefusedXml(ImportModifierCodemod):
metadata = Metadata(
name="use-defusedxml",
summary="Use `defusedxml` for Parsing XML",
Expand All @@ -63,7 +31,7 @@ class UseDefusedXml(SimpleCodemod):
change_description = "Replace builtin XML method with safe `defusedxml` method"

@cached_property
def matching_functions(self) -> dict[str, str]:
def mapping(self) -> dict[str, str]:
"""Build a mapping of functions to their defusedxml imports"""
_matching_functions: dict[str, str] = {}
for module, defusedxml, methods in [
Expand All @@ -78,19 +46,6 @@ def matching_functions(self) -> dict[str, str]:
)
return _matching_functions

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
if not self.node_is_selected(tree):
return tree

visitor = DefusedXmlModifier(
self.context,
self.file_context,
self.matching_functions,
self.change_description,
)
result_tree = visitor.transform_module(tree)
self.file_context.codemod_changes.extend(visitor.changes_in_file)
if visitor.changes_in_file:
self.add_dependency(DefusedXML)

return result_tree
@property
def dependency(self) -> Dependency:
return DefusedXML

0 comments on commit d38b12b

Please sign in to comment.