Skip to content

Commit

Permalink
Initial implementation of flask-enable-csrf-protection
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva authored and drdavella committed Jan 17, 2024
1 parent 65878b1 commit 2f30748
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ test = [
"types-mock==5.1.*",
"django>=4,<6",
"numpy~=1.26.0",
"flask_wtf~=1.2.0",
]
complexity = [
"radon==6.0.*",
Expand Down
13 changes: 13 additions & 0 deletions src/codemodder/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def __hash__(self):
return hash(self.requirement)


FLaskWTF = Dependency(
Requirement("flask-wtf~=1.2.0"),
description="""\
This package integrates WTForms into Flask. WTForms provides data validation and and CSRF protection which helps harden applications.
""",
_license=License(
"BSD-3-Clause",
"https://opensource.org/license/BSD-3-clause/",
),
oss_link="https://github.com/wtforms/flask-wtf/",
package_link="https://pypi.org/project/Flask-WTF/",
)

DefusedXML = Dependency(
Requirement("defusedxml~=0.7.1"),
description="""\
Expand Down
2 changes: 2 additions & 0 deletions src/core_codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .remove_debug_breakpoint import RemoveDebugBreakpoint
from .combine_startswith_endswith import CombineStartswithEndswith
from .fix_deprecated_logging_warn import FixDeprecatedLoggingWarn
from .flask_enable_csrf_protection import FlaskEnableCSRFProtection

registry = CodemodCollection(
origin="pixee",
Expand Down Expand Up @@ -94,5 +95,6 @@
RemoveDebugBreakpoint,
CombineStartswithEndswith,
FixDeprecatedLoggingWarn,
FlaskEnableCSRFProtection,
],
)
68 changes: 68 additions & 0 deletions src/core_codemods/flask_enable_csrf_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Union
import libcst as cst
from codemodder.codemods.api import BaseCodemod
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.utils_mixin import AncestorPatternsMixin, NameResolutionMixin
from codemodder.dependency import FLaskWTF


class FlaskEnableCSRFProtection(
BaseCodemod, NameResolutionMixin, AncestorPatternsMixin
):
NAME = "flask-enable-csrf-protection"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW
DESCRIPTION = "Uses CSRFProtect module to harden the app."
SUMMARY = "Enable CSRF protection globally for a Flask app."
REFERENCES = [
{"url": "https://owasp.org/www-community/attacks/csrf", "description": ""},
]

def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine,
updated_node: cst.SimpleStatementLine,
) -> Union[
cst.BaseStatement, cst.FlattenSentinel[cst.BaseStatement], cst.RemovalSentinel
]:
if self.filter_by_path_includes_or_excludes(self.node_position(original_node)):
match original_node.body:
case [cst.Assign(value=cst.Call() as call) as assign]:
base_name = self.find_base_name(call)
if base_name and base_name == "flask.Flask":
named_targets = self._find_named_targets(assign)
flows_into_csrf_protect = map(
self._flows_into_csrf_protect, named_targets
)
if named_targets and not all(flows_into_csrf_protect):
self.add_needed_import("flask_wtf.csrf", "CSRFProtect")
self.add_dependency(FLaskWTF)
self.report_change(original_node)
return cst.FlattenSentinel(
[
updated_node,
cst.parse_statement(
f"csrf = CSRFProtect({named_targets[0].value})"
),
]
)
return updated_node

def _find_named_targets(self, node: cst.Assign) -> list[cst.Name]:
all_names = []
for at in node.targets:
match at:
case cst.AssignTarget(target=cst.Name() as target):
all_names.append(target)
return all_names

def _flows_into_csrf_protect(self, name: cst.Name) -> bool:
accesses = self.find_accesses(name)
for access in accesses:
maybe_arg = self.is_argument_of_call(access.node)
maybe_call = self.get_parent(maybe_arg) if maybe_arg else None
if (
maybe_call
and self.find_base_name(maybe_call) == "flask_wtf.csrf.CSRFProtect"
):
return True
return False
4 changes: 4 additions & 0 deletions tests/samples/flask_enable_csrf_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from flask import Flask
from flask_wtf.csrf import CSRFProtect

app = Flask(__name__)

0 comments on commit 2f30748

Please sign in to comment.