From d38b12bc38b21e5eaa25fedba0eb8cd8d0fedabb Mon Sep 17 00:00:00 2001 From: Daniel D'Avella Date: Tue, 5 Mar 2024 15:25:56 -0500 Subject: [PATCH] Refactor to share code for import modifier codemods --- .../codemods/import_modifier_codemod.py | 70 +++++++++++++++++ src/core_codemods/api/__init__.py | 2 +- src/core_codemods/api/core_codemod.py | 7 ++ src/core_codemods/harden_pickle_load.py | 75 +++---------------- src/core_codemods/use_defused_xml.py | 59 ++------------- 5 files changed, 97 insertions(+), 116 deletions(-) create mode 100644 src/codemodder/codemods/import_modifier_codemod.py diff --git a/src/codemodder/codemods/import_modifier_codemod.py b/src/codemodder/codemods/import_modifier_codemod.py new file mode 100644 index 000000000..f29d268a7 --- /dev/null +++ b/src/codemodder/codemods/import_modifier_codemod.py @@ -0,0 +1,70 @@ +from abc import ABCMeta, abstractmethod +from typing import Mapping + +import libcst as cst +from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor + +from codemodder.codemods.api import SimpleCodemod +from codemodder.codemods.imported_call_modifier import ImportedCallModifier +from codemodder.dependency import Dependency + + +class MappingImportedCallModifier(ImportedCallModifier[Mapping[str, str]]): + def update_attribute(self, true_name, original_node, updated_node, new_args): + if not self.node_is_selected(original_node): + return updated_node + + import_name = self.matching_functions[true_name] + AddImportsVisitor.add_needed_import(self.context, import_name) + RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) + return updated_node.with_changes( + args=new_args, + func=cst.Attribute( + value=cst.parse_expression(import_name), + attr=cst.Name(value=true_name.split(".")[-1]), + ), + ) + + def update_simple_name(self, true_name, original_node, updated_node, new_args): + if not self.node_is_selected(original_node): + return updated_node + + import_name = self.matching_functions[true_name] + AddImportsVisitor.add_needed_import(self.context, import_name) + RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) + return updated_node.with_changes( + args=new_args, + func=cst.Attribute( + value=cst.parse_expression(import_name), + attr=cst.Name(value=true_name.split(".")[-1]), + ), + ) + + +class ImportModifierCodemod(SimpleCodemod, metaclass=ABCMeta): + @property + def dependency(self) -> Dependency | None: + return None + + @property + @abstractmethod + def mapping(self) -> Mapping[str, str]: + pass + + def transform_module_impl(self, tree: cst.Module) -> cst.Module: + if not self.node_is_selected(tree): + return tree + + visitor = MappingImportedCallModifier( + self.context, + self.file_context, + self.mapping, + self.change_description, + self.results, + ) + result_tree = visitor.transform_module(tree) + self.file_context.codemod_changes.extend(visitor.changes_in_file) + if visitor.changes_in_file and (dependency := self.dependency): + self.add_dependency(dependency) + + return result_tree diff --git a/src/core_codemods/api/__init__.py b/src/core_codemods/api/__init__.py index c2762dda9..847ae4f71 100644 --- a/src/core_codemods/api/__init__.py +++ b/src/core_codemods/api/__init__.py @@ -1,4 +1,4 @@ # ruff: noqa: F401 from codemodder.codemods.api import Metadata, Reference, ReviewGuidance -from .core_codemod import CoreCodemod, SimpleCodemod +from .core_codemod import CoreCodemod, ImportModifierCodemod, SimpleCodemod diff --git a/src/core_codemods/api/core_codemod.py b/src/core_codemods/api/core_codemod.py index 15e5a8669..7f6246149 100644 --- a/src/core_codemods/api/core_codemod.py +++ b/src/core_codemods/api/core_codemod.py @@ -5,6 +5,9 @@ from codemodder.codemods.base_codemod import Metadata from codemodder.codemods.base_detector import BaseDetector from codemodder.codemods.base_transformer import BaseTransformerPipeline +from codemodder.codemods.import_modifier_codemod import ( + ImportModifierCodemod as _ImportModifierCodemod, +) from codemodder.context import CodemodExecutionContext @@ -52,3 +55,7 @@ class SimpleCodemod(_SimpleCodemod): """ codemod_base = CoreCodemod + + +class ImportModifierCodemod(SimpleCodemod, _ImportModifierCodemod): + pass diff --git a/src/core_codemods/harden_pickle_load.py b/src/core_codemods/harden_pickle_load.py index f9a582256..f08d300dd 100644 --- a/src/core_codemods/harden_pickle_load.py +++ b/src/core_codemods/harden_pickle_load.py @@ -1,51 +1,10 @@ from typing import Mapping -import libcst as cst -from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor +from codemodder.dependency import Dependency, Fickling +from core_codemods.api import ImportModifierCodemod, Metadata, Reference, ReviewGuidance -from core_codemods.api import ( - SimpleCodemod, - Metadata, - Reference, - ReviewGuidance, -) -from codemodder.codemods.imported_call_modifier import ImportedCallModifier -from codemodder.dependency import Fickling - -class HardenPickleModifier(ImportedCallModifier[Mapping[str, str]]): - def update_attribute(self, true_name, original_node, updated_node, new_args): - if not self.node_is_selected(original_node): - return updated_node - - import_name = self.matching_functions[true_name] - AddImportsVisitor.add_needed_import(self.context, import_name) - RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) - return updated_node.with_changes( - args=new_args, - func=cst.Attribute( - value=cst.parse_expression(import_name), - attr=cst.Name(value=true_name.split(".")[-1]), - ), - ) - - def update_simple_name(self, true_name, original_node, updated_node, new_args): - if not self.node_is_selected(original_node): - return updated_node - - import_name = self.matching_functions[true_name] - AddImportsVisitor.add_needed_import(self.context, import_name) - RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) - return updated_node.with_changes( - args=new_args, - func=cst.Attribute( - value=cst.parse_expression(import_name), - attr=cst.Name(value=true_name.split(".")[-1]), - ), - ) - - -class HardenPickleLoad(SimpleCodemod): +class HardenPickleLoad(ImportModifierCodemod): metadata = Metadata( name="harden-pickle-load", summary="Harden `pickle.load()` against deserialization attacks", @@ -66,23 +25,13 @@ class HardenPickleLoad(SimpleCodemod): change_description = "Harden `pickle.load()` against deserialization attacks" - def transform_module_impl(self, tree: cst.Module) -> cst.Module: - if not self.node_is_selected(tree): - return tree - - visitor = HardenPickleModifier( - self.context, - self.file_context, - # NOTE: the fickling api doesn't seem to support `loads` yet - { - "pickle.load": "fickling", - }, - self.change_description, - self.results, - ) - result_tree = visitor.transform_module(tree) - self.file_context.codemod_changes.extend(visitor.changes_in_file) - if visitor.changes_in_file: - self.add_dependency(Fickling) + @property + def dependency(self) -> Dependency: + return Fickling - return result_tree + @property + def mapping(self) -> Mapping[str, str]: + # NOTE: the fickling api doesn't seem to support `loads` yet + return { + "pickle.load": "fickling", + } diff --git a/src/core_codemods/use_defused_xml.py b/src/core_codemods/use_defused_xml.py index 615c498ca..0d21226ea 100644 --- a/src/core_codemods/use_defused_xml.py +++ b/src/core_codemods/use_defused_xml.py @@ -1,39 +1,7 @@ from functools import cached_property -from typing import Mapping - -import libcst as cst -from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor - -from codemodder.codemods.imported_call_modifier import ImportedCallModifier -from codemodder.dependency import DefusedXML -from core_codemods.api import Metadata, Reference, ReviewGuidance, SimpleCodemod - - -class DefusedXmlModifier(ImportedCallModifier[Mapping[str, str]]): - def update_attribute(self, true_name, original_node, updated_node, new_args): - import_name = self.matching_functions[true_name] - AddImportsVisitor.add_needed_import(self.context, import_name) - RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) - return updated_node.with_changes( - args=new_args, - func=cst.Attribute( - value=cst.parse_expression(import_name), - attr=cst.Name(value=true_name.split(".")[-1]), - ), - ) - - def update_simple_name(self, true_name, original_node, updated_node, new_args): - import_name = self.matching_functions[true_name] - AddImportsVisitor.add_needed_import(self.context, import_name) - RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) - return updated_node.with_changes( - args=new_args, - func=cst.Attribute( - value=cst.parse_expression(import_name), - attr=cst.Name(value=true_name.split(".")[-1]), - ), - ) +from codemodder.dependency import DefusedXML, Dependency +from core_codemods.api import ImportModifierCodemod, Metadata, Reference, ReviewGuidance ETREE_METHODS = ["parse", "fromstring", "iterparse", "XMLParser"] SAX_METHODS = ["parse", "make_parser", "parseString"] @@ -41,7 +9,7 @@ def update_simple_name(self, true_name, original_node, updated_node, new_args): # TODO: add expat methods? -class UseDefusedXml(SimpleCodemod): +class UseDefusedXml(ImportModifierCodemod): metadata = Metadata( name="use-defusedxml", summary="Use `defusedxml` for Parsing XML", @@ -63,7 +31,7 @@ class UseDefusedXml(SimpleCodemod): change_description = "Replace builtin XML method with safe `defusedxml` method" @cached_property - def matching_functions(self) -> dict[str, str]: + def mapping(self) -> dict[str, str]: """Build a mapping of functions to their defusedxml imports""" _matching_functions: dict[str, str] = {} for module, defusedxml, methods in [ @@ -78,19 +46,6 @@ def matching_functions(self) -> dict[str, str]: ) return _matching_functions - def transform_module_impl(self, tree: cst.Module) -> cst.Module: - if not self.node_is_selected(tree): - return tree - - visitor = DefusedXmlModifier( - self.context, - self.file_context, - self.matching_functions, - self.change_description, - ) - result_tree = visitor.transform_module(tree) - self.file_context.codemod_changes.extend(visitor.changes_in_file) - if visitor.changes_in_file: - self.add_dependency(DefusedXML) - - return result_tree + @property + def dependency(self) -> Dependency: + return DefusedXML