diff --git a/pyproject.toml b/pyproject.toml index 8ae2f13f..8bbd870b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ test = [ "types-mock==5.1.*", "django>=4,<6", "numpy~=1.26.0", + "flask_wtf~=1.2.0", ] complexity = [ "radon==6.0.*", diff --git a/src/codemodder/dependency.py b/src/codemodder/dependency.py index 67a44ea9..1d1179e2 100644 --- a/src/codemodder/dependency.py +++ b/src/codemodder/dependency.py @@ -37,6 +37,19 @@ def __hash__(self): return hash(self.requirement) +FLaskWTF = Dependency( + Requirement("flask-wtf~=1.2.0"), + description="""\ + This package integrates WTForms into Flask. WTForms provides data validation and and CSRF protection which helps harden applications. +""", + _license=License( + "BSD-3-Clause", + "https://opensource.org/license/BSD-3-clause/", + ), + oss_link="https://github.com/wtforms/flask-wtf/", + package_link="https://pypi.org/project/Flask-WTF/", +) + DefusedXML = Dependency( Requirement("defusedxml~=0.7.1"), description="""\ diff --git a/src/core_codemods/__init__.py b/src/core_codemods/__init__.py index 822f94cf..9530c879 100644 --- a/src/core_codemods/__init__.py +++ b/src/core_codemods/__init__.py @@ -44,6 +44,7 @@ from .remove_debug_breakpoint import RemoveDebugBreakpoint from .combine_startswith_endswith import CombineStartswithEndswith from .fix_deprecated_logging_warn import FixDeprecatedLoggingWarn +from .flask_enable_csrf_protection import FlaskEnableCSRFProtection registry = CodemodCollection( origin="pixee", @@ -94,5 +95,6 @@ RemoveDebugBreakpoint, CombineStartswithEndswith, FixDeprecatedLoggingWarn, + FlaskEnableCSRFProtection, ], ) diff --git a/src/core_codemods/flask_enable_csrf_protection.py b/src/core_codemods/flask_enable_csrf_protection.py new file mode 100644 index 00000000..de5d2ffa --- /dev/null +++ b/src/core_codemods/flask_enable_csrf_protection.py @@ -0,0 +1,68 @@ +from typing import Union +import libcst as cst +from codemodder.codemods.api import BaseCodemod +from codemodder.codemods.base_codemod import ReviewGuidance +from codemodder.codemods.utils_mixin import AncestorPatternsMixin, NameResolutionMixin +from codemodder.dependency import FLaskWTF + + +class FlaskEnableCSRFProtection( + BaseCodemod, NameResolutionMixin, AncestorPatternsMixin +): + NAME = "flask-enable-csrf-protection" + REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW + DESCRIPTION = "Uses CSRFProtect module to harden the app." + SUMMARY = "Enable CSRF protection globally for a Flask app." + REFERENCES = [ + {"url": "https://owasp.org/www-community/attacks/csrf", "description": ""}, + ] + + def leave_SimpleStatementLine( + self, + original_node: cst.SimpleStatementLine, + updated_node: cst.SimpleStatementLine, + ) -> Union[ + cst.BaseStatement, cst.FlattenSentinel[cst.BaseStatement], cst.RemovalSentinel + ]: + if self.filter_by_path_includes_or_excludes(self.node_position(original_node)): + match original_node.body: + case [cst.Assign(value=cst.Call() as call) as assign]: + base_name = self.find_base_name(call) + if base_name and base_name == "flask.Flask": + named_targets = self._find_named_targets(assign) + flows_into_csrf_protect = map( + self._flows_into_csrf_protect, named_targets + ) + if named_targets and not all(flows_into_csrf_protect): + self.add_needed_import("flask_wtf.csrf", "CSRFProtect") + self.add_dependency(FLaskWTF) + self.report_change(original_node) + return cst.FlattenSentinel( + [ + updated_node, + cst.parse_statement( + f"csrf = CSRFProtect({named_targets[0].value})" + ), + ] + ) + return updated_node + + def _find_named_targets(self, node: cst.Assign) -> list[cst.Name]: + all_names = [] + for at in node.targets: + match at: + case cst.AssignTarget(target=cst.Name() as target): + all_names.append(target) + return all_names + + def _flows_into_csrf_protect(self, name: cst.Name) -> bool: + accesses = self.find_accesses(name) + for access in accesses: + maybe_arg = self.is_argument_of_call(access.node) + maybe_call = self.get_parent(maybe_arg) if maybe_arg else None + if ( + maybe_call + and self.find_base_name(maybe_call) == "flask_wtf.csrf.CSRFProtect" + ): + return True + return False diff --git a/tests/samples/flask_enable_csrf_protection.py b/tests/samples/flask_enable_csrf_protection.py new file mode 100644 index 00000000..3610af90 --- /dev/null +++ b/tests/samples/flask_enable_csrf_protection.py @@ -0,0 +1,4 @@ +from flask import Flask +from flask_wtf.csrf import CSRFProtect + +app = Flask(__name__)