diff --git a/src/codemodder/utils/utils.py b/src/codemodder/utils/utils.py index 6ec7f6e6..4f223778 100644 --- a/src/codemodder/utils/utils.py +++ b/src/codemodder/utils/utils.py @@ -19,3 +19,20 @@ def true_value(node: cst.Name | cst.SimpleString) -> str | int | bool: return False return val return "" + + +def extract_targets_of_assignment( + assignment: cst.AnnAssign | cst.Assign | cst.WithItem | cst.NamedExpr, +) -> list[cst.BaseExpression]: + match assignment: + case cst.AnnAssign(): + if assignment.target: + return [assignment.target] + case cst.Assign(): + return [t.target for t in assignment.targets] + case cst.NamedExpr(): + return [assignment.target] + case cst.WithItem(): + if assignment.asname: + return [assignment.asname.name] + return [] diff --git a/src/core_codemods/secure_flask_session_config.py b/src/core_codemods/secure_flask_session_config.py index 72f6e391..0bc4c7ee 100644 --- a/src/core_codemods/secure_flask_session_config.py +++ b/src/core_codemods/secure_flask_session_config.py @@ -6,7 +6,7 @@ from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import BaseCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin -from codemodder.utils.utils import true_value +from codemodder.utils.utils import extract_targets_of_assignment, true_value from codemodder.codemods.base_visitor import BaseTransformer from codemodder.change import Change from codemodder.file_context import FileContext @@ -103,8 +103,12 @@ def _store_flask_app(self, original_node) -> None: flask_app_parent = self.get_metadata(ParentNodeProvider, original_node) match flask_app_parent: case cst.AnnAssign() | cst.Assign(): - flask_app_attr = flask_app_parent.targets[0].target - self.flask_app_name = flask_app_attr.value + targets = extract_targets_of_assignment(flask_app_parent) + # TODO: handle other assignments ex. l[0] = Flask(...) , a.b = Flask(...) + if targets and matchers.matches( + first_target := targets[0], matchers.Name() + ): + self.flask_app_name = first_target.value # def _remove_config(self, key): # try: @@ -163,18 +167,18 @@ def assign_node_with_secure_config( return updated_node def _is_config_update_call(self, original_node: cst.Call): - config = cst.Name(value="config") - app_name = cst.Name(value=self.flask_app_name) - app_config_node = cst.Attribute(value=app_name, attr=config) + config = matchers.Name(value="config") + app_name = matchers.Name(value=self.flask_app_name) + app_config_node = matchers.Attribute(value=app_name, attr=config) update = cst.Name(value="update") return matchers.matches( original_node.func, matchers.Attribute(value=app_config_node, attr=update) ) def _is_config_subscript(self, original_node: cst.Assign): - config = cst.Name(value="config") - app_name = cst.Name(value=self.flask_app_name) - app_config_node = cst.Attribute(value=app_name, attr=config) + config = matchers.Name(value="config") + app_name = matchers.Name(value=self.flask_app_name) + app_config_node = matchers.Attribute(value=app_name, attr=config) return matchers.matches( original_node.targets[0].target, matchers.Subscript(value=app_config_node) ) diff --git a/tests/codemods/test_secure_flask_session_config.py b/tests/codemods/test_secure_flask_session_config.py index f38b7501..0700e868 100644 --- a/tests/codemods/test_secure_flask_session_config.py +++ b/tests/codemods/test_secure_flask_session_config.py @@ -56,14 +56,14 @@ def test_from_import(self, tmpdir): input_code = """\ from flask import Flask - app = flask.Flask(__name__) + app = Flask(__name__) app.secret_key = "dev" app.config.update(SESSION_COOKIE_SECURE=False) """ expexted_output = """\ from flask import Flask - app = flask.Flask(__name__) + app = Flask(__name__) app.secret_key = "dev" app.config.update(SESSION_COOKIE_SECURE=True) """ @@ -88,6 +88,38 @@ def test_import_alias(self, tmpdir): self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) assert len(self.file_context.codemod_changes) == 1 + def test_annotated_assign(self, tmpdir): + input_code = """\ + import flask + app: flask.Flask = flask.Flask(__name__) + app.secret_key = "dev" + # more code + app.config.update(SESSION_COOKIE_SECURE=False) + """ + expexted_output = """\ + import flask + app: flask.Flask = flask.Flask(__name__) + app.secret_key = "dev" + # more code + app.config.update(SESSION_COOKIE_SECURE=True) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) + assert len(self.file_context.codemod_changes) == 1 + + def test_other_assignment_type(self, tmpdir): + input_code = """\ + import flask + class AppStore: + pass + store = AppStore() + store.app = flask.Flask(__name__) + store.app.secret_key = "dev" + # more code + store.app.config.update(SESSION_COOKIE_SECURE=False) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + @pytest.mark.parametrize( "config_lines,expected_config_lines", [