Skip to content

Commit

Permalink
Tests and docs for flask-enable-csrf-protection
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva committed Jan 11, 2024
1 parent eea8940 commit 192c7f0
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 2 deletions.
2 changes: 1 addition & 1 deletion integration_tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _lines_from_codepath(code_path):
def _replace_lines_with(lines, replacements):
total_lines = len(lines)
for lineno, replacement in replacements:
if lineno > total_lines:
if lineno >= total_lines:
lines.extend(replacement)
continue
lines[lineno] = replacement
Expand Down
36 changes: 36 additions & 0 deletions integration_tests/test_flask_enable_csrf_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from core_codemods.flask_enable_csrf_protection import FlaskEnableCSRFProtection
from integration_tests.base_test import (
BaseIntegrationTest,
original_and_expected_from_code_path,
)


class TestFlaskEnableCSRFProtection(BaseIntegrationTest):
codemod = FlaskEnableCSRFProtection
code_path = "tests/samples/flask_enable_csrf_protection.py"
original_code, expected_new_code = original_and_expected_from_code_path(
code_path,
[
(1, """from flask_wtf.csrf import CSRFProtect\n"""),
(2, """\n"""),
(3, """app = Flask(__name__)\n"""),
(4, """csrf_app = CSRFProtect(app)\n"""),
],
)

# fmt: off
expected_diff =(
"""--- \n"""
"""+++ \n"""
"""@@ -1,3 +1,5 @@\n"""
""" from flask import Flask\n"""
"""+from flask_wtf.csrf import CSRFProtect\n"""
""" \n"""
""" app = Flask(__name__)\n"""
"""+csrf_app = CSRFProtect(app)\n"""
)
# fmt: on

expected_line_change = "3"
change_description = FlaskEnableCSRFProtection.CHANGE_DESCRIPTION
num_changed_files = 2
4 changes: 4 additions & 0 deletions src/codemodder/scripts/generate_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ class DocMetadata:
importance="Low",
guidance_explained="This change fixes deprecated uses and is safe.",
),
"flask-enable-csrf-protection": DocMetadata(
importance="High",
guidance_explained="Flask views may require proper handling of CSRF to function as expected and thus this change may break some views.",
),
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Flask views using `FlaskForm` have CSRF protection enabled by default. However other views may use AJAX to perform unsafe HTTP methods. FlaskWTF provides a way to enable CSRF protection globally for all views of a Flask app.
Our changes look something like this:

```diff
from flask import Flask
+from flask_wtf.csrf import CSRFProtect

app = Flask(__name__)
+csrf_app = CSRFProtect(app)
```
97 changes: 97 additions & 0 deletions tests/codemods/test_flask_enable_csrf_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from core_codemods.flask_enable_csrf_protection import FlaskEnableCSRFProtection
from tests.codemods.base_codemod_test import BaseCodemodTest
from textwrap import dedent


class TestFlaskEnableCSRFProtection(BaseCodemodTest):
codemod = FlaskEnableCSRFProtection

def test_name(self):
assert self.codemod.name() == "flask-enable-csrf-protection"

def test_simple(self, tmpdir):
input_code = """\
from flask import Flask
app = Flask(__name__)
"""
expected = """\
from flask import Flask
from flask_wtf.csrf import CSRFProtect
app = Flask(__name__)
csrf_app = CSRFProtect(app)
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_simple_alias(self, tmpdir):
input_code = """\
from flask import Flask as Flosk
app = Flosk(__name__)
"""
expected = """\
from flask import Flask as Flosk
from flask_wtf.csrf import CSRFProtect
app = Flosk(__name__)
csrf_app = CSRFProtect(app)
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_multiple(self, tmpdir):
input_code = """\
from flask import Flask
app = Flask(__name__)
app2 = Flask(__name__)
"""
expected = """\
from flask import Flask
from flask_wtf.csrf import CSRFProtect
app = Flask(__name__)
csrf_app = CSRFProtect(app)
app2 = Flask(__name__)
csrf_app2 = CSRFProtect(app2)
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 2

def test_multiple_inline(self, tmpdir):
input_code = """\
from flask import Flask
app = Flask(__name__); app2 = Flask(__name__)
"""
expected = """\
from flask import Flask
from flask_wtf.csrf import CSRFProtect
app = Flask(__name__); app2 = Flask(__name__); csrf_app = CSRFProtect(app); csrf_app2 = CSRFProtect(app2)
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_multiple_inline_suite(self, tmpdir):
input_code = """\
from flask import Flask
if True: app = Flask(__name__); app2 = Flask(__name__)
"""
expected = """\
from flask import Flask
from flask_wtf.csrf import CSRFProtect
if True: app = Flask(__name__); app2 = Flask(__name__); csrf_app = CSRFProtect(app); csrf_app2 = CSRFProtect(app2)
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_simple_protected(self, tmpdir):
input_code = """\
from flask import Flask
from flask_wtf.csrf import CSRFProtect
app = Flask(__name__)
csrf_app = CSRFProtect(app)
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0
1 change: 0 additions & 1 deletion tests/samples/flask_enable_csrf_protection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from flask import Flask
from flask_wtf.csrf import CSRFProtect

app = Flask(__name__)

0 comments on commit 192c7f0

Please sign in to comment.