diff --git a/src/core_codemods/secure_flask_session_config.py b/src/core_codemods/secure_flask_session_config.py index 7944b8b7..a67d65bf 100644 --- a/src/core_codemods/secure_flask_session_config.py +++ b/src/core_codemods/secure_flask_session_config.py @@ -1,18 +1,23 @@ import libcst as cst -from libcst.codemod import Codemod, CodemodContext, ContextAwareVisitor -from libcst.metadata import ParentNodeProvider, ScopeProvider +from libcst.codemod import Codemod, CodemodContext, ContextAwareTransformer +from libcst.metadata import ParentNodeProvider, ScopeProvider, PositionProvider from libcst import matchers from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import BaseCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin - - -class SecureFlaskSessionConfig(BaseCodemod): - METADATA_DEPENDENCIES = BaseCodemod.METADATA_DEPENDENCIES + ( - ParentNodeProvider, - ScopeProvider, - ) +from codemodder.codemods.base_visitor import UtilsMixin +from codemodder.codemods.base_visitor import BaseTransformer +from codemodder.change import Change +from codemodder.file_context import FileContext +from typing import Union + + +class SecureFlaskSessionConfig(BaseCodemod, Codemod): + # METADATA_DEPENDENCIES = BaseCodemod.METADATA_DEPENDENCIES + ( + # ParentNodeProvider, + # ScopeProvider, + # ) NAME = "secure-flask-session-configuration" SUMMARY = "UTODO" REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_REVIEW @@ -24,234 +29,167 @@ class SecureFlaskSessionConfig(BaseCodemod): } ] - SECURE_SESSION_CONFIGS = dict( - # None value indicates unassigned, using default is safe - # values in order of precedence - SESSION_COOKIE_HTTPONLY=[None, True], - SESSION_COOKIE_SECURE=[True], - SESSION_COOKIE_SAMESITE=["Lax", "Strict"], - ) + METADATA_DEPENDENCIES = (PositionProvider,) def transform_module_impl(self, tree: cst.Module) -> cst.Module: - flask_visitor = FindConfigCalls(self.context) - tree.visit(flask_visitor) - if not flask_visitor.flask_app_name: - return tree + flask_codemod = FixFlaskConfig(self.context, self.file_context) + result_tree = flask_codemod.transform_module(tree) - if len(config_access := flask_visitor.config_access) == 0: - return self.insert_config_line_endof_mod( - tree, flask_visitor.flask_app_name, self.SECURE_SESSION_CONFIGS - ) - - configs_to_write = self.SECURE_SESSION_CONFIGS.copy() - - for config_line in config_access: - match config_line: - case cst.Call(): - # app.config.update(...) - # defined_configs = self._get_configs(config_line) - for arg in config_line.args: - key = arg.keyword.value - val = [true_value(arg.value)] - if key in self.SECURE_SESSION_CONFIGS and val not in self.SECURE_SESSION_CONFIGS[key]: - del configs_to_write[key] - secure_val = self.SECURE_SESSION_CONFIGS[key][0] - arg.with_changes(value=cst.parse_expression(f"{secure_val}")) - - case cst.Assign(): - # app.config['...'] - defined_config = self._get_config_from_slice(config_line) - - (key, vals), = defined_config.items() - - defined_val = true_value(vals[-1]) - if key in self.SECURE_SESSION_CONFIGS and defined_val not in self.SECURE_SESSION_CONFIGS[key]: - del configs_to_write[key] - secure_val = self.SECURE_SESSION_CONFIGS[key][0] - config_line.with_changes(value=cst.parse_expression(f"{secure_val}")) - - if isinstance(config_update := config_access[-1], cst.Call): - # If the last config access is of form `app.config.update...` - # reuse that line. - self.reuse_config_line(config_update, configs_to_write) + if not flask_codemod.flask_app_name: return tree - # todo: if there is an .update line, add values directly there - return self.insert_config_line_endof_mod( - tree, flask_visitor.flask_app_name, configs_to_write - ) - - def get_defined_configs(self, config_access): - all_defined_configs = {} - for config_line in config_access: - match config_line: - case cst.Call(): - # app.config.update(...) - defined_configs = self._get_configs(config_line) - all_defined_configs.update(defined_configs) - case cst.Assign(): - # app.config['...'] - defined_configs = self._get_config_from_slice(config_line) - all_defined_configs.update(defined_configs) - - return all_defined_configs - def _get_configs(self, config_line: cst.Call): - defined_configs = {} - for arg in config_line.args: - defined_configs[arg.keyword.value] = [true_value(arg.value)] - return defined_configs - - def _get_config_from_slice(self, config_line: cst.Assign): - defined_configs = {} - key = true_value(config_line.targets[0].target.slice[0].slice.value) - defined_configs[key] = [true_value(config_line.value)] - return defined_configs - - def reuse_config_line( - self, config_line: cst.Call, configs: dict - ) -> None: - if not configs: - return - # TODO: record change - # line_number is the end of the module where we will insert the new flag. - # pos_to_match = self.node_position(original_node) - # line_number = pos_to_match.end.line - # self.changes_in_file.append( - # Change(line_number, DjangoSessionCookieSecureOff.CHANGE_DESCRIPTION) - # ) - # self.file_context.codemod_changes.append( - # Change(line_number, self.CHANGE_DESCRIPTION) - # ) - # config_string = ", ".join( - # f"{key}='{value[0]}'" if isinstance(value[0], str) else f"{key}={value[0]}" - # for key, value in configs.items() - # if value and value[0] is not None - # ) - # final_line = cst.parse_statement(f"{app_name}.config.update({config_string})") - # new_body = original_node.body[:-1] + (final_line,) - from codemodder.codemods.api.helpers import NewArg - - to_add = [NewArg(name=key, value=str(vals[0]), add_if_missing=True) for key, vals in configs.items() if vals[0] is not None] - - new_args = self.replace_args( - config_line, to_add - ) - # self.update_arg_target(config_line, new_args) - config_line.with_changes(args=new_args) - def reuse_config_subscript_line( - self, original_node: cst.Module, app_name: str, configs: dict, defined_key: str + if flask_codemod.configs_to_write: + return self.insert_secure_configs( + tree, + result_tree, + flask_codemod.flask_app_name, + flask_codemod.configs_to_write, + ) + return result_tree + + def insert_secure_configs( + self, + original_node: cst.Module, + updated_node: cst.Module, + app_name: str, + configs: dict, ) -> cst.Module: if not configs: - return original_node - # TODO: record change - # line_number is the end of the module where we will insert the new flag. - # pos_to_match = self.node_position(original_node) - # line_number = pos_to_match.end.line - # self.changes_in_file.append( - # Change(line_number, DjangoSessionCookieSecureOff.CHANGE_DESCRIPTION) - # ) - # self.file_context.codemod_changes.append( - # Change(line_number, self.CHANGE_DESCRIPTION) - # ) - config_string = ", ".join( - f"{key}='{value[0]}'" if isinstance(value[0], str) else f"{key}={value[0]}" - for key, value in configs.items() - if value and value[0] is not None - ) - - final_line = cst.parse_statement(f"{app_name}.config.update({config_string})") - secure_val = ( - self.SECURE_SESSION_CONFIGS[defined_key][0] - or self.SECURE_SESSION_CONFIGS[defined_key][1] - ) - final_config_subscript_line = cst.parse_statement( - f"{app_name}.config['{defined_key}'] = {secure_val}" - ) - new_body = original_node.body[:-1] + ( - final_config_subscript_line, - final_line, - ) - return original_node.with_changes(body=new_body) + return updated_node - def insert_config_line_endof_mod( - self, original_node: cst.Module, app_name: str, configs: dict - ) -> cst.Module: - if not configs: - return original_node - # TODO: record change - # line_number is the end of the module where we will insert the new flag. - # pos_to_match = self.node_position(original_node) - # line_number = pos_to_match.end.line - # self.changes_in_file.append( - # Change(line_number, DjangoSessionCookieSecureOff.CHANGE_DESCRIPTION) - # ) - # self.file_context.codemod_changes.append( - # Change(line_number, self.CHANGE_DESCRIPTION) - # ) config_string = ", ".join( f"{key}='{value[0]}'" if isinstance(value[0], str) else f"{key}={value[0]}" for key, value in configs.items() if value and value[0] is not None ) if not config_string: - return original_node + return updated_node + + self.report_change_endof_module(original_node) final_line = cst.parse_statement(f"{app_name}.config.update({config_string})") - new_body = original_node.body + (final_line,) - return original_node.with_changes(body=new_body) + new_body = updated_node.body + (final_line,) + return updated_node.with_changes(body=new_body) + + def report_change_endof_module(self, original_node: cst.Module) -> None: + # line_number is the end of the module where we will insert the new line. + pos_to_match = self.node_position(original_node) + line_number = pos_to_match.end.line + self.file_context.codemod_changes.append( + Change(line_number, self.CHANGE_DESCRIPTION) + ) -class FindConfigCalls(ContextAwareVisitor, NameResolutionMixin): +class FixFlaskConfig(BaseTransformer, NameResolutionMixin): """ Visitor to find calls to flask.Flask and related `.config` accesses. """ - METADATA_DEPENDENCIES = (ParentNodeProvider,) + METADATA_DEPENDENCIES = (PositionProvider, ParentNodeProvider) + SECURE_SESSION_CONFIGS = dict( + # None value indicates unassigned, using default is safe + # values in order of precedence + SESSION_COOKIE_HTTPONLY=[None, True], + SESSION_COOKIE_SECURE=[True], + SESSION_COOKIE_SAMESITE=["Lax", "Strict"], + ) - def __init__(self, context: CodemodContext) -> None: - self.config_access: list = [] + def __init__(self, codemod_context: CodemodContext, file_context: FileContext): + super().__init__(codemod_context, []) self.flask_app_name = "" - super().__init__(context) - - def _find_config_accesses(self, flask_app_attr: cst.AnnAssign | cst.Assign): - assignments = self.find_assignments(flask_app_attr) - for assignment in assignments: - if assignment.references: - # Flask app instance is accessed - references_to_app = [x.node for x in assignment.references] - for node in references_to_app: - parent = self.get_metadata(ParentNodeProvider, node) - match parent: - case cst.Attribute(): - config = cst.Name(value="config") - if matchers.matches( - parent, matchers.Attribute(value=node, attr=config) - ): - gparent = self.get_metadata(ParentNodeProvider, parent) - ggparent = self.get_metadata( - ParentNodeProvider, gparent - ) - if matchers.matches(gparent, matchers.Subscript()): - gggparent = self.get_metadata( - ParentNodeProvider, ggparent - ) - self.config_access.append(gggparent) - else: - self.config_access.append(ggparent) - - def leave_Call(self, original_node: cst.Call) -> None: + self.configs_to_write = self.SECURE_SESSION_CONFIGS.copy() + self.file_context = file_context + + def _store_flask_app(self, original_node) -> None: + flask_app_parent = self.get_metadata(ParentNodeProvider, original_node) + match flask_app_parent: + case cst.AnnAssign() | cst.Assign(): + flask_app_attr = flask_app_parent.targets[0].target + self.flask_app_name = flask_app_attr.value + + def _remove_config(self, key): + try: + del self.configs_to_write[key] + except KeyError: + pass + + def _get_secure_config_val(self, key): + val = self.SECURE_SESSION_CONFIGS[key][0] or self.SECURE_SESSION_CONFIGS[key][1] + return cst.parse_expression(f'"{val}"' if isinstance(val, str) else f"{val}") + + @property + def flask_app_is_assigned(self): + return bool(self.flask_app_name) + + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): true_name = self.find_base_name(original_node.func) if true_name == "flask.Flask": - flask_app_parent = self.get_metadata(ParentNodeProvider, original_node) - match flask_app_parent: - case cst.AnnAssign() | cst.Assign(): - flask_app_attr = flask_app_parent.targets[0].target - self.flask_app_name = flask_app_attr.value - self._find_config_accesses(flask_app_attr) + self._store_flask_app(original_node) + + if self.flask_app_is_assigned and self._is_config_update_call(original_node): + return self.call_node_with_secure_configs(original_node, updated_node) + return updated_node + + def call_node_with_secure_configs( + self, original_node: cst.Call, updated_node: cst.Call + ) -> cst.Call: + new_args = [] + for arg in updated_node.args: + if (key := arg.keyword.value) in self.SECURE_SESSION_CONFIGS: + self._remove_config(key) + if true_value(arg.value) not in self.SECURE_SESSION_CONFIGS[key]: + safe_value = self._get_secure_config_val(key) + arg = arg.with_changes(value=safe_value) + new_args.append(arg) + + if updated_node.args != new_args: + self.report_change(original_node) + return updated_node.with_changes(args=new_args) + + def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign): + if self.flask_app_is_assigned and self._is_config_subscript(original_node): + return self.assign_node_with_secure_config(original_node, updated_node) + return updated_node + + def assign_node_with_secure_config( + self, original_node: cst.Assign, updated_node: cst.Assign + ) -> cst.Assign: + key = true_value(updated_node.targets[0].target.slice[0].slice.value) + if key in self.SECURE_SESSION_CONFIGS: + self._remove_config(key) + if true_value(updated_node.value) not in self.SECURE_SESSION_CONFIGS[key]: + safe_value = self._get_secure_config_val(key) + self.report_change(original_node) + return updated_node.with_changes(value=safe_value) + return updated_node + + def _is_config_update_call(self, original_node: cst.Call): + config = cst.Name(value="config") + app_name = cst.Name(value=self.flask_app_name) + app_config_node = cst.Attribute(value=app_name, attr=config) + update = cst.Name(value="update") + return matchers.matches( + original_node.func, matchers.Attribute(value=app_config_node, attr=update) + ) + + def _is_config_subscript(self, original_node: cst.Assign): + config = cst.Name(value="config") + app_name = cst.Name(value=self.flask_app_name) + app_config_node = cst.Attribute(value=app_name, attr=config) + return matchers.matches( + original_node.targets[0].target, matchers.Subscript(value=app_config_node) + ) - return original_node + def report_change(self, original_node): + # TODO: GET POS TO WORK + + # line_number = self.lineno_for_node(original_node) + line_number = self.lineno_for_node(original_node) + self.file_context.codemod_changes.append( + Change(line_number, SecureFlaskSessionConfig.CHANGE_DESCRIPTION) + ) -def true_value(node: cst.Name | cst.SimpleString): +def true_value(node: cst.Name | cst.SimpleString) -> str | int | bool: # todo: move to a more general util from codemodder.project_analysis.file_parsers.utils import clean_simplestring @@ -268,11 +206,5 @@ def true_value(node: cst.Name | cst.SimpleString): return True elif val.lower() == "false": return False - try: - return int(val) - except ValueError: - try: - return float(val) - except ValueError: - # If no conversion worked, return the original string - return val + return val + return "" diff --git a/tests/codemods/test_secure_flask_session_config.py b/tests/codemods/test_secure_flask_session_config.py index f33cf9b8..6c7fd2a8 100644 --- a/tests/codemods/test_secure_flask_session_config.py +++ b/tests/codemods/test_secure_flask_session_config.py @@ -19,6 +19,7 @@ def test_no_flask_app(self, tmpdir): response.set_cookie("name", "value") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 def test_app_not_accessed(self, tmpdir): input_code = """\ @@ -37,7 +38,7 @@ def test_app_not_accessed(self, tmpdir): app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax') """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - # assert len(self.file_context.codemod_changes) == 1 + assert len(self.file_context.codemod_changes) == 1 def test_app_defined_separate_module(self, tmpdir): # TODO: test this as an integration test with two real modules @@ -57,6 +58,7 @@ def test_app_not_assigned(self, tmpdir): print(1) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 def test_app_accessed_config_not_called(self, tmpdir): input_code = """\ @@ -75,7 +77,7 @@ def test_app_accessed_config_not_called(self, tmpdir): # more code """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - # assert len(self.file_context.codemod_changes) == 1 + assert len(self.file_context.codemod_changes) == 1 def test_from_import(self, tmpdir): input_code = """\ @@ -94,7 +96,7 @@ def test_from_import(self, tmpdir): # more code """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - # assert len(self.file_context.codemod_changes) == 1 + assert len(self.file_context.codemod_changes) == 1 def test_import_alias(self, tmpdir): input_code = f"""\ @@ -111,7 +113,7 @@ def test_import_alias(self, tmpdir): # more code """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - # assert len(self.file_context.codemod_changes) == 1 + assert len(self.file_context.codemod_changes) == 1 @pytest.mark.parametrize( "config_lines,expected_config_lines", @@ -135,17 +137,32 @@ def test_import_alias(self, tmpdir): """app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax')""", """app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax')""", ), + ( + """app.config.update(SECRET_KEY='123SOMEKEY') +var = 1""", + """app.config.update(SECRET_KEY='123SOMEKEY') +var = 1 +app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax')""", + ), ( """app.config.update(SECRET_KEY='123SOMEKEY')""", - """app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax', SECRET_KEY='123SOMEKEY')""", + """app.config.update(SECRET_KEY='123SOMEKEY') +app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax')""", ), ( """app.config.update(SESSION_COOKIE_SECURE=True)""", - """app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax')""", + """app.config.update(SESSION_COOKIE_SECURE=True) +app.config.update(SESSION_COOKIE_SAMESITE='Lax')""", ), ( """app.config.update(SESSION_COOKIE_HTTPONLY=True)""", - """app.config.update(SESSION_COOKIE_HTTPONLY=True, SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax')""", + """app.config.update(SESSION_COOKIE_HTTPONLY=True) +app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax')""", + ), + ( + """app.config.update(SESSION_COOKIE_HTTPONLY=False)""", + """app.config.update(SESSION_COOKIE_HTTPONLY=True) +app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax')""", ), ( """app.config['SESSION_COOKIE_SECURE'] = False""", @@ -158,40 +175,40 @@ def test_import_alias(self, tmpdir): app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax')""", ), ( - '''app.config["SESSION_COOKIE_SECURE"] = True + """app.config["SESSION_COOKIE_SECURE"] = True app.config["SESSION_COOKIE_SAMESITE"] = "Lax" -''', - '''app.config["SESSION_COOKIE_SECURE"] = True +""", + """app.config["SESSION_COOKIE_SECURE"] = True app.config["SESSION_COOKIE_SAMESITE"] = "Lax" -''', +""", ), ( - '''app.config["SESSION_COOKIE_SECURE"] = False + """app.config["SESSION_COOKIE_SECURE"] = False app.config["SESSION_COOKIE_SAMESITE"] = None -''', - '''app.config["SESSION_COOKIE_SECURE"] = True +""", + """app.config["SESSION_COOKIE_SECURE"] = True app.config["SESSION_COOKIE_SAMESITE"] = "Lax" -''', +""", + ), + ( + """app.config["SESSION_COOKIE_SECURE"] = False +app.config["SESSION_COOKIE_HTTPONLY"] = False +app.config["SESSION_COOKIE_SAMESITE"] = "Strict" +""", + """app.config["SESSION_COOKIE_SECURE"] = True +app.config["SESSION_COOKIE_HTTPONLY"] = True +app.config["SESSION_COOKIE_SAMESITE"] = "Strict" +""", + ), + ( + """app.config["SESSION_COOKIE_SECURE"] = False +app.config["SESSION_COOKIE_SECURE"] = True +""", + """app.config["SESSION_COOKIE_SECURE"] = True +app.config["SESSION_COOKIE_SECURE"] = True +app.config.update(SESSION_COOKIE_SAMESITE='Lax') +""", ), -# ( -# '''app.config["SESSION_COOKIE_SECURE"] = False -# app.config["SESSION_COOKIE_HTTPONLY"] = False -# app.config["SESSION_COOKIE_SAMESITE"] = "Strict" -# ''', -# '''app.config["SESSION_COOKIE_SECURE"] = True -# app.config["SESSION_COOKIE_HTTPONLY"] = True -# app.config["SESSION_COOKIE_SAMESITE"] = "Strict" -# ''', -# ), - # ( - # """app.config["SESSION_COOKIE_SECURE"] = False - # app.config["SESSION_COOKIE_SECURE"] = True - # """, - # """app.config["SESSION_COOKIE_SECURE"] = False - # app.config["SESSION_COOKIE_SECURE"] = True - # app.config.update(SESSION_COOKIE_SAMESITE='Lax') - # """, - # ), ], ) def test_config_accessed_variations( @@ -231,4 +248,4 @@ def configure(): # either within configure() call or after it's called """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - # assert len(self.file_context.codemod_changes) == 1 + assert len(self.file_context.codemod_changes) == 1