Skip to content

Commit

Permalink
Added support for semicolon separated statements
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva authored and drdavella committed Jan 17, 2024
1 parent 2f30748 commit 9578c99
Showing 1 changed file with 51 additions and 21 deletions.
72 changes: 51 additions & 21 deletions src/core_codemods/flask_enable_csrf_protection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Optional, Union
import libcst as cst
from codemodder.codemods.api import BaseCodemod
from codemodder.codemods.base_codemod import ReviewGuidance
Expand All @@ -17,6 +17,20 @@ class FlaskEnableCSRFProtection(
{"url": "https://owasp.org/www-community/attacks/csrf", "description": ""},
]

def leave_SimpleStatementSuite(
self,
original_node: cst.SimpleStatementSuite,
updated_node: cst.SimpleStatementSuite,
) -> cst.BaseSuite:
if self.filter_by_path_includes_or_excludes(self.node_position(original_node)):
new_stmts = self._get_new_stmts(original_node)
if new_stmts:
self.add_needed_import("flask_wtf.csrf", "CSRFProtect")
self.add_dependency(FLaskWTF)
self.report_change(original_node)
return updated_node.with_changes(body=[*original_node.body, *new_stmts])
return updated_node

def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine,
Expand All @@ -25,28 +39,44 @@ def leave_SimpleStatementLine(
cst.BaseStatement, cst.FlattenSentinel[cst.BaseStatement], cst.RemovalSentinel
]:
if self.filter_by_path_includes_or_excludes(self.node_position(original_node)):
match original_node.body:
case [cst.Assign(value=cst.Call() as call) as assign]:
base_name = self.find_base_name(call)
if base_name and base_name == "flask.Flask":
named_targets = self._find_named_targets(assign)
flows_into_csrf_protect = map(
self._flows_into_csrf_protect, named_targets
)
if named_targets and not all(flows_into_csrf_protect):
self.add_needed_import("flask_wtf.csrf", "CSRFProtect")
self.add_dependency(FLaskWTF)
self.report_change(original_node)
return cst.FlattenSentinel(
[
updated_node,
cst.parse_statement(
f"csrf = CSRFProtect({named_targets[0].value})"
),
]
)
new_stmts = self._get_new_stmts(original_node)
if new_stmts:
self.add_needed_import("flask_wtf.csrf", "CSRFProtect")
self.add_dependency(FLaskWTF)
self.report_change(original_node)
if len(original_node.body) > 1:
return updated_node.with_changes(
body=[*original_node.body, *new_stmts]
)
return cst.FlattenSentinel(
(updated_node, cst.SimpleStatementLine(body=[new_stmts[0]]))
)
return updated_node

def _get_new_stmts(self, original_node):
new_stmts = []
for stmt in original_node.body:
if maybe_small_stmt := self._handle_statement(stmt):
new_stmts.append(maybe_small_stmt)
return new_stmts

def _handle_statement(self, stmt) -> Optional[cst.BaseSmallStatement]:
match stmt:
case cst.Assign(value=cst.Call() as call) as assign:
base_name = self.find_base_name(call)
if base_name and base_name == "flask.Flask":
named_targets = self._find_named_targets(assign)
flows_into_csrf_protect = map(
self._flows_into_csrf_protect, named_targets
)
if named_targets and not all(flows_into_csrf_protect):
new_stmt = cst.parse_statement(
f"csrf_{named_targets[0].value} = CSRFProtect({named_targets[0].value})"
)
new_stmt = cst.ensure_type(new_stmt, cst.SimpleStatementLine)
return new_stmt.body[0]
return None

def _find_named_targets(self, node: cst.Assign) -> list[cst.Name]:
all_names = []
for at in node.targets:
Expand Down

0 comments on commit 9578c99

Please sign in to comment.