Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codemod: parameterization of file path in Flask's send-file #214

Merged
merged 5 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions integration_tests/test_replace_flask_send_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from core_codemods.replace_flask_send_file import (
ReplaceFlaskSendFile,
)
from integration_tests.base_test import (
BaseIntegrationTest,
original_and_expected_from_code_path,
)


class TestReplaceFlaskSendFile(BaseIntegrationTest):
codemod = ReplaceFlaskSendFile
code_path = "tests/samples/replace_flask_send_file.py"
original_code, expected_new_code = original_and_expected_from_code_path(
code_path,
[
(0, """from flask import Flask\n"""),
(1, """import flask\n"""),
(2, """from pathlib import Path\n"""),
(3, """\n"""),
(4, """app = Flask(__name__)\n"""),
(5, """\n"""),
(6, """@app.route("/uploads/<path:name>")\n"""),
(7, """def download_file(name):\n"""),
(
8,
""" return flask.send_from_directory((p := Path(f'path/to/{name}.txt')).parent, p.name)\n""",
),
],
)

# fmt: off
expected_diff =(
"""--- \n"""
"""+++ \n"""
"""@@ -1,7 +1,9 @@\n"""
"""-from flask import Flask, send_file\n"""
"""+from flask import Flask\n"""
"""+import flask\n"""
"""+from pathlib import Path\n"""
""" \n"""
""" app = Flask(__name__)\n"""
""" \n"""
""" @app.route("/uploads/<path:name>")\n"""
""" def download_file(name):\n"""
"""- return send_file(f'path/to/{name}.txt')\n"""
"""+ return flask.send_from_directory((p := Path(f'path/to/{name}.txt')).parent, p.name)\n"""

)
# fmt: on

expected_line_change = "7"
change_description = ReplaceFlaskSendFile.CHANGE_DESCRIPTION
num_changed_files = 1
2 changes: 2 additions & 0 deletions src/codemodder/codemods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
return infer_expression_type(node.left)
case cst.BinaryOperation(operator=cst.Add()):
return infer_expression_type(node.left) or infer_expression_type(node.right)
case cst.BinaryOperation(operator=cst.Modulo()):
return infer_expression_type(node.left) or infer_expression_type(node.right)

Check warning on line 50 in src/codemodder/codemods/utils.py

View check run for this annotation

Codecov / codecov/patch

src/codemodder/codemods/utils.py#L50

Added line #L50 was not covered by tests
case cst.IfExp():
if_true = infer_expression_type(node.body)
or_else = infer_expression_type(node.orelse)
Expand Down
74 changes: 72 additions & 2 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from typing import Any, Collection, Optional, Tuple, Union
import libcst as cst
from libcst import MetadataDependent, matchers
Expand All @@ -9,6 +10,7 @@
BuiltinAssignment,
ImportAssignment,
ParentNodeProvider,
Scope,
ScopeProvider,
)
from libcst.metadata.scope_provider import GlobalScope
Expand All @@ -17,7 +19,7 @@
class NameResolutionMixin(MetadataDependent):
METADATA_DEPENDENCIES: Tuple[Any, ...] = (ScopeProvider,)

def _find_imported_name(self, node: cst.Name):
def _find_imported_name(self, node: cst.Name) -> Optional[str]:
match self.find_single_assignment(node):
case ImportAssignment(
name=node.value,
Expand All @@ -40,7 +42,7 @@ def _find_imported_name(self, node: cst.Name):

return node.value

def find_base_name(self, node):
def find_base_name(self, node) -> Optional[str]:
"""
Given a node, resolve its name to its basest form.

Expand Down Expand Up @@ -160,6 +162,21 @@ def find_assignments(
return set(next(iter(scope.accesses[node])).referents)
return set()

def generate_available_name(self, node, preference: list[str]) -> str:
"""
Generate an available name within node's scope. It will check for availability the names of a given list in order. If the list is exausted, returns the first name of the form {name}_{count} such that name is the first name in the preference list.
"""
used_names = self.find_used_names_within_nodes_scope(node)
for name in preference:
if name not in used_names:
return name
count = 1
name = preference[0] + f"_{count}"
while name in used_names:
count += 1
name = preference[0] + f"_{count}"
return name

def find_used_names_in_module(self):
"""
Find all the used names in the scope of a libcst Module.
Expand All @@ -176,6 +193,59 @@ def find_used_names_in_module(self):
names.extend(visitor.names)
return names

def find_used_names_within_nodes_scope(self, node: cst.CSTNode) -> set[str]:
"""
Find all the names used within all the ancestor and descendent scopes of a given node's scope.
"""
# TODO support for global and nonlocal statements
scope = self.get_metadata(ScopeProvider, node, None)
return self.find_used_names_within_scope(scope) if scope else set()

def find_used_names_within_scope(self, scope: Scope) -> set[str]:
"""
Find all the names used within all the ancestor and descendent scopes for a given scope.
"""
related = itertools.chain(
self._find_ancestor_scopes(scope), self._find_descendent_scopes(scope)
)
names: set[str] = set()
for s in related:
names.update(self._find_used_names_scope_only(s))
return names

def _find_ancestor_scopes(self, scope: Scope) -> set[Scope]:
ancestors: set[Scope] = {scope}
current = scope
while not isinstance(current, GlobalScope):
current = current.parent
ancestors.add(current)
return ancestors

def _build_scopes_child_tree(self) -> dict[Scope, list[Scope]]:
all_scopes = {
scope
for scope in self.context.wrapper.resolve(ScopeProvider).values()
if scope
}
tree: dict[Scope, list[Scope]] = {k: [] for k in all_scopes if k}
for s in all_scopes:
if not isinstance(s, GlobalScope):
tree.get(s.parent, []).append(s)
return tree

def _find_descendent_scopes(self, scope: Scope):
tree = self._build_scopes_child_tree()
descendents = set()
stack = [scope]
while stack:
current = stack.pop()
descendents.update(tree[current])
stack.extend(tree[current])
return descendents

def _find_used_names_scope_only(self, scope: Scope) -> set[str]:
return {ass.name for ass in scope.assignments}

def find_global_scope(self):
"""Find the global scope for a libcst Module node."""
scopes = self.context.wrapper.resolve(ScopeProvider).values()
Expand Down
4 changes: 4 additions & 0 deletions src/codemodder/scripts/generate_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ class DocMetadata:
importance="High",
guidance_explained="Flask views may require proper handling of CSRF to function as expected and thus this change may break some views.",
),
"replace-flask-send-file": DocMetadata(
importance="Medium",
guidance_explained="We believe this change is safe and will not cause any issues.",
),
}


Expand Down
16 changes: 16 additions & 0 deletions src/codemodder/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Sequence
import libcst as cst
from functools import cache

Expand Down Expand Up @@ -52,3 +53,18 @@ def extract_targets_of_assignment(
if assignment.asname:
return [assignment.asname.name]
return []


def positional_to_keyword(
args: Sequence[cst.Arg], pos_to_keyword: list[str | None]
) -> list[cst.Arg]:
"""
Given a sequence of Args, converts all the positional arguments into keyword arguments according to a given map.
"""
new_args = []
for i, arg in enumerate(args):
if arg.keyword == None and pos_to_keyword[i] != None:
new_args.append(arg.with_changes(keyword=cst.Name(pos_to_keyword[i])))
else:
new_args.append(arg)
return new_args
2 changes: 2 additions & 0 deletions src/core_codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .combine_startswith_endswith import CombineStartswithEndswith
from .fix_deprecated_logging_warn import FixDeprecatedLoggingWarn
from .flask_enable_csrf_protection import FlaskEnableCSRFProtection
from .replace_flask_send_file import ReplaceFlaskSendFile

registry = CodemodCollection(
origin="pixee",
Expand Down Expand Up @@ -96,5 +97,6 @@
CombineStartswithEndswith,
FixDeprecatedLoggingWarn,
FlaskEnableCSRFProtection,
ReplaceFlaskSendFile,
],
)
19 changes: 19 additions & 0 deletions src/core_codemods/docs/pixee_python_replace-flask-send-file.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
The `Flask` `send_file` function from Flask is susceptible to a path traversal attack if its input is not properly validated.
In a path traversal attack, the malicious agent can craft a path containing special paths like `./` or `../` to resolve a file outside of the expected directory path. This potentially allows the agent to overwrite, delete or read arbitrary files. In the case of `flask.send_file`, the result is that a malicious user could potentially download sensitive files that exist on the filesystem where the application is being hosted.
Flask offers a native solution with the `flask.send_from_directory` function that validates the given path.

Our changes look something like this:

```diff
-from flask import Flask, send_file
+from flask import Flask
+import flask
+from pathlib import Path

app = Flask(__name__)

@app.route("/uploads/<path:name>")
def download_file(name):
- return send_file(f'path/to/{name}.txt')
+ return flask.send_from_directory((p := Path(f'path/to/{name}.txt')).parent, p.name)
```
117 changes: 117 additions & 0 deletions src/core_codemods/replace_flask_send_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import libcst as cst
from typing import Optional
from codemodder.codemods.api import BaseCodemod
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.utils import BaseType, infer_expression_type
from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin
from codemodder.utils.utils import positional_to_keyword


class ReplaceFlaskSendFile(BaseCodemod, NameAndAncestorResolutionMixin):
NAME = "replace-flask-send-file"
SUMMARY = "Replace unsafe usage of `flask.send_file`"
DESCRIPTION = SUMMARY
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
REFERENCES = [
{
"url": "https://flask.palletsprojects.com/en/3.0.x/api/#flask.send_from_directory",
"description": "",
},
{
"url": "https://owasp.org/www-community/attacks/Path_Traversal",
"description": "",
},
]

pos_to_key_map: list[str | None] = [
"mimetype",
"as_attachment",
"download_name",
"conditional",
"etag",
"last_modified",
"max_age",
]

def leave_Call(
self, original_node: cst.Call, updated_node: cst.Call
) -> cst.BaseExpression:
if self.filter_by_path_includes_or_excludes(original_node):
maybe_base_name = self.find_base_name(original_node)
if maybe_base_name and maybe_base_name == "flask.send_file":
maybe_tuple = self.parameterize_path(original_node.args[0])
if maybe_tuple:
new_args = [
maybe_tuple[0],
maybe_tuple[1],
*positional_to_keyword(
original_node.args[1:], self.pos_to_key_map
),
]
self.report_change(original_node)
self.add_needed_import("flask")
self.remove_unused_import(original_node)
new_func = cst.parse_expression("flask.send_from_directory")
return updated_node.with_changes(func=new_func, args=new_args)

return updated_node

def _wrap_in_path(self, expr) -> cst.Call:
self.add_needed_import("pathlib", "Path")
return cst.Call(func=cst.Name(value="Path"), args=[cst.Arg(expr)])

def _attribute_reference(self, expr, attribute: str) -> cst.Attribute:
return cst.Attribute(value=expr, attr=cst.Name(attribute))

def _build_args(self, expr):
return (
cst.Arg(self._attribute_reference(expr, "parent")),
cst.Arg(self._attribute_reference(expr, "name")),
)

def _build_args_with_named_expr(self, expr):
available_name = self.generate_available_name(expr, ["p"])
named_expr = cst.NamedExpr(
target=cst.Name(available_name),
value=expr,
lpar=[cst.LeftParen()],
rpar=[cst.RightParen()],
)
return (
cst.Arg(self._attribute_reference(named_expr, "parent")),
cst.Arg(self._attribute_reference(cst.Name(available_name), "name")),
)

def _build_args_with_path_and_named_expr(self, expr):
available_name = self.generate_available_name(expr, ["p"])
named_expr = cst.NamedExpr(
target=cst.Name(available_name),
value=self._wrap_in_path(expr),
lpar=[cst.LeftParen()],
rpar=[cst.RightParen()],
)
return (
cst.Arg(self._attribute_reference(named_expr, "parent")),
cst.Arg(self._attribute_reference(cst.Name(available_name), "name")),
)

def parameterize_path(self, arg: cst.Arg) -> Optional[tuple[cst.Arg, cst.Arg]]:
expr = self.resolve_expression(arg.value)
tipo = infer_expression_type(expr)
# is it a string?
# TODO support for infering types from string methods e.g. 'a'.capitalize()
match tipo:
case BaseType.STRING:
return self._build_args_with_path_and_named_expr(arg.value)

# is it a Path object?
# TODO support for identifying Path operators/function e.g. Path('1') / Path('2')
match expr:
case cst.Call():
base_name = self.find_base_name(expr)
if base_name and base_name == "pathlib.Path":
if arg.value is expr:
return self._build_args_with_named_expr(arg.value)
return self._build_args(arg.value)

return None
Loading
Loading