diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 9fb76a5a..729fa761 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -383,6 +383,16 @@ def is_attribute_value(self, node: cst.CSTNode) -> Optional[cst.Attribute]: return maybe_parent return None + def is_subscript_value(self, node: cst.CSTNode) -> Optional[cst.Subscript]: + """ + Checks if node is the value of an Attribute. + """ + maybe_parent = self.get_parent(node) + match maybe_parent: + case cst.Subscript(value=node): + return maybe_parent + return None + def find_immediate_function_def( self, node: cst.CSTNode ) -> Optional[cst.FunctionDef]: diff --git a/src/core_codemods/flask_json_response_type.py b/src/core_codemods/flask_json_response_type.py index 1ee7bacc..571d0e70 100644 --- a/src/core_codemods/flask_json_response_type.py +++ b/src/core_codemods/flask_json_response_type.py @@ -77,7 +77,7 @@ def leave_Return(self, original_node: cst.Return): self._fix_json_dumps(original_node.value), ) # make_response(...) - elif maybe_make_response := self._is_make_response_with_json( + elif maybe_make_response := self._is_make_response_with_json_with_unset_ct( original_node.value ): if maybe_dict := self._has_dict_with_headers_mr_call( @@ -125,9 +125,9 @@ def _is_tuple_with_json_string_response( case cst.Tuple(): elements = node.elements first = elements[0].value - if self._is_json_dumps_call(first) or self._is_make_response_with_json( + if self._is_json_dumps_call( first - ): + ) or self._is_make_response_with_json_with_unset_ct(first): return node return None @@ -168,9 +168,44 @@ def _is_json_dumps_call(self, node: cst.BaseExpression) -> Optional[cst.Call]: return expr return None - def _is_make_response_with_json( + def _has_content_type_set(self, node: cst.BaseExpression) -> bool: + if not isinstance(node, cst.Name): + return False + for access in self.find_accesses(node): + maybe_attr = self.is_attribute_value(access.node) + # is headers attribute? e.g. resp.headers + match maybe_attr: + case cst.Attribute(attr=cst.Name(value="headers")): + pass + case _: + return False + maybe_subscript = ( + self.is_subscript_value(maybe_attr) if maybe_attr else None + ) + maybe_assignment = ( + self.is_target_of_assignment(maybe_subscript) + if maybe_subscript + else None + ) + if maybe_assignment: + # is subscript content-type? + match maybe_subscript: + case cst.Subscript( + slice=[ + cst.SubscriptElement( + slice=cst.Index(value=cst.SimpleString() as index) + ) + ] + ): + if index.raw_value == "Content-Type": + return True + return False + + def _is_make_response_with_json_with_unset_ct( self, node: cst.BaseExpression ) -> Optional[cst.Call]: + if self._has_content_type_set(node): + return None expr = self.resolve_expression(node) match expr: case cst.Call(args=[cst.Arg(first_arg), *_]): diff --git a/tests/codemods/test_flask_json_response_type.py b/tests/codemods/test_flask_json_response_type.py index 1aaf1504..e0a85087 100644 --- a/tests/codemods/test_flask_json_response_type.py +++ b/tests/codemods/test_flask_json_response_type.py @@ -320,3 +320,19 @@ def foo(request): return bar(dict_response) """ self.run_and_assert(tmpdir, input_code, input_code) + + def test_simple_indirect_content_type_set(self, tmpdir): + input_code = """ + from flask import make_response, Flask + import json + + app = Flask(__name__) + + @app.route("/test") + def foo(request): + json_response = json.dumps({ "user_input": request.GET.get("input") }) + response = make_response(json_response) + response.headers['Content-Type'] = 'application/json' + return response + """ + self.run_and_assert(tmpdir, input_code, input_code)