From 9c6fb6b40754db679c6233d6f17a058985e4bf3a Mon Sep 17 00:00:00 2001 From: Ulada Zakharava Date: Thu, 23 Nov 2023 14:33:45 +0000 Subject: [PATCH] Switch to Connexion 3 framework This is a huge PR being result of over a 100 commits made by a number of people in ##36052 and #37638. It switches to Connexion 3 as the driving backend implementation for both - Airflow REST APIs and Flask app that powers Airflow UI. It should be largely backwards compatible when it comes to behaviour of both APIs and Airflow Webserver views, however due to decisions made by Connexion 3 maintainers, it changes heavily the technology stack used under-the-hood: 1) Connexion App is an ASGI-compatible Open-API spec-first framework using ASGI as an interface between webserver and Python web application. ASGI is an asynchronous successor of WSGI. 2) Connexion itself is using Starlette to run asynchronous web services in Python. 3) We continue using gunicorn appliation server that still uses WSGI standard, which means that we can continue using Flask and we are usig standard Uvicorn ASGI webserver that converts the ASGI interface to WSGI interface of Gunicorn Some of the problems handled in this PR There were two problem was with session handling: * the get_session_cookie - did not get the right cookie - it returned "session" string. The right fix was to change cookie_jar into cookie.jar because this is where apparently TestClient of starlette is holding the cookies (visible when you debug) * The client does not accept "set_cookie" method - it accepts passing cookies via "cookies" dictionary - this is the usual httpx client - see https://www.starlette.io/testclient/ - so we have to set cookie directly in the get method to try it out Add "flask_client_with_login" for tests that neeed flask client Some tests require functionality not available to Starlette test client as they use Flask test client specific features - for those we have an option to get flask test client instead of starlette one. Fix error handling for new connection 3 approach Error handling for Connexion 3 integration needed to be reworked. The way it behaves is much the same as it works in main: * for API errors - we get application/problem+json responses * for UI erros - we have rendered views * for redirection - we have correct location header (it's been missing) * the api error handled was not added as available middleware in the www tests It should fix all test_views_base.py tests which were failing on lack of location header for redirection. Fix wrong response is tests_view_cluster_activity The problem in the test was that Starlette Test Client opens a new connection and start new session, while flask test client uses the same database session. The test did not show data because the data was not committed and session was not closed - which also failed sqlite local tests with "database is locked" error. Fix test_extra_links The tests were failing again because the dagrun created was not committed and session not closed. This worked with flask client that used the same session accidentally but did not work with test client from Starlette. Also it caused "database locked" in sqlite / local tests. Switch to non-deprecated auth manager Fix to test_views_log.py This PR partially fixes sessions and request parameter for test_views_log. Some tests are still failing but for different reasons - to be investigated. Fix views_custom_user_views tests The problem in those tests was that the check in security manager was based on the assumption that the security manager was shared between the client and test flask application - because they were coming from the same flask app. But when we use starlette, the call goes to a new process started and the user is deleted in the database - so the shortcut of checking the security manager did not work. The change is that we are now checking if the user is deleted by calling /users/show (we need a new users READ permission for that) - this way we go to the database and check if the user was indeed deleted. Fix test_task_instance_endpoint tests There were two reasons for the test failed: * when the Job was added to task instance, the task instance was not merged in session, which means that commit did not store the added Job * some of the tests were expecting a call with specific session and they failed because session was different. Replacing the session with mock.ANY tells pytest that this parameter can be anything - we will have different session when when the call will be made with ASGI/Starlette Fix parameter validation * added default value for limit parameter across the board. Connexion 3 does not like if the parameter had no default and we had not provided one - even if our custom decorated was adding it. Adding default value and updating our decorator to treat None as `default` fixed a number of problems where limits were not passed * swapped openapi specification for /datasets/{uri} and /dataset/events. Since `{uri}` was defined first, connection matched `events` with `{uri}` and chose parameter definitions from `{uri}` not events Fix test_log_enpoint tests The problem here was that some sessions should be committed/closed but also in order to run it standalone we wanted to create log templates in the database - as it relied implcitly on log templates created by other tests. Fix test_views_dagrun, test_views_tasks and test_views_log Fixed by switching to use flask client for testing rather than starlette. Starlette client in this case has some side effects that are also impacting Sqlite's session being created in a different thread and deleted with close_all_sessions fixture. Fix test_views_dagrun Fixed by switching to use flask client for testing rather than starlette. Starlette client in this case has some side effects that are also impacting Sqlite's session being created in a different thread and deleted with close_all_sessions fixture. Co-authored-by: sudipto baral Co-authored-by: satoshi-sh Co-authored-by: Maksim Yermakou Co-authored-by: Ulada Zakharava Better API initialization including vending of API specification. The way paths are added and initialized is better (for example FAB contributes their path via new method in Auth Manager. This also add back-compatibility to FAB auth manaager to continue working on Airflow 2.9. --- .github/workflows/basic-tests.yml | 2 + .../endpoints/connection_endpoint.py | 2 +- .../api_connexion/endpoints/dag_endpoint.py | 2 +- .../api_connexion/endpoints/dag_parsing.py | 11 +- .../endpoints/dag_warning_endpoint.py | 2 +- .../endpoints/dataset_endpoint.py | 6 +- .../endpoints/event_log_endpoint.py | 2 +- .../endpoints/import_error_endpoint.py | 2 +- .../api_connexion/endpoints/log_endpoint.py | 5 +- .../api_connexion/endpoints/pool_endpoint.py | 2 +- .../endpoints/task_instance_endpoint.py | 2 +- airflow/api_connexion/exceptions.py | 55 ++- airflow/api_connexion/openapi/v1.yaml | 61 +-- airflow/api_connexion/parameters.py | 14 +- airflow/auth/managers/base_auth_manager.py | 13 +- airflow/cli/commands/internal_api_command.py | 54 ++- airflow/cli/commands/webserver_command.py | 16 +- airflow/migrations/env.py | 4 +- .../fab/auth_manager/fab_auth_manager.py | 37 +- airflow/utils/json.py | 3 +- airflow/www/app.py | 56 ++- .../www/extensions/init_appbuilder_links.py | 2 +- airflow/www/extensions/init_views.py | 244 +++++++---- airflow/www/package.json | 2 +- airflow/www/static/js/types/api-generated.ts | 80 ++-- airflow/www/views.py | 18 +- airflow/www/yarn.lock | 20 +- .../src/airflow_breeze/global_constants.py | 2 +- .../core-concepts/auth-manager.rst | 4 +- hatch_build.py | 3 +- newsfragments/37638.significant.rst | 4 + pyproject.toml | 1 + tests/api_connexion/conftest.py | 37 +- .../endpoints/test_config_endpoint.py | 98 ++--- .../endpoints/test_connection_endpoint.py | 164 +++----- .../endpoints/test_dag_endpoint.py | 309 +++++++------- .../endpoints/test_dag_parsing.py | 29 +- .../endpoints/test_dag_run_endpoint.py | 387 +++++++++--------- .../endpoints/test_dag_source_endpoint.py | 60 ++- .../endpoints/test_dag_warning_endpoint.py | 57 ++- .../endpoints/test_dataset_endpoint.py | 178 ++++---- .../endpoints/test_event_log_endpoint.py | 134 +++--- .../endpoints/test_extra_link_endpoint.py | 61 +-- .../endpoints/test_forward_to_fab_endpoint.py | 63 +-- .../endpoints/test_health_endpoint.py | 14 +- .../endpoints/test_import_error_endpoint.py | 92 ++--- .../endpoints/test_log_endpoint.py | 132 +++--- .../test_mapped_task_instance_endpoint.py | 161 ++++---- .../endpoints/test_plugin_endpoint.py | 56 ++- .../endpoints/test_pool_endpoint.py | 150 ++++--- .../endpoints/test_provider_endpoint.py | 28 +- .../endpoints/test_task_endpoint.py | 79 ++-- .../endpoints/test_task_instance_endpoint.py | 343 ++++++++-------- .../endpoints/test_variable_endpoint.py | 124 +++--- .../endpoints/test_version_endpoint.py | 6 +- .../endpoints/test_xcom_endpoint.py | 66 +-- .../schemas/test_dag_run_schema.py | 2 +- .../test_role_and_permission_schema.py | 14 +- tests/api_connexion/test_auth.py | 55 +-- tests/api_connexion/test_cors.py | 45 +- tests/api_connexion/test_error_handling.py | 14 +- tests/api_connexion/test_security.py | 20 +- .../auth/backend/test_basic_auth.py | 12 +- tests/api_experimental/conftest.py | 2 +- .../endpoints/test_rpc_api_endpoint.py | 18 +- tests/auth/managers/test_base_auth_manager.py | 3 - .../cli/commands/test_internal_api_command.py | 9 +- tests/cli/commands/test_webserver_command.py | 12 +- tests/conftest.py | 31 +- .../auth/backend/test_kerberos_auth.py | 17 +- tests/plugins/test_plugins_manager.py | 15 +- .../aws/auth_manager/test_aws_auth_manager.py | 13 +- .../aws/auth_manager/views/test_auth.py | 22 +- .../api/auth/backend/test_basic_auth.py | 8 +- .../test_role_and_permission_endpoint.py | 160 ++++---- .../api_endpoints/test_user_endpoint.py | 210 +++++----- .../api_endpoints/test_user_schema.py | 17 +- tests/providers/fab/auth_manager/conftest.py | 14 +- .../fab/auth_manager/decorators/test_auth.py | 20 +- .../fab/auth_manager/test_security.py | 118 +++--- .../auth_manager/views/test_permissions.py | 8 +- .../fab/auth_manager/views/test_roles_list.py | 12 +- .../fab/auth_manager/views/test_user.py | 12 +- .../fab/auth_manager/views/test_user_edit.py | 12 +- .../fab/auth_manager/views/test_user_stats.py | 14 +- .../common/auth_backend/test_google_openid.py | 24 +- tests/sensors/test_external_task_sensor.py | 4 +- .../amazon/aws/tests/test_aws_auth_manager.py | 7 +- tests/test_utils/api_connexion_utils.py | 2 +- tests/test_utils/decorators.py | 10 +- tests/test_utils/mock_cors_middeleware.py | 35 ++ .../remote_user_api_auth_backend.py | 2 +- tests/test_utils/www.py | 18 +- tests/utils/test_helpers.py | 4 +- tests/www/api/experimental/conftest.py | 10 +- .../experimental/test_dag_runs_endpoint.py | 16 +- tests/www/api/experimental/test_endpoints.py | 91 ++-- tests/www/test_app.py | 20 +- tests/www/test_auth.py | 8 +- tests/www/test_security_manager.py | 4 +- tests/www/test_utils.py | 6 +- tests/www/views/conftest.py | 43 +- .../www/views/test_anonymous_as_admin_role.py | 5 +- tests/www/views/test_session.py | 49 ++- tests/www/views/test_views.py | 25 +- tests/www/views/test_views_acl.py | 82 ++-- tests/www/views/test_views_base.py | 58 +-- tests/www/views/test_views_blocked.py | 2 +- .../www/views/test_views_cluster_activity.py | 6 +- tests/www/views/test_views_connection.py | 4 +- .../www/views/test_views_custom_user_views.py | 64 +-- tests/www/views/test_views_dagrun.py | 82 ++-- tests/www/views/test_views_dataset.py | 40 +- tests/www/views/test_views_extra_links.py | 61 ++- tests/www/views/test_views_grid.py | 48 ++- tests/www/views/test_views_home.py | 18 +- tests/www/views/test_views_log.py | 62 +-- tests/www/views/test_views_mount.py | 8 +- tests/www/views/test_views_paused.py | 8 +- tests/www/views/test_views_pool.py | 2 +- tests/www/views/test_views_rate_limit.py | 24 +- tests/www/views/test_views_rendered.py | 6 +- tests/www/views/test_views_robots.py | 6 +- tests/www/views/test_views_task_norun.py | 4 +- tests/www/views/test_views_tasks.py | 95 +++-- tests/www/views/test_views_trigger_dag.py | 28 +- tests/www/views/test_views_variable.py | 24 +- 127 files changed, 2931 insertions(+), 2652 deletions(-) create mode 100644 newsfragments/37638.significant.rst create mode 100644 tests/test_utils/mock_cors_middeleware.py diff --git a/.github/workflows/basic-tests.yml b/.github/workflows/basic-tests.yml index db84bae38e2e2..3bf42b1ce815d 100644 --- a/.github/workflows/basic-tests.yml +++ b/.github/workflows/basic-tests.yml @@ -148,6 +148,8 @@ jobs: env: HATCH_ENV: "test" working-directory: ./clients/python + - name: Compile www assets + run: breeze compile-www-assets - name: "Install Airflow in editable mode with fab for webserver tests" run: pip install -e ".[fab]" - name: "Install Python client" diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index c17a9280d78f8..452ccb42cfbbe 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -91,7 +91,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API @provide_session def get_connections( *, - limit: int, + limit: int | None = None, offset: int = 0, order_by: str = "id", session: Session = NEW_SESSION, diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 1895bfeaec762..1efecbbbba5db 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -94,7 +94,7 @@ def get_dag_details( @provide_session def get_dags( *, - limit: int, + limit: int | None = None, offset: int = 0, tags: Collection[str] | None = None, dag_id_pattern: str | None = None, diff --git a/airflow/api_connexion/endpoints/dag_parsing.py b/airflow/api_connexion/endpoints/dag_parsing.py index 8c48888629b2b..c53b0d01c51f9 100644 --- a/airflow/api_connexion/endpoints/dag_parsing.py +++ b/airflow/api_connexion/endpoints/dag_parsing.py @@ -19,7 +19,8 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Sequence -from flask import Response, current_app +from connexion import NoContent +from flask import current_app from itsdangerous import BadSignature, URLSafeSerializer from sqlalchemy import exc, select @@ -39,7 +40,9 @@ @security.requires_access_dag("PUT") @provide_session -def reparse_dag_file(*, file_token: str, session: Session = NEW_SESSION) -> Response: +def reparse_dag_file( + *, file_token: str, session: Session = NEW_SESSION +) -> tuple[str | NoContent, HTTPStatus]: """Request re-parsing a DAG file.""" secret_key = current_app.config["SECRET_KEY"] auth_s = URLSafeSerializer(secret_key) @@ -65,5 +68,5 @@ def reparse_dag_file(*, file_token: str, session: Session = NEW_SESSION) -> Resp session.commit() except exc.IntegrityError: session.rollback() - return Response("Duplicate request", HTTPStatus.CREATED) - return Response(status=HTTPStatus.CREATED) + return "Duplicate request", HTTPStatus.CREATED + return NoContent, HTTPStatus.CREATED diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py b/airflow/api_connexion/endpoints/dag_warning_endpoint.py index d59db8c3d3082..f1eeddf0c8104 100644 --- a/airflow/api_connexion/endpoints/dag_warning_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py @@ -43,7 +43,7 @@ @provide_session def get_dag_warnings( *, - limit: int, + limit: int | None = None, dag_id: str | None = None, warning_type: str | None = None, offset: int | None = None, diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py index bfdb8d0a5e7ee..bbc91f85eac3a 100644 --- a/airflow/api_connexion/endpoints/dataset_endpoint.py +++ b/airflow/api_connexion/endpoints/dataset_endpoint.py @@ -82,7 +82,7 @@ def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse: @provide_session def get_datasets( *, - limit: int, + limit: int | None = None, offset: int = 0, uri_pattern: str | None = None, dag_ids: str | None = None, @@ -113,11 +113,11 @@ def get_datasets( @security.requires_access_dataset("GET") -@provide_session @format_parameters({"limit": check_limit}) +@provide_session def get_dataset_events( *, - limit: int, + limit: int | None = None, offset: int = 0, order_by: str = "timestamp", dataset_id: int | None = None, diff --git a/airflow/api_connexion/endpoints/event_log_endpoint.py b/airflow/api_connexion/endpoints/event_log_endpoint.py index 3b3dbe6efd490..23caee3755686 100644 --- a/airflow/api_connexion/endpoints/event_log_endpoint.py +++ b/airflow/api_connexion/endpoints/event_log_endpoint.py @@ -64,7 +64,7 @@ def get_event_logs( included_events: str | None = None, before: str | None = None, after: str | None = None, - limit: int, + limit: int | None = None, offset: int | None = None, order_by: str = "event_log_id", session: Session = NEW_SESSION, diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py b/airflow/api_connexion/endpoints/import_error_endpoint.py index 76b706eac1ae4..b63d0c30115d4 100644 --- a/airflow/api_connexion/endpoints/import_error_endpoint.py +++ b/airflow/api_connexion/endpoints/import_error_endpoint.py @@ -77,7 +77,7 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> @provide_session def get_import_errors( *, - limit: int, + limit: int | None = None, offset: int | None = None, order_by: str = "import_error_id", session: Session = NEW_SESSION, diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py index 239f08ecdaf40..5493b6278d10b 100644 --- a/airflow/api_connexion/endpoints/log_endpoint.py +++ b/airflow/api_connexion/endpoints/log_endpoint.py @@ -107,7 +107,10 @@ def get_log( logs = logs[0] if task_try_number is not None else logs # we must have token here, so we can safely ignore it token = URLSafeSerializer(key).dumps(metadata) # type: ignore[assignment] - return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs)) + return Response( + logs_schema.dumps(LogResponseObject(continuation_token=token, content=logs)), + headers={"Content-Type": "application/json"}, + ) # text/plain. Stream logs = task_log_reader.read_log_stream(ti, task_try_number, metadata) diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py b/airflow/api_connexion/endpoints/pool_endpoint.py index 553d50c7464b7..ef59ed21b6321 100644 --- a/airflow/api_connexion/endpoints/pool_endpoint.py +++ b/airflow/api_connexion/endpoints/pool_endpoint.py @@ -68,7 +68,7 @@ def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: @provide_session def get_pools( *, - limit: int, + limit: int | None = None, order_by: str = "id", offset: int | None = None, session: Session = NEW_SESSION, diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 9919162262191..145e28c1be56b 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -309,7 +309,7 @@ def _apply_range_filter(query: Select, key: ClauseElement, value_range: tuple[T, @provide_session def get_task_instances( *, - limit: int, + limit: int | None = None, dag_id: str | None = None, dag_run_id: str | None = None, execution_date_gte: str | None = None, diff --git a/airflow/api_connexion/exceptions.py b/airflow/api_connexion/exceptions.py index 75d9261ef6d44..fa2015a2dea1c 100644 --- a/airflow/api_connexion/exceptions.py +++ b/airflow/api_connexion/exceptions.py @@ -19,13 +19,12 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Any -import werkzeug -from connexion import FlaskApi, ProblemException, problem +from connexion import ProblemException, problem from airflow.utils.docs import get_docs_url if TYPE_CHECKING: - import flask + from connexion.lifecycle import ConnexionRequest, ConnexionResponse doc_link = get_docs_url("stable-rest-api-ref.html") @@ -40,37 +39,29 @@ } -def common_error_handler(exception: BaseException) -> flask.Response: +def problem_error_handler(_request: ConnexionRequest, exception: ProblemException) -> ConnexionResponse: """Use to capture connexion exceptions and add link to the type field.""" - if isinstance(exception, ProblemException): - link = EXCEPTIONS_LINK_MAP.get(exception.status) - if link: - response = problem( - status=exception.status, - title=exception.title, - detail=exception.detail, - type=link, - instance=exception.instance, - headers=exception.headers, - ext=exception.ext, - ) - else: - response = problem( - status=exception.status, - title=exception.title, - detail=exception.detail, - type=exception.type, - instance=exception.instance, - headers=exception.headers, - ext=exception.ext, - ) + link = EXCEPTIONS_LINK_MAP.get(exception.status) + if link: + return problem( + status=exception.status, + title=exception.title, + detail=exception.detail, + type=link, + instance=exception.instance, + headers=exception.headers, + ext=exception.ext, + ) else: - if not isinstance(exception, werkzeug.exceptions.HTTPException): - exception = werkzeug.exceptions.InternalServerError() - - response = problem(title=exception.name, detail=exception.description, status=exception.code) - - return FlaskApi.get_response(response) + return problem( + status=exception.status, + title=exception.title, + detail=exception.detail, + type=exception.type, + instance=exception.instance, + headers=exception.headers, + ext=exception.ext, + ) class NotFound(ProblemException): diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 273d69ab705db..9a23e38c5c158 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -1453,6 +1453,10 @@ paths: responses: "204": description: Success. + content: + text/html: + schema: + type: string "400": $ref: "#/components/responses/BadRequest" "401": @@ -1831,6 +1835,10 @@ paths: responses: "204": description: Success. + content: + text/html: + schema: + type: string "400": $ref: "#/components/responses/BadRequest" "401": @@ -1973,8 +1981,8 @@ paths: response = self.client.get( request_url, query_string={"token": token}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain","REMOTE_USER": "test"}, + ) continuation_token = response.json["continuation_token"] metadata = URLSafeSerializer(key).loads(continuation_token) @@ -2108,7 +2116,7 @@ paths: properties: content: type: string - plain/text: + text/plain: schema: type: string @@ -2194,29 +2202,6 @@ paths: "403": $ref: "#/components/responses/PermissionDenied" - /datasets/{uri}: - parameters: - - $ref: "#/components/parameters/DatasetURI" - get: - summary: Get a dataset - description: Get a dataset by uri. - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: get_dataset - tags: [Dataset] - responses: - "200": - description: Success. - content: - application/json: - schema: - $ref: "#/components/schemas/Dataset" - "401": - $ref: "#/components/responses/Unauthenticated" - "403": - $ref: "#/components/responses/PermissionDenied" - "404": - $ref: "#/components/responses/NotFound" - /datasets/events: get: summary: Get dataset events @@ -2274,6 +2259,30 @@ paths: '404': $ref: '#/components/responses/NotFound' + /datasets/{uri}: + parameters: + - $ref: "#/components/parameters/DatasetURI" + get: + summary: Get a dataset + description: Get a dataset by uri. + x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint + operationId: get_dataset + tags: [Dataset] + responses: + "200": + description: Success. + content: + application/json: + schema: + $ref: "#/components/schemas/Dataset" + "401": + $ref: "#/components/responses/Unauthenticated" + "403": + $ref: "#/components/responses/PermissionDenied" + "404": + $ref: "#/components/responses/NotFound" + + /config: get: summary: Get current configuration diff --git a/airflow/api_connexion/parameters.py b/airflow/api_connexion/parameters.py index a05ded37614d4..79e34feecef3d 100644 --- a/airflow/api_connexion/parameters.py +++ b/airflow/api_connexion/parameters.py @@ -41,7 +41,7 @@ def validate_istimezone(value: datetime) -> None: raise BadRequest("Invalid datetime format", detail="Naive datetime is disallowed") -def format_datetime(value: str) -> datetime: +def format_datetime(value: str | None) -> datetime | None: """ Format datetime objects. @@ -50,6 +50,8 @@ def format_datetime(value: str) -> datetime: This should only be used within connection views because it raises 400 """ + if value is None: + return None value = value.strip() if value[-1] != "Z": value = value.replace(" ", "+") @@ -59,7 +61,7 @@ def format_datetime(value: str) -> datetime: raise BadRequest("Incorrect datetime argument", detail=str(err)) -def check_limit(value: int) -> int: +def check_limit(value: int | None) -> int: """ Check the limit does not exceed configured value. @@ -68,7 +70,8 @@ def check_limit(value: int) -> int: """ max_val = conf.getint("api", "maximum_page_limit") # user configured max page limit fallback = conf.getint("api", "fallback_page_limit") - + if value is None: + return fallback if value > max_val: log.warning( "The limit param value %s passed in API exceeds the configured maximum page limit %s", @@ -99,8 +102,9 @@ def format_parameters_decorator(func: T) -> T: @wraps(func) def wrapped_function(*args, **kwargs): for key, formatter in params_formatters.items(): - if key in kwargs: - kwargs[key] = formatter(kwargs[key]) + value = formatter(kwargs.get(key)) + if value: + kwargs[key] = value return func(*args, **kwargs) return cast(T, wrapped_function) diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 86f0ebd6dc254..c8422d3dd2ed6 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -19,7 +19,7 @@ from abc import abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Container, Literal, Sequence +from typing import TYPE_CHECKING, Any, Container, Literal, Sequence from flask_appbuilder.menu import MenuItem from sqlalchemy import select @@ -82,7 +82,7 @@ def get_cli_commands() -> list[CLICommand]: return [] def get_api_endpoints(self) -> None | Blueprint: - """Return API endpoint(s) definition for the auth manager.""" + """Return API endpoint(s) definition for the auth manager for Airflow 2.9.""" return None def get_user_name(self) -> str: @@ -442,3 +442,12 @@ def security_manager(self) -> AirflowSecurityManagerV2: from airflow.www.security_manager import AirflowSecurityManagerV2 return AirflowSecurityManagerV2(self.appbuilder) + + def get_auth_manager_api_specification(self) -> tuple[str | None, dict[Any, Any]]: + """ + Return the mount point and specification (openapi) for auth manager contributed API (Airflow 2.10). + + By default is raises NotImplementedError which produces a warning in airflow logs when auth manager is + initialized, but you can return None, {} if the auth manager does not contribute API. + """ + raise NotImplementedError diff --git a/airflow/cli/commands/internal_api_command.py b/airflow/cli/commands/internal_api_command.py index 8c25d1fa5ae58..9cd96b5062ea5 100644 --- a/airflow/cli/commands/internal_api_command.py +++ b/airflow/cli/commands/internal_api_command.py @@ -28,9 +28,10 @@ from pathlib import Path from tempfile import gettempdir from time import sleep +from typing import TYPE_CHECKING +import connexion import psutil -from flask import Flask from flask_appbuilder import SQLA from flask_caching import Cache from flask_wtf.csrf import CSRFProtect @@ -54,7 +55,12 @@ from airflow.www.extensions.init_security import init_xframe_protection from airflow.www.extensions.init_views import init_api_internal, init_error_handlers +if TYPE_CHECKING: + from flask import Flask + + log = logging.getLogger(__name__) +connexion_app: connexion.FlaskApp | None = None app: Flask | None = None @@ -72,10 +78,10 @@ def internal_api(args): if args.debug: log.info("Starting the Internal API server on port %s and host %s.", args.port, args.hostname) - app = create_app(testing=conf.getboolean("core", "unit_test_mode")) + app = create_connexion_app(testing=conf.getboolean("core", "unit_test_mode")) app.run( - debug=True, # nosec - use_reloader=not app.config["TESTING"], + log_level="debug", + # reload=not app.app.config["TESTING"], port=args.port, host=args.hostname, ) @@ -102,7 +108,7 @@ def internal_api(args): "--workers", str(num_workers), "--worker-class", - str(args.workerclass), + "uvicorn.workers.UvicornWorker", "--timeout", str(worker_timeout), "--bind", @@ -125,7 +131,7 @@ def internal_api(args): if args.daemon: run_args += ["--daemon"] - run_args += ["airflow.cli.commands.internal_api_command:cached_app()"] + run_args += ["airflow.cli.commands.internal_api_command:cached_connexion_app()"] # To prevent different workers creating the web app and # all writing to the database at the same time, we use the --preload option. @@ -182,7 +188,7 @@ def start_and_monitor_gunicorn(args): if args.daemon: # This makes possible errors get reported before daemonization os.environ["SKIP_DAGS_PARSING"] = "True" - create_app(None) + create_connexion_app(None) os.environ.pop("SKIP_DAGS_PARSING") pid_file_path = Path(pid_file) @@ -196,9 +202,10 @@ def start_and_monitor_gunicorn(args): ) -def create_app(config=None, testing=False): +def create_connexion_app(config=None, testing=False): """Create a new instance of Airflow Internal API app.""" - flask_app = Flask(__name__) + connexion_app = connexion.FlaskApp(__name__) + flask_app = connexion_app.app flask_app.config["APP_NAME"] = "Airflow Internal API" flask_app.config["TESTING"] = testing @@ -243,16 +250,31 @@ def create_app(config=None, testing=False): with flask_app.app_context(): init_error_handlers(flask_app) - init_api_internal(flask_app, standalone_api=True) + init_api_internal(connexion_app, standalone_api=True) init_jinja_globals(flask_app) init_xframe_protection(flask_app) - return flask_app + return connexion_app -def cached_app(config=None, testing=False): - """Return cached instance of Airflow Internal API app.""" +def cached_connexion_app(config=None, testing=False) -> connexion.FlaskApp: + """Return cached instance of Airflow WWW app.""" + global connexion_app global app - if not app: - app = create_app(config=config, testing=testing) - return app + if not connexion_app: + connexion_app = create_connexion_app(config=config, testing=testing) + app = connexion_app.app + return connexion_app + + +def purge_cached_connexion_app(): + """Remove the cached version of the app in global state.""" + global connexion_app + global app + connexion_app = None + app = None + + +def cached_app(config=None, testing=False) -> Flask: + """Return flask app from connexion_app.""" + return cached_connexion_app(config=config, testing=testing).app diff --git a/airflow/cli/commands/webserver_command.py b/airflow/cli/commands/webserver_command.py index 4285564e1fd17..86f689a556c41 100644 --- a/airflow/cli/commands/webserver_command.py +++ b/airflow/cli/commands/webserver_command.py @@ -350,17 +350,17 @@ def webserver(args): if ssl_cert and not ssl_key: raise AirflowException("An SSL key must also be provided for use with " + ssl_cert) - from airflow.www.app import create_app + from airflow.www.app import create_connexion_app if args.debug: print(f"Starting the web server on port {args.port} and host {args.hostname}.") - app = create_app(testing=conf.getboolean("core", "unit_test_mode")) + app = create_connexion_app(testing=conf.getboolean("core", "unit_test_mode")) app.run( - debug=True, - use_reloader=not app.config["TESTING"], + log_level="debug", port=args.port, host=args.hostname, - ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None, + ssl_keyfile=ssl_key if ssl_cert and ssl_key else None, + ssl_certfile=ssl_cert if ssl_cert and ssl_key else None, ) else: print( @@ -384,7 +384,7 @@ def webserver(args): "--workers", str(num_workers), "--worker-class", - str(args.workerclass), + "uvicorn.workers.UvicornWorker", "--timeout", str(worker_timeout), "--bind", @@ -412,7 +412,7 @@ def webserver(args): if ssl_cert: run_args += ["--certfile", ssl_cert, "--keyfile", ssl_key] - run_args += ["airflow.www.app:cached_app()"] + run_args += ["airflow.www.app:cached_connexion_app()"] if conf.getboolean("webserver", "reload_on_plugin_change", fallback=False): log.warning( @@ -477,7 +477,7 @@ def start_and_monitor_gunicorn(args): if args.daemon: # This makes possible errors get reported before daemonization os.environ["SKIP_DAGS_PARSING"] = "True" - create_app(None) + create_connexion_app(None) os.environ.pop("SKIP_DAGS_PARSING") pid_file_path = Path(pid_file) diff --git a/airflow/migrations/env.py b/airflow/migrations/env.py index c1558f6399671..3ede15813e2df 100644 --- a/airflow/migrations/env.py +++ b/airflow/migrations/env.py @@ -130,6 +130,6 @@ def process_revision_directives(context, revision, directives): if "airflow.www.app" in sys.modules: # Already imported, make sure we clear out any cached app - from airflow.www.app import purge_cached_app + from airflow.www.app import purge_cached_connexion_app - purge_cached_app() + purge_cached_connexion_app() diff --git a/airflow/providers/fab/auth_manager/fab_auth_manager.py b/airflow/providers/fab/auth_manager/fab_auth_manager.py index ffd5e5cab5d3e..6f78ee87dfba1 100644 --- a/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -20,9 +20,8 @@ import argparse from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Container +from typing import TYPE_CHECKING, Any, Container -from connexion import FlaskApi from flask import Blueprint, url_for from sqlalchemy import select from sqlalchemy.orm import Session, joinedload @@ -83,7 +82,8 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.yaml import safe_load from airflow.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED -from airflow.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver + +FAB_API_MOUNT_POINT = "/auth/fab/v1" if TYPE_CHECKING: from airflow.auth.managers.models.base_user import BaseUser @@ -148,19 +148,48 @@ def get_cli_commands() -> list[CLICommand]: ] def get_api_endpoints(self) -> None | Blueprint: + """ + Airflow 2.9 compatible way of initializing Auth Manager API. + + This method is back-compatibility for Airflow 2.9 - this is the way API endpoints were + added in Connexion 2 for Airflow 2.9 and we want to keep that option to make FAB provider + continue working for Airflow 2.9. We should remove that method when min airflow version for + FAB providers is 2.10. + """ + # Legacy endpoint for Airflow 2.9 folder = Path(__file__).parents[0].resolve() # this is airflow/auth/managers/fab/ with folder.joinpath("openapi", "v1.yaml").open() as f: specification = safe_load(f) + from connexion import FlaskApi + + # This import is only available in Airflow 2.9 + from airflow.www.extensions.init_views import ( # type: ignore[attr-defined] + _CustomErrorRequestBodyValidator, + _LazyResolver, + ) + return FlaskApi( specification=specification, resolver=_LazyResolver(), - base_path="/auth/fab/v1", + base_path=FAB_API_MOUNT_POINT, options={"swagger_ui": SWAGGER_ENABLED, "swagger_path": SWAGGER_BUNDLE.__fspath__()}, strict_validation=True, validate_responses=True, validator_map={"body": _CustomErrorRequestBodyValidator}, ).blueprint + def get_auth_manager_api_specification(self) -> tuple[str, dict[Any, Any]]: + """ + Get mount point for the Auth Manager contributed API. + + This method is a new method of retrieving mount point and specification + is used by Airflow 2.10 to register API endpoints for Auth managers. + """ + folder = Path(__file__).parents[0].resolve() # this is airflow/auth/managers/fab/ + with folder.joinpath("openapi", "v1.yaml").open() as f: + specification = safe_load(f) + return FAB_API_MOUNT_POINT, specification + def get_user_display_name(self) -> str: """Return the user's display name associated to the user in session.""" user = self.get_user() diff --git a/airflow/utils/json.py b/airflow/utils/json.py index 4d89e340c1cd4..2540edf9a0cbb 100644 --- a/airflow/utils/json.py +++ b/airflow/utils/json.py @@ -37,7 +37,8 @@ class AirflowJsonProvider(JSONProvider): def dumps(self, obj, **kwargs): kwargs.setdefault("ensure_ascii", self.ensure_ascii) kwargs.setdefault("sort_keys", self.sort_keys) - return json.dumps(obj, **kwargs, cls=WebEncoder) + kwargs.setdefault("cls", WebEncoder) + return json.dumps(obj, **kwargs) def loads(self, s: str | bytes, **kwargs): return json.loads(s, **kwargs) diff --git a/airflow/www/app.py b/airflow/www/app.py index 78d51c64a6ab1..3027faab225bc 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -21,7 +21,8 @@ from datetime import timedelta from os.path import isabs -from flask import Flask +import connexion +from flask import Flask, request from flask_appbuilder import SQLA from flask_wtf.csrf import CSRFProtect from markupsafe import Markup @@ -51,12 +52,13 @@ ) from airflow.www.extensions.init_session import init_airflow_session_interface from airflow.www.extensions.init_views import ( - init_api_auth_provider, + init_api_auth_manager, init_api_connexion, init_api_error_handlers, init_api_experimental, init_api_internal, init_appbuilder_views, + init_cors_middleware, init_error_handlers, init_flash_views, init_plugins, @@ -64,15 +66,28 @@ from airflow.www.extensions.init_wsgi_middlewares import init_wsgi_middleware app: Flask | None = None - +connexion_app: connexion.FlaskApp | None = None # Initializes at the module level, so plugins can access it. # See: /docs/plugins.rst csrf = CSRFProtect() -def create_app(config=None, testing=False): +def create_connexion_app(config=None, testing=False): """Create a new instance of Airflow WWW app.""" - flask_app = Flask(__name__) + conn_app = connexion.FlaskApp(__name__) + + @conn_app.app.before_request + def before_request(): + """Exempts the view function associated with '/api/v1' requests from CSRF protection.""" + if request.path.startswith("/api/v1"): # TODO: make sure this path is correct + view_function = conn_app.app.view_functions.get(request.endpoint) + if view_function: + # Exempt the view function from CSRF protection + conn_app.app.extensions["csrf"].exempt(view_function) + + init_cors_middleware(conn_app) + + flask_app = conn_app.app flask_app.secret_key = conf.get("webserver", "SECRET_KEY") flask_app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(minutes=settings.get_session_lifetime_config()) @@ -80,6 +95,7 @@ def create_app(config=None, testing=False): flask_app.config["MAX_CONTENT_LENGTH"] = conf.getfloat("webserver", "allowed_payload_size") * 1024 * 1024 webserver_config = conf.get_mandatory_value("webserver", "config_file") + # Enable customizations in webserver_config.py to be applied via Flask.current_app. with flask_app.app_context(): flask_app.config.from_pyfile(webserver_config, silent=True) @@ -169,34 +185,44 @@ def create_app(config=None, testing=False): init_appbuilder_links(flask_app) init_plugins(flask_app) init_error_handlers(flask_app) - init_api_connexion(flask_app) + init_api_connexion(conn_app) if conf.getboolean("webserver", "run_internal_api", fallback=False): if not _ENABLE_AIP_44: raise RuntimeError("The AIP_44 is not enabled so you cannot use it.") - init_api_internal(flask_app) + init_api_internal(conn_app) init_api_experimental(flask_app) - init_api_auth_provider(flask_app) - init_api_error_handlers(flask_app) # needs to be after all api inits to let them add their path first get_auth_manager().init() + init_api_auth_manager(conn_app) + + init_api_error_handlers(conn_app) # needs to be after all api inits to let them add their path first init_jinja_globals(flask_app) init_xframe_protection(flask_app) init_cache_control(flask_app) init_airflow_session_interface(flask_app) init_check_user_active(flask_app) - return flask_app + return conn_app -def cached_app(config=None, testing=False): +def cached_connexion_app(config=None, testing=False) -> connexion.FlaskApp: """Return cached instance of Airflow WWW app.""" + global connexion_app global app - if not app: - app = create_app(config=config, testing=testing) - return app + if not connexion_app: + connexion_app = create_connexion_app(config=config, testing=testing) + app = connexion_app.app + return connexion_app -def purge_cached_app(): +def purge_cached_connexion_app(): """Remove the cached version of the app in global state.""" + global connexion_app global app + connexion_app = None app = None + + +def cached_app(config=None, testing=False) -> Flask: + """Return flask app from connexion_app.""" + return cached_connexion_app(config=config, testing=testing).app diff --git a/airflow/www/extensions/init_appbuilder_links.py b/airflow/www/extensions/init_appbuilder_links.py index 0d2f4e13e9293..effbca892a395 100644 --- a/airflow/www/extensions/init_appbuilder_links.py +++ b/airflow/www/extensions/init_appbuilder_links.py @@ -53,7 +53,7 @@ def init_appbuilder_links(app): appbuilder.add_link( name=RESOURCE_DOCS, label="REST API Reference (Swagger UI)", - href="/api/v1./api/v1_swagger_ui_index", + href="SwaggerView.swagger", category=RESOURCE_DOCS_MENU, ) appbuilder.add_link( diff --git a/airflow/www/extensions/init_views.py b/airflow/www/extensions/init_views.py index c671aa6195215..2b4611dbe8d83 100644 --- a/airflow/www/extensions/init_views.py +++ b/airflow/www/extensions/init_views.py @@ -18,24 +18,29 @@ import logging import warnings -from functools import cached_property +from enum import Enum +from functools import cached_property, lru_cache from pathlib import Path from typing import TYPE_CHECKING +from urllib.parse import urlsplit -from connexion import FlaskApi, ProblemException, Resolver -from connexion.decorators.validation import RequestBodyValidator -from connexion.exceptions import BadRequestProblem -from flask import request +import connexion +import starlette.exceptions +from connexion import ProblemException, Resolver +from connexion.options import SwaggerUIOptions +from connexion.problem import problem -from airflow.api_connexion.exceptions import common_error_handler +from airflow.api_connexion.exceptions import problem_error_handler from airflow.configuration import conf -from airflow.exceptions import RemovedInAirflow3Warning +from airflow.exceptions import AirflowConfigException, RemovedInAirflow3Warning from airflow.security import permissions from airflow.utils.yaml import safe_load from airflow.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED from airflow.www.extensions.init_auth_manager import get_auth_manager if TYPE_CHECKING: + import starlette.exceptions + from connexion.lifecycle import ConnexionRequest, ConnexionResponse from flask import Flask log = logging.getLogger(__name__) @@ -128,6 +133,8 @@ def init_appbuilder_views(app): # add_view_no_menu to change item position. # I added link in extensions.init_appbuilder_links.init_appbuilder_links appbuilder.add_view_no_menu(views.RedocView) + if conf.getboolean("webserver", "enable_swagger_ui", fallback=True): + appbuilder.add_view_no_menu(views.SwaggerView) # Development views appbuilder.add_view_no_menu(views.DevView) appbuilder.add_view_no_menu(views.DocsView) @@ -172,26 +179,6 @@ def init_error_handlers(app: Flask): from airflow.www import views app.register_error_handler(500, views.show_traceback) - app.register_error_handler(404, views.not_found) - - -def set_cors_headers_on_response(response): - """Add response headers.""" - allow_headers = conf.get("api", "access_control_allow_headers") - allow_methods = conf.get("api", "access_control_allow_methods") - allow_origins = conf.get("api", "access_control_allow_origins") - if allow_headers: - response.headers["Access-Control-Allow-Headers"] = allow_headers - if allow_methods: - response.headers["Access-Control-Allow-Methods"] = allow_methods - if allow_origins == "*": - response.headers["Access-Control-Allow-Origin"] = "*" - elif allow_origins: - allowed_origins = allow_origins.split(" ") - origin = request.environ.get("HTTP_ORIGIN", allowed_origins[0]) - if origin in allowed_origins: - response.headers["Access-Control-Allow-Origin"] = origin - return response class _LazyResolution: @@ -225,90 +212,135 @@ def resolve(self, operation): return _LazyResolution(self.resolve_function_from_operation_id, operation_id) -class _CustomErrorRequestBodyValidator(RequestBodyValidator): - """Custom request body validator that overrides error messages. +# contains map of base paths that have api endpoints - By default, Connextion emits a very generic *None is not of type 'object'* - error when receiving an empty request body (with the view specifying the - body as non-nullable). We overrides it to provide a more useful message. - """ - def validate_schema(self, data, url): - if not self.is_null_value_valid and data is None: - raise BadRequestProblem(detail="Request body must not be empty") - return super().validate_schema(data, url) +class BaseAPIPaths(Enum): + """Known Airflow API paths.""" + + REST_API = "/api/v1" + INTERNAL_API = "/internal_api/v1" + EXPERIMENTAL_API = "/api/experimental" + + +def get_base_url() -> str: + """Return base url to prepend to all API routes.""" + webserver_base_url = conf.get_mandatory_value("webserver", "BASE_URL", fallback="") + if webserver_base_url.endswith("/"): + raise AirflowConfigException("webserver.base_url conf cannot have a trailing slash.") + base_url = urlsplit(webserver_base_url)[2] + if not base_url or base_url == "/": + base_url = "" + return base_url + + +BASE_URL = get_base_url() + +auth_mgr_mount_point: str | None = None + + +@lru_cache(maxsize=1) +def get_enabled_api_paths() -> list[str]: + enabled_apis = [] + enabled_apis.append(BaseAPIPaths.REST_API.value) + if conf.getboolean("webserver", "run_internal_api", fallback=False): + enabled_apis.append(BaseAPIPaths.INTERNAL_API.value) + if conf.getboolean("api", "enable_experimental_api", fallback=False): + enabled_apis.append(BaseAPIPaths.EXPERIMENTAL_API.value) + if auth_mgr_mount_point: + enabled_apis.append(auth_mgr_mount_point) + return enabled_apis + +def is_current_request_on_api_path() -> bool: + from flask.globals import request -base_paths: list[str] = [] # contains the list of base paths that have api endpoints + return any([request.path.startswith(BASE_URL + p) for p in get_enabled_api_paths()]) -def init_api_error_handlers(app: Flask) -> None: +def init_api_error_handlers(connexion_app: connexion.FlaskApp) -> None: """Add error handlers for 404 and 405 errors for existing API paths.""" from airflow.www import views - @app.errorhandler(404) - def _handle_api_not_found(ex): - if any([request.path.startswith(p) for p in base_paths]): + def _handle_api_not_found(error) -> ConnexionResponse | str: + if is_current_request_on_api_path(): # 404 errors are never handled on the blueprint level # unless raised from a view func so actual 404 errors, # i.e. "no route for it" defined, need to be handled # here on the application level - return common_error_handler(ex) - else: - return views.not_found(ex) - - @app.errorhandler(405) - def _handle_method_not_allowed(ex): - if any([request.path.startswith(p) for p in base_paths]): - return common_error_handler(ex) - else: - return views.method_not_allowed(ex) - - app.register_error_handler(ProblemException, common_error_handler) - - -def init_api_connexion(app: Flask) -> None: + return connexion_app._http_exception(error) + return views.not_found(error) + + def _handle_api_method_not_allowed(error) -> ConnexionResponse | str: + if is_current_request_on_api_path(): + return connexion_app._http_exception(error) + return views.method_not_allowed(error) + + def _handle_redirect( + request: ConnexionRequest, ex: starlette.exceptions.HTTPException + ) -> ConnexionResponse: + return problem( + title=connexion.http_facts.HTTP_STATUS_CODES.get(ex.status_code), + detail=ex.detail, + headers={"Location": ex.detail}, + status=ex.status_code, + ) + + # in case of 404 and 405 we handle errors at the Flask APP level in order to have access to + # context and be able to render the error page for the UI + connexion_app.app.register_error_handler(404, _handle_api_not_found) + connexion_app.app.register_error_handler(405, _handle_api_method_not_allowed) + + # We should handle redirects at connexion_app level - the requests will be redirected to the target + # location - so they can return application/problem+json response with the Location header regardless + # ot the request path - does not matter if it is API or UI request + connexion_app.add_error_handler(301, _handle_redirect) + connexion_app.add_error_handler(302, _handle_redirect) + connexion_app.add_error_handler(307, _handle_redirect) + connexion_app.add_error_handler(308, _handle_redirect) + + # Everything else we handle at the connexion_app level by default error handler + connexion_app.add_error_handler(ProblemException, problem_error_handler) + + +def init_api_connexion(connexion_app: connexion.FlaskApp) -> None: """Initialize Stable API.""" - base_path = "/api/v1" - base_paths.append(base_path) - with ROOT_APP_DIR.joinpath("api_connexion", "openapi", "v1.yaml").open() as f: specification = safe_load(f) - api_bp = FlaskApi( + swagger_ui_options = SwaggerUIOptions( + swagger_ui=SWAGGER_ENABLED, + swagger_ui_template_dir=SWAGGER_BUNDLE, + ) + + connexion_app.add_api( specification=specification, resolver=_LazyResolver(), - base_path=base_path, - options={"swagger_ui": SWAGGER_ENABLED, "swagger_path": SWAGGER_BUNDLE.__fspath__()}, + base_path=BASE_URL + BaseAPIPaths.REST_API.value, + swagger_ui_options=swagger_ui_options, strict_validation=True, validate_responses=True, - validator_map={"body": _CustomErrorRequestBodyValidator}, - ).blueprint - api_bp.after_request(set_cors_headers_on_response) - - app.register_blueprint(api_bp) - app.extensions["csrf"].exempt(api_bp) + ) -def init_api_internal(app: Flask, standalone_api: bool = False) -> None: +def init_api_internal(connexion_app: connexion.FlaskApp, standalone_api: bool = False) -> None: """Initialize Internal API.""" if not standalone_api and not conf.getboolean("webserver", "run_internal_api", fallback=False): return - base_paths.append("/internal_api/v1") with ROOT_APP_DIR.joinpath("api_internal", "openapi", "internal_api_v1.yaml").open() as f: specification = safe_load(f) - api_bp = FlaskApi( + swagger_ui_options = SwaggerUIOptions( + swagger_ui=SWAGGER_ENABLED, + swagger_ui_template_dir=SWAGGER_BUNDLE, + ) + + connexion_app.add_api( specification=specification, - base_path="/internal_api/v1", - options={"swagger_ui": SWAGGER_ENABLED, "swagger_path": SWAGGER_BUNDLE.__fspath__()}, + base_path=BASE_URL + BaseAPIPaths.INTERNAL_API.value, + swagger_ui_options=swagger_ui_options, strict_validation=True, validate_responses=True, - ).blueprint - api_bp.after_request(set_cors_headers_on_response) - - app.register_blueprint(api_bp) - app.after_request_funcs.setdefault(api_bp.name, []).append(set_cors_headers_on_response) - app.extensions["csrf"].exempt(api_bp) + ) def init_api_experimental(app): @@ -324,16 +356,50 @@ def init_api_experimental(app): RemovedInAirflow3Warning, stacklevel=2, ) - base_paths.append("/api/experimental") - app.register_blueprint(endpoints.api_experimental, url_prefix="/api/experimental") + app.register_blueprint( + endpoints.api_experimental, url_prefix=BASE_URL + BaseAPIPaths.EXPERIMENTAL_API.value + ) app.extensions["csrf"].exempt(endpoints.api_experimental) -def init_api_auth_provider(app): +def init_api_auth_manager(connexion_app: connexion.FlaskApp): """Initialize the API offered by the auth manager.""" - auth_mgr = get_auth_manager() - blueprint = auth_mgr.get_api_endpoints() - if blueprint: - base_paths.append(blueprint.url_prefix) - app.register_blueprint(blueprint) - app.extensions["csrf"].exempt(blueprint) + global auth_mgr_mount_point + try: + auth_mgr_mount_point, specification = get_auth_manager().get_auth_manager_api_specification() + except NotImplementedError: + log.warning( + "Your Auth manager does not have a `get_auth_manager_api_specification` method which" + "means that it does not provide additional API. You can implement this method and " + "return None, {} tuple to get rid of this warning." + ) + return + if not auth_mgr_mount_point: + return + swagger_ui_options = SwaggerUIOptions( + swagger_ui=conf.getboolean("webserver", "enable_swagger_ui", fallback=True), + swagger_ui_template_dir=SWAGGER_BUNDLE, + ) + from airflow.www.extensions.init_views import BASE_URL, _LazyResolver + + connexion_app.add_api( + specification=specification, + resolver=_LazyResolver(), + base_path=BASE_URL + auth_mgr_mount_point, + swagger_ui_options=swagger_ui_options, + strict_validation=True, + validate_responses=True, + ) + + +def init_cors_middleware(connexion_app: connexion.FlaskApp): + from starlette.middleware.cors import CORSMiddleware + + connexion_app.add_middleware( + CORSMiddleware, + connexion.middleware.MiddlewarePosition.BEFORE_ROUTING, + allow_origins=conf.get("api", "access_control_allow_origins"), + allow_credentials=True, + allow_methods=conf.get("api", "access_control_allow_methods"), + allow_headers=conf.get("api", "access_control_allow_headers"), + ) diff --git a/airflow/www/package.json b/airflow/www/package.json index fd2a275198e5c..5783bdbe35bfe 100644 --- a/airflow/www/package.json +++ b/airflow/www/package.json @@ -142,7 +142,7 @@ "reactflow": "^11.7.4", "redoc": "^2.0.0-rc.72", "remark-gfm": "^3.0.1", - "swagger-ui-dist": "4.1.3", + "swagger-ui-dist": "5.11.8", "tsconfig-paths": "^3.14.2", "type-fest": "^2.17.0", "url-search-params-polyfill": "^8.1.0", diff --git a/airflow/www/static/js/types/api-generated.ts b/airflow/www/static/js/types/api-generated.ts index 1b82d07835ab2..a7b631bb1e5a3 100644 --- a/airflow/www/static/js/types/api-generated.ts +++ b/airflow/www/static/js/types/api-generated.ts @@ -626,8 +626,8 @@ export interface paths { * response = self.client.get( * request_url, * query_string={"token": token}, - * headers={"Accept": "text/plain"}, - * environ_overrides={"REMOTE_USER": "test"}, + * headers={"Accept": "text/plain","REMOTE_USER": "test"}, + * * ) * continuation_token = response.json["continuation_token"] * metadata = URLSafeSerializer(key).loads(continuation_token) @@ -725,6 +725,12 @@ export interface paths { "/datasets": { get: operations["get_datasets"]; }; + "/datasets/events": { + /** Get dataset events */ + get: operations["get_dataset_events"]; + /** Create dataset event */ + post: operations["create_dataset_event"]; + }; "/datasets/{uri}": { /** Get a dataset by uri. */ get: operations["get_dataset"]; @@ -735,12 +741,6 @@ export interface paths { }; }; }; - "/datasets/events": { - /** Get dataset events */ - get: operations["get_dataset_events"]; - /** Create dataset event */ - post: operations["create_dataset_event"]; - }; "/config": { get: operations["get_config"]; }; @@ -3845,7 +3845,11 @@ export interface operations { }; responses: { /** Success. */ - 204: never; + 204: { + content: { + "text/html": string; + }; + }; 400: components["responses"]["BadRequest"]; 401: components["responses"]["Unauthenticated"]; 403: components["responses"]["PermissionDenied"]; @@ -4333,7 +4337,11 @@ export interface operations { }; responses: { /** Success. */ - 204: never; + 204: { + content: { + "text/html": string; + }; + }; 400: components["responses"]["BadRequest"]; 401: components["responses"]["Unauthenticated"]; 403: components["responses"]["PermissionDenied"]; @@ -4488,8 +4496,8 @@ export interface operations { * response = self.client.get( * request_url, * query_string={"token": token}, - * headers={"Accept": "text/plain"}, - * environ_overrides={"REMOTE_USER": "test"}, + * headers={"Accept": "text/plain","REMOTE_USER": "test"}, + * * ) * continuation_token = response.json["continuation_token"] * metadata = URLSafeSerializer(key).loads(continuation_token) @@ -4636,7 +4644,7 @@ export interface operations { "application/json": { content?: string; }; - "plain/text": string; + "text/plain": string; }; }; 401: components["responses"]["Unauthenticated"]; @@ -4711,26 +4719,6 @@ export interface operations { 403: components["responses"]["PermissionDenied"]; }; }; - /** Get a dataset by uri. */ - get_dataset: { - parameters: { - path: { - /** The encoded Dataset URI */ - uri: components["parameters"]["DatasetURI"]; - }; - }; - responses: { - /** Success. */ - 200: { - content: { - "application/json": components["schemas"]["Dataset"]; - }; - }; - 401: components["responses"]["Unauthenticated"]; - 403: components["responses"]["PermissionDenied"]; - 404: components["responses"]["NotFound"]; - }; - }; /** Get dataset events */ get_dataset_events: { parameters: { @@ -4790,6 +4778,26 @@ export interface operations { }; }; }; + /** Get a dataset by uri. */ + get_dataset: { + parameters: { + path: { + /** The encoded Dataset URI */ + uri: components["parameters"]["DatasetURI"]; + }; + }; + responses: { + /** Success. */ + 200: { + content: { + "application/json": components["schemas"]["Dataset"]; + }; + }; + 401: components["responses"]["Unauthenticated"]; + 403: components["responses"]["PermissionDenied"]; + 404: components["responses"]["NotFound"]; + }; + }; get_config: { parameters: { query: { @@ -5686,15 +5694,15 @@ export type GetDagWarningsVariables = CamelCasedPropertiesDeep< export type GetDatasetsVariables = CamelCasedPropertiesDeep< operations["get_datasets"]["parameters"]["query"] >; -export type GetDatasetVariables = CamelCasedPropertiesDeep< - operations["get_dataset"]["parameters"]["path"] ->; export type GetDatasetEventsVariables = CamelCasedPropertiesDeep< operations["get_dataset_events"]["parameters"]["query"] >; export type CreateDatasetEventVariables = CamelCasedPropertiesDeep< operations["create_dataset_event"]["requestBody"]["content"]["application/json"] >; +export type GetDatasetVariables = CamelCasedPropertiesDeep< + operations["get_dataset"]["parameters"]["path"] +>; export type GetConfigVariables = CamelCasedPropertiesDeep< operations["get_config"]["parameters"]["query"] >; diff --git a/airflow/www/views.py b/airflow/www/views.py index 8311050bfff56..0ce15e3475143 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -3373,7 +3373,6 @@ def historical_metrics_data(self): """Return cluster activity historical metrics.""" start_date = _safe_parse_datetime(request.args.get("start_date")) end_date = _safe_parse_datetime(request.args.get("end_date")) - with create_session() as session: # DagRuns dag_run_types = session.execute( @@ -3685,14 +3684,14 @@ def parse_dag(self, file_token: str): from airflow.api_connexion.endpoints.dag_parsing import reparse_dag_file with create_session() as session: - response = reparse_dag_file(file_token=file_token, session=session) + api_response = reparse_dag_file(file_token=file_token, session=session) response_messages = { 201: ["Reparsing request submitted successfully", "info"], 401: ["Unauthenticated request", "error"], 403: ["Permission Denied", "error"], 404: ["DAG not found", "error"], } - flash(response_messages[response.status_code][0], response_messages[response.status_code][1]) + flash(response_messages[api_response[1]][0], response_messages[api_response[1]][1]) redirect_url = get_safe_url(request.values.get("redirect_url")) return redirect(redirect_url) @@ -3775,10 +3774,21 @@ class RedocView(AirflowBaseView): @expose("/redoc") def redoc(self): """Redoc API documentation.""" - openapi_spec_url = url_for("/api/v1./api/v1_openapi_yaml") + openapi_spec_url = "api/v1/openapi.yaml" return self.render_template("airflow/redoc.html", openapi_spec_url=openapi_spec_url) +class SwaggerView(AirflowBaseView): + """Swagger UI View.""" + + default_view = "swagger" + + @expose("/swagger") + def swagger(self): + """Redoc API documentation.""" + return redirect("api/v1/ui/") + + ###################################################################################### # ModelViews ###################################################################################### diff --git a/airflow/www/yarn.lock b/airflow/www/yarn.lock index b1ab3485d04de..c8191de10c6e2 100644 --- a/airflow/www/yarn.lock +++ b/airflow/www/yarn.lock @@ -10534,6 +10534,18 @@ safe-regex-test@^1.0.0: resolved "https://registry.yarnpkg.com/safer-buffer/-/safer-buffer-2.1.2.tgz#44fa161b0187b9549dd84bb91802f9bd8385cd6a" integrity sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg== +sanitize-html@^2.12.1: + version "2.12.1" + resolved "https://registry.yarnpkg.com/sanitize-html/-/sanitize-html-2.12.1.tgz#280a0f5c37305222921f6f9d605be1f6558914c7" + integrity sha512-Plh+JAn0UVDpBRP/xEjsk+xDCoOvMBwQUf/K+/cBAVuTbtX8bj2VB7S1sL1dssVpykqp0/KPSesHrqXtokVBpA== + dependencies: + deepmerge "^4.2.2" + escape-string-regexp "^4.0.0" + htmlparser2 "^8.0.0" + is-plain-object "^5.0.0" + parse-srcset "^1.0.2" + postcss "^8.3.11" + sax@^1.2.4: version "1.2.4" resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9" @@ -11163,10 +11175,10 @@ svgo@^2.7.0: picocolors "^1.0.0" stable "^0.1.8" -swagger-ui-dist@4.1.3: - version "4.1.3" - resolved "https://registry.yarnpkg.com/swagger-ui-dist/-/swagger-ui-dist-4.1.3.tgz#2be9f9de9b5c19132fa4a5e40933058c151563dc" - integrity sha512-WvfPSfAAMlE/sKS6YkW47nX/hA7StmhYnAHc6wWCXNL0oclwLj6UXv0hQCkLnDgvebi0MEV40SJJpVjKUgH1IQ== +swagger-ui-dist@5.11.8: + version "5.11.8" + resolved "https://registry.yarnpkg.com/swagger-ui-dist/-/swagger-ui-dist-5.11.8.tgz#5f92f1f4ca979a5df847da5df180c8b10ccc3e0c" + integrity sha512-IfPtCPdf6opT5HXrzHO4kjL1eco0/8xJCtcs7ilhKuzatrpF2j9s+3QbOag6G3mVFKf+g+Ca5UG9DquVUs2obA== swagger2openapi@^7.0.6: version "7.0.6" diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py index c0c29a933b37b..09f92d40cec1c 100644 --- a/dev/breeze/src/airflow_breeze/global_constants.py +++ b/dev/breeze/src/airflow_breeze/global_constants.py @@ -505,7 +505,7 @@ def get_airflow_extras(): { "python-version": "3.8", "airflow-version": "2.9.1", - "remove-providers": "", + "remove-providers": "fab", # TODO fix tests! "run-tests": "true", }, ] diff --git a/docs/apache-airflow/core-concepts/auth-manager.rst b/docs/apache-airflow/core-concepts/auth-manager.rst index 521264fd78ba7..b044972a9ba4a 100644 --- a/docs/apache-airflow/core-concepts/auth-manager.rst +++ b/docs/apache-airflow/core-concepts/auth-manager.rst @@ -163,7 +163,9 @@ Auth managers may vend CLI commands which will be included in the ``airflow`` co Rest API ^^^^^^^^ -Auth managers may vend Rest API endpoints which will be included in the :doc:`/stable-rest-api-ref` by implementing the ``get_api_endpoints`` method. The endpoints can be used to manage resources such as users, groups, roles (if any) handled by your auth manager. Endpoints are only vended for the currently configured auth manager. +Auth managers may vend Rest API endpoints which will be included in the :doc:`/stable-rest-api-ref` by implementing the ``set_api_endpoints`` method (Airflow 2.9) +or ``get_auth_manager_api_specification`` (Airflow 2.10+). +The endpoints can be used to manage resources such as users, groups, roles (if any) handled by your auth manager. Endpoints are only vended for the currently configured auth manager. Next Steps ^^^^^^^^^^ diff --git a/hatch_build.py b/hatch_build.py index a51689d9e2679..6bad88e7138c0 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -424,7 +424,7 @@ # The usage was added in #30596, seemingly only to override and improve the default error message. # Either revert that change or find another way, preferably without using connexion internals. # This limit can be removed after https://github.com/apache/airflow/issues/35234 is fixed - "connexion[flask]>=2.14.2,<3.0", + "connexion[flask,uvicorn]>=3.0", "cron-descriptor>=1.2.24", "croniter>=2.0.2", "cryptography>=41.0.0", @@ -486,6 +486,7 @@ # The issue tracking it is https://github.com/apache/airflow/issues/28723 "sqlalchemy>=1.4.36,<2.0", "sqlalchemy-jsonfield>=1.0", + "starlette>=0.37.1", "tabulate>=0.7.5", "tenacity>=8.0.0,!=8.2.0", "termcolor>=1.1.0", diff --git a/newsfragments/37638.significant.rst b/newsfragments/37638.significant.rst new file mode 100644 index 0000000000000..7e498df5bb617 --- /dev/null +++ b/newsfragments/37638.significant.rst @@ -0,0 +1,4 @@ +Replaced test_should_respond_400_on_invalid_request with test_ignore_read_only_fields in the test_dag_endpoint.py. + +Connexion V3 request body validator doesn't raise the read-only property error and just ignore the read-only field. +You can find the detail about the change `here `_ diff --git a/pyproject.toml b/pyproject.toml index b9a84382481b8..d384a2ec74994 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -382,6 +382,7 @@ combine-as-imports = true "airflow/security/kerberos.py" = ["E402"] "airflow/security/utils.py" = ["E402"] "tests/providers/common/io/xcom/test_backend.py" = ["E402"] +"tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py" = ["E402"] "tests/providers/elasticsearch/log/elasticmock/__init__.py" = ["E402"] "tests/providers/elasticsearch/log/elasticmock/utilities/__init__.py" = ["E402"] "tests/providers/google/cloud/hooks/vertex_ai/test_batch_prediction_job.py" = ["E402"] diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py index abd09fa1c02ec..53af22095d4e4 100644 --- a/tests/api_connexion/conftest.py +++ b/tests/api_connexion/conftest.py @@ -21,6 +21,7 @@ from airflow.www import app from tests.test_utils.config import conf_vars from tests.test_utils.decorators import dont_initialize_flask_app_submodules +from tests.test_utils.mock_cors_middeleware import init_mock_cors_middleware @pytest.fixture(scope="session") @@ -30,6 +31,7 @@ def minimal_app_for_api(): "init_appbuilder", "init_api_experimental_auth", "init_api_connexion", + "init_jinja_globals", "init_api_error_handlers", "init_airflow_session_interface", "init_appbuilder_views", @@ -37,8 +39,33 @@ def minimal_app_for_api(): ) def factory(): with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}): - _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore - _app.config["AUTH_ROLE_PUBLIC"] = None + _app = app.create_connexion_app( + testing=True, + config={"WTF_CSRF_ENABLED": False, "AUTH_ROLE_PUBLIC": None}, + ) # type:ignore + init_mock_cors_middleware(_app, allow_origins=["http://apache.org", "http://example.com"]) + return _app + + return factory() + + +@pytest.fixture(scope="session") +def minimal_app_for_api_cors_allow_all(): + @dont_initialize_flask_app_submodules( + skip_all_except=[ + "init_appbuilder", + "init_api_experimental_auth", + "init_api_connexion", + "init_jinja_globals", + "init_api_error_handlers", + "init_airflow_session_interface", + "init_appbuilder_views", + ] + ) + def factory(): + with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}): + _app = app.create_connexion_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore + init_mock_cors_middleware(_app, allow_origins=["*"]) return _app return factory() @@ -63,9 +90,9 @@ def dagbag(): @pytest.fixture def set_auto_role_public(request): app = request.getfixturevalue("minimal_app_for_api") - auto_role_public = app.config["AUTH_ROLE_PUBLIC"] - app.config["AUTH_ROLE_PUBLIC"] = request.param + auto_role_public = app.app.config["AUTH_ROLE_PUBLIC"] + app.app.config["AUTH_ROLE_PUBLIC"] = request.param yield - app.config["AUTH_ROLE_PUBLIC"] = auto_role_public + app.app.config["AUTH_ROLE_PUBLIC"] = auto_role_public diff --git a/tests/api_connexion/endpoints/test_config_endpoint.py b/tests/api_connexion/endpoints/test_config_endpoint.py index 3dd5814e5d79e..8384f8fb061cf 100644 --- a/tests/api_connexion/endpoints/test_config_endpoint.py +++ b/tests/api_connexion/endpoints/test_config_endpoint.py @@ -22,7 +22,7 @@ import pytest from airflow.security import permissions -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -49,33 +49,31 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore with conf_vars({("webserver", "expose_config"): "True"}): yield minimal_app_for_api - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestGetConfig: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_text_plain(self, mock_as_dict): - response = self.client.get( - "/api/v1/config", headers={"Accept": "text/plain"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/config", headers={"Accept": "text/plain", "REMOTE_USER": "test"}) mock_as_dict.assert_called_with(display_source=False, display_sensitive=True) assert response.status_code == 200 expected = textwrap.dedent( @@ -88,14 +86,12 @@ def test_should_respond_200_text_plain(self, mock_as_dict): smtp_mail_from = airflow@example.com """ ) - assert expected == response.data.decode() + assert expected == response.text @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) @conf_vars({("webserver", "expose_config"): "non-sensitive-only"}) def test_should_respond_200_text_plain_with_non_sensitive_only(self, mock_as_dict): - response = self.client.get( - "/api/v1/config", headers={"Accept": "text/plain"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/config", headers={"Accept": "text/plain", "REMOTE_USER": "test"}) mock_as_dict.assert_called_with(display_source=False, display_sensitive=False) assert response.status_code == 200 expected = textwrap.dedent( @@ -108,14 +104,13 @@ def test_should_respond_200_text_plain_with_non_sensitive_only(self, mock_as_dic smtp_mail_from = airflow@example.com """ ) - assert expected == response.data.decode() + assert expected == response.text @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_application_json(self, mock_as_dict): response = self.client.get( "/api/v1/config", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) mock_as_dict.assert_called_with(display_source=False, display_sensitive=True) assert response.status_code == 200 @@ -136,14 +131,13 @@ def test_should_respond_200_application_json(self, mock_as_dict): }, ] } - assert expected == response.json + assert response.json() == expected @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_single_section_as_text_plain(self, mock_as_dict): response = self.client.get( "/api/v1/config?section=smtp", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) mock_as_dict.assert_called_with(display_source=False, display_sensitive=True) assert response.status_code == 200 @@ -154,14 +148,13 @@ def test_should_respond_200_single_section_as_text_plain(self, mock_as_dict): smtp_mail_from = airflow@example.com """ ) - assert expected == response.data.decode() + assert expected == response.text @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_single_section_as_json(self, mock_as_dict): response = self.client.get( "/api/v1/config?section=smtp", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) mock_as_dict.assert_called_with(display_source=False, display_sensitive=True) assert response.status_code == 200 @@ -176,38 +169,35 @@ def test_should_respond_200_single_section_as_json(self, mock_as_dict): }, ] } - assert expected == response.json + assert expected == response.json() @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_404_when_section_not_exist(self, mock_as_dict): response = self.client.get( "/api/v1/config?section=smtp1", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert "section=smtp1 not found." in response.json["detail"] + assert "section=smtp1 not found." in response.json()["detail"] @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_406(self, mock_as_dict): response = self.client.get( "/api/v1/config", - headers={"Accept": "application/octet-stream"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/octet-stream", "REMOTE_USER": "test"}, ) assert response.status_code == 406 def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/config", headers={"Accept": "application/json"}) - assert_401(response) + assert response.status_code == 401 def test_should_raises_403_unauthorized(self): response = self.client.get( "/api/v1/config", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"Accept": "application/json", "REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -216,11 +206,10 @@ def test_should_raises_403_unauthorized(self): def test_should_respond_403_when_expose_config_off(self): response = self.client.get( "/api/v1/config", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 403 - assert "chose not to expose" in response.json["detail"] + assert "chose not to expose" in response.json()["detail"] @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -236,15 +225,14 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c class TestGetValue: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_text_plain(self, mock_as_dict): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert response.status_code == 200 expected = textwrap.dedent( @@ -253,7 +241,7 @@ def test_should_respond_200_text_plain(self, mock_as_dict): smtp_mail_from = airflow@example.com """ ) - assert expected == response.data.decode() + assert expected == response.text @patch( "airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", @@ -272,8 +260,7 @@ def test_should_respond_200_text_plain(self, mock_as_dict): def test_should_respond_200_text_plain_with_non_sensitive_only(self, mock_as_dict, section, option): response = self.client.get( f"/api/v1/config/section/{section}/option/{option}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert response.status_code == 200 expected = textwrap.dedent( @@ -282,14 +269,13 @@ def test_should_respond_200_text_plain_with_non_sensitive_only(self, mock_as_dic {option} = < hidden > """ ) - assert expected == response.data.decode() + assert expected == response.text @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_application_json(self, mock_as_dict): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 200 expected = { @@ -302,25 +288,23 @@ def test_should_respond_200_application_json(self, mock_as_dict): }, ] } - assert expected == response.json + assert expected == response.json() @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_404_when_option_not_exist(self, mock_as_dict): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from1", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert "The option [smtp/smtp_mail_from1] is not found in config." in response.json["detail"] + assert "The option [smtp/smtp_mail_from1] is not found in config." in response.json()["detail"] @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_406(self, mock_as_dict): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from", - headers={"Accept": "application/octet-stream"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/octet-stream", "REMOTE_USER": "test"}, ) assert response.status_code == 406 @@ -329,13 +313,12 @@ def test_should_raises_401_unauthenticated(self): "/api/v1/config/section/smtp/option/smtp_mail_from", headers={"Accept": "application/json"} ) - assert_401(response) + assert response.status_code == 401 def test_should_raises_403_unauthorized(self): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"Accept": "application/json", "REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -344,11 +327,10 @@ def test_should_raises_403_unauthorized(self): def test_should_respond_403_when_expose_config_off(self): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 403 - assert "chose not to expose" in response.json["detail"] + assert "chose not to expose" in response.json()["detail"] @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py b/tests/api_connexion/endpoints/test_connection_endpoint.py index c88b8a56de9d5..8a209af3a1b93 100644 --- a/tests/api_connexion/endpoints/test_connection_endpoint.py +++ b/tests/api_connexion/endpoints/test_connection_endpoint.py @@ -26,7 +26,7 @@ from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.security import permissions from airflow.utils.session import provide_session -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_connections from tests.test_utils.www import _check_last_log @@ -36,9 +36,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -48,19 +48,19 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestConnectionEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore # we want only the connection created here for this test clear_db_connections(False) @@ -81,20 +81,16 @@ def test_delete_should_respond_204(self, session): session.commit() conn = session.query(Connection).all() assert len(conn) == 1 - response = self.client.delete( - "/api/v1/connections/test-connection", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.delete("/api/v1/connections/test-connection", headers={"REMOTE_USER": "test"}) assert response.status_code == 204 connection = session.query(Connection).all() assert len(connection) == 0 _check_last_log(session, dag_id=None, event="api.connection.delete", execution_date=None) def test_delete_should_respond_404(self): - response = self.client.delete( - "/api/v1/connections/test-connection", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.delete("/api/v1/connections/test-connection", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "The Connection with connection_id: `test-connection` was not found", "status": 404, "title": "Connection not found", @@ -104,11 +100,11 @@ def test_delete_should_respond_404(self): def test_should_raises_401_unauthenticated(self): response = self.client.delete("/api/v1/connections/test-connection") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( - "/api/v1/connections/test-connection-id", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/api/v1/connections/test-connection-id", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 @@ -145,11 +141,9 @@ def test_should_respond_200(self, session): session.commit() result = session.query(Connection).all() assert len(result) == 1 - response = self.client.get( - "/api/v1/connections/test-connection-id", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/connections/test-connection-id", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "connection_id": "test-connection-id", "conn_type": "mysql", "description": "test description", @@ -171,28 +165,24 @@ def test_should_mask_sensitive_values_in_extra(self, session): session.add(connection_model) session.commit() - response = self.client.get( - "/api/v1/connections/test-connection-id", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/connections/test-connection-id", headers={"REMOTE_USER": "test"}) - assert response.json["extra"] == '{"nonsensitive": "just_a_value", "api_token": "***"}' + assert response.json()["extra"] == '{"nonsensitive": "just_a_value", "api_token": "***"}' def test_should_respond_404(self): - response = self.client.get( - "/api/v1/connections/invalid-connection", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/connections/invalid-connection", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "The Connection with connection_id: `invalid-connection` was not found", "status": 404, "title": "Connection not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/connections/test-connection-id") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -229,9 +219,9 @@ def test_should_respond_200(self, session): session.commit() result = session.query(Connection).all() assert len(result) == 2 - response = self.client.get("/api/v1/connections", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/connections", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "connections": [ { "connection_id": "test-connection-id-1", @@ -264,11 +254,11 @@ def test_should_respond_200_with_order_by(self, session): result = session.query(Connection).all() assert len(result) == 2 response = self.client.get( - "/api/v1/connections?order_by=-connection_id", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/connections?order_by=-connection_id", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 # Using - means descending - assert response.json == { + assert response.json() == { "connections": [ { "connection_id": "test-connection-id-2", @@ -295,7 +285,7 @@ def test_should_respond_200_with_order_by(self, session): def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/connections") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -352,10 +342,10 @@ def test_handle_limit_offset(self, url, expected_conn_ids, session): connections = self._create_connections(10) session.add_all(connections) session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - conn_ids = [conn["connection_id"] for conn in response.json["connections"] if conn] + assert response.json()["total_entries"] == 10 + conn_ids = [conn["connection_id"] for conn in response.json()["connections"] if conn] assert conn_ids == expected_conn_ids def test_should_respect_page_size_limit_default(self, session): @@ -363,23 +353,21 @@ def test_should_respect_page_size_limit_default(self, session): session.add_all(connection_models) session.commit() - response = self.client.get("/api/v1/connections", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/connections", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["connections"]) == 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["connections"]) == 100 def test_invalid_order_by_raises_400(self, session): connection_models = self._create_connections(200) session.add_all(connection_models) session.commit() - response = self.client.get( - "/api/v1/connections?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/connections?order_by=invalid", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 assert ( - response.json["detail"] == "Ordering with 'invalid' is disallowed or" + response.json()["detail"] == "Ordering with 'invalid' is disallowed or" " the attribute does not exist on the model" ) @@ -388,11 +376,11 @@ def test_limit_of_zero_should_return_default(self, session): session.add_all(connection_models) session.commit() - response = self.client.get("/api/v1/connections?limit=0", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/connections?limit=0", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["connections"]) == 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["connections"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, session): @@ -400,9 +388,9 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session): session.add_all(connection_models) session.commit() - response = self.client.get("/api/v1/connections?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/connections?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["connections"]) == 150 + assert len(response.json()["connections"]) == 150 def _create_connections(self, count): return [ @@ -424,7 +412,7 @@ def test_patch_should_respond_200(self, payload, session): self._create_connection(session) response = self.client.patch( - "/api/v1/connections/test-connection-id", json=payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/connections/test-connection-id", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 _check_last_log(session, dag_id=None, event="api.connection.edit", execution_date=None) @@ -442,12 +430,12 @@ def test_patch_should_respond_200_with_update_mask(self, session): response = self.client.patch( "/api/v1/connections/test-connection-id?update_mask=port,login", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 connection = session.query(Connection).filter_by(conn_id=test_connection).first() assert connection.password is None - assert response.json == { + assert response.json() == { "connection_id": test_connection, # not updated "conn_type": "test_type", # Not updated "description": None, # Not updated @@ -513,10 +501,10 @@ def test_patch_should_respond_400_for_invalid_fields_in_update_mask( response = self.client.patch( f"/api/v1/connections/test-connection-id?{update_mask}", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == error_message + assert response.json()["detail"] == error_message @pytest.mark.parametrize( "payload, error_message", @@ -552,15 +540,15 @@ def test_patch_should_respond_400_for_invalid_fields_in_update_mask( def test_patch_should_respond_400_for_invalid_update(self, payload, error_message, session): self._create_connection(session) response = self.client.patch( - "/api/v1/connections/test-connection-id", json=payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/connections/test-connection-id", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 - assert error_message in response.json["detail"] + assert error_message in response.json()["detail"] def test_patch_should_respond_404_not_found(self): payload = {"connection_id": "test-connection-id", "conn_type": "test-type", "port": 90} response = self.client.patch( - "/api/v1/connections/test-connection-id", json=payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/connections/test-connection-id", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 assert { @@ -568,7 +556,7 @@ def test_patch_should_respond_404_not_found(self): "status": 404, "title": "Connection not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): self._create_connection(session) @@ -578,7 +566,7 @@ def test_should_raises_401_unauthenticated(self, session): json={"connection_id": "test-connection-id", "conn_type": "test_type", "extra": "{'key': 'var'}"}, ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -599,9 +587,7 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c class TestPostConnection(TestConnectionEndpoint): def test_post_should_respond_200(self, session): payload = {"connection_id": "test-connection-id", "conn_type": "test_type"} - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 connection = session.query(Connection).all() assert len(connection) == 1 @@ -612,11 +598,9 @@ def test_post_should_respond_200(self, session): def test_post_should_respond_200_extra_null(self, session): payload = {"connection_id": "test-connection-id", "conn_type": "test_type", "extra": None} - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["extra"] is None + assert response.json()["extra"] is None connection = session.query(Connection).all() assert len(connection) == 1 assert connection[0].conn_id == "test-connection-id" @@ -626,11 +610,9 @@ def test_post_should_respond_400_for_invalid_payload(self): payload = { "connection_id": "test-connection-id", } # conn_type missing - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'conn_type': ['Missing data for required field.']}", "status": 400, "title": "Bad Request", @@ -639,11 +621,9 @@ def test_post_should_respond_400_for_invalid_payload(self): def test_post_should_respond_400_for_invalid_conn_id(self): payload = {"connection_id": "****", "conn_type": "test_type"} - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "The key '****' has to be made of " "alphanumeric characters, dashes, dots and underscores exclusively", "status": 400, @@ -653,16 +633,12 @@ def test_post_should_respond_400_for_invalid_conn_id(self): def test_post_should_respond_409_already_exist(self): payload = {"connection_id": "test-connection-id", "conn_type": "test_type"} - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 # Another request - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 409 - assert response.json == { + assert response.json() == { "detail": "Connection already exist. ID: test-connection-id", "status": 409, "title": "Conflict", @@ -674,7 +650,7 @@ def test_should_raises_401_unauthenticated(self): "/api/v1/connections", json={"connection_id": "test-connection-id", "conn_type": "test_type"} ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -693,11 +669,9 @@ class TestConnection(TestConnectionEndpoint): @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) def test_should_respond_200(self): payload = {"connection_id": "test-connection-id", "conn_type": "sqlite"} - response = self.client.post( - "/api/v1/connections/test", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections/test", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "status": True, "message": "Connection successfully tested", } @@ -705,7 +679,7 @@ def test_should_respond_200(self): @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) def test_connection_env_is_cleaned_after_run(self): payload = {"connection_id": "test-connection-id", "conn_type": "sqlite"} - self.client.post("/api/v1/connections/test", json=payload, environ_overrides={"REMOTE_USER": "test"}) + self.client.post("/api/v1/connections/test", json=payload, headers={"REMOTE_USER": "test"}) assert not any([key.startswith(CONN_ENV_PREFIX) for key in os.environ.keys()]) @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) @@ -713,11 +687,9 @@ def test_post_should_respond_400_for_invalid_payload(self): payload = { "connection_id": "test-connection-id", } # conn_type missing - response = self.client.post( - "/api/v1/connections/test", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections/test", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'conn_type': ['Missing data for required field.']}", "status": 400, "title": "Bad Request", @@ -729,13 +701,11 @@ def test_should_raises_401_unauthenticated(self): "/api/v1/connections/test", json={"connection_id": "test-connection-id", "conn_type": "test_type"} ) - assert_401(response) + assert response.status_code == 401 def test_should_respond_403_by_default(self): payload = {"connection_id": "test-connection-id", "conn_type": "sqlite"} - response = self.client.post( - "/api/v1/connections/test", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections/test", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 403 assert response.text == ( "Testing connections is disabled in Airflow configuration. " diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index b514faba276d9..5804178a024e0 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -31,7 +31,7 @@ from airflow.security import permissions from airflow.utils.session import provide_session from airflow.utils.state import TaskInstanceState -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags from tests.test_utils.www import _check_last_log @@ -53,10 +53,10 @@ def current_file_token(url_safe_serializer) -> str: @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -65,13 +65,13 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user(app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "TEST_DAG_1", access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "TEST_DAG_1", access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, ) @@ -94,13 +94,13 @@ def configured_app(minimal_app_for_api): dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3} - app.dag_bag = dag_bag + connexion_app.app.dag_bag = dag_bag - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test_granular_permissions") # type: ignore class TestDagEndpoint: @@ -113,8 +113,9 @@ def clean_db(): @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: self.clean_db() - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = configured_app.app + self.client = self.connexion_app.test_client() # type:ignore self.dag_id = DAG_ID self.dag2_id = DAG2_ID self.dag3_id = DAG3_ID @@ -177,7 +178,7 @@ class TestGetDag(TestDagEndpoint): @conf_vars({("webserver", "secret_key"): "mysecret"}) def test_should_respond_200(self): self._create_dag_models(1) - response = self.client.get("/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/dags/TEST_DAG_1", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "dag_id": "TEST_DAG_1", @@ -208,7 +209,7 @@ def test_should_respond_200(self): "timetable_description": None, "has_import_errors": False, "pickle_id": None, - } == response.json + } == response.json() @conf_vars({("webserver", "secret_key"): "mysecret"}) def test_should_respond_200_with_schedule_interval_none(self, session): @@ -220,7 +221,7 @@ def test_should_respond_200_with_schedule_interval_none(self, session): ) session.add(dag_model) session.commit() - response = self.client.get("/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/dags/TEST_DAG_1", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "dag_id": "TEST_DAG_1", @@ -251,17 +252,17 @@ def test_should_respond_200_with_schedule_interval_none(self, session): "timetable_description": None, "has_import_errors": False, "pickle_id": None, - } == response.json + } == response.json() def test_should_respond_200_with_granular_dag_access(self): self._create_dag_models(1) response = self.client.get( - "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + "/api/v1/dags/TEST_DAG_1", headers={"REMOTE_USER": "test_granular_permissions"} ) assert response.status_code == 200 def test_should_respond_404(self): - response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/dags/INVALID_DAG", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 def test_should_raises_401_unauthenticated(self): @@ -269,18 +270,18 @@ def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/dags/TEST_DAG_1") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( - f"/api/v1/dags/{self.dag_id}/details", environ_overrides={"REMOTE_USER": "test_no_permissions"} + f"/api/v1/dags/{self.dag_id}/details", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 def test_should_respond_403_with_granular_access_for_different_dag(self): self._create_dag_models(3) response = self.client.get( - "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + "/api/v1/dags/TEST_DAG_2", headers={"REMOTE_USER": "test_granular_permissions"} ) assert response.status_code == 403 @@ -295,9 +296,9 @@ def test_should_respond_403_with_granular_access_for_different_dag(self): def test_should_return_specified_fields(self, fields): self._create_dag_models(1) response = self.client.get( - f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", headers={"REMOTE_USER": "test"} ) - res_json = response.json + res_json = response.json() assert len(res_json.keys()) == len(fields) for field in fields: assert field in res_json @@ -313,7 +314,7 @@ def test_should_return_specified_fields(self, fields): def test_should_respond_400_with_not_exists_fields(self, fields): self._create_dag_models(1) response = self.client.get( - f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", headers={"REMOTE_USER": "test"} ) assert response.status_code == 400, f"Current code: {response.status_code}" @@ -340,11 +341,9 @@ class TestGetDagDetails(TestDagEndpoint): def test_should_respond_200(self, url_safe_serializer): self._create_dag_model_for_details_endpoint(self.dag_id) current_file_token = url_safe_serializer.dumps("/tmp/dag.py") - response = self.client.get( - f"/api/v1/dags/{self.dag_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - last_parsed = response.json["last_parsed"] + last_parsed = response.json()["last_parsed"] expected = { "catchup": True, "concurrency": 16, @@ -396,16 +395,14 @@ def test_should_respond_200(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - assert response.json == expected + assert response.json() == expected def test_should_respond_200_with_dataset_expression(self, url_safe_serializer): self._create_dag_model_for_details_endpoint_with_dataset_expression(self.dag_id) current_file_token = url_safe_serializer.dumps("/tmp/dag.py") - response = self.client.get( - f"/api/v1/dags/{self.dag_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - last_parsed = response.json["last_parsed"] + last_parsed = response.json()["last_parsed"] expected = { "catchup": True, "concurrency": 16, @@ -462,16 +459,14 @@ def test_should_respond_200_with_dataset_expression(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - assert response.json == expected + assert response.json() == expected def test_should_response_200_with_doc_md_none(self, url_safe_serializer): current_file_token = url_safe_serializer.dumps("/tmp/dag.py") self._create_dag_model_for_details_endpoint(self.dag2_id) - response = self.client.get( - f"/api/v1/dags/{self.dag2_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag2_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - last_parsed = response.json["last_parsed"] + last_parsed = response.json()["last_parsed"] expected = { "catchup": True, "concurrency": 16, @@ -516,16 +511,14 @@ def test_should_response_200_with_doc_md_none(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - assert response.json == expected + assert response.json() == expected def test_should_response_200_for_null_start_date(self, url_safe_serializer): current_file_token = url_safe_serializer.dumps("/tmp/dag.py") self._create_dag_model_for_details_endpoint(self.dag3_id) - response = self.client.get( - f"/api/v1/dags/{self.dag3_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag3_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - last_parsed = response.json["last_parsed"] + last_parsed = response.json()["last_parsed"] expected = { "catchup": True, "concurrency": 16, @@ -570,17 +563,17 @@ def test_should_response_200_for_null_start_date(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - assert response.json == expected + assert response.json() == expected def test_should_respond_200_serialized(self, url_safe_serializer): current_file_token = url_safe_serializer.dumps("/tmp/dag.py") self._create_dag_model_for_details_endpoint(self.dag_id) # Get the dag out of the dagbag before we patch it to an empty one - SerializedDagModel.write_dag(self.app.dag_bag.get_dag(self.dag_id)) + SerializedDagModel.write_dag(self.flask_app.dag_bag.get_dag(self.dag_id)) # Create empty app with empty dagbag to check if DAG is read from db dag_bag = DagBag(os.devnull, include_examples=False, read_dags_from_db=True) - patcher = unittest.mock.patch.object(self.app, "dag_bag", dag_bag) + patcher = unittest.mock.patch.object(self.flask_app, "dag_bag", dag_bag) patcher.start() expected = { @@ -633,19 +626,15 @@ def test_should_respond_200_serialized(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - response = self.client.get( - f"/api/v1/dags/{self.dag_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - expected.update({"last_parsed": response.json["last_parsed"]}) - assert response.json == expected + expected.update({"last_parsed": response.json()["last_parsed"]}) + assert response.json() == expected patcher.stop() - response = self.client.get( - f"/api/v1/dags/{self.dag_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 expected = { "catchup": True, @@ -697,20 +686,20 @@ def test_should_respond_200_serialized(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - expected.update({"last_parsed": response.json["last_parsed"]}) - assert response.json == expected + expected.update({"last_parsed": response.json()["last_parsed"]}) + assert response.json() == expected def test_should_raises_401_unauthenticated(self): response = self.client.get(f"/api/v1/dags/{self.dag_id}/details") - assert_401(response) + assert response.status_code == 401 def test_should_raise_404_when_dag_is_not_found(self): response = self.client.get( - "/api/v1/dags/non_existing_dag_id/details", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/dags/non_existing_dag_id/details", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "The DAG with dag_id: non_existing_dag_id was not found", "status": 404, "title": "DAG not found", @@ -729,10 +718,10 @@ def test_should_return_specified_fields(self, fields): self._create_dag_model_for_details_endpoint(self.dag2_id) response = self.client.get( f"/api/v1/dags/{self.dag2_id}/details?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - res_json = response.json + res_json = response.json() assert len(res_json.keys()) == len(fields) for field in fields: assert field in res_json @@ -742,7 +731,7 @@ def test_should_respond_400_with_not_exists_fields(self): self._create_dag_model_for_details_endpoint(self.dag2_id) response = self.client.get( f"/api/v1/dags/{self.dag2_id}/details?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400, f"Current code: {response.status_code}" @@ -768,7 +757,7 @@ def test_should_respond_200(self, session, url_safe_serializer): dags_query = session.query(DagModel).filter(~DagModel.is_subdag) assert len(dags_query.all()) == 3 - response = self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") file_token2 = url_safe_serializer.dumps("/tmp/dag_2.py") @@ -843,12 +832,12 @@ def test_should_respond_200(self, session, url_safe_serializer): }, ], "total_entries": 2, - } == response.json + } == response.json() def test_only_active_true_returns_active_dags(self, url_safe_serializer): self._create_dag_models(1) self._create_deactivated_dag() - response = self.client.get("api/v1/dags?only_active=True", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags?only_active=True", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") assert response.status_code == 200 assert { @@ -888,12 +877,12 @@ def test_only_active_true_returns_active_dags(self, url_safe_serializer): } ], "total_entries": 1, - } == response.json + } == response.json() def test_only_active_false_returns_all_dags(self, url_safe_serializer): self._create_dag_models(1) self._create_deactivated_dag() - response = self.client.get("api/v1/dags?only_active=False", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags?only_active=False", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") file_token_2 = url_safe_serializer.dumps("/tmp/dag_del_1.py") assert response.status_code == 200 @@ -967,7 +956,7 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer): }, ], "total_entries": 2, - } == response.json + } == response.json() @pytest.mark.parametrize( "url, expected_dag_ids", @@ -989,9 +978,9 @@ def test_filter_dags_by_tags_works(self, url, expected_dag_ids): dag3.sync_to_db() dag4.sync_to_db() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dag_ids = [dag["dag_id"] for dag in response.json["dags"]] + dag_ids = [dag["dag_id"] for dag in response.json()["dags"]] assert expected_dag_ids == dag_ids @@ -1017,20 +1006,18 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids): dag3.sync_to_db() dag4.sync_to_db() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dag_ids = {dag["dag_id"] for dag in response.json["dags"]} + dag_ids = {dag["dag_id"] for dag in response.json()["dags"]} assert expected_dag_ids == dag_ids def test_should_respond_200_with_granular_dag_access(self): self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) + response = self.client.get("/api/v1/dags", headers={"REMOTE_USER": "test_granular_permissions"}) assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" + assert len(response.json()["dags"]) == 1 + assert response.json()["dags"][0]["dag_id"] == "TEST_DAG_1" @pytest.mark.parametrize( "url, expected_dag_ids", @@ -1064,41 +1051,41 @@ def test_should_respond_200_with_granular_dag_access(self): def test_should_respond_200_and_handle_pagination(self, url, expected_dag_ids): self._create_dag_models(10) - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dag_ids = [dag["dag_id"] for dag in response.json["dags"]] + dag_ids = [dag["dag_id"] for dag in response.json()["dags"]] assert expected_dag_ids == dag_ids - assert 10 == response.json["total_entries"] + assert 10 == response.json()["total_entries"] def test_should_respond_200_default_limit(self): self._create_dag_models(101) - response = self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert 100 == len(response.json["dags"]) - assert 101 == response.json["total_entries"] + assert 100 == len(response.json()["dags"]) + assert 101 == response.json()["total_entries"] def test_should_raises_401_unauthenticated(self): response = self.client.get("api/v1/dags") - assert_401(response) + assert response.status_code == 401 def test_should_respond_403_unauthorized(self): self._create_dag_models(1) - response = self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test_no_permissions"}) + response = self.client.get("api/v1/dags", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 def test_paused_true_returns_paused_dags(self, url_safe_serializer): self._create_dag_models(1, dag_id_prefix="TEST_DAG_PAUSED", is_paused=True) self._create_dag_models(1, dag_id_prefix="TEST_DAG_UNPAUSED", is_paused=False) - response = self.client.get("api/v1/dags?paused=True", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags?paused=True", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") assert response.status_code == 200 assert { @@ -1138,12 +1125,12 @@ def test_paused_true_returns_paused_dags(self, url_safe_serializer): } ], "total_entries": 1, - } == response.json + } == response.json() def test_paused_false_returns_unpaused_dags(self, url_safe_serializer): self._create_dag_models(1, dag_id_prefix="TEST_DAG_PAUSED", is_paused=True) self._create_dag_models(1, dag_id_prefix="TEST_DAG_UNPAUSED", is_paused=False) - response = self.client.get("api/v1/dags?paused=False", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags?paused=False", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") assert response.status_code == 200 assert { @@ -1183,12 +1170,12 @@ def test_paused_false_returns_unpaused_dags(self, url_safe_serializer): } ], "total_entries": 1, - } == response.json + } == response.json() def test_paused_none_returns_all_dags(self, url_safe_serializer): self._create_dag_models(1, dag_id_prefix="TEST_DAG_PAUSED", is_paused=True) self._create_dag_models(1, dag_id_prefix="TEST_DAG_UNPAUSED", is_paused=False) - response = self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") assert response.status_code == 200 assert { @@ -1261,19 +1248,17 @@ def test_paused_none_returns_all_dags(self, url_safe_serializer): }, ], "total_entries": 2, - } == response.json + } == response.json() def test_should_return_specified_fields(self): self._create_dag_models(2) self._create_deactivated_dag() fields = ["dag_id", "file_token", "owners"] - response = self.client.get( - f"api/v1/dags?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"api/v1/dags?fields={','.join(fields)}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - res_json = response.json + res_json = response.json() for dag in res_json["dags"]: assert len(dag.keys()) == len(fields) for field in fields: @@ -1283,9 +1268,7 @@ def test_should_respond_400_with_not_exists_fields(self): self._create_dag_models(1) self._create_deactivated_dag() fields = ["#caw&c"] - response = self.client.get( - f"api/v1/dags?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"api/v1/dags?fields={','.join(fields)}", headers={"REMOTE_USER": "test"}) assert response.status_code == 400, f"Current code: {response.status_code}" @@ -1314,7 +1297,7 @@ def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, sessio response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 expected_response = { @@ -1350,7 +1333,7 @@ def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, sessio "has_import_errors": False, "pickle_id": None, } - assert response.json == expected_response + assert response.json() == expected_response _check_last_log( session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None, expected_extra=payload ) @@ -1362,28 +1345,26 @@ def test_should_respond_200_on_patch_with_granular_dag_access(self, session): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + headers={"REMOTE_USER": "test_granular_permissions"}, ) assert response.status_code == 200 _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None) - def test_should_respond_400_on_invalid_request(self): + def test_ignore_read_only_fields(self): patch_body = { - "is_paused": True, + "is_paused": False, "schedule_interval": { "__type": "CronExpression", "value": "1 1 * * *", }, } dag_model = self._create_dag_model() - response = self.client.patch(f"/api/v1/dags/{dag_model.dag_id}", json=patch_body) - assert response.status_code == 400 - assert response.json == { - "detail": "Property is read-only - 'schedule_interval'", - "status": 400, - "title": "Bad Request", - "type": EXCEPTIONS_LINK_MAP[400], - } + response = self.client.patch( + f"/api/v1/dags/{dag_model.dag_id}", json=patch_body, headers={"REMOTE_USER": "test"} + ) + assert response.status_code == 200 + assert response.json()["is_paused"] is False + assert response.json()["schedule_interval"] == {"__type": "CronExpression", "value": "2 2 * * *"} def test_validation_error_raises_400(self): patch_body = { @@ -1393,10 +1374,10 @@ def test_validation_error_raises_400(self): response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}", json=patch_body, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + headers={"REMOTE_USER": "test_granular_permissions"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'ispaused': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -1408,10 +1389,10 @@ def test_non_existing_dag_raises_not_found(self): "is_paused": True, } response = self.client.patch( - "/api/v1/dags/non_existing_dag", json=patch_body, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/dags/non_existing_dag", json=patch_body, headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": None, "status": 404, "title": "Dag with id: 'non_existing_dag' not found", @@ -1419,7 +1400,7 @@ def test_non_existing_dag_raises_not_found(self): } def test_should_respond_404(self): - response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/dags/INVALID_DAG", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 @provide_session @@ -1439,7 +1420,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_respond_200_with_update_mask(self, url_safe_serializer): file_token = url_safe_serializer.dumps("/tmp/dag_1.py") @@ -1450,7 +1431,7 @@ def test_should_respond_200_with_update_mask(self, url_safe_serializer): response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}?update_mask=is_paused", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -1487,7 +1468,7 @@ def test_should_respond_200_with_update_mask(self, url_safe_serializer): "has_import_errors": False, "pickle_id": None, } - assert response.json == expected_response + assert response.json() == expected_response @pytest.mark.parametrize( "payload, update_mask, error_message", @@ -1514,10 +1495,10 @@ def test_should_respond_400_for_invalid_fields_in_update_mask(self, payload, upd response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}?{update_mask}", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == error_message + assert response.json()["detail"] == error_message def test_should_respond_403_unauthorized(self): dag_model = self._create_dag_model() @@ -1526,7 +1507,7 @@ def test_should_respond_403_unauthorized(self): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1566,7 +1547,7 @@ def test_should_respond_200_on_patch_is_paused(self, session, url_safe_serialize json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -1640,7 +1621,7 @@ def test_should_respond_200_on_patch_is_paused(self, session, url_safe_serialize }, ], "total_entries": 2, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_dags", execution_date=None) def test_should_respond_200_on_patch_is_paused_using_update_mask(self, session, url_safe_serializer): @@ -1657,7 +1638,7 @@ def test_should_respond_200_on_patch_is_paused_using_update_mask(self, session, json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -1731,7 +1712,7 @@ def test_should_respond_200_on_patch_is_paused_using_update_mask(self, session, }, ], "total_entries": 2, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_dags", execution_date=None) def test_wrong_value_as_update_mask_rasise(self, session): @@ -1746,11 +1727,11 @@ def test_wrong_value_as_update_mask_rasise(self, session): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "Only `is_paused` field can be updated through the REST API", "status": 400, "title": "Bad Request", @@ -1769,11 +1750,11 @@ def test_invalid_request_body_raises_badrequest(self, session): json={ "ispaused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'ispaused': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -1789,7 +1770,7 @@ def test_only_active_true_returns_active_dags(self, url_safe_serializer, session json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert { @@ -1829,7 +1810,7 @@ def test_only_active_true_returns_active_dags(self, url_safe_serializer, session } ], "total_entries": 1, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_dags", execution_date=None) def test_only_active_false_returns_all_dags(self, url_safe_serializer, session): @@ -1841,7 +1822,7 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer, session): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) file_token_2 = url_safe_serializer.dumps("/tmp/dag_del_1.py") @@ -1916,7 +1897,7 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer, session): }, ], "total_entries": 2, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_dags", execution_date=None) @pytest.mark.parametrize( @@ -1943,10 +1924,10 @@ def test_filter_dags_by_tags_works(self, url, expected_dag_ids): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - dag_ids = [dag["dag_id"] for dag in response.json["dags"]] + dag_ids = [dag["dag_id"] for dag in response.json()["dags"]] assert expected_dag_ids == dag_ids @@ -1977,10 +1958,10 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - dag_ids = {dag["dag_id"] for dag in response.json["dags"]} + dag_ids = {dag["dag_id"] for dag in response.json()["dags"]} assert expected_dag_ids == dag_ids @@ -1991,11 +1972,11 @@ def test_should_respond_200_with_granular_dag_access(self): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + headers={"REMOTE_USER": "test_granular_permissions"}, ) assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" + assert len(response.json()["dags"]) == 1 + assert response.json()["dags"][0]["dag_id"] == "TEST_DAG_1" @pytest.mark.parametrize( "url, expected_dag_ids", @@ -2034,15 +2015,15 @@ def test_should_respond_200_and_handle_pagination(self, url, expected_dag_ids): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - dag_ids = [dag["dag_id"] for dag in response.json["dags"]] + dag_ids = [dag["dag_id"] for dag in response.json()["dags"]] assert expected_dag_ids == dag_ids - assert 10 == response.json["total_entries"] + assert 10 == response.json()["total_entries"] def test_should_respond_200_default_limit(self): self._create_dag_models(101) @@ -2052,13 +2033,13 @@ def test_should_respond_200_default_limit(self): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert 100 == len(response.json["dags"]) - assert 101 == response.json["total_entries"] + assert 100 == len(response.json()["dags"]) + assert 101 == response.json()["total_entries"] def test_should_raises_401_unauthenticated(self): response = self.client.patch( @@ -2068,7 +2049,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_respond_403_unauthorized(self): self._create_dag_models(1) @@ -2077,7 +2058,7 @@ def test_should_respond_403_unauthorized(self): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -2092,7 +2073,7 @@ def test_should_respond_200_and_pause_dags(self, url_safe_serializer): json={ "is_paused": True, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -2166,7 +2147,7 @@ def test_should_respond_200_and_pause_dags(self, url_safe_serializer): }, ], "total_entries": 2, - } == response.json + } == response.json() @provide_session def test_should_respond_200_and_pause_dag_pattern(self, session, url_safe_serializer): @@ -2179,7 +2160,7 @@ def test_should_respond_200_and_pause_dag_pattern(self, session, url_safe_serial json={ "is_paused": True, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -2253,7 +2234,7 @@ def test_should_respond_200_and_pause_dag_pattern(self, session, url_safe_serial }, ], "total_entries": 2, - } == response.json + } == response.json() dags_not_updated = session.query(DagModel).filter(~DagModel.is_paused) assert len(dags_not_updated.all()) == 8 @@ -2268,7 +2249,7 @@ def test_should_respond_200_and_reverse_ordering(self, session, url_safe_seriali response = self.client.get( "/api/v1/dags?order_by=-dag_id", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -2342,7 +2323,7 @@ def test_should_respond_200_and_reverse_ordering(self, session, url_safe_seriali }, ], "total_entries": 2, - } == response.json + } == response.json() def test_should_respons_400_dag_id_pattern_missing(self): self._create_dag_models(1) @@ -2351,7 +2332,7 @@ def test_should_respons_400_dag_id_pattern_missing(self): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 @@ -2385,7 +2366,7 @@ def test_that_dag_can_be_deleted(self, session): response = self.client.delete( "/api/v1/dags/TEST_DAG_1", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 204 _check_last_log(session, dag_id="TEST_DAG_1", event="api.delete_dag", execution_date=None) @@ -2393,10 +2374,10 @@ def test_that_dag_can_be_deleted(self, session): def test_raise_when_dag_is_not_found(self): response = self.client.delete( "/api/v1/dags/TEST_DAG_1", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": None, "status": 404, "title": "Dag with id: 'TEST_DAG_1' not found", @@ -2412,10 +2393,10 @@ def test_raises_when_task_instances_of_dag_is_still_running(self, dag_maker, ses session.flush() response = self.client.delete( "/api/v1/dags/TEST_DAG_1", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 409 - assert response.json == { + assert response.json() == { "detail": "Task instances of dag with id: 'TEST_DAG_1' are still running", "status": 409, "title": "Conflict", @@ -2426,7 +2407,7 @@ def test_users_without_delete_permission_cannot_delete_dag(self): self._create_dag_models(1) response = self.client.delete( "/api/v1/dags/TEST_DAG_1", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_dag_parsing.py b/tests/api_connexion/endpoints/test_dag_parsing.py index 1155e1d8841ff..2e4e0d9ec8275 100644 --- a/tests/api_connexion/endpoints/test_dag_parsing.py +++ b/tests/api_connexion/endpoints/test_dag_parsing.py @@ -43,23 +43,23 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)], # type: ignore ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore TEST_DAG_ID, access_control={"Test": [permissions.ACTION_CAN_EDIT]}, ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestDagParsingRequest: @@ -82,26 +82,20 @@ def test_201_and_400_requests(self, url_safe_serializer, session): test_dag: DAG = dagbag.dags[TEST_DAG_ID] url = f"/api/v1/parseDagFile/{url_safe_serializer.dumps(test_dag.fileloc)}" - response = self.client.put( - url, headers={"Accept": "application/json"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.put(url, headers={"Accept": "application/json", "REMOTE_USER": "test"}) assert 201 == response.status_code parsing_requests = session.scalars(select(DagPriorityParsingRequest)).all() assert parsing_requests[0].fileloc == test_dag.fileloc # Duplicate file parsing request - response = self.client.put( - url, headers={"Accept": "application/json"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.put(url, headers={"Accept": "application/json", "REMOTE_USER": "test"}) assert 201 == response.status_code parsing_requests = session.scalars(select(DagPriorityParsingRequest)).all() assert parsing_requests[0].fileloc == test_dag.fileloc def test_bad_file_request(self, url_safe_serializer, session): url = f"/api/v1/parseDagFile/{url_safe_serializer.dumps('/some/random/file.py')}" - response = self.client.put( - url, headers={"Accept": "application/json"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.put(url, headers={"Accept": "application/json", "REMOTE_USER": "test"}) assert response.status_code == 404 parsing_requests = session.scalars(select(DagPriorityParsingRequest)).all() @@ -111,8 +105,7 @@ def test_bad_user_request(self, url_safe_serializer, session): url = f"/api/v1/parseDagFile/{url_safe_serializer.dumps('/some/random/file.py')}" response = self.client.put( url, - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"Accept": "application/json", "REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 2d7c6ac0544d7..1ba603e9b10c8 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -35,7 +35,7 @@ from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags from tests.test_utils.www import _check_last_log @@ -45,10 +45,10 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -63,7 +63,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_dag_view_only", role_name="TestViewDags", permissions=[ @@ -75,7 +75,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_view_dags", role_name="TestViewDags", permissions=[ @@ -84,25 +84,25 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_granular_permissions", role_name="TestGranularDag", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)], ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "TEST_DAG_ID", access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_dag_view_only") # type: ignore - delete_user(app, username="test_view_dags") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_roles(app) + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_dag_view_only") # type: ignore + delete_user(connexion_app.app, username="test_view_dags") # type: ignore + delete_user(connexion_app.app, username="test_granular_permissions") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_roles(connexion_app.app) class TestDagRunEndpoint: @@ -112,8 +112,9 @@ class TestDagRunEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = configured_app.app + self.client = self.connexion_app.test_client() # type:ignore clear_db_runs() clear_db_serialized_dags() clear_db_dags() @@ -123,13 +124,14 @@ def teardown_method(self) -> None: clear_db_dags() clear_db_serialized_dags() - def _create_dag(self, dag_id): + def _create_dag(self, dag_id, is_active=True, has_import_errors=False): dag_instance = DagModel(dag_id=dag_id) - dag_instance.is_active = True + dag_instance.is_active = is_active + dag_instance.has_import_errors = has_import_errors with create_session() as session: session.add(dag_instance) dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)}) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) return dag_instance def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): @@ -177,21 +179,21 @@ def test_should_respond_204(self, session): session.add_all(self._create_test_dag_run()) session.commit() response = self.client.delete( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", headers={"REMOTE_USER": "test"} ) assert response.status_code == 204 # Check if the Dag Run is deleted from the database response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 def test_should_respond_404(self): response = self.client.delete( - "api/v1/dags/INVALID_DAG_RUN/dagRuns/INVALID_DAG_RUN", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/INVALID_DAG_RUN/dagRuns/INVALID_DAG_RUN", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "DAGRun with DAG ID: 'INVALID_DAG_RUN' and DagRun ID: 'INVALID_DAG_RUN' not found", "status": 404, "title": "Not Found", @@ -206,12 +208,12 @@ def test_should_raises_401_unauthenticated(self, session): "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -244,10 +246,10 @@ def test_should_respond_200(self, session): result = session.query(DagRun).all() assert len(result) == 1 response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "TEST_DAG_ID", "dag_run_id": "TEST_DAG_RUN_ID", "end_date": None, @@ -266,7 +268,7 @@ def test_should_respond_200(self, session): def test_should_respond_404(self): response = self.client.get( - "api/v1/dags/invalid-id/dagRuns/invalid-id", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/invalid-id/dagRuns/invalid-id", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 expected_resp = { @@ -275,7 +277,7 @@ def test_should_respond_404(self): "title": "DAGRun not found", "type": EXCEPTIONS_LINK_MAP[404], } - assert expected_resp == response.json + assert expected_resp == response.json() def test_should_raises_401_unauthenticated(self, session): dagrun_model = DagRun( @@ -291,7 +293,7 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "fields", @@ -316,11 +318,10 @@ def test_should_return_specified_fields(self, session, fields): assert len(result) == 1 response = self.client.get( f"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - res_json = response.json - print("get dagRun", res_json) + res_json = response.json() assert len(res_json.keys()) == len(fields) for field in fields: assert field in res_json @@ -342,7 +343,7 @@ def test_should_respond_400_with_not_exists_fields(self, session): fields = ["#caw&c"] response = self.client.get( f"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400, f"Current code: {response.status_code}" @@ -375,11 +376,9 @@ def test_should_respond_200(self, session): self._create_test_dag_run() result = session.query(DagRun).all() assert len(result) == 2 - response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_runs": [ { "dag_id": "TEST_DAG_ID", @@ -422,22 +421,22 @@ def test_filter_by_state(self, session): self._create_test_dag_run(state="queued", idx_start=3) assert session.query(DagRun).count() == 4 response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns?state=running,queued", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns?state=running,queued", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["total_entries"] == 4 - assert response.json["dag_runs"][0]["state"] == response.json["dag_runs"][1]["state"] == "running" - assert response.json["dag_runs"][2]["state"] == response.json["dag_runs"][3]["state"] == "queued" + assert response.json()["total_entries"] == 4 + assert response.json()["dag_runs"][0]["state"] == response.json()["dag_runs"][1]["state"] == "running" + assert response.json()["dag_runs"][2]["state"] == response.json()["dag_runs"][3]["state"] == "queued" def test_invalid_order_by_raises_400(self): self._create_test_dag_run() response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns?order_by=invalid", headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 msg = "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_return_correct_results_with_order_by(self, session): self._create_test_dag_run() @@ -445,13 +444,13 @@ def test_return_correct_results_with_order_by(self, session): assert len(result) == 2 response = self.client.get( "api/v1/dags/TEST_DAG_ID/dagRuns?order_by=-execution_date", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert self.default_time < self.default_time_2 # - means descending - assert response.json == { + assert response.json() == { "dag_runs": [ { "dag_id": "TEST_DAG_ID", @@ -492,19 +491,19 @@ def test_return_correct_results_with_order_by(self, session): def test_should_return_all_with_tilde_as_dag_id_and_all_dag_permissions(self): self._create_test_dag_run(extra_dag=True) expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID", "TEST_DAG_ID_3", "TEST_DAG_ID_4"] - response = self.client.get("api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags/~/dagRuns", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] + dag_run_ids = [dag_run["dag_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): self._create_test_dag_run(extra_dag=True) expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"] response = self.client.get( - "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + "api/v1/dags/~/dagRuns", headers={"REMOTE_USER": "test_granular_permissions"} ) assert response.status_code == 200 - dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] + dag_run_ids = [dag_run["dag_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def test_should_raises_401_unauthenticated(self): @@ -512,7 +511,7 @@ def test_should_raises_401_unauthenticated(self): response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "fields", @@ -527,10 +526,10 @@ def test_should_return_specified_fields(self, session, fields): assert len(result) == 2 response = self.client.get( f"api/v1/dags/TEST_DAG_ID/dagRuns?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - for dag_run in response.json["dag_runs"]: + for dag_run in response.json()["dag_runs"]: assert len(dag_run.keys()) == len(fields) for field in fields: assert field in dag_run @@ -540,7 +539,7 @@ def test_should_respond_400_with_not_exists_fields(self): fields = ["#caw&c"] response = self.client.get( f"api/v1/dags/TEST_DAG_ID/dagRuns?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400, f"Current code: {response.status_code}" @@ -601,31 +600,29 @@ class TestGetDagRunsPagination(TestDagRunEndpoint): ) def test_handle_limit_and_offset(self, url, expected_dag_run_ids): self._create_dag_runs(10) - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"]] + assert response.json()["total_entries"] == 10 + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def test_should_respect_page_size_limit(self): self._create_dag_runs(200) - response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["dag_runs"]) == 100 # default is 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["dag_runs"]) == 100 # default is 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self): self._create_dag_runs(200) response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns?limit=180", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns?limit=180", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert len(response.json["dag_runs"]) == 150 + assert len(response.json()["dag_runs"]) == 150 def _create_dag_runs(self, count): dag_runs = [ @@ -713,10 +710,10 @@ def test_date_filters_gte_and_lte(self, url, expected_dag_run_ids, session): d.updated_at = d.execution_date session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == len(expected_dag_run_ids) - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"]] + assert response.json()["total_entries"] == len(expected_dag_run_ids) + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def _create_dag_runs(self): @@ -767,10 +764,10 @@ class TestGetDagRunsEndDateFilters(TestDagRunEndpoint): ) def test_end_date_gte_lte(self, url, expected_dag_run_ids): self._create_test_dag_run("success") # state==success, then end date is today - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == len(expected_dag_run_ids) - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"] if dag_run] + assert response.json()["total_entries"] == len(expected_dag_run_ids) + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"] if dag_run] assert dag_run_ids == expected_dag_run_ids @@ -780,10 +777,10 @@ def test_should_respond_200(self): response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"]}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_runs": [ { "dag_id": "TEST_DAG_ID", @@ -826,10 +823,10 @@ def test_raises_validation_error_for_invalid_request(self): response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dagids": ["TEST_DAG_ID"]}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'dagids': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -842,22 +839,22 @@ def test_filter_by_state(self): response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"], "states": ["running", "queued"]}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 4 - assert response.json["dag_runs"][0]["state"] == response.json["dag_runs"][1]["state"] == "running" - assert response.json["dag_runs"][2]["state"] == response.json["dag_runs"][3]["state"] == "queued" + assert response.json()["total_entries"] == 4 + assert response.json()["dag_runs"][0]["state"] == response.json()["dag_runs"][1]["state"] == "running" + assert response.json()["dag_runs"][2]["state"] == response.json()["dag_runs"][3]["state"] == "queued" def test_order_by_descending_works(self): self._create_test_dag_run() response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"], "order_by": "-dag_run_id"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_runs": [ { "dag_id": "TEST_DAG_ID", @@ -900,21 +897,21 @@ def test_order_by_raises_for_invalid_attr(self): response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"], "order_by": "-dag_ru"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 msg = "Ordering with 'dag_ru' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): self._create_test_dag_run(extra_dag=True) response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dag_ids": []}, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + headers={"REMOTE_USER": "test_granular_permissions"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_runs": [ { "dag_id": "TEST_DAG_ID", @@ -967,17 +964,17 @@ def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions def test_payload_validation(self, payload, error): self._create_test_dag_run() response = self.client.post( - "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/~/dagRuns/list", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 - assert response.json.get("detail") == error + assert response.json()["detail"] == error def test_should_raises_401_unauthenticated(self): self._create_test_dag_run() response = self.client.post("api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"]}) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -1034,23 +1031,21 @@ class TestGetDagRunBatchPagination(TestDagRunEndpoint): def test_handle_limit_and_offset(self, payload, expected_dag_run_ids): self._create_dag_runs(10) response = self.client.post( - "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/~/dagRuns/list", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"]] + assert response.json()["total_entries"] == 10 + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def test_should_respect_page_size_limit(self): self._create_dag_runs(200) - response = self.client.post( - "api/v1/dags/~/dagRuns/list", json={}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("api/v1/dags/~/dagRuns/list", json={}, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["dag_runs"]) == 100 # default is 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["dag_runs"]) == 100 # default is 100 def _create_dag_runs(self, count): dag_runs = [ @@ -1115,11 +1110,11 @@ class TestGetDagRunBatchDateFilters(TestDagRunEndpoint): def test_date_filters_gte_and_lte(self, payload, expected_dag_run_ids): self._create_dag_runs() response = self.client.post( - "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/~/dagRuns/list", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["total_entries"] == len(expected_dag_run_ids) - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"]] + assert response.json()["total_entries"] == len(expected_dag_run_ids) + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def _create_dag_runs(self): @@ -1187,10 +1182,10 @@ def test_naive_date_filters_raises_400(self, payload, expected_response): self._create_dag_runs() response = self.client.post( - "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/~/dagRuns/list", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 - assert response.json["detail"] == expected_response + assert response.json()["detail"] == expected_response @pytest.mark.parametrize( "payload, expected_dag_run_ids", @@ -1208,11 +1203,11 @@ def test_naive_date_filters_raises_400(self, payload, expected_response): def test_end_date_gte_lte(self, payload, expected_dag_run_ids): self._create_test_dag_run("success") # state==success, then end date is today response = self.client.post( - "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/~/dagRuns/list", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["total_entries"] == len(expected_dag_run_ids) - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"] if dag_run] + assert response.json()["total_entries"] == len(expected_dag_run_ids) + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"] if dag_run] assert dag_run_ids == expected_dag_run_ids @@ -1268,7 +1263,7 @@ def test_should_respond_200( response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json=request_json, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -1288,7 +1283,7 @@ def test_should_respond_200( expected_data_interval_start = data_interval_start expected_data_interval_end = data_interval_end - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": "TEST_DAG_ID", "dag_run_id": expected_dag_run_id, @@ -1311,10 +1306,10 @@ def test_raises_validation_error_for_invalid_request(self): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"executiondate": "2020-11-10T08:25:56Z"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'executiondate': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -1326,10 +1321,10 @@ def test_raises_validation_error_for_invalid_params(self): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"conf": {"validated_number": 5000}}, # DAG param must be between 1 and 10 - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert "Invalid input for param" in response.json["detail"] + assert "Invalid input for param" in response.json()["detail"] @mock.patch("airflow.api_connexion.endpoints.dag_run_endpoint.get_airflow_app") def test_dagrun_creation_exception_is_handled(self, mock_get_app, session): @@ -1341,10 +1336,10 @@ def test_dagrun_creation_exception_is_handled(self, mock_get_app, session): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"execution_date": "2020-11-10T08:25:56Z"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": error_message, "status": 400, "title": "Bad Request", @@ -1352,34 +1347,28 @@ def test_dagrun_creation_exception_is_handled(self, mock_get_app, session): } def test_should_respond_404_if_a_dag_is_inactive(self, session): - dm = self._create_dag("TEST_INACTIVE_DAG_ID") - dm.is_active = False - session.add(dm) - session.flush() + self._create_dag("TEST_INACTIVE_DAG_ID", is_active=False) response = self.client.post( "api/v1/dags/TEST_INACTIVE_DAG_ID/dagRuns", json={}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 404 + assert response.json()["status"] == 404 def test_should_respond_400_if_a_dag_has_import_errors(self, session): """Test that if a dagmodel has import errors, dags won't be triggered""" - dm = self._create_dag("TEST_DAG_ID") - dm.has_import_errors = True - session.add(dm) - session.flush() + self._create_dag("TEST_DAG_ID", has_import_errors=True) response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert { + assert response.json() == { "detail": "DAG with dag_id: 'TEST_DAG_ID' has import errors", "status": 400, "title": "DAG cannot be triggered", "type": EXCEPTIONS_LINK_MAP[400], - } == response.json + } def test_should_response_200_for_matching_execution_date_logical_date(self): execution_date = "2020-11-10T08:25:56.939143+00:00" @@ -1391,12 +1380,12 @@ def test_should_response_200_for_matching_execution_date_logical_date(self): "execution_date": execution_date, "logical_date": logical_date, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) dag_run_id = f"manual__{logical_date}" assert response.status_code == 200 - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": "TEST_DAG_ID", "dag_run_id": dag_run_id, @@ -1420,11 +1409,11 @@ def test_should_response_400_for_conflicting_execution_date_logical_date(self): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"execution_date": execution_date, "logical_date": logical_date}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["title"] == "logical_date conflicts with execution_date" - assert response.json["detail"] == (f"'{logical_date}' != '{execution_date}'") + assert response.json()["title"] == "logical_date conflicts with execution_date" + assert response.json()["detail"] == (f"'{logical_date}' != '{execution_date}'") @pytest.mark.parametrize( "data_interval_start, data_interval_end, expected", @@ -1462,10 +1451,10 @@ def test_should_response_400_for_missing_start_date_or_end_date( "data_interval_start": data_interval_start, "data_interval_end": data_interval_end, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == expected + assert response.json()["detail"] == expected @pytest.mark.parametrize( "data, expected", @@ -1491,10 +1480,10 @@ def test_should_response_400_for_missing_start_date_or_end_date( def test_should_response_400_for_naive_datetime_and_bad_datetime(self, data, expected): self._create_dag("TEST_DAG_ID") response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns", json=data, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns", json=data, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 - assert response.json["detail"] == expected + assert response.json()["detail"] == expected @pytest.mark.parametrize( "data, expected", @@ -1512,16 +1501,16 @@ def test_should_response_400_for_naive_datetime_and_bad_datetime(self, data, exp def test_should_response_400_for_non_dict_dagrun_conf(self, data, expected): self._create_dag("TEST_DAG_ID") response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns", json=data, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns", json=data, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 - assert response.json["detail"] == expected + assert response.json()["detail"] == expected def test_response_404(self): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"dag_run_id": "TEST_DAG_RUN", "execution_date": self.default_time}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 assert { @@ -1529,7 +1518,7 @@ def test_response_404(self): "status": 404, "title": "DAG not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() @pytest.mark.parametrize( "url, request_json, expected_response", @@ -1541,7 +1530,7 @@ def test_response_404(self): "execution_date": "2020-06-12T18:00:00+00:00", }, { - "detail": "Property is read-only - 'start_date'", + "detail": "{'start_date': ['Unknown field.']}", "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], @@ -1552,7 +1541,7 @@ def test_response_404(self): "api/v1/dags/TEST_DAG_ID/dagRuns", {"state": "failed", "execution_date": "2020-06-12T18:00:00+00:00"}, { - "detail": "Property is read-only - 'state'", + "detail": "{'state': ['Unknown field.']}", "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], @@ -1563,9 +1552,9 @@ def test_response_404(self): ) def test_response_400(self, url, request_json, expected_response): self._create_dag("TEST_DAG_ID") - response = self.client.post(url, json=request_json, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.post(url, json=request_json, headers={"REMOTE_USER": "test"}) assert response.status_code == 400, response.data - assert expected_response == response.json + assert expected_response == response.json() def test_response_409(self): self._create_test_dag_run() @@ -1575,10 +1564,10 @@ def test_response_409(self): "dag_run_id": "TEST_DAG_RUN_ID_1", "execution_date": self.default_time_3, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 409, response.data - assert response.json == { + assert response.json() == { "detail": "DAGRun with DAG ID: 'TEST_DAG_ID' and " "DAGRun ID: 'TEST_DAG_RUN_ID_1' already exists", "status": 409, @@ -1595,11 +1584,11 @@ def test_response_409_when_execution_date_is_same(self): "dag_run_id": "TEST_DAG_RUN_ID_6", "execution_date": self.default_time, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 409, response.data - assert response.json == { + assert response.json() == { "detail": "DAGRun with DAG ID: 'TEST_DAG_ID' and " "DAGRun logical date: '2020-06-11 18:00:00+00:00' already exists", "status": 409, @@ -1616,7 +1605,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "username", @@ -1630,7 +1619,7 @@ def test_should_raises_403_unauthorized(self, username): "dag_run_id": "TEST_DAG_RUN_ID_1", "execution_date": self.default_time, }, - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, ) assert response.status_code == 403 @@ -1663,7 +1652,7 @@ def test_should_respond_200(self, state, run_type, dag_maker, session): dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: task = EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun(run_id=dag_run_id, run_type=run_type) ti = dr.get_task_instance(task_id="task_id") ti.task = task @@ -1676,7 +1665,7 @@ def test_should_respond_200(self, state, run_type, dag_maker, session): response = self.client.patch( f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}", json=request_json, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) if state != "queued": @@ -1685,7 +1674,7 @@ def test_should_respond_200(self, state, run_type, dag_maker, session): dr = session.query(DagRun).filter(DagRun.run_id == dr.run_id).first() assert response.status_code == 200 - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": dag_id, "dag_run_id": dag_run_id, @@ -1706,17 +1695,21 @@ def test_schema_validation_error_raises(self, dag_maker, session): dag_id = "TEST_DAG_ID" dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: - EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) - dag_maker.create_dagrun(run_id=dag_run_id) + task = EmptyOperator(task_id="task_id", dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) + dr = dag_maker.create_dagrun(run_id=dag_run_id, state=DagRunState.FAILED) + ti = dr.get_task_instance(task_id="task_id") + ti.task = task + ti.state = State.SUCCESS + session.merge(ti) + session.commit() response = self.client.patch( f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}", json={"states": "success"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'states': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -1737,10 +1730,10 @@ def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state, response = self.client.patch( "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", json=request_json, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": f"'{invalid_state}' is not one of ['success', 'failed', 'queued'] - 'state'", "status": 400, "title": "Bad Request", @@ -1755,7 +1748,7 @@ def test_should_raises_401_unauthenticated(self, session): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.patch( @@ -1763,7 +1756,7 @@ def test_should_raise_403_forbidden(self): json={ "state": "success", }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1773,7 +1766,7 @@ def test_should_respond_404(self): json={ "state": "success", }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -1787,7 +1780,7 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: task = EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun(run_id=dag_run_id, run_type=DagRunType.SCHEDULED) ti = dr.get_task_instance(task_id="task_id") ti.task = task @@ -1809,7 +1802,7 @@ def test_should_respond_200(self, dag_maker, session): dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: task = EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun(run_id=dag_run_id, state=DagRunState.FAILED) ti = dr.get_task_instance(task_id="task_id") ti.task = task @@ -1822,12 +1815,12 @@ def test_should_respond_200(self, dag_maker, session): response = self.client.post( f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear", json=request_json, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) dr = session.query(DagRun).filter(DagRun.run_id == dr.run_id).first() assert response.status_code == 200 - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": dag_id, "dag_run_id": dag_run_id, @@ -1851,16 +1844,20 @@ def test_schema_validation_error_raises_for_invalid_fields(self, dag_maker, sess dag_id = "TEST_DAG_ID" dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: - EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) - dag_maker.create_dagrun(run_id=dag_run_id, state=DagRunState.FAILED) + task = EmptyOperator(task_id="task_id", dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) + dr = dag_maker.create_dagrun(run_id=dag_run_id, state=DagRunState.FAILED) + ti = dr.get_task_instance(task_id="task_id") + ti.task = task + ti.state = State.SUCCESS + session.merge(ti) + session.commit() response = self.client.post( f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear", json={"dryrun": False}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'dryrun': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -1873,7 +1870,7 @@ def test_dry_run(self, dag_maker, session): dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: task = EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun(run_id=dag_run_id) ti = dr.get_task_instance(task_id="task_id") ti.task = task @@ -1886,11 +1883,11 @@ def test_dry_run(self, dag_maker, session): response = self.client.post( f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear", json=request_json, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "task_instances": [ { "dag_id": dag_id, @@ -1915,7 +1912,7 @@ def test_should_raises_401_unauthenticated(self, session): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.post( @@ -1923,7 +1920,7 @@ def test_should_raise_403_forbidden(self): json={ "dry_run": True, }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1933,7 +1930,7 @@ def test_should_respond_404(self): json={ "dry_run": True, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -1947,7 +1944,7 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: task = EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun(run_id=dag_run_id, run_type=DagRunType.SCHEDULED) ti = dr.get_task_instance(task_id="task_id") ti.task = task @@ -1993,7 +1990,7 @@ def test_should_respond_200(self, dag_maker, session): response = self.client.get( "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/upstreamDatasetEvents", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 expected_response = { @@ -2024,12 +2021,12 @@ def test_should_respond_200(self, dag_maker, session): ], "total_entries": 1, } - assert response.json == expected_response + assert response.json() == expected_response def test_should_respond_404(self): response = self.client.get( "api/v1/dags/invalid-id/dagRuns/invalid-id/upstreamDatasetEvents", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 expected_resp = { @@ -2038,7 +2035,7 @@ def test_should_respond_404(self): "title": "DAGRun not found", "type": EXCEPTIONS_LINK_MAP[404], } - assert expected_resp == response.json + assert expected_resp == response.json() def test_should_raises_401_unauthenticated(self, session): dagrun_model = DagRun( @@ -2054,7 +2051,7 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/upstreamDatasetEvents") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -2103,13 +2100,13 @@ def test_should_respond_200(self, dag_maker, session): response = self.client.patch( f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote", json={"note": new_note_value}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) dr = session.query(DagRun).filter(DagRun.run_id == created_dr.run_id).first() assert response.status_code == 200, response.text assert dr.note == new_note_value - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": dr.dag_id, "dag_run_id": dr.run_id, @@ -2132,10 +2129,10 @@ def test_should_respond_200(self, dag_maker, session): response = self.client.patch( f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": dr.dag_id, "dag_run_id": dr.run_id, @@ -2170,10 +2167,10 @@ def test_schema_validation_error_raises(self, dag_maker, session): response = self.client.patch( f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote", json={"notes": new_note_value}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'notes': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -2185,13 +2182,13 @@ def test_should_raises_401_unauthenticated(self, session): "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/setNote", json={"note": "I am setting a note while being unauthenticated."}, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.patch( "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/setNote", json={"note": "I am setting a note without the proper permissions."}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -2199,7 +2196,7 @@ def test_should_respond_404(self): response = self.client.patch( "api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/setNote", json={"note": "I am setting a note on a DAG that doesn't exist."}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -2211,8 +2208,8 @@ def test_should_respond_404(self): def test_should_respond_200_with_anonymous_user(self, dag_maker, session): from airflow.www import app as application - app = application.create_app(testing=True) - app.config["AUTH_ROLE_PUBLIC"] = "Admin" + app = application.create_connexion_app(testing=True) + app.app.config["AUTH_ROLE_PUBLIC"] = "Admin" dag_runs = self._create_test_dag_run(DagRunState.SUCCESS) session.add_all(dag_runs) session.commit() diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index 14c7d1534d4dc..1688600fe2245 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -24,7 +24,7 @@ from airflow.models import DagBag from airflow.security import permissions -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags pytestmark = pytest.mark.db_test @@ -42,38 +42,38 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], # type: ignore ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore TEST_DAG_ID, access_control={"Test": [permissions.ACTION_CAN_READ]}, ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore EXAMPLE_DAG_ID, access_control={"Test": [permissions.ACTION_CAN_READ]}, ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore TEST_MULTIPLE_DAGS_ID, access_control={"Test": [permissions.ACTION_CAN_READ]}, ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestGetSource: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore self.clear_db() def teardown_method(self) -> None: @@ -100,12 +100,9 @@ def test_should_respond_200_text(self, url_safe_serializer): dag_docstring = self._get_dag_file_docstring(test_dag.fileloc) url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" - response = self.client.get( - url, headers={"Accept": "text/plain"}, environ_overrides={"REMOTE_USER": "test"} - ) - + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert 200 == response.status_code - assert dag_docstring in response.data.decode() + assert dag_docstring in response.text assert "text/plain" == response.headers["Content-Type"] def test_should_respond_200_json(self, url_safe_serializer): @@ -115,12 +112,10 @@ def test_should_respond_200_json(self, url_safe_serializer): dag_docstring = self._get_dag_file_docstring(test_dag.fileloc) url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" - response = self.client.get( - url, headers={"Accept": "application/json"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(url, headers={"Accept": "application/json", "REMOTE_USER": "test"}) assert 200 == response.status_code - assert dag_docstring in response.json["content"] + assert dag_docstring in response.json()["content"] assert "application/json" == response.headers["Content-Type"] def test_should_respond_406(self, url_safe_serializer): @@ -129,18 +124,14 @@ def test_should_respond_406(self, url_safe_serializer): test_dag: DAG = dagbag.dags[TEST_DAG_ID] url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" - response = self.client.get( - url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(url, headers={"Accept": "image/webp", "REMOTE_USER": "test"}) assert 406 == response.status_code def test_should_respond_404(self): wrong_fileloc = "abcd1234" url = f"/api/v1/dagSources/{wrong_fileloc}" - response = self.client.get( - url, headers={"Accept": "application/json"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(url, headers={"Accept": "application/json", "REMOTE_USER": "test"}) assert 404 == response.status_code @@ -154,7 +145,7 @@ def test_should_raises_401_unauthenticated(self, url_safe_serializer): headers={"Accept": "text/plain"}, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self, url_safe_serializer): dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) @@ -163,8 +154,7 @@ def test_should_raise_403_forbidden(self, url_safe_serializer): response = self.client.get( f"/api/v1/dagSources/{url_safe_serializer.dumps(first_dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -175,12 +165,11 @@ def test_should_respond_403_not_readable(self, url_safe_serializer): response = self.client.get( f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) read_dag = self.client.get( f"/api/v1/dags/{NOT_READABLE_DAG_ID}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 403 assert read_dag.status_code == 403 @@ -192,13 +181,12 @@ def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_se response = self.client.get( f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) read_dag = self.client.get( f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 403 assert read_dag.status_code == 200 diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py index cc398329b9644..915f8b959000a 100644 --- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py @@ -24,7 +24,7 @@ from airflow.models.dagwarning import DagWarning from airflow.security import permissions from airflow.utils.session import create_session -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_dag_warnings, clear_db_dags pytestmark = pytest.mark.db_test @@ -32,9 +32,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test", role_name="Test", permissions=[ @@ -42,9 +42,9 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), ], # type: ignore ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test_with_dag2_read", role_name="TestWithDag2Read", permissions=[ @@ -53,11 +53,11 @@ def configured_app(minimal_app_for_api): ], # type: ignore ) - yield minimal_app_for_api + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_with_dag2_read") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test_with_dag2_read") # type: ignore class TestBaseDagWarning: @@ -65,8 +65,8 @@ class TestBaseDagWarning: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore def teardown_method(self) -> None: clear_db_dag_warnings() @@ -95,11 +95,11 @@ def setup_method(self): def test_response_one(self): response = self.client.get( "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test"}, - query_string={"dag_id": "dag1", "warning_type": "non-existent pool"}, + headers={"REMOTE_USER": "test"}, + params={"dag_id": "dag1", "warning_type": "non-existent pool"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "dag_warnings": [ { @@ -115,11 +115,11 @@ def test_response_one(self): def test_response_some(self): response = self.client.get( "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test"}, - query_string={"warning_type": "non-existent pool"}, + headers={"REMOTE_USER": "test"}, + params={"warning_type": "non-existent pool"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["dag_warnings"]) == 2 assert response_data == { "dag_warnings": ANY, @@ -129,11 +129,11 @@ def test_response_some(self): def test_response_none(self, session): response = self.client.get( "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test"}, - query_string={"dag_id": "missing_dag"}, + headers={"REMOTE_USER": "test"}, + params={"dag_id": "missing_dag"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "dag_warnings": [], "total_entries": 0, @@ -142,11 +142,11 @@ def test_response_none(self, session): def test_response_all(self): response = self.client.get( "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["dag_warnings"]) == 2 assert response_data == { "dag_warnings": ANY, @@ -155,19 +155,17 @@ def test_response_all(self): def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/dagWarnings") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/api/v1/dagWarnings", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/api/v1/dagWarnings", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): response = self.client.get( "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, - query_string={"dag_id": "dag1"}, + headers={"REMOTE_USER": "test_with_dag2_read"}, + params={"dag_id": "dag1"}, ) assert response.status_code == 403 @@ -178,7 +176,6 @@ def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): ) def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): response = self.client.get( - "/api/v1/dagWarnings", - query_string={"dag_id": "dag1", "warning_type": "non-existent pool"}, + "/api/v1/dagWarnings?dag_id=dag1&warning_type=non-existent+pool", ) assert response.status_code == expected_status_code diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py index 5b6e2f24146e4..f49fa5e26ea2f 100644 --- a/tests/api_connexion/endpoints/test_dataset_endpoint.py +++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py @@ -37,7 +37,7 @@ from airflow.utils import timezone from airflow.utils.session import provide_session from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_datasets, clear_db_runs @@ -48,9 +48,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -58,9 +58,9 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DATASET), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_queued_event", role_name="TestQueuedEvent", permissions=[ @@ -70,11 +70,11 @@ def configured_app(minimal_app_for_api): ], ) - yield app + yield connexion_app - delete_user(app, username="test_queued_event") # type: ignore - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test_queued_event") # type: ignore class TestDatasetEndpoint: @@ -82,8 +82,8 @@ class TestDatasetEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() clear_db_datasets() clear_db_runs() @@ -112,10 +112,10 @@ def test_should_respond_200(self, session): with assert_queries_count(5): response = self.client.get( f"/api/v1/datasets/{urllib.parse.quote('s3://bucket/key', safe='')}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "id": 1, "uri": "s3://bucket/key", "extra": {"foo": "bar"}, @@ -128,7 +128,7 @@ def test_should_respond_200(self, session): def test_should_respond_404(self): response = self.client.get( f"/api/v1/datasets/{urllib.parse.quote('s3://bucket/key', safe='')}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 assert { @@ -136,12 +136,12 @@ def test_should_respond_404(self): "status": 404, "title": "Dataset not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): self._create_dataset(session) response = self.client.get(f"/api/v1/datasets/{urllib.parse.quote('s3://bucket/key', safe='')}") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -177,10 +177,10 @@ def test_should_respond_200(self, session): assert session.query(DatasetModel).count() == 2 with assert_queries_count(8): - response = self.client.get("/api/v1/datasets", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "datasets": [ { @@ -220,12 +220,12 @@ def test_order_by_raises_400_for_invalid_attr(self, session): assert session.query(DatasetModel).count() == 2 response = self.client.get( - "/api/v1/datasets?order_by=fake", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/datasets?order_by=fake", headers={"REMOTE_USER": "test"} ) # missing attr assert response.status_code == 400 msg = "Ordering with 'fake' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_should_raises_401_unauthenticated(self, session): datasets = [ @@ -243,7 +243,7 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get("/api/v1/datasets") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "url, expected_datasets", @@ -273,9 +273,9 @@ def test_filter_datasets_by_uri_pattern_works(self, url, expected_datasets, sess dataset4 = DatasetModel("wasb://some_dataset_bucket_/key") session.add_all([dataset1, dataset2, dataset3, dataset4]) session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dataset_urls = {dataset["uri"] for dataset in response.json["datasets"]} + dataset_urls = {dataset["uri"] for dataset in response.json()["datasets"]} assert expected_datasets == dataset_urls @pytest.mark.parametrize("dag_ids, expected_num", [("dag1,dag2", 2), ("dag3", 1), ("dag2,dag3", 2)]) @@ -294,11 +294,9 @@ def test_filter_datasets_by_dag_ids_works(self, dag_ids, expected_num, session): task_ref1 = TaskOutletDatasetReference(dag_id="dag3", task_id="task1", dataset=dataset3) session.add_all([dataset1, dataset2, dataset3, dag1, dag2, dag3, dag_ref1, dag_ref2, task_ref1]) session.commit() - response = self.client.get( - f"/api/v1/datasets?dag_ids={dag_ids}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/datasets?dag_ids={dag_ids}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["datasets"]) == expected_num @pytest.mark.parametrize( @@ -323,10 +321,10 @@ def test_filter_datasets_by_dag_ids_and_uri_pattern_works( session.commit() response = self.client.get( f"/api/v1/datasets?dag_ids={dag_ids}&uri_pattern={uri_pattern}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["datasets"]) == expected_num @pytest.mark.parametrize( @@ -383,10 +381,10 @@ def test_limit_and_offset(self, url, expected_dataset_uris, session): session.add_all(datasets) session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dataset_uris = [dataset["uri"] for dataset in response.json["datasets"]] + dataset_uris = [dataset["uri"] for dataset in response.json()["datasets"]] assert dataset_uris == expected_dataset_uris def test_should_respect_page_size_limit_default(self, session): @@ -402,10 +400,10 @@ def test_should_respect_page_size_limit_default(self, session): session.add_all(datasets) session.commit() - response = self.client.get("/api/v1/datasets", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["datasets"]) == 100 + assert len(response.json()["datasets"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, session): @@ -421,10 +419,10 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session): session.add_all(datasets) session.commit() - response = self.client.get("/api/v1/datasets?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["datasets"]) == 150 + assert len(response.json()["datasets"]) == 150 class TestGetDatasetEvents(TestDatasetEndpoint): @@ -445,10 +443,10 @@ def test_should_respond_200(self, session): session.commit() assert session.query(DatasetEvent).count() == 2 - response = self.client.get("/api/v1/datasets/events", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets/events", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "dataset_events": [ { @@ -507,12 +505,10 @@ def test_filtering(self, attr, value, session): session.commit() assert session.query(DatasetEvent).count() == 3 - response = self.client.get( - f"/api/v1/datasets/events?{attr}={value}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/datasets/events?{attr}={value}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "dataset_events": [ { @@ -550,16 +546,16 @@ def test_order_by_raises_400_for_invalid_attr(self, session): assert session.query(DatasetEvent).count() == 2 response = self.client.get( - "/api/v1/datasets/events?order_by=fake", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/datasets/events?order_by=fake", headers={"REMOTE_USER": "test"} ) # missing attr assert response.status_code == 400 msg = "Ordering with 'fake' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_should_raises_401_unauthenticated(self, session): response = self.client.get("/api/v1/datasets/events") - assert_401(response) + assert response.status_code == 401 def test_includes_created_dagrun(self, session): self._create_dataset(session) @@ -587,10 +583,10 @@ def test_includes_created_dagrun(self, session): event.created_dagruns.append(dagrun) session.commit() - response = self.client.get("/api/v1/datasets/events", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets/events", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "dataset_events": [ { @@ -662,11 +658,11 @@ def test_should_respond_200(self, session): self._create_dataset(session) event_payload = {"dataset_uri": "s3://bucket/key", "extra": {"foo": "bar"}} response = self.client.post( - "/api/v1/datasets/events", json=event_payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/datasets/events", json=event_payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "id": ANY, "created_dagruns": [], @@ -692,7 +688,7 @@ def test_should_mask_sensitive_extra_logs(self, session): self._create_dataset(session) event_payload = {"dataset_uri": "s3://bucket/key", "extra": {"password": "bar"}} response = self.client.post( - "/api/v1/datasets/events", json=event_payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/datasets/events", json=event_payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 @@ -709,14 +705,14 @@ def test_order_by_raises_400_for_invalid_attr(self, session): self._create_dataset(session) event_invalid_payload = {"dataset_uri": "TEST_DATASET_URI", "extra": {"foo": "bar"}, "fake": {}} response = self.client.post( - "/api/v1/datasets/events", json=event_invalid_payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/datasets/events", json=event_invalid_payload, headers={"REMOTE_USER": "test"} ) - assert response.status_code == 400 + assert response.json()["status"] == 400 def test_should_raises_401_unauthenticated(self, session): self._create_dataset(session) response = self.client.post("/api/v1/datasets/events", json={"dataset_uri": "TEST_DATASET_URI"}) - assert_401(response) + assert response.json()["status"] == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -775,10 +771,10 @@ def test_limit_and_offset(self, url, expected_event_runids, session): session.add_all(events) session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - event_runids = [event["source_run_id"] for event in response.json["dataset_events"]] + event_runids = [event["source_run_id"] for event in response.json()["dataset_events"]] assert event_runids == expected_event_runids def test_should_respect_page_size_limit_default(self, session): @@ -797,10 +793,10 @@ def test_should_respect_page_size_limit_default(self, session): session.add_all(events) session.commit() - response = self.client.get("/api/v1/datasets/events", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets/events", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["dataset_events"]) == 100 + assert len(response.json()["dataset_events"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, session): @@ -819,12 +815,10 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session): session.add_all(events) session.commit() - response = self.client.get( - "/api/v1/datasets/events?limit=180", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/datasets/events?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["dataset_events"]) == 150 + assert len(response.json()["dataset_events"]) == 150 class TestQueuedEventEndpoint(TestDatasetEndpoint): @@ -855,11 +849,11 @@ def test_should_respond_200(self, session, create_dummy_dag): response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "created_at": self.default_time, "uri": "s3://bucket/key", "dag_id": "dag", @@ -871,7 +865,7 @@ def test_should_respond_404(self): response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -880,7 +874,7 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): dag_id = "dummy" @@ -888,7 +882,7 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get(f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self, session): dag_id = "dummy" @@ -896,7 +890,7 @@ def test_should_raise_403_forbidden(self, session): response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -938,7 +932,7 @@ def test_delete_should_respond_204(self, session, create_dummy_dag): response = self.client.delete( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 204 @@ -954,7 +948,7 @@ def test_should_respond_404(self): response = self.client.delete( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -963,20 +957,20 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): dag_id = "dummy" dataset_uri = "dummy" response = self.client.delete(f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self, session): dag_id = "dummy" dataset_uri = "dummy" response = self.client.delete( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -991,11 +985,11 @@ def test_should_respond_200(self, session, create_dummy_dag): response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "queued_events": [ { "created_at": self.default_time, @@ -1011,7 +1005,7 @@ def test_should_respond_404(self): response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -1020,21 +1014,21 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): dag_id = "dummy" response = self.client.get(f"/api/v1/dags/{dag_id}/datasets/queuedEvent") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): dag_id = "dummy" response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1064,7 +1058,7 @@ def test_should_respond_404(self): response = self.client.delete( f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -1073,21 +1067,21 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): dag_id = "dummy" response = self.client.delete(f"/api/v1/dags/{dag_id}/datasets/queuedEvent") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): dag_id = "dummy" response = self.client.delete( f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1129,11 +1123,11 @@ def test_should_respond_200(self, session, create_dummy_dag): response = self.client.get( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "queued_events": [ { "created_at": self.default_time, @@ -1149,7 +1143,7 @@ def test_should_respond_404(self): response = self.client.get( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -1158,21 +1152,21 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): dataset_uri = "not_exists" response = self.client.get(f"/api/v1/datasets/queuedEvent/{dataset_uri}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): dataset_uri = "not_exists" response = self.client.get( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1208,7 +1202,7 @@ def test_delete_should_respond_204(self, session, create_dummy_dag): response = self.client.delete( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 204 @@ -1221,7 +1215,7 @@ def test_should_respond_404(self): response = self.client.delete( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -1230,21 +1224,21 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): dataset_uri = "not_exists" response = self.client.delete(f"/api/v1/datasets/queuedEvent/{dataset_uri}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): dataset_uri = "not_exists" response = self.client.delete( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py index 6738858ddd00f..aca91ae59a0e5 100644 --- a/tests/api_connexion/endpoints/test_event_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py @@ -22,7 +22,7 @@ from airflow.models import Log from airflow.security import permissions from airflow.utils import timezone -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_logs @@ -31,34 +31,34 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore ) create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test_granular", role_name="TestGranular", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "TEST_DAG_ID_1", access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "TEST_DAG_ID_2", access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_granular") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_granular") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore @pytest.fixture @@ -91,7 +91,9 @@ def maker(event, when, **kwargs): log_model.dttm = when session.add(log_model) + session.commit() session.flush() + session.close() return log_model return maker @@ -100,8 +102,8 @@ def maker(event, when, **kwargs): class TestEventLogEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore clear_db_logs() self.default_time = timezone.parse("2020-06-10T20:00:00+00:00") self.default_time_2 = timezone.parse("2020-06-11T07:00:00+00:00") @@ -116,9 +118,7 @@ def teardown_method(self) -> None: ) def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, log_model): event_log_id = log_model.id - response = self.client.get( - f"/api/v1/eventLogs/{event_log_id}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/eventLogs/{event_log_id}", headers={"REMOTE_USER": "test"}) response = self.client.get("/api/v1/eventLogs") @@ -128,11 +128,9 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c class TestGetEventLog(TestEventLogEndpoint): def test_should_respond_200(self, log_model): event_log_id = log_model.id - response = self.client.get( - f"/api/v1/eventLogs/{event_log_id}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/eventLogs/{event_log_id}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "event_log_id": event_log_id, "event": "TEST_EVENT", "dag_id": "TEST_DAG_ID", @@ -145,26 +143,24 @@ def test_should_respond_200(self, log_model): } def test_should_respond_404(self): - response = self.client.get("/api/v1/eventLogs/1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/eventLogs/1", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": None, "status": 404, "title": "Event Log not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, log_model): event_log_id = log_model.id response = self.client.get(f"/api/v1/eventLogs/{event_log_id}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/api/v1/eventLogs", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/api/v1/eventLogs", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @pytest.mark.parametrize( @@ -188,10 +184,12 @@ def test_should_respond_200(self, session, create_log_model): log_model_3.dttm = self.default_time_2 session.add(log_model_3) + session.commit() session.flush() - response = self.client.get("/api/v1/eventLogs", environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.get("/api/v1/eventLogs", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "event_logs": [ { "event_log_id": log_model_1.id, @@ -236,12 +234,12 @@ def test_order_eventlogs_by_owner(self, create_log_model, session): log_model_3 = Log(event="cli_scheduler", owner="root", extra='{"host_name": "e24b454f002a"}') log_model_3.dttm = self.default_time_2 session.add(log_model_3) + session.commit() session.flush() - response = self.client.get( - "/api/v1/eventLogs?order_by=-owner", environ_overrides={"REMOTE_USER": "test"} - ) + session.close() + response = self.client.get("/api/v1/eventLogs?order_by=-owner", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "event_logs": [ { "event_log_id": log_model_2.id, @@ -283,7 +281,7 @@ def test_order_eventlogs_by_owner(self, create_log_model, session): def test_should_raises_401_unauthenticated(self, log_model): response = self.client.get("/api/v1/eventLogs") - assert_401(response) + assert response.status_code == 401 def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session): eventlog1 = create_log_model( @@ -302,33 +300,36 @@ def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, s ) session.add_all([eventlog1, eventlog2]) session.commit() + session.close() for attr in ["dag_id", "task_id", "owner", "event"]: attr_value = f"TEST_{attr}_1".upper() response = self.client.get( - f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} + f"/api/v1/eventLogs?{attr}={attr_value}", headers={"REMOTE_USER": "test_granular"} ) assert response.status_code == 200 - assert response.json["total_entries"] == 1 - assert len(response.json["event_logs"]) == 1 - assert response.json["event_logs"][0][attr] == attr_value + assert {eventlog[attr] for eventlog in response.json()["event_logs"]} == {attr_value} + assert response.json()["total_entries"] == 1 + assert len(response.json()["event_logs"]) == 1 + assert response.json()["event_logs"][0][attr] == attr_value def test_should_filter_eventlogs_by_when(self, create_log_model, session): eventlog1 = create_log_model(event="TEST_EVENT_1", when=self.default_time) eventlog2 = create_log_model(event="TEST_EVENT_2", when=self.default_time_2) session.add_all([eventlog1, eventlog2]) session.commit() + session.close() for when_attr, expected_eventlog_event in { "before": "TEST_EVENT_1", "after": "TEST_EVENT_2", }.items(): response = self.client.get( f"/api/v1/eventLogs?{when_attr}=2020-06-10T20%3A00%3A01%2B00%3A00", # self.default_time + 1s - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 1 - assert len(response.json["event_logs"]) == 1 - assert response.json["event_logs"][0]["event"] == expected_eventlog_event + assert response.json()["total_entries"] == 1 + assert len(response.json()["event_logs"]) == 1 + assert response.json()["event_logs"][0]["event"] == expected_eventlog_event def test_should_filter_eventlogs_by_run_id(self, create_log_model, session): eventlog1 = create_log_model(event="TEST_EVENT_1", when=self.default_time, run_id="run_1") @@ -336,29 +337,30 @@ def test_should_filter_eventlogs_by_run_id(self, create_log_model, session): eventlog3 = create_log_model(event="TEST_EVENT_3", when=self.default_time, run_id="run_2") session.add_all([eventlog1, eventlog2, eventlog3]) session.commit() + session.close() for run_id, expected_eventlogs in { "run_1": {"TEST_EVENT_1"}, "run_2": {"TEST_EVENT_2", "TEST_EVENT_3"}, }.items(): response = self.client.get( f"/api/v1/eventLogs?run_id={run_id}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == len(expected_eventlogs) - assert len(response.json["event_logs"]) == len(expected_eventlogs) - assert {eventlog["event"] for eventlog in response.json["event_logs"]} == expected_eventlogs - assert all({eventlog["run_id"] == run_id for eventlog in response.json["event_logs"]}) + assert response.json()["total_entries"] == len(expected_eventlogs) + assert len(response.json()["event_logs"]) == len(expected_eventlogs) + assert {eventlog["event"] for eventlog in response.json()["event_logs"]} == expected_eventlogs + assert all({eventlog["run_id"] == run_id for eventlog in response.json()["event_logs"]}) def test_should_filter_eventlogs_by_included_events(self, create_log_model): for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: create_log_model(event=event, when=self.default_time) response = self.client.get( "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, + headers={"REMOTE_USER": "test_granular"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["event_logs"]) == 2 assert response_data["total_entries"] == 2 assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]} @@ -368,10 +370,10 @@ def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): create_log_model(event=event, when=self.default_time) response = self.client.get( "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, + headers={"REMOTE_USER": "test_granular"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["event_logs"]) == 1 assert response_data["total_entries"] == 1 assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} @@ -437,46 +439,48 @@ def test_handle_limit_and_offset(self, url, expected_events, task_instance, sess log_models = self._create_event_logs(task_instance, 10) session.add_all(log_models) session.commit() - - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - events = [event_log["event"] for event_log in response.json["event_logs"]] + assert response.json()["total_entries"] == 10 + events = [event_log["event"] for event_log in response.json()["event_logs"]] assert events == expected_events def test_should_respect_page_size_limit_default(self, task_instance, session): log_models = self._create_event_logs(task_instance, 200) session.add_all(log_models) + session.commit() session.flush() + session.close() - response = self.client.get("/api/v1/eventLogs", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/eventLogs", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["event_logs"]) == 100 # default 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["event_logs"]) == 100 # default 100 def test_should_raise_400_for_invalid_order_by_name(self, task_instance, session): log_models = self._create_event_logs(task_instance, 200) session.add_all(log_models) + session.commit() session.flush() - - response = self.client.get( - "/api/v1/eventLogs?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} - ) + session.close() + response = self.client.get("/api/v1/eventLogs?order_by=invalid", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, task_instance, session): log_models = self._create_event_logs(task_instance, 200) session.add_all(log_models) + session.commit() session.flush() - - response = self.client.get("/api/v1/eventLogs?limit=180", environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.get("/api/v1/eventLogs?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["event_logs"]) == 150 + assert len(response.json()["event_logs"]) == 150 def _create_event_logs(self, task_instance, count): return [Log(event=f"TEST_EVENT_{i}", task_instance=task_instance) for i in range(1, count + 1)] diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py index f6590b6d5a995..93e275e482589 100644 --- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py +++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py @@ -42,10 +42,10 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -54,12 +54,12 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestGetExtraLinks: @@ -70,13 +70,13 @@ def setup_attrs(self, configured_app, session) -> None: clear_db_runs() clear_db_xcom() - self.app = configured_app + self.connexion_app = configured_app self.dag = self._create_dag() - self.app.dag_bag = DagBag(os.devnull, include_examples=False) - self.app.dag_bag.dags = {self.dag.dag_id: self.dag} # type: ignore - self.app.dag_bag.sync_to_db() # type: ignore + self.connexion_app.app.dag_bag = DagBag(os.devnull, include_examples=False) + self.connexion_app.app.dag_bag.dags = {self.dag.dag_id: self.dag} # type: ignore + self.connexion_app.app.dag_bag.sync_to_db() # type: ignore self.dag.create_dagrun( run_id="TEST_DAG_RUN_ID", @@ -86,9 +86,10 @@ def setup_attrs(self, configured_app, session) -> None: session=session, data_interval=DataInterval(timezone.datetime(2020, 1, 1), timezone.datetime(2020, 1, 2)), ) + session.commit() session.flush() - - self.client = self.app.test_client() # type:ignore + session.close() + self.client = self.connexion_app.test_client() # type:ignore def teardown_method(self) -> None: clear_db_runs() @@ -124,7 +125,7 @@ def _create_dag(self): ], ) def test_should_respond_404(self, url, expected_title, expected_detail): - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert 404 == response.status_code assert { @@ -132,12 +133,12 @@ def test_should_respond_404(self, url, expected_title, expected_detail): "status": 404, "title": expected_title, "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raise_403_forbidden(self): response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_SINGLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -152,23 +153,23 @@ def test_should_respond_200(self): ) response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_SINGLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert 200 == response.status_code, response.data + assert 200 == response.status_code assert { "BigQuery Console": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID" - } == response.json + } == response.json() @mock_plugin_manager(plugins=[]) def test_should_respond_200_missing_xcom(self): response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_SINGLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert 200 == response.status_code, response.data - assert {"BigQuery Console": None} == response.json + assert 200 == response.status_code + assert {"BigQuery Console": None} == response.json() @mock_plugin_manager(plugins=[]) def test_should_respond_200_multiple_links(self): @@ -181,24 +182,24 @@ def test_should_respond_200_multiple_links(self): ) response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_MULTIPLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert 200 == response.status_code, response.data + assert 200 == response.status_code assert { "BigQuery Console #1": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID_1", "BigQuery Console #2": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID_2", - } == response.json + } == response.json() @mock_plugin_manager(plugins=[]) def test_should_respond_200_multiple_links_missing_xcom(self): response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_MULTIPLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert 200 == response.status_code, response.data - assert {"BigQuery Console #1": None, "BigQuery Console #2": None} == response.json + assert 200 == response.status_code + assert {"BigQuery Console #1": None, "BigQuery Console #2": None} == response.json() def test_should_respond_200_support_plugins(self): class GoogleLink(BaseOperatorLink): @@ -229,10 +230,10 @@ class AirflowTestPlugin(AirflowPlugin): with mock_plugin_manager(plugins=[AirflowTestPlugin]): response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_SINGLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert 200 == response.status_code, response.data + assert 200 == response.status_code assert { "BigQuery Console": None, "Google": "https://www.google.com", @@ -240,4 +241,4 @@ class AirflowTestPlugin(AirflowPlugin): "https://s3.amazonaws.com/airflow-logs/" "TEST_DAG_ID/TEST_SINGLE_QUERY/2020-01-01T00%3A00%3A00%2B00%3A00" ), - } == response.json + } == response.json() diff --git a/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py index a9f2d9ceb4691..3a71fc9d67e28 100644 --- a/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py +++ b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py @@ -59,7 +59,7 @@ def autoclean_user_payload(autoclean_username, autoclean_email): @pytest.fixture def autoclean_admin_user(configured_app, autoclean_user_payload): - security_manager = configured_app.appbuilder.sm + security_manager = configured_app.app.appbuilder.sm return security_manager.add_user( role=security_manager.find_role("Admin"), **autoclean_user_payload, @@ -82,9 +82,9 @@ def autoclean_email(): @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -100,28 +100,29 @@ def configured_app(minimal_app_for_api): ], ) - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore class TestFABforwarding: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = configured_app.app + self.client = self.connexion_app.test_client() # type:ignore def teardown_method(self): """ Delete all roles except these ones. Test and TestNoPermissions are deleted by delete_user above """ - session = self.app.appbuilder.get_session + session = self.flask_app.appbuilder.get_session existing_roles = set(EXISTING_ROLES) existing_roles.update(["Test", "TestNoPermissions"]) roles = session.query(Role).filter(~Role.name.in_(existing_roles)).all() for role in roles: - delete_role(self.app, role.name) + delete_role(self.flask_app, role.name) users = session.query(User).filter(User.changed_on == timezone.parse(DEFAULT_TIME)) users.delete(synchronize_session=False) session.commit() @@ -130,31 +131,31 @@ def teardown_method(self): class TestFABRoleForwarding(TestFABforwarding): @mock.patch("airflow.api_connexion.endpoints.forward_to_fab_endpoint.get_auth_manager") def test_raises_400_if_manager_is_not_fab(self, mock_get_auth_manager): - mock_get_auth_manager.return_value = BaseAuthManager(self.app.appbuilder) - response = self.client.get("api/v1/roles", environ_overrides={"REMOTE_USER": "test"}) + mock_get_auth_manager.return_value = BaseAuthManager(self.flask_app.appbuilder) + response = self.client.get("api/v1/roles", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 assert ( - response.json["detail"] + response.json()["detail"] == "This endpoint is only available when using the default auth manager FabAuthManager." ) def test_get_role_forwards_to_fab(self): - resp = self.client.get("api/v1/roles/Test", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.get("api/v1/roles/Test", headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 def test_get_roles_forwards_to_fab(self): - resp = self.client.get("api/v1/roles", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.get("api/v1/roles", headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 def test_delete_role_forwards_to_fab(self): - role = create_role(self.app, "mytestrole") - resp = self.client.delete(f"api/v1/roles/{role.name}", environ_overrides={"REMOTE_USER": "test"}) + role = create_role(self.flask_app, "mytestrole") + resp = self.client.delete(f"api/v1/roles/{role.name}", headers={"REMOTE_USER": "test"}) assert resp.status_code == 204 def test_patch_role_forwards_to_fab(self): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") resp = self.client.patch( - f"api/v1/roles/{role.name}", json={"name": "Test2"}, environ_overrides={"REMOTE_USER": "test"} + f"api/v1/roles/{role.name}", json={"name": "Test2"}, headers={"REMOTE_USER": "test"} ) assert resp.status_code == 200 @@ -163,11 +164,11 @@ def test_post_role_forwards_to_fab(self): "name": "Test2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], } - resp = self.client.post("api/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.post("api/v1/roles", json=payload, headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 def test_get_role_permissions_forwards_to_fab(self): - resp = self.client.get("api/v1/permissions", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.get("api/v1/permissions", headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 @@ -192,29 +193,29 @@ def _create_users(self, count, roles=None): def test_get_user_forwards_to_fab(self): users = self._create_users(1) - session = self.app.appbuilder.get_session + session = self.flask_app.appbuilder.get_session session.add_all(users) session.commit() - resp = self.client.get("api/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.get("api/v1/users/TEST_USER1", headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 def test_get_users_forwards_to_fab(self): users = self._create_users(2) - session = self.app.appbuilder.get_session + session = self.flask_app.appbuilder.get_session session.add_all(users) session.commit() - resp = self.client.get("api/v1/users", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.get("api/v1/users", headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 def test_post_user_forwards_to_fab(self, autoclean_username, autoclean_user_payload): response = self.client.post( "/api/v1/users", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() - security_manager = self.app.appbuilder.sm + security_manager = self.flask_app.appbuilder.sm user = security_manager.find_user(autoclean_username) assert user is not None assert user.roles == [security_manager.find_role("Public")] @@ -225,14 +226,14 @@ def test_patch_user_forwards_to_fab(self, autoclean_username, autoclean_user_pay response = self.client.patch( f"/api/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() def test_delete_user_forwards_to_fab(self): users = self._create_users(1) - session = self.app.appbuilder.get_session + session = self.flask_app.appbuilder.get_session session.add_all(users) session.commit() - resp = self.client.delete("api/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.delete("api/v1/users/TEST_USER1", headers={"REMOTE_USER": "test"}) assert resp.status_code == 204 diff --git a/tests/api_connexion/endpoints/test_health_endpoint.py b/tests/api_connexion/endpoints/test_health_endpoint.py index 7d73b338e5105..3f68f75a4ba65 100644 --- a/tests/api_connexion/endpoints/test_health_endpoint.py +++ b/tests/api_connexion/endpoints/test_health_endpoint.py @@ -36,8 +36,8 @@ class TestHealthTestBase: @pytest.fixture(autouse=True) def setup_attrs(self, minimal_app_for_api) -> None: - self.app = minimal_app_for_api - self.client = self.app.test_client() # type:ignore + self.connexion_app = minimal_app_for_api + self.client = self.connexion_app.test_client() # type:ignore with create_session() as session: session.query(Job).delete() @@ -54,7 +54,8 @@ def test_healthy_scheduler_status(self, session): SchedulerJobRunner(job=job) session.add(job) session.commit() - resp_json = self.client.get("/api/v1/health").json + session.close() + resp_json = self.client.get("/api/v1/health").json() assert "healthy" == resp_json["metadatabase"]["status"] assert "healthy" == resp_json["scheduler"]["status"] assert ( @@ -69,7 +70,8 @@ def test_unhealthy_scheduler_is_slow(self, session): SchedulerJobRunner(job=job) session.add(job) session.commit() - resp_json = self.client.get("/api/v1/health").json + session.close() + resp_json = self.client.get("/api/v1/health").json() assert "healthy" == resp_json["metadatabase"]["status"] assert "unhealthy" == resp_json["scheduler"]["status"] assert ( @@ -78,7 +80,7 @@ def test_unhealthy_scheduler_is_slow(self, session): ) def test_unhealthy_scheduler_no_job(self): - resp_json = self.client.get("/api/v1/health").json + resp_json = self.client.get("/api/v1/health").json() assert "healthy" == resp_json["metadatabase"]["status"] assert "unhealthy" == resp_json["scheduler"]["status"] assert resp_json["scheduler"]["latest_scheduler_heartbeat"] is None @@ -86,6 +88,6 @@ def test_unhealthy_scheduler_no_job(self): @mock.patch.object(SchedulerJobRunner, "most_recent_job") def test_unhealthy_metadatabase_status(self, most_recent_job_mock): most_recent_job_mock.side_effect = Exception - resp_json = self.client.get("/api/v1/health").json + resp_json = self.client.get("/api/v1/health").json() assert "unhealthy" == resp_json["metadatabase"]["status"] assert resp_json["scheduler"]["latest_scheduler_heartbeat"] is None diff --git a/tests/api_connexion/endpoints/test_import_error_endpoint.py b/tests/api_connexion/endpoints/test_import_error_endpoint.py index 4549b74ae9975..cc73fea8a42e9 100644 --- a/tests/api_connexion/endpoints/test_import_error_endpoint.py +++ b/tests/api_connexion/endpoints/test_import_error_endpoint.py @@ -25,7 +25,7 @@ from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import ParseImportError from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_import_errors @@ -37,9 +37,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test", role_name="Test", permissions=[ @@ -47,16 +47,16 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), ], # type: ignore ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test_single_dag", role_name="TestSingleDAG", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore ) + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore # For some reason, DAG level permissions are not synced when in the above list of perms, # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( + connexion_app.app.appbuilder.sm.bulk_sync_roles( [ { "role": "TestSingleDAG", @@ -65,11 +65,11 @@ def configured_app(minimal_app_for_api): ] ) - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_single_dag") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test_single_dag") # type: ignore class TestBaseImportError: @@ -77,8 +77,8 @@ class TestBaseImportError: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore clear_db_import_errors() clear_db_dags() @@ -103,12 +103,10 @@ def test_response_200(self, session): session.add(import_error) session.commit() - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/importErrors/{import_error.id}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() response_data["import_error_id"] = 1 assert { "filename": "Lorem_ipsum.py", @@ -118,14 +116,14 @@ def test_response_200(self, session): } == response_data def test_response_404(self): - response = self.client.get("/api/v1/importErrors/2", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/importErrors/2", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "The ImportError with import_error_id: `2` was not found", "status": 404, "title": "Import error not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): import_error = ParseImportError( @@ -138,12 +136,10 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get(f"/api/v1/importErrors/{import_error.id}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/api/v1/importErrors", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 def test_should_raise_403_forbidden_without_dag_read(self, session): @@ -156,7 +152,7 @@ def test_should_raise_403_forbidden_without_dag_read(self, session): session.commit() response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + f"/api/v1/importErrors/{import_error.id}", headers={"REMOTE_USER": "test_single_dag"} ) assert response.status_code == 403 @@ -173,11 +169,11 @@ def test_should_return_200_with_single_dag_read(self, session): session.commit() response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + f"/api/v1/importErrors/{import_error.id}", headers={"REMOTE_USER": "test_single_dag"} ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() response_data["import_error_id"] = 1 assert { "filename": "Lorem_ipsum.py", @@ -199,11 +195,11 @@ def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, sessio session.commit() response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + f"/api/v1/importErrors/{import_error.id}", headers={"REMOTE_USER": "test_single_dag"} ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() response_data["import_error_id"] = 1 assert { "filename": "Lorem_ipsum.py", @@ -226,10 +222,10 @@ def test_get_import_errors(self, session): session.add_all(import_error) session.commit() - response = self.client.get("/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/importErrors", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() self._normalize_import_errors(response_data["import_errors"]) assert { "import_errors": [ @@ -262,11 +258,11 @@ def test_get_import_errors_order_by(self, session): session.commit() response = self.client.get( - "/api/v1/importErrors?order_by=-timestamp", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/importErrors?order_by=-timestamp", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() self._normalize_import_errors(response_data["import_errors"]) assert { "import_errors": [ @@ -298,13 +294,11 @@ def test_order_by_raises_400_for_invalid_attr(self, session): session.add_all(import_error) session.commit() - response = self.client.get( - "/api/v1/importErrors?order_by=timest", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/importErrors?order_by=timest", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'timest' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_should_raises_401_unauthenticated(self, session): import_error = [ @@ -320,7 +314,7 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get("/api/v1/importErrors") - assert_401(response) + assert response.status_code == 401 def test_get_import_errors_single_dag(self, session): for dag_id in TEST_DAG_IDS: @@ -335,12 +329,10 @@ def test_get_import_errors_single_dag(self, session): session.add(importerror) session.commit() - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) + response = self.client.get("/api/v1/importErrors", headers={"REMOTE_USER": "test_single_dag"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() self._normalize_import_errors(response_data["import_errors"]) assert { "import_errors": [ @@ -368,12 +360,10 @@ def test_get_import_errors_single_dag_in_dagfile(self, session): session.add(importerror) session.commit() - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) + response = self.client.get("/api/v1/importErrors", headers={"REMOTE_USER": "test_single_dag"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() self._normalize_import_errors(response_data["import_errors"]) assert { "import_errors": [ @@ -415,10 +405,10 @@ def test_limit_and_offset(self, url, expected_import_error_ids, session): session.add_all(import_errors) session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - import_ids = [pool["filename"] for pool in response.json["import_errors"]] + import_ids = [pool["filename"] for pool in response.json()["import_errors"]] assert import_ids == expected_import_error_ids def test_should_respect_page_size_limit_default(self, session): @@ -432,9 +422,9 @@ def test_should_respect_page_size_limit_default(self, session): ] session.add_all(import_errors) session.commit() - response = self.client.get("/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/importErrors", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["import_errors"]) == 100 + assert len(response.json()["import_errors"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, session): @@ -448,8 +438,6 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session): ] session.add_all(import_errors) session.commit() - response = self.client.get( - "/api/v1/importErrors?limit=180", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/importErrors?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["import_errors"]) == 150 + assert len(response.json()["import_errors"]) == 150 diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index d472b6902b3b1..05fce4e381629 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -33,7 +33,7 @@ from airflow.security import permissions from airflow.utils import timezone from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_runs pytestmark = pytest.mark.db_test @@ -41,10 +41,10 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, + connexion_app.app, username="test", role_name="Test", permissions=[ @@ -52,12 +52,12 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") - yield app + yield connexion_app - delete_user(app, username="test") - delete_user(app, username="test_no_permissions") + delete_user(connexion_app.app, username="test") + delete_user(connexion_app.app, username="test_no_permissions") class TestGetLog: @@ -71,8 +71,9 @@ class TestGetLog: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app, configure_loggers, dag_maker, session) -> None: - self.app = configured_app - self.client = self.app.test_client() + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # Make sure that the configure_logging is not cached self.old_modules = dict(sys.modules) @@ -92,7 +93,7 @@ def add_one(x: int): start_date=timezone.parse(self.default_time), ) - configured_app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) # Add dummy dag for checking picking correct log with same task_id and different dag_id case. with dag_maker( @@ -105,13 +106,15 @@ def add_one(x: int): execution_date=timezone.parse(self.default_time), start_date=timezone.parse(self.default_time), ) - configured_app.dag_bag.bag_dag(dummy_dag, root_dag=dummy_dag) + self.flask_app.dag_bag.bag_dag(dummy_dag, root_dag=dummy_dag) for ti in dr.task_instances: ti.try_number = 1 ti.hostname = "localhost" self.ti = dr.task_instances[0] + session.commit() + session.close() @pytest.fixture def configure_loggers(self, tmp_path, create_log_template): @@ -145,6 +148,11 @@ def configure_loggers(self, tmp_path, create_log_template): logging.config.dictConfig(logging_config) + create_log_template( + "dag_id={{ ti.dag_id }}/run_id={{ ti.run_id }}/task_id={{ ti.task_id }}/" + "{% if ti.map_index >= 0 %}map_index={{ ti.map_index }}/{% endif %}" + "attempt={{ try_number }}.log" + ) yield logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) @@ -153,23 +161,22 @@ def teardown_method(self): clear_db_runs() def test_should_respond_200_json(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": False}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token}, - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) expected_filename = ( f"{self.log_dir}/dag_id={self.DAG_ID}/run_id={self.RUN_ID}/task_id={self.TASK_ID}/attempt=1.log" ) assert ( - response.json["content"] + response.json()["content"] == f"[('localhost', '*** Found local files:\\n*** * {expected_filename}\\nLog for testing.')]" ) - info = serializer.loads(response.json["continuation_token"]) + info = serializer.loads(response.json()["continuation_token"]) assert info == {"end_of_log": True, "log_pos": 16} assert 200 == response.status_code @@ -191,19 +198,18 @@ def test_should_respond_200_json(self): def test_should_respond_200_text_plain(self, request_url, expected_filename, extra_query_string): expected_filename = expected_filename.replace("LOG_DIR", str(self.log_dir)) - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( request_url, - query_string={"token": token, **extra_query_string}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token, **extra_query_string}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert 200 == response.status_code assert ( - response.data.decode("utf-8") + response.text == f"localhost\n*** Found local files:\n*** * {expected_filename}\nLog for testing.\n" ) @@ -226,40 +232,39 @@ def test_get_logs_of_removed_task(self, request_url, expected_filename, extra_qu expected_filename = expected_filename.replace("LOG_DIR", str(self.log_dir)) # Recreate DAG without tasks - dagbag = self.app.dag_bag + dagbag = self.flask_app.dag_bag dag = DAG(self.DAG_ID, start_date=timezone.parse(self.default_time)) del dagbag.dags[self.DAG_ID] dagbag.bag_dag(dag=dag, root_dag=dag) - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( request_url, - query_string={"token": token, **extra_query_string}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token, **extra_query_string}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert 200 == response.status_code assert ( - response.data.decode("utf-8") + response.text == f"localhost\n*** Found local files:\n*** * {expected_filename}\nLog for testing.\n" ) def test_get_logs_response_with_ti_equal_to_none(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/Invalid-Task-ID/logs/1", - query_string={"token": token}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": None, "status": 404, "title": "TaskInstance not found", @@ -277,43 +282,40 @@ def test_get_logs_with_metadata_as_download_large_file(self): response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/" f"taskInstances/{self.TASK_ID}/logs/1?full_content=True", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) - assert "1st line" in response.data.decode("utf-8") - assert "2nd line" in response.data.decode("utf-8") - assert "3rd line" in response.data.decode("utf-8") - assert "should never be read" not in response.data.decode("utf-8") + assert "1st line" in response.text + assert "2nd line" in response.text + assert "3rd line" in response.text + assert "should never be read" not in response.text @mock.patch("airflow.api_connexion.endpoints.log_endpoint.TaskLogReader") def test_get_logs_for_handler_without_read_method(self, mock_log_reader): type(mock_log_reader.return_value).supports_read = PropertyMock(return_value=False) - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": False}) # check guessing response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token}, - headers={"Content-Type": "application/jso"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token}, + headers={"Content-Type": "application/json", "REMOTE_USER": "test"}, ) assert 400 == response.status_code - assert "Task log handler does not support read logs." in response.data.decode("utf-8") + assert "Task log handler does not support read logs." in response.text def test_bad_signature_raises(self): token = {"download_logs": False} response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token}, - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) - assert response.json == { + assert response.json() == { "detail": None, "status": 400, "title": "Bad Signature. Please use only the tokens provided by the API.", @@ -324,11 +326,10 @@ def test_raises_404_for_invalid_dag_run_id(self): response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/NO_DAG_RUN/" # invalid run_id f"taskInstances/{self.TASK_ID}/logs/1?", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": None, "status": 404, "title": "TaskInstance not found", @@ -336,55 +337,52 @@ def test_raises_404_for_invalid_dag_run_id(self): } def test_should_raises_401_unauthenticated(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": False}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token}, + params={"token": token}, headers={"Accept": "application/json"}, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + params={"token": token}, + headers={"Accept": "text/plain", "REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 def test_should_raise_404_when_missing_map_index_param_for_mapped_task(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.MAPPED_TASK_ID}/logs/1", - query_string={"token": token}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json["title"] == "TaskInstance not found" + assert response.json()["title"] == "TaskInstance not found" def test_should_raise_404_when_filtering_on_map_index_for_unmapped_task(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token, "map_index": 0}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token, "map_index": 0}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json["title"] == "TaskInstance not found" + assert response.json()["title"] == "TaskInstance not found" diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py index 78054f379efa4..dd835a99e595c 100644 --- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py @@ -33,7 +33,7 @@ from airflow.utils.session import provide_session from airflow.utils.state import State, TaskInstanceState from airflow.utils.timezone import datetime -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields from tests.test_utils.mock_operators import MockOperator @@ -48,9 +48,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -61,13 +61,13 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_roles(app) + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_roles(connexion_app.app) class TestMappedTaskInstanceEndpoint: @@ -87,8 +87,9 @@ def setup_attrs(self, configured_app) -> None: "queue": "default_queue", "job_id": 0, } - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore clear_db_runs() clear_db_sla_miss() clear_rendered_ti_fields() @@ -132,9 +133,9 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): setattr(ti, "start_date", DEFAULT_DATETIME_1) session.add(ti) - self.app.dag_bag = DagBag(os.devnull, include_examples=False) - self.app.dag_bag.dags = {dag_id: dag_maker.dag} # type: ignore - self.app.dag_bag.sync_to_db() # type: ignore + self.flask_app.dag_bag = DagBag(os.devnull, include_examples=False) + self.flask_app.dag_bag.dags = {dag_id: dag_maker.dag} # type: ignore + self.flask_app.dag_bag.sync_to_db() # type: ignore session.flush() mapped.expand_mapped_task(dr.run_id, session=session) @@ -201,10 +202,10 @@ class TestNonExistent(TestMappedTaskInstanceEndpoint): def test_non_existent_task_instance(self, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json["title"] == "DAG mapped_tis not found" + assert response.json()["title"] == "DAG mapped_tis not found" class TestGetMappedTaskInstance(TestMappedTaskInstanceEndpoint): @@ -212,10 +213,10 @@ class TestGetMappedTaskInstance(TestMappedTaskInstanceEndpoint): def test_mapped_task_instances(self, one_task_with_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/0", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "mapped_tis", "dag_run_id": "run_mapped_tis", "duration": None, @@ -251,22 +252,22 @@ def test_should_raises_401_unauthenticated(self): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/1", ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 def test_without_map_index_returns_custom_404(self, one_task_with_mapped_tis): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "Task instance is mapped, add the map_index value to the URL", "status": 404, "title": "Task instance not found", @@ -276,20 +277,20 @@ def test_without_map_index_returns_custom_404(self, one_task_with_mapped_tis): def test_one_mapped_task_works(self, one_task_with_single_mapped_ti): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/0", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/1", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "Task instance is mapped, add the map_index value to the URL", "status": 404, "title": "Task instance not found", @@ -302,71 +303,73 @@ class TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint): def test_mapped_task_instances(self, one_task_with_many_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 100 + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 100 @provide_session def test_mapped_task_instances_offset_limit(self, one_task_with_many_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?offset=4&limit=10", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 10 - assert list(range(4, 14)) == [ti["map_index"] for ti in response.json["task_instances"]] + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 10 + assert list(range(4, 14)) == [ti["map_index"] for ti in response.json()["task_instances"]] @provide_session def test_mapped_task_instances_order(self, one_task_with_many_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 100 - assert list(range(100)) == [ti["map_index"] for ti in response.json["task_instances"]] + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 100 + assert list(range(100)) == [ti["map_index"] for ti in response.json()["task_instances"]] @provide_session def test_mapped_task_instances_reverse_order(self, one_task_with_many_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?order_by=-map_index", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 100 - assert list(range(109, 9, -1)) == [ti["map_index"] for ti in response.json["task_instances"]] + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 100 + assert list(range(109, 9, -1)) == [ti["map_index"] for ti in response.json()["task_instances"]] @provide_session def test_mapped_task_instances_state_order(self, one_task_with_many_mapped_tis, session): + session.commit() + session.close() response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?order_by=-state", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 100 + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 100 assert list(range(5)) + list(range(25, 110)) + list(range(5, 15)) == [ - ti["map_index"] for ti in response.json["task_instances"] + ti["map_index"] for ti in response.json()["task_instances"] ] # State ascending response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?order_by=state", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 100 + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 100 assert list(range(5, 25)) + list(range(90, 110)) + list(range(25, 85)) == [ - ti["map_index"] for ti in response.json["task_instances"] + ti["map_index"] for ti in response.json()["task_instances"] ] @provide_session @@ -374,85 +377,85 @@ def test_mapped_task_instances_invalid_order(self, one_task_with_many_mapped_tis response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?order_by=unsupported", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == "Ordering with 'unsupported' is not supported" + assert response.json()["detail"] == "Ordering with 'unsupported' is not supported" @provide_session def test_mapped_task_instances_with_date(self, one_task_with_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" f"?start_date_gte={QUOTED_DEFAULT_DATETIME_STR_1}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 3 - assert len(response.json["task_instances"]) == 3 + assert response.json()["total_entries"] == 3 + assert len(response.json()["task_instances"]) == 3 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" f"?start_date_gte={QUOTED_DEFAULT_DATETIME_STR_2}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 0 - assert response.json["task_instances"] == [] + assert response.json()["total_entries"] == 0 + assert response.json()["task_instances"] == [] @provide_session def test_mapped_task_instances_with_state(self, one_task_with_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped?state=success", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 3 - assert len(response.json["task_instances"]) == 3 + assert response.json()["total_entries"] == 3 + assert len(response.json()["task_instances"]) == 3 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped?state=running", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 0 - assert response.json["task_instances"] == [] + assert response.json()["total_entries"] == 0 + assert response.json()["task_instances"] == [] @provide_session def test_mapped_task_instances_with_pool(self, one_task_with_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?pool=default_pool", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 3 - assert len(response.json["task_instances"]) == 3 + assert response.json()["total_entries"] == 3 + assert len(response.json()["task_instances"]) == 3 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped?pool=test_pool", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 0 - assert response.json["task_instances"] == [] + assert response.json()["total_entries"] == 0 + assert response.json()["task_instances"] == [] @provide_session def test_mapped_task_instances_with_queue(self, one_task_with_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped?queue=default", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 3 - assert len(response.json["task_instances"]) == 3 + assert response.json()["total_entries"] == 3 + assert len(response.json()["task_instances"]) == 3 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped?queue=test_queue", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 0 - assert response.json["task_instances"] == [] + assert response.json()["total_entries"] == 0 + assert response.json()["task_instances"] == [] @provide_session def test_mapped_task_instances_with_executor(self, one_task_with_mapped_tis, session): @@ -476,16 +479,16 @@ def test_mapped_task_instances_with_executor(self, one_task_with_mapped_tis, ses def test_mapped_task_instances_with_zero_mapped(self, one_task_with_zero_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 0 - assert response.json["task_instances"] == [] + assert response.json()["total_entries"] == 0 + assert response.json()["task_instances"] == [] def test_should_raise_404_not_found_for_nonexistent_task(self): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/nonexistent_task/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json["title"] == "Task id nonexistent_task not found" + assert response.json()["title"] == "Task id nonexistent_task not found" diff --git a/tests/api_connexion/endpoints/test_plugin_endpoint.py b/tests/api_connexion/endpoints/test_plugin_endpoint.py index 0206c1ff0fc7d..c120f522f240c 100644 --- a/tests/api_connexion/endpoints/test_plugin_endpoint.py +++ b/tests/api_connexion/endpoints/test_plugin_endpoint.py @@ -28,7 +28,7 @@ from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.base import Timetable from airflow.utils.module_loading import qualname -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import BaseOperatorLink from tests.test_utils.config import conf_vars from tests.test_utils.mock_plugins import mock_plugin_manager @@ -103,19 +103,19 @@ class MockPlugin(AirflowPlugin): @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestPluginsEndpoint: @@ -124,8 +124,8 @@ def setup_attrs(self, configured_app) -> None: """ Setup For XCom endpoint TC """ - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore class TestGetPlugins(TestPluginsEndpoint): @@ -133,9 +133,9 @@ def test_get_plugins_return_200(self): mock_plugin = MockPlugin() mock_plugin.name = "test_plugin" with mock_plugin_manager(plugins=[mock_plugin]): - response = self.client.get("api/v1/plugins", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/plugins", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "plugins": [ { "appbuilder_menu_items": [appbuilder_menu_items], @@ -167,24 +167,22 @@ def test_get_plugins_works_with_more_plugins(self): mock_plugin_2 = AirflowPlugin() mock_plugin_2.name = "test_plugin2" with mock_plugin_manager(plugins=[mock_plugin, mock_plugin_2]): - response = self.client.get("api/v1/plugins", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/plugins", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 2 + assert response.json()["total_entries"] == 2 def test_get_plugins_return_200_if_no_plugins(self): with mock_plugin_manager(plugins=[]): - response = self.client.get("api/v1/plugins", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/plugins", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/plugins") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/api/v1/plugins", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/api/v1/plugins", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @@ -230,35 +228,35 @@ class TestGetPluginsPagination(TestPluginsEndpoint): def test_handle_limit_offset(self, url, expected_plugin_names): plugins = self._create_plugins(10) with mock_plugin_manager(plugins=plugins): - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - plugin_names = [plugin["name"] for plugin in response.json["plugins"] if plugin] + assert response.json()["total_entries"] == 10 + plugin_names = [plugin["name"] for plugin in response.json()["plugins"] if plugin] assert plugin_names == expected_plugin_names def test_should_respect_page_size_limit_default(self): plugins = self._create_plugins(200) with mock_plugin_manager(plugins=plugins): - response = self.client.get("/api/v1/plugins", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/plugins", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["plugins"]) == 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["plugins"]) == 100 def test_limit_of_zero_should_return_default(self): plugins = self._create_plugins(200) with mock_plugin_manager(plugins=plugins): - response = self.client.get("/api/v1/plugins?limit=0", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/plugins?limit=0", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["plugins"]) == 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["plugins"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self): plugins = self._create_plugins(200) with mock_plugin_manager(plugins=plugins): - response = self.client.get("/api/v1/plugins?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/plugins?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["plugins"]) == 150 + assert len(response.json()["plugins"]) == 150 def _create_plugins(self, count): plugins = [] diff --git a/tests/api_connexion/endpoints/test_pool_endpoint.py b/tests/api_connexion/endpoints/test_pool_endpoint.py index f709bda9a1ed6..b7b56c59f5464 100644 --- a/tests/api_connexion/endpoints/test_pool_endpoint.py +++ b/tests/api_connexion/endpoints/test_pool_endpoint.py @@ -22,7 +22,7 @@ from airflow.models.pool import Pool from airflow.security import permissions from airflow.utils.session import provide_session -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_pools from tests.test_utils.www import _check_last_log @@ -32,10 +32,10 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -45,19 +45,19 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestBasePoolEndpoints: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore clear_db_pools() def teardown_method(self) -> None: @@ -69,9 +69,10 @@ def test_response_200(self, session): pool_model = Pool(pool="test_pool_a", slots=3, include_deferred=True) session.add(pool_model) session.commit() + session.close() result = session.query(Pool).all() assert len(result) == 2 # accounts for the default pool as well - response = self.client.get("/api/v1/pools", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/pools", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "pools": [ @@ -101,15 +102,16 @@ def test_response_200(self, session): }, ], "total_entries": 2, - } == response.json + } == response.json() def test_response_200_with_order_by(self, session): pool_model = Pool(pool="test_pool_a", slots=3, include_deferred=True) session.add(pool_model) session.commit() + session.close() result = session.query(Pool).all() assert len(result) == 2 # accounts for the default pool as well - response = self.client.get("/api/v1/pools?order_by=slots", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/pools?order_by=slots", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "pools": [ @@ -139,15 +141,15 @@ def test_response_200_with_order_by(self, session): }, ], "total_entries": 2, - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/pools") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get("/api/v1/pools", environ_overrides={"REMOTE_USER": "test_no_permissions"}) + response = self.client.get("/api/v1/pools", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @@ -178,46 +180,48 @@ def test_limit_and_offset(self, url, expected_pool_ids, session): pools = [Pool(pool=f"test_pool{i}", slots=1, include_deferred=False) for i in range(1, 121)] session.add_all(pools) session.commit() + session.close() result = session.query(Pool).count() assert result == 121 # accounts for default pool as well - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - pool_ids = [pool["name"] for pool in response.json["pools"]] + pool_ids = [pool["name"] for pool in response.json()["pools"]] assert pool_ids == expected_pool_ids def test_should_respect_page_size_limit_default(self, session): pools = [Pool(pool=f"test_pool{i}", slots=1, include_deferred=False) for i in range(1, 121)] session.add_all(pools) session.commit() + session.close() result = session.query(Pool).count() assert result == 121 - response = self.client.get("/api/v1/pools", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/pools", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["pools"]) == 100 + assert len(response.json()["pools"]) == 100 def test_should_raise_400_for_invalid_orderby(self, session): pools = [Pool(pool=f"test_pool{i}", slots=1, include_deferred=False) for i in range(1, 121)] session.add_all(pools) session.commit() + session.close() result = session.query(Pool).count() assert result == 121 - response = self.client.get( - "/api/v1/pools?order_by=open_slots", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/pools?order_by=open_slots", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'open_slots' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, session): pools = [Pool(pool=f"test_pool{i}", slots=1, include_deferred=False) for i in range(1, 200)] session.add_all(pools) session.commit() + session.close() result = session.query(Pool).count() assert result == 200 - response = self.client.get("/api/v1/pools?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/pools?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["pools"]) == 150 + assert len(response.json()["pools"]) == 150 class TestGetPool(TestBasePoolEndpoints): @@ -225,7 +229,8 @@ def test_response_200(self, session): pool_model = Pool(pool="test_pool_a", slots=3, include_deferred=True) session.add(pool_model) session.commit() - response = self.client.get("/api/v1/pools/test_pool_a", environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.get("/api/v1/pools/test_pool_a", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "name": "test_pool_a", @@ -238,22 +243,22 @@ def test_response_200(self, session): "open_slots": 3, "description": None, "include_deferred": True, - } == response.json + } == response.json() def test_response_404(self): - response = self.client.get("/api/v1/pools/invalid_pool", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/pools/invalid_pool", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "Pool with name:'invalid_pool' not found", "status": 404, "title": "Not Found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/pools/default_pool") - assert_401(response) + assert response.status_code == 401 class TestDeletePool(TestBasePoolEndpoints): @@ -262,45 +267,57 @@ def test_response_204(self, session): pool_instance = Pool(pool=pool_name, slots=3, include_deferred=False) session.add(pool_instance) session.commit() - - response = self.client.delete(f"api/v1/pools/{pool_name}", environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.delete(f"api/v1/pools/{pool_name}", headers={"REMOTE_USER": "test"}) assert response.status_code == 204 # Check if the pool is deleted from the db - response = self.client.get(f"api/v1/pools/{pool_name}", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(f"api/v1/pools/{pool_name}", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 _check_last_log(session, dag_id=None, event="api.delete_pool", execution_date=None) def test_response_404(self): - response = self.client.delete("api/v1/pools/invalid_pool", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.delete("api/v1/pools/invalid_pool", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "Pool with name:'invalid_pool' not found", "status": 404, "title": "Not Found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): pool_name = "test_pool" pool_instance = Pool(pool=pool_name, slots=3, include_deferred=False) session.add(pool_instance) session.commit() - + session.close() response = self.client.delete(f"api/v1/pools/{pool_name}") - assert_401(response) + assert response.status_code == 401 # Should still exists - response = self.client.get(f"/api/v1/pools/{pool_name}", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(f"/api/v1/pools/{pool_name}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 + def test_response_204(self, session): + pool_name = "test_pool" + pool_instance = Pool(pool=pool_name, slots=3, include_deferred=False) + session.add(pool_instance) + session.commit() + session.close() + response = self.client.delete(f"api/v1/pools/{pool_name}", headers={"REMOTE_USER": "test"}) + assert response.status_code == 204 + # Check if the pool is deleted from the db + response = self.client.get(f"api/v1/pools/{pool_name}", headers={"REMOTE_USER": "test"}) + assert response.status_code == 404 + class TestPostPool(TestBasePoolEndpoints): def test_response_200(self, session): response = self.client.post( "api/v1/pools", json={"name": "test_pool_a", "slots": 3, "description": "test pool", "include_deferred": True}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert { @@ -314,7 +331,7 @@ def test_response_200(self, session): "open_slots": 3, "description": "test pool", "include_deferred": True, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.post_pool", execution_date=None) def test_response_409(self, session): @@ -322,10 +339,11 @@ def test_response_409(self, session): pool_instance = Pool(pool=pool_name, slots=3, include_deferred=False) session.add(pool_instance) session.commit() + session.close() response = self.client.post( "api/v1/pools", json={"name": "test_pool_a", "slots": 3, "include_deferred": False}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 409 assert { @@ -333,7 +351,7 @@ def test_response_409(self, session): "status": 409, "title": "Conflict", "type": EXCEPTIONS_LINK_MAP[409], - } == response.json + } == response.json() @pytest.mark.parametrize( "request_json, error_detail", @@ -361,21 +379,19 @@ def test_response_409(self, session): ], ) def test_response_400(self, request_json, error_detail): - response = self.client.post( - "api/v1/pools", json=request_json, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("api/v1/pools", json=request_json, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 assert { "detail": error_detail, "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.post("api/v1/pools", json={"name": "test_pool_a", "slots": 3}) - assert_401(response) + assert response.status_code == 401 class TestPatchPool(TestBasePoolEndpoints): @@ -383,10 +399,11 @@ def test_response_200(self, session): pool = Pool(pool="test_pool", slots=2, include_deferred=True) session.add(pool) session.commit() + session.close() response = self.client.patch( "api/v1/pools/test_pool", json={"name": "test_pool_a", "slots": 3, "include_deferred": False}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert { @@ -400,7 +417,7 @@ def test_response_200(self, session): "slots": 3, "description": None, "include_deferred": False, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_pool", execution_date=None) @pytest.mark.parametrize( @@ -422,8 +439,9 @@ def test_response_400(self, error_detail, request_json, session): pool = Pool(pool="test_pool", slots=2, include_deferred=False) session.add(pool) session.commit() + session.close() response = self.client.patch( - "api/v1/pools/test_pool", json=request_json, environ_overrides={"REMOTE_USER": "test"} + "api/v1/pools/test_pool", json=request_json, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 assert { @@ -431,13 +449,13 @@ def test_response_400(self, error_detail, request_json, session): "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], - } == response.json + } == response.json() def test_not_found_when_no_pool_available(self): response = self.client.patch( "api/v1/pools/test_pool", json={"name": "test_pool_a", "slots": 3}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 assert { @@ -445,31 +463,31 @@ def test_not_found_when_no_pool_available(self): "status": 404, "title": "Not Found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): pool = Pool(pool="test_pool", slots=2, include_deferred=False) session.add(pool) session.commit() - + session.close() response = self.client.patch( "api/v1/pools/test_pool", json={"name": "test_pool_a", "slots": 3}, ) - assert_401(response) + assert response.status_code == 401 class TestModifyDefaultPool(TestBasePoolEndpoints): def test_delete_400(self): - response = self.client.delete("api/v1/pools/default_pool", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.delete("api/v1/pools/default_pool", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 assert { "detail": "Default Pool can't be deleted", "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], - } == response.json + } == response.json() @pytest.mark.parametrize( "status_code, url, json, expected_response", @@ -595,9 +613,9 @@ def test_delete_400(self): ], ) def test_patch(self, status_code, url, json, expected_response, session): - response = self.client.patch(url, json=json, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.patch(url, json=json, headers={"REMOTE_USER": "test"}) assert response.status_code == status_code - assert response.json == expected_response + assert response.json() == expected_response _check_last_log(session, dag_id=None, event="api.patch_pool", execution_date=None) @@ -649,7 +667,8 @@ def test_response_200( pool = Pool(pool="test_pool", slots=3, include_deferred=False) session.add(pool) session.commit() - response = self.client.patch(url, json=patch_json, environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.patch(url, json=patch_json, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "name": expected_name, @@ -662,20 +681,20 @@ def test_response_200( "open_slots": expected_slots, "description": None, "include_deferred": expected_include_deferred, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_pool", execution_date=None) @pytest.mark.parametrize( "error_detail, url, patch_json", [ pytest.param( - "Property is read-only - 'occupied_slots'", + "{'occupied_slots': ['Unknown field.']}", "api/v1/pools/test_pool?update_mask=slots, name, occupied_slots", {"name": "test_pool_a", "slots": 2, "occupied_slots": 1}, id="Patching read only field", ), pytest.param( - "Property is read-only - 'queued_slots'", + "{'queued_slots': ['Unknown field.']}", "api/v1/pools/test_pool?update_mask=slots, name, queued_slots", {"name": "test_pool_a", "slots": 2, "queued_slots": 1}, id="Patching read only field", @@ -699,11 +718,12 @@ def test_response_400(self, error_detail, url, patch_json, session): pool = Pool(pool="test_pool", slots=3, include_deferred=False) session.add(pool) session.commit() - response = self.client.patch(url, json=patch_json, environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.patch(url, json=patch_json, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 assert { "detail": error_detail, "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], - } == response.json + } == response.json() diff --git a/tests/api_connexion/endpoints/test_provider_endpoint.py b/tests/api_connexion/endpoints/test_provider_endpoint.py index 7c973a9bb4132..fec203cdab1d3 100644 --- a/tests/api_connexion/endpoints/test_provider_endpoint.py +++ b/tests/api_connexion/endpoints/test_provider_endpoint.py @@ -52,26 +52,26 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestBaseProviderEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app, cleanup_providers_manager) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore class TestGetProviders(TestBaseProviderEndpoint): @@ -81,9 +81,9 @@ class TestGetProviders(TestBaseProviderEndpoint): return_value={}, ) def test_response_200_empty_list(self, mock_providers): - response = self.client.get("/api/v1/providers", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/providers", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == {"providers": [], "total_entries": 0} + assert response.json() == {"providers": [], "total_entries": 0} @mock.patch( "airflow.providers_manager.ProvidersManager.providers", @@ -91,9 +91,9 @@ def test_response_200_empty_list(self, mock_providers): return_value=MOCK_PROVIDERS, ) def test_response_200(self, mock_providers): - response = self.client.get("/api/v1/providers", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/providers", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "providers": [ { "description": "Amazon Web Services (AWS) https://aws.amazon.com/", @@ -114,7 +114,5 @@ def test_should_raises_401_unauthenticated(self): assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/api/v1/providers", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/api/v1/providers", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index 64b7f4a8c5b7d..c4b9fcb95be45 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -28,7 +28,7 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator from airflow.security import permissions -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags pytestmark = pytest.mark.db_test @@ -36,9 +36,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -47,12 +47,12 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestTaskEndpoint: @@ -80,7 +80,7 @@ def setup_dag(self, configured_app): task1 >> task2 dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = {dag.dag_id: dag, mapped_dag.dag_id: mapped_dag} - configured_app.dag_bag = dag_bag # type:ignore + configured_app.app.dag_bag = dag_bag # type:ignore @staticmethod def clean_db(): @@ -91,8 +91,9 @@ def clean_db(): @pytest.fixture(autouse=True) def setup_attrs(self, configured_app, setup_dag) -> None: self.clean_db() - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore def teardown_method(self) -> None: self.clean_db() @@ -140,10 +141,10 @@ def test_should_respond_200(self): "doc_md": None, } response = self.client.get( - f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected def test_mapped_task(self): expected = { @@ -177,17 +178,17 @@ def test_mapped_task(self): } response = self.client.get( f"/api/v1/dags/{self.mapped_dag_id}/tasks/{self.mapped_task_id}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected def test_should_respond_200_serialized(self): # Get the dag out of the dagbag before we patch it to an empty one - SerializedDagModel.write_dag(self.app.dag_bag.get_dag(self.dag_id)) + SerializedDagModel.write_dag(self.flask_app.dag_bag.get_dag(self.dag_id)) dag_bag = DagBag(os.devnull, include_examples=False, read_dags_from_db=True) - patcher = unittest.mock.patch.object(self.app, "dag_bag", dag_bag) + patcher = unittest.mock.patch.object(self.flask_app, "dag_bag", dag_bag) patcher.start() expected = { @@ -230,35 +231,35 @@ def test_should_respond_200_serialized(self): "doc_md": None, } response = self.client.get( - f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected patcher.stop() def test_should_respond_404(self): task_id = "xxxx_not_existing" response = self.client.get( - f"/api/v1/dags/{self.dag_id}/tasks/{task_id}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/{self.dag_id}/tasks/{task_id}", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 def test_should_respond_404_when_dag_not_found(self): dag_id = "xxxx_not_existing" response = self.client.get( - f"/api/v1/dags/{dag_id}/tasks/{self.task_id}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/{dag_id}/tasks/{self.task_id}", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 - assert response.json["title"] == "DAG not found" + assert response.json()["title"] == "DAG not found" def test_should_raises_401_unauthenticated(self): response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( - f"/api/v1/dags/{self.dag_id}/tasks", environ_overrides={"REMOTE_USER": "test_no_permissions"} + f"/api/v1/dags/{self.dag_id}/tasks", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 @@ -341,11 +342,9 @@ def test_should_respond_200(self): ], "total_entries": 2, } - response = self.client.get( - f"/api/v1/dags/{self.dag_id}/tasks", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected def test_get_tasks_mapped(self): expected = { @@ -415,46 +414,48 @@ def test_get_tasks_mapped(self): "total_entries": 2, } response = self.client.get( - f"/api/v1/dags/{self.mapped_dag_id}/tasks", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/{self.mapped_dag_id}/tasks", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected def test_should_respond_200_ascending_order_by_start_date(self): response = self.client.get( f"/api/v1/dags/{self.dag_id}/tasks?order_by=start_date", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert self.task1_start_date < self.task2_start_date - assert response.json["tasks"][0]["task_id"] == self.task_id - assert response.json["tasks"][1]["task_id"] == self.task_id2 + assert response.json()["tasks"][0]["task_id"] == self.task_id + assert response.json()["tasks"][1]["task_id"] == self.task_id2 def test_should_respond_200_descending_order_by_start_date(self): response = self.client.get( f"/api/v1/dags/{self.dag_id}/tasks?order_by=-start_date", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 # - means is descending assert self.task1_start_date < self.task2_start_date - assert response.json["tasks"][0]["task_id"] == self.task_id2 - assert response.json["tasks"][1]["task_id"] == self.task_id + assert response.json()["tasks"][0]["task_id"] == self.task_id2 + assert response.json()["tasks"][1]["task_id"] == self.task_id def test_should_raise_400_for_invalid_order_by_name(self): response = self.client.get( f"/api/v1/dags/{self.dag_id}/tasks?order_by=invalid_task_colume_name", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == "'EmptyOperator' object has no attribute 'invalid_task_colume_name'" + assert ( + response.json()["detail"] == "'EmptyOperator' object has no attribute 'invalid_task_colume_name'" + ) def test_should_respond_404(self): dag_id = "xxxx_not_existing" - response = self.client.get(f"/api/v1/dags/{dag_id}/tasks", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(f"/api/v1/dags/{dag_id}/tasks", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 def test_should_raises_401_unauthenticated(self): response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks") - assert_401(response) + assert response.status_code == 401 diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index 62ae45c1bd077..aae47e8331f6a 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -36,7 +36,7 @@ from airflow.utils.state import State from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields from tests.test_utils.www import _check_last_log @@ -52,9 +52,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -67,7 +67,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_dag_read_only", role_name="TestDagReadOnly", permissions=[ @@ -78,7 +78,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_task_read_only", role_name="TestTaskReadOnly", permissions=[ @@ -89,7 +89,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_read_only_one_dag", role_name="TestReadOnlyOneDag", permissions=[ @@ -99,7 +99,7 @@ def configured_app(minimal_app_for_api): ) # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms, # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( + connexion_app.app.appbuilder.sm.bulk_sync_roles( [ { "role": "TestReadOnlyOneDag", @@ -107,16 +107,16 @@ def configured_app(minimal_app_for_api): } ] ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_dag_read_only") # type: ignore - delete_user(app, username="test_task_read_only") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_read_only_one_dag") # type: ignore - delete_roles(app) + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_dag_read_only") # type: ignore + delete_user(connexion_app.app, username="test_task_read_only") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test_read_only_one_dag") # type: ignore + delete_roles(connexion_app.app) class TestTaskInstanceEndpoint: @@ -136,8 +136,9 @@ def setup_attrs(self, configured_app, dagbag) -> None: "queue": "default_queue", "job_id": 0, } - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore clear_db_runs() clear_db_sla_miss() clear_rendered_ti_fields() @@ -196,6 +197,7 @@ def create_task_instances( tis.append(ti) session.commit() + session.close() return tis @@ -217,12 +219,13 @@ def test_should_respond_200(self, username, session): # https://github.com/apache/airflow/issues/14421 session.query(TaskInstance).update({TaskInstance.operator: None}, synchronize_session="fetch") session.commit() + session.close() response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -264,12 +267,14 @@ def test_should_respond_200_with_task_state_in_deferred(self, session): ti.triggerer_job = Job() TriggererJobRunner(job=ti.triggerer_job) ti.triggerer_job.state = "running" + session.merge(ti) session.commit() + session.close() response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - data = response.json + data = response.json() # this logic in effect replicates mock.ANY for these values values_to_ignore = { @@ -326,10 +331,10 @@ def test_should_respond_200_with_task_state_in_removed(self, session): self.create_task_instances(session, task_instances=[{"state": State.REMOVED}], update_extras=True) response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -374,13 +379,14 @@ def test_should_respond_200_task_instance_with_sla_and_rendered(self, session): rendered_fields = RTIF(tis[0], render_templates=False) session.add(rendered_fields) session.commit() + session.close() response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -431,17 +437,18 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session): setattr(ti, attr, getattr(old_ti, attr)) session.add(ti) session.commit() + session.close() # in each loop, we should get the right mapped TI back for map_index in (1, 2): response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances" f"/print_the_context/{map_index}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -477,28 +484,28 @@ def test_should_raises_401_unauthenticated(self): response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 def test_raises_404_for_nonexistent_task_instance(self): response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/nonexistent_task", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json["title"] == "Task instance not found" + assert response.json()["title"] == "Task instance not found" def test_unmapped_map_index_should_return_404(self, session): self.create_task_instances(session) response = self.client.get( - "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/-1", - environ_overrides={"REMOTE_USER": "test"}, + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/-6", + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -508,7 +515,7 @@ def test_should_return_404_for_mapped_endpoint(self, session): response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/" f"taskInstances/print_the_context/{index}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -517,7 +524,7 @@ def test_should_return_404_for_list_mapped_endpoint(self, session): response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/" "taskInstances/print_the_context/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -705,10 +712,10 @@ def test_should_respond_200(self, task_instances, update_extras, url, expected_t update_extras=update_extras, task_instances=task_instances, ) - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == expected_ti - assert len(response.json["task_instances"]) == expected_ti + assert response.json()["total_entries"] == expected_ti + assert len(response.json()["task_instances"]) == expected_ti @pytest.mark.parametrize( "task_instances, user, expected_ti", @@ -749,36 +756,34 @@ def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ ], dag_id=dag_id, ) - response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} - ) + response = self.client.get("/api/v1/dags/~/dagRuns/~/taskInstances", headers={"REMOTE_USER": user}) assert response.status_code == 200 - assert response.json["total_entries"] == expected_ti - assert len(response.json["task_instances"]) == expected_ti + assert response.json()["total_entries"] == expected_ti + assert len(response.json()["task_instances"]) == expected_ti def test_should_respond_200_for_dag_id_filter(self, session): self.create_task_instances(session) self.create_task_instances(session, dag_id="example_skip_dag") response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/~/taskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 count = session.query(TaskInstance).filter(TaskInstance.dag_id == "example_python_operator").count() - assert count == response.json["total_entries"] - assert count == len(response.json["task_instances"]) + assert count == response.json()["total_entries"] + assert count == len(response.json()["task_instances"]) def test_should_raises_401_unauthenticated(self): response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/~/taskInstances", ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/~/taskInstances", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -951,12 +956,12 @@ def test_should_respond_200( ) response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, json=payload, ) - assert response.status_code == 200, response.json - assert expected_ti_count == response.json["total_entries"] - assert expected_ti_count == len(response.json["task_instances"]) + assert response.status_code == 200, response.json() + assert expected_ti_count == response.json()["total_entries"] + assert expected_ti_count == len(response.json()["task_instances"]) @pytest.mark.parametrize( "task_instances, payload, expected_ti_count", @@ -990,12 +995,12 @@ def test_should_respond_200_when_task_instance_properties_are_none( ) response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) - assert response.status_code == 200, response.json - assert expected_ti_count == response.json["total_entries"] - assert expected_ti_count == len(response.json["task_instances"]) + assert response.status_code == 200, response.json() + assert expected_ti_count == response.json()["total_entries"] + assert expected_ti_count == len(response.json()["task_instances"]) @pytest.mark.parametrize( "payload, expected_ti, total_ti", @@ -1014,24 +1019,24 @@ def test_should_respond_200_dag_ids_filter(self, payload, expected_ti, total_ti, self.create_task_instances(session, dag_id="example_skip_dag") response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 200 - assert len(response.json["task_instances"]) == expected_ti - assert response.json["total_entries"] == total_ti + assert len(response.json()["task_instances"]) == expected_ti + assert response.json()["total_entries"] == total_ti def test_should_raises_401_unauthenticated(self): response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", json={"dag_ids": ["example_python_operator", "example_skip_dag"]}, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, json={"dag_ids": ["example_python_operator", "example_skip_dag"]}, ) assert response.status_code == 403 @@ -1043,11 +1048,11 @@ def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, sess response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test_read_only_one_dag"}, + headers={"REMOTE_USER": "test_read_only_one_dag"}, json=payload, ) assert response.status_code == 403 - assert response.json == { + assert response.json() == { "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']", "status": 403, "title": "Forbidden", @@ -1057,19 +1062,19 @@ def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, sess def test_should_raise_400_for_no_json(self): response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == "Request body must not be empty" + assert response.json()["detail"] == "RequestBody is required" def test_should_raise_400_for_unknown_fields(self): response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={"unknown_field": "unknown_value"}, ) assert response.status_code == 400 - assert response.json["detail"] == "{'unknown_field': ['Unknown field.']}" + assert response.json()["detail"] == "{'unknown_field': ['Unknown field.']}" @pytest.mark.parametrize( "payload, expected", @@ -1087,11 +1092,11 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se self.create_task_instances(session) response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 400 - assert expected in response.json["detail"] + assert expected in response.json()["detail"] class TestPostClearTaskInstances(TestTaskInstanceEndpoint): @@ -1287,14 +1292,14 @@ def test_should_respond_200(self, main_dag, task_instances, request_dag, payload task_instances=task_instances, update_extras=False, ) - self.app.dag_bag.sync_to_db() + self.flask_app.dag_bag.sync_to_db() response = self.client.post( f"/api/v1/dags/{request_dag}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 200 - assert len(response.json["task_instances"]) == expected_ti + assert len(response.json()["task_instances"]) == expected_ti _check_last_log( session, dag_id=request_dag, @@ -1309,15 +1314,15 @@ def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, s self.create_task_instances(session) dag_id = "example_python_operator" payload = {"include_subdags": True, "reset_dag_runs": True, "dry_run": False} - self.app.dag_bag.sync_to_db() + self.flask_app.dag_bag.sync_to_db() response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 200 mock_clearti.assert_called_once_with( - [], session, dag=self.app.dag_bag.get_dag(dag_id), dag_run_state=State.QUEUED + [], mock.ANY, dag=self.flask_app.dag_bag.get_dag(dag_id), dag_run_state=State.QUEUED ) _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) @@ -1329,10 +1334,10 @@ def test_clear_taskinstance_is_called_with_invalid_task_ids(self, session): assert dagrun.state == "running" payload = {"dry_run": False, "reset_dag_runs": True, "task_ids": [""]} - self.app.dag_bag.sync_to_db() + self.flask_app.dag_bag.sync_to_db() response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 200 @@ -1384,7 +1389,7 @@ def test_should_respond_200_with_reset_dag_run(self, session): ) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) @@ -1429,8 +1434,8 @@ def test_should_respond_200_with_reset_dag_run(self, session): }, ] for task_instance in expected_response: - assert task_instance in response.json["task_instances"] - assert 6 == len(response.json["task_instances"]) + assert task_instance in response.json()["task_instances"] + assert 6 == len(response.json()["task_instances"]) assert 0 == failed_dag_runs, 0 _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) @@ -1477,7 +1482,7 @@ def test_should_respond_200_with_dag_run_id(self, session): ) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert 200 == response.status_code @@ -1489,8 +1494,8 @@ def test_should_respond_200_with_dag_run_id(self, session): "task_id": "print_the_context", }, ] - assert response.json["task_instances"] == expected_response - assert 1 == len(response.json["task_instances"]) + assert response.json()["task_instances"] == expected_response + assert 1 == len(response.json()["task_instances"]) _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) def test_should_respond_200_with_include_past(self, session): @@ -1536,7 +1541,7 @@ def test_should_respond_200_with_include_past(self, session): ) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert 200 == response.status_code @@ -1579,8 +1584,8 @@ def test_should_respond_200_with_include_past(self, session): }, ] for task_instance in expected_response: - assert task_instance in response.json["task_instances"] - assert 6 == len(response.json["task_instances"]) + assert task_instance in response.json()["task_instances"] + assert 6 == len(response.json()["task_instances"]) _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) def test_should_respond_200_with_include_future(self, session): @@ -1625,7 +1630,7 @@ def test_should_respond_200_with_include_future(self, session): ) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) @@ -1669,8 +1674,8 @@ def test_should_respond_200_with_include_future(self, session): }, ] for task_instance in expected_response: - assert task_instance in response.json["task_instances"] - assert 6 == len(response.json["task_instances"]) + assert task_instance in response.json()["task_instances"] + assert 6 == len(response.json()["task_instances"]) _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) def test_should_respond_404_for_nonexistent_dagrun_id(self, session): @@ -1700,13 +1705,13 @@ def test_should_respond_404_for_nonexistent_dagrun_id(self, session): ) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert 404 == response.status_code assert ( - response.json["title"] + response.json()["title"] == "Dag Run id TEST_DAG_RUN_ID_100 not found in dag example_python_operator" ) _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) @@ -1722,13 +1727,13 @@ def test_should_raises_401_unauthenticated(self): "include_subdags": True, }, ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) def test_should_raise_403_forbidden(self, username: str): response = self.client.post( "/api/v1/dags/example_python_operator/clearTaskInstances", - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, json={ "dry_run": False, "reset_dag_runs": True, @@ -1763,19 +1768,19 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se task_instances=task_instances, update_extras=False, ) - self.app.dag_bag.sync_to_db() + self.flask_app.dag_bag.sync_to_db() response = self.client.post( "/api/v1/dags/example_python_operator/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 400 - assert response.json["detail"] == expected + assert response.json()["detail"] == expected def test_raises_404_for_non_existent_dag(self): response = self.client.post( "/api/v1/dags/non-existent-dag/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "reset_dag_runs": True, @@ -1785,7 +1790,7 @@ def test_raises_404_for_non_existent_dag(self): }, ) assert response.status_code == 404 - assert response.json["title"] == "Dag id non-existent-dag not found" + assert response.json()["title"] == "Dag id non-existent-dag not found" class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint): @@ -1801,7 +1806,7 @@ def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, sessi ) response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "task_id": "print_the_context", @@ -1814,7 +1819,7 @@ def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, sessi }, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "task_instances": [ { "dag_id": "example_python_operator", @@ -1835,7 +1840,7 @@ def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, sessi state="failed", task_id="print_the_context", upstream=True, - session=session, + session=mock.ANY, ) @mock.patch("airflow.models.dag.DAG.set_task_instance_state") @@ -1850,7 +1855,7 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_ ) response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "task_id": "print_the_context", @@ -1863,7 +1868,7 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_ }, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "task_instances": [ { "dag_id": "example_python_operator", @@ -1884,7 +1889,7 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_ state="failed", task_id="print_the_context", upstream=True, - session=session, + session=mock.ANY, ) @pytest.mark.parametrize( @@ -1953,11 +1958,11 @@ def test_should_handle_errors(self, error, code, payload, session): self.create_task_instances(session) response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == code - assert response.json["detail"] == error + assert response.json()["detail"] == error def test_should_raises_401_unauthenticated(self): response = self.client.post( @@ -1973,13 +1978,13 @@ def test_should_raises_401_unauthenticated(self): "new_state": "failed", }, ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) def test_should_raise_403_forbidden(self, username): response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, json={ "dry_run": True, "task_id": "print_the_context", @@ -1996,7 +2001,7 @@ def test_should_raise_403_forbidden(self, username): def test_should_raise_404_not_found_dag(self): response = self.client.post( "/api/v1/dags/INVALID_DAG/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "task_id": "print_the_context", @@ -2016,7 +2021,7 @@ def test_should_raise_not_found_if_execution_date_is_wrong(self, mock_set_task_i date = DEFAULT_DATETIME_1 + dt.timedelta(days=1) response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "task_id": "print_the_context", @@ -2029,7 +2034,7 @@ def test_should_raise_not_found_if_execution_date_is_wrong(self, mock_set_task_i }, ) assert response.status_code == 404 - assert response.json["detail"] == ( + assert response.json()["detail"] == ( f"Task instance not found for task 'print_the_context' on execution_date {date}" ) assert mock_set_task_instance_state.call_count == 0 @@ -2037,7 +2042,7 @@ def test_should_raise_not_found_if_execution_date_is_wrong(self, mock_set_task_i def test_should_raise_404_not_found_task(self): response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "task_id": "INVALID_TASK", @@ -2087,11 +2092,11 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se self.create_task_instances(session) response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 400 - assert response.json["detail"] == expected + assert response.json()["detail"] == expected class TestPatchTaskInstance(TestTaskInstanceEndpoint): @@ -2115,14 +2120,14 @@ def test_should_call_mocked_api(self, mock_set_task_instance_state, session): ) response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "new_state": NEW_STATE, }, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "dag_run_id": "TEST_DAG_RUN_ID", "execution_date": "2020-01-01T00:00:00+00:00", @@ -2135,7 +2140,7 @@ def test_should_call_mocked_api(self, mock_set_task_instance_state, session): map_indexes=[-1], state=NEW_STATE, commit=True, - session=session, + session=mock.ANY, ) _check_last_log( session, @@ -2160,14 +2165,14 @@ def test_should_not_call_mocked_api_for_dry_run(self, mock_set_task_instance_sta ) response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "new_state": NEW_STATE, }, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "dag_run_id": "TEST_DAG_RUN_ID", "execution_date": "2020-01-01T00:00:00+00:00", @@ -2183,7 +2188,7 @@ def test_should_update_task_instance_state(self, session): self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "new_state": NEW_STATE, @@ -2192,11 +2197,10 @@ def test_should_update_task_instance_state(self, session): response2 = self.client.get( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, - json={}, + headers={"REMOTE_USER": "test"}, ) assert response2.status_code == 200 - assert response2.json["state"] == NEW_STATE + assert response2.json()["state"] == NEW_STATE def test_should_update_task_instance_state_default_dry_run_to_true(self, session): self.create_task_instances(session) @@ -2205,7 +2209,7 @@ def test_should_update_task_instance_state_default_dry_run_to_true(self, session self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "new_state": NEW_STATE, }, @@ -2213,11 +2217,10 @@ def test_should_update_task_instance_state_default_dry_run_to_true(self, session response2 = self.client.get( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, - json={}, + headers={"REMOTE_USER": "test"}, ) assert response2.status_code == 200 - assert response2.json["state"] == NEW_STATE + assert response2.json()["state"] == NEW_STATE def test_should_update_mapped_task_instance_state(self, session): NEW_STATE = "failed" @@ -2227,10 +2230,11 @@ def test_should_update_mapped_task_instance_state(self, session): ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) session.add(ti) session.commit() + session.close() self.client.patch( f"{self.ENDPOINT_URL}/{map_index}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "new_state": NEW_STATE, @@ -2239,11 +2243,10 @@ def test_should_update_mapped_task_instance_state(self, session): response2 = self.client.get( f"{self.ENDPOINT_URL}/{map_index}", - environ_overrides={"REMOTE_USER": "test"}, - json={}, + headers={"REMOTE_USER": "test"}, ) assert response2.status_code == 200 - assert response2.json["state"] == NEW_STATE + assert response2.json()["state"] == NEW_STATE @pytest.mark.parametrize( "error, code, payload", @@ -2261,51 +2264,51 @@ def test_should_update_mapped_task_instance_state(self, session): def test_should_handle_errors(self, error, code, payload, session): response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == code - assert response.json["detail"] == error + assert response.json()["detail"] == error def test_should_raise_400_for_unknown_fields(self, session): self.create_task_instances(session) response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dryrun": True, "new_state": "failed", }, ) assert response.status_code == 400 - assert response.json["detail"] == "{'dryrun': ['Unknown field.']}" + assert response.json()["detail"] == "{'dryrun': ['Unknown field.']}" def test_should_raise_404_for_non_existent_dag(self): response = self.client.patch( "/api/v1/dags/non-existent-dag/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "new_state": "failed", }, ) assert response.status_code == 404 - assert response.json["title"] == "DAG not found" - assert response.json["detail"] == "DAG 'non-existent-dag' not found" + assert response.json()["title"] == "DAG not found" + assert response.json()["detail"] == "DAG 'non-existent-dag' not found" def test_should_raise_404_for_non_existent_task_in_dag(self): response = self.client.patch( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/non_existent_task", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "new_state": "failed", }, ) assert response.status_code == 404 - assert response.json["title"] == "Task not found" + assert response.json()["title"] == "Task not found" assert ( - response.json["detail"] == "Task 'non_existent_task' not found in DAG 'example_python_operator'" + response.json()["detail"] == "Task 'non_existent_task' not found in DAG 'example_python_operator'" ) def test_should_raises_401_unauthenticated(self): @@ -2316,13 +2319,13 @@ def test_should_raises_401_unauthenticated(self): "new_state": "failed", }, ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) def test_should_raise_403_forbidden(self, username): response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, json={ "dry_run": True, "new_state": "failed", @@ -2333,7 +2336,7 @@ def test_should_raise_403_forbidden(self, username): def test_should_raise_404_not_found_dag(self): response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "new_state": "failed", @@ -2344,7 +2347,7 @@ def test_should_raise_404_not_found_dag(self): def test_should_raise_404_not_found_task(self): response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "new_state": "failed", @@ -2378,12 +2381,12 @@ def test_should_raise_400_for_invalid_task_instance_state(self, payload, expecte self.create_task_instances(session) response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 400 - assert response.json["detail"] == expected - assert response.json["detail"] == expected + assert response.json()["detail"] == expected + assert response.json()["detail"] == expected class TestSetTaskInstanceNote(TestTaskInstanceEndpoint): @@ -2401,10 +2404,10 @@ def test_should_respond_200(self, session): "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" "print_the_context/setNote", json={"note": new_note_value}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.text - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -2452,6 +2455,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session): setattr(ti, attr, getattr(old_ti, attr)) session.add(ti) session.commit() + session.close() # in each loop, we should get the right mapped TI back for map_index in (1, 2): @@ -2460,11 +2464,11 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session): "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" f"print_the_context/{map_index}/setNote", json={"note": new_note_value}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.text - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -2502,15 +2506,16 @@ def test_should_respond_200_when_note_is_empty(self, session): ti.task_instance_note = None session.add(ti) session.commit() + session.close() new_note_value = "My super cool TaskInstance note." response = self.client.patch( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" "print_the_context/setNote", json={"note": new_note_value}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.text - assert response.json["note"] == new_note_value + assert response.json()["note"] == new_note_value def test_should_raise_400_for_unknown_fields(self, session): self.create_task_instances(session) @@ -2518,10 +2523,10 @@ def test_should_raise_400_for_unknown_fields(self, session): "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" "print_the_context/setNote", json={"note": "a valid field", "not": "an unknown field"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == "{'not': ['Unknown field.']}" + assert response.json()["detail"] == "{'not': ['Unknown field.']}" def test_should_raises_401_unauthenticated(self): for map_index in ["", "/0"]: @@ -2533,7 +2538,7 @@ def test_should_raises_401_unauthenticated(self): url, json={"note": "I am setting a note while being unauthenticated."}, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): for map_index in ["", "/0"]: @@ -2541,7 +2546,7 @@ def test_should_raise_403_forbidden(self): "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" f"print_the_context{map_index}/setNote", json={"note": "I am setting a note without the proper permissions."}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -2552,7 +2557,7 @@ def test_should_respond_404(self, session): f"api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" f"{map_index}/setNote", json={"note": "I am setting a note on a DAG that doesn't exist."}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -2570,10 +2575,10 @@ def test_should_respond_empty_non_scheduled(self, session): response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" "print_the_context/dependencies", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.text - assert response.json == {"dependencies": []} + assert response.json() == {"dependencies": []} @pytest.mark.parametrize( "state, dependencies", @@ -2620,10 +2625,10 @@ def test_should_respond_dependencies(self, session, state, dependencies): response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" "print_the_context/dependencies", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.text - assert response.json == dependencies + assert response.json() == dependencies def test_should_respond_dependencies_mapped(self, session): tis = self.create_task_instances( @@ -2638,7 +2643,7 @@ def test_should_respond_dependencies_mapped(self, session): response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" "print_the_context/0/dependencies", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.text @@ -2651,16 +2656,16 @@ def test_should_raises_401_unauthenticated(self): response = self.client.get( url, ) - assert_401(response) + assert response.status_code == 401 - def test_should_raise_403_forbidden(self): + def test_should_raise_404(self): for map_index in ["", "/0"]: response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" f"print_the_context{map_index}/dependencies", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 403 + assert response.status_code == 404 def test_should_respond_404(self, session): self.create_task_instances(session) @@ -2668,6 +2673,6 @@ def test_should_respond_404(self, session): response = self.client.get( f"api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" f"{map_index}/dependencies", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py b/tests/api_connexion/endpoints/test_variable_endpoint.py index 0e300b0a8f380..37cdbf42f3db9 100644 --- a/tests/api_connexion/endpoints/test_variable_endpoint.py +++ b/tests/api_connexion/endpoints/test_variable_endpoint.py @@ -23,7 +23,7 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Variable from airflow.security import permissions -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_variables from tests.test_utils.www import _check_last_log @@ -33,10 +33,10 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -47,7 +47,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_read_only", role_name="TestReadOnly", permissions=[ @@ -55,28 +55,28 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_delete_only", role_name="TestDeleteOnly", permissions=[ (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_read_only") # type: ignore - delete_user(app, username="test_delete_only") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_read_only") # type: ignore + delete_user(connexion_app.app, username="test_delete_only") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestVariableEndpoint: @pytest.fixture(autouse=True) def setup_method(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore clear_db_variables() def teardown_method(self) -> None: @@ -87,22 +87,20 @@ class TestDeleteVariable(TestVariableEndpoint): def test_should_delete_variable(self, session): Variable.set("delete_var1", 1) # make sure variable is added - response = self.client.get("/api/v1/variables/delete_var1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/variables/delete_var1", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response = self.client.delete( - "/api/v1/variables/delete_var1", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.delete("/api/v1/variables/delete_var1", headers={"REMOTE_USER": "test"}) assert response.status_code == 204 # make sure variable is deleted - response = self.client.get("/api/v1/variables/delete_var1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/variables/delete_var1", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 _check_last_log(session, dag_id=None, event="api.variable.delete", execution_date=None) def test_should_respond_404_if_key_does_not_exist(self): response = self.client.delete( - "/api/v1/variables/NONEXIST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/variables/NONEXIST_VARIABLE_KEY", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 @@ -111,17 +109,17 @@ def test_should_raises_401_unauthenticated(self): # make sure variable is added response = self.client.delete("/api/v1/variables/delete_var1") - assert_401(response) + assert response.status_code == 401 # make sure variable is not deleted - response = self.client.get("/api/v1/variables/delete_var1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/variables/delete_var1", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 def test_should_raise_403_forbidden(self): expected_value = '{"foo": 1}' Variable.set("TEST_VARIABLE_KEY", expected_value) response = self.client.get( - "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/api/v1/variables/TEST_VARIABLE_KEY", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 @@ -139,17 +137,17 @@ class TestGetVariable(TestVariableEndpoint): def test_read_variable(self, user, expected_status_code): expected_value = '{"foo": 1}' Variable.set("TEST_VARIABLE_KEY", expected_value) - response = self.client.get( - "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": user} - ) + response = self.client.get("/api/v1/variables/TEST_VARIABLE_KEY", headers={"REMOTE_USER": user}) assert response.status_code == expected_status_code if expected_status_code == 200: - assert response.json == {"key": "TEST_VARIABLE_KEY", "value": expected_value, "description": None} + assert response.json() == { + "key": "TEST_VARIABLE_KEY", + "value": expected_value, + "description": None, + } def test_should_respond_404_if_not_found(self): - response = self.client.get( - "/api/v1/variables/NONEXIST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/variables/NONEXIST_VARIABLE_KEY", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 def test_should_raises_401_unauthenticated(self): @@ -157,17 +155,17 @@ def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/variables/TEST_VARIABLE_KEY") - assert_401(response) + assert response.status_code == 401 def test_should_handle_slashes_in_keys(self): expected_value = "hello" Variable.set("foo/bar", expected_value) response = self.client.get( f"/api/v1/variables/{urllib.parse.quote('foo/bar', safe='')}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == {"key": "foo/bar", "value": expected_value, "description": None} + assert response.json() == {"key": "foo/bar", "value": expected_value, "description": None} class TestGetVariables(TestVariableEndpoint): @@ -209,42 +207,40 @@ def test_should_get_list_variables(self, query, expected): Variable.set("var1", 1, "I am a variable") Variable.set("var2", "foo", "Another variable") Variable.set("var3", "[100, 101]") - response = self.client.get(query, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(query, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected def test_should_respect_page_size_limit_default(self): for i in range(101): Variable.set(f"var{i}", i) - response = self.client.get("/api/v1/variables", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/variables", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 101 - assert len(response.json["variables"]) == 100 + assert response.json()["total_entries"] == 101 + assert len(response.json()["variables"]) == 100 def test_should_raise_400_for_invalid_order_by(self): for i in range(101): Variable.set(f"var{i}", i) - response = self.client.get( - "/api/v1/variables?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/variables?order_by=invalid", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self): for i in range(200): Variable.set(f"var{i}", i) - response = self.client.get("/api/v1/variables?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/variables?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["variables"]) == 150 + assert len(response.json()["variables"]) == 150 def test_should_raises_401_unauthenticated(self): Variable.set("var1", 1) response = self.client.get("/api/v1/variables?limit=2&offset=0") - assert_401(response) + assert response.status_code == 401 class TestPatchVariable(TestVariableEndpoint): @@ -257,10 +253,10 @@ def test_should_update_variable(self, session): response = self.client.patch( "/api/v1/variables/var1", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == {"key": "var1", "value": "updated", "description": None} + assert response.json() == {"key": "var1", "value": "updated", "description": None} _check_last_log( session, dag_id=None, event="api.variable.edit", execution_date=None, expected_extra=payload ) @@ -270,10 +266,10 @@ def test_should_update_variable_with_mask(self, session): response = self.client.patch( "/api/v1/variables/var1?update_mask=description", json={"key": "var1", "value": "updated", "description": "after_update"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == {"key": "var1", "value": "foo", "description": "after_update"} + assert response.json() == {"key": "var1", "value": "foo", "description": "after_update"} _check_last_log(session, dag_id=None, event="api.variable.edit", execution_date=None) def test_should_reject_invalid_update(self): @@ -283,10 +279,10 @@ def test_should_reject_invalid_update(self): "key": "var1", "value": "foo", }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "title": "Variable not found", "status": 404, "type": EXCEPTIONS_LINK_MAP[404], @@ -299,10 +295,10 @@ def test_should_reject_invalid_update(self): "key": "var2", "value": "updated", }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "title": "Invalid post body", "status": 400, "type": EXCEPTIONS_LINK_MAP[400], @@ -314,9 +310,9 @@ def test_should_reject_invalid_update(self): json={ "key": "var2", }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.json == { + assert response.json() == { "title": "Invalid Variable schema", "status": 400, "type": EXCEPTIONS_LINK_MAP[400], @@ -334,7 +330,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 class TestPostVariables(TestVariableEndpoint): @@ -353,14 +349,14 @@ def test_should_create_variable(self, description, session): response = self.client.post( "/api/v1/variables", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 _check_last_log( session, dag_id=None, event="api.variable.create", execution_date=None, expected_extra=payload ) - response = self.client.get("/api/v1/variables/var_create", environ_overrides={"REMOTE_USER": "test"}) - assert response.json == { + response = self.client.get("/api/v1/variables/var_create", headers={"REMOTE_USER": "test"}) + assert response.json() == { "key": "var_create", "value": "{}", "description": description, @@ -372,7 +368,7 @@ def test_should_create_masked_variable(self, session): response = self.client.post( "/api/v1/variables", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 expected_extra = { @@ -386,8 +382,8 @@ def test_should_create_masked_variable(self, session): execution_date=None, expected_extra=expected_extra, ) - response = self.client.get("/api/v1/variables/api_key", environ_overrides={"REMOTE_USER": "test"}) - assert response.json == payload + response = self.client.get("/api/v1/variables/api_key", headers={"REMOTE_USER": "test"}) + assert response.json() == payload def test_should_reject_invalid_request(self, session): response = self.client.post( @@ -396,10 +392,10 @@ def test_should_reject_invalid_request(self, session): "key": "var_create", "v": "{}", }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "title": "Invalid Variable schema", "status": 400, "type": EXCEPTIONS_LINK_MAP[400], @@ -416,4 +412,4 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 diff --git a/tests/api_connexion/endpoints/test_version_endpoint.py b/tests/api_connexion/endpoints/test_version_endpoint.py index 6c21985a73584..f966347e04a8a 100644 --- a/tests/api_connexion/endpoints/test_version_endpoint.py +++ b/tests/api_connexion/endpoints/test_version_endpoint.py @@ -29,8 +29,8 @@ def setup_attrs(self, minimal_app_for_api) -> None: """ Setup For XCom endpoint TC """ - self.app = minimal_app_for_api - self.client = self.app.test_client() # type:ignore + self.connexion_app = minimal_app_for_api + self.client = self.connexion_app.test_client() # type:ignore @mock.patch("airflow.api_connexion.endpoints.version_endpoint.airflow.__version__", "MOCK_VERSION") @mock.patch( @@ -40,5 +40,5 @@ def test_should_respond_200(self, mock_get_airflow_get_commit): response = self.client.get("/api/v1/version") assert 200 == response.status_code - assert {"git_version": "GIT_COMMIT", "version": "MOCK_VERSION"} == response.json + assert {"git_version": "GIT_COMMIT", "version": "MOCK_VERSION"} == response.json() mock_get_airflow_get_commit.assert_called_once_with() diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index 1e4dbb56780cf..d0727b5292c1a 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -31,7 +31,7 @@ from airflow.utils.session import create_session from airflow.utils.timezone import utcnow from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom @@ -49,10 +49,10 @@ def orm_deserialize_value(self): @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -61,23 +61,23 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_granular_permissions", role_name="TestGranularDag", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), ], ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "test-dag-id-1", access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore def _compare_xcom_collections(collection1: dict, collection_2: dict): @@ -109,8 +109,8 @@ def setup_attrs(self, configured_app) -> None: """ Setup For XCom endpoint TC """ - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore # clear existing xcoms self.clean_db() @@ -132,11 +132,11 @@ def test_should_respond_200(self): self._create_xcom_entry(dag_id, run_id, execution_date_parsed, task_id, xcom_key) response = self.client.get( f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 200 == response.status_code - current_data = response.json + current_data = response.json() current_data["timestamp"] = "TIMESTAMP" assert current_data == { "dag_id": dag_id, @@ -158,10 +158,10 @@ def test_should_raise_404_for_non_existent_xcom(self): self._create_xcom_entry(dag_id, run_id, execution_date_parsed, task_id, xcom_key) response = self.client.get( f"/api/v1/dags/nonexistentdagid/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 404 == response.status_code - assert response.json["title"] == "XCom entry not found" + assert response.json()["title"] == "XCom entry not found" def test_should_raises_401_unauthenticated(self): dag_id = "test-dag-id" @@ -175,7 +175,7 @@ def test_should_raises_401_unauthenticated(self): f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}" ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): dag_id = "test-dag-id" @@ -188,7 +188,7 @@ def test_should_raise_403_forbidden(self): self._create_xcom_entry(dag_id, run_id, execution_date_parsed, task_id, xcom_key) response = self.client.get( f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -262,13 +262,13 @@ def test_custom_xcom_deserialize(self, allowed: bool, query: str, expected_statu url = f"/api/v1/dags/dag/dagRuns/run/taskInstances/task/xcomEntries/key{query}" with mock.patch("airflow.api_connexion.endpoints.xcom_endpoint.XCom", XCom): with conf_vars({("api", "enable_xcom_deserialize_support"): str(allowed)}): - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) if isinstance(expected_status_or_value, int): assert response.status_code == expected_status_or_value else: assert response.status_code == 200 - assert response.json["value"] == expected_status_or_value + assert response.json()["value"] == expected_status_or_value class TestGetXComEntries(TestXComEndpoint): @@ -282,11 +282,11 @@ def test_should_respond_200(self): self._create_xcom_entries(dag_id, run_id, execution_date_parsed, task_id) response = self.client.get( f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 200 == response.status_code - response_data = response.json + response_data = response.json() for xcom_entry in response_data["xcom_entries"]: xcom_entry["timestamp"] = "TIMESTAMP" _compare_xcom_collections( @@ -329,11 +329,11 @@ def test_should_respond_200_with_tilde_and_access_to_all_dags(self): response = self.client.get( "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 200 == response.status_code - response_data = response.json + response_data = response.json() for xcom_entry in response_data["xcom_entries"]: xcom_entry["timestamp"] = "TIMESTAMP" _compare_xcom_collections( @@ -392,11 +392,11 @@ def test_should_respond_200_with_tilde_and_granular_dag_access(self): self._create_invalid_xcom_entries(execution_date_parsed) response = self.client.get( "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + headers={"REMOTE_USER": "test_granular_permissions"}, ) assert 200 == response.status_code - response_data = response.json + response_data = response.json() for xcom_entry in response_data["xcom_entries"]: xcom_entry["timestamp"] = "TIMESTAMP" _compare_xcom_collections( @@ -436,11 +436,11 @@ def assert_expected_result(expected_entries, map_index=None): response = self.client.get( "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries" f"{('?map_index=' + str(map_index)) if map_index is not None else ''}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 200 == response.status_code - response_data = response.json + response_data = response.json() for xcom_entry in response_data["xcom_entries"]: xcom_entry["timestamp"] = "TIMESTAMP" assert response_data == { @@ -479,11 +479,11 @@ def test_should_respond_200_with_xcom_key(self): def assert_expected_result(expected_entries, key=None): response = self.client.get( f"/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries?xcom_key={key}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 200 == response.status_code - response_data = response.json + response_data = response.json() for xcom_entry in response_data["xcom_entries"]: xcom_entry["timestamp"] = "TIMESTAMP" assert response_data == { @@ -522,7 +522,7 @@ def test_should_raises_401_unauthenticated(self): f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries" ) - assert_401(response) + assert response.status_code == 401 def _create_xcom_entries(self, dag_id, run_id, execution_date, task_id, mapped_ti=False): with create_session() as session: @@ -683,8 +683,8 @@ def test_handle_limit_offset(self, query_params, expected_xcom_ids): ) session.add(xcom) - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - conn_ids = [conn["key"] for conn in response.json["xcom_entries"] if conn] + assert response.json()["total_entries"] == 10 + conn_ids = [conn["key"] for conn in response.json()["xcom_entries"] if conn] assert conn_ids == expected_xcom_ids diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py b/tests/api_connexion/schemas/test_dag_run_schema.py index ce187868c78f3..6e7ba21dae63f 100644 --- a/tests/api_connexion/schemas/test_dag_run_schema.py +++ b/tests/api_connexion/schemas/test_dag_run_schema.py @@ -129,7 +129,7 @@ def test_invalid_execution_date_raises(self): serialized_dagrun = {"execution_date": "mydate"} with pytest.raises(BadRequest) as ctx: dagrun_schema.load(serialized_dagrun) - assert str(ctx.value) == "Incorrect datetime argument" + assert str(ctx.value) == "400: Invalid date string: mydate" class TestDagRunCollection(TestDAGRunBase): diff --git a/tests/api_connexion/schemas/test_role_and_permission_schema.py b/tests/api_connexion/schemas/test_role_and_permission_schema.py index a8a4924216838..26cd87c976786 100644 --- a/tests/api_connexion/schemas/test_role_and_permission_schema.py +++ b/tests/api_connexion/schemas/test_role_and_permission_schema.py @@ -33,17 +33,17 @@ class TestRoleCollectionItemSchema: @pytest.fixture(scope="class") def role(self, minimal_app_for_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_api.app, # type: ignore name="Test", permissions=[ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), ], ) - delete_role(minimal_app_for_api, "Test") + delete_role(minimal_app_for_api.app, "Test") @pytest.fixture(autouse=True) def _set_attrs(self, minimal_app_for_api, role): - self.app = minimal_app_for_api + self.connexion_app = minimal_app_for_api self.role = role def test_serialize(self): @@ -69,24 +69,24 @@ class TestRoleCollectionSchema: @pytest.fixture(scope="class") def role1(self, minimal_app_for_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_api.app, # type: ignore name="Test1", permissions=[ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), ], ) - delete_role(minimal_app_for_api, "Test1") + delete_role(minimal_app_for_api.app, "Test1") @pytest.fixture(scope="class") def role2(self, minimal_app_for_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_api.app, # type: ignore name="Test2", permissions=[ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), ], ) - delete_role(minimal_app_for_api, "Test2") + delete_role(minimal_app_for_api.app, "Test2") def test_serialize(self, role1, role2): instance = RoleCollection([role1, role2], total_entries=2) diff --git a/tests/api_connexion/test_auth.py b/tests/api_connexion/test_auth.py index 869b69990f00c..2ec6187c4cc0c 100644 --- a/tests/api_connexion/test_auth.py +++ b/tests/api_connexion/test_auth.py @@ -19,7 +19,6 @@ from base64 import b64encode import pytest -from flask_login import current_user from tests.test_utils.api_connexion_utils import assert_401 from tests.test_utils.config import conf_vars @@ -32,9 +31,10 @@ class BaseTestAuth: @pytest.fixture(autouse=True) def set_attrs(self, minimal_app_for_api): - self.app = minimal_app_for_api + self.connexion_app = minimal_app_for_api + self.flask_app = self.connexion_app.app - sm = self.app.appbuilder.sm + sm = self.flask_app.appbuilder.sm tester = sm.find_user(username="test") if not tester: role_admin = sm.find_role("Admin") @@ -53,25 +53,28 @@ class TestBasicAuth(BaseTestAuth): def with_basic_auth_backend(self, minimal_app_for_api): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + flask_app = minimal_app_for_api.app + old_auth = getattr(flask_app, "api_auth") try: - with conf_vars({("api", "auth_backends"): "airflow.api.auth.backend.basic_auth"}): - init_api_experimental_auth(minimal_app_for_api) + with conf_vars( + {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} + ): + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_success(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert current_user.email == "test@fab.org" + # assert current_user.email == "test@fab.org" assert response.status_code == 200 - assert response.json == { + assert response.json() == { "pools": [ { "name": "default_pool", @@ -103,7 +106,7 @@ def test_success(self): ], ) def test_malformed_headers(self, token): - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 401 assert response.headers["Content-Type"] == "application/problem+json" @@ -120,7 +123,7 @@ def test_malformed_headers(self, token): ], ) def test_invalid_auth_header(self, token): - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 401 assert response.headers["Content-Type"] == "application/problem+json" @@ -133,22 +136,23 @@ class TestSessionAuth(BaseTestAuth): def with_session_backend(self, minimal_app_for_api): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + flask_app = minimal_app_for_api.app + old_auth = getattr(flask_app, "api_auth") try: with conf_vars({("api", "auth_backends"): "airflow.api.auth.backend.session"}): - init_api_experimental_auth(minimal_app_for_api) + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_success(self): clear_db_pools() - admin_user = client_with_login(self.app, username="test", password="test") + admin_user = client_with_login(self.connexion_app, username="test", password="test") response = admin_user.get("/api/v1/pools") assert response.status_code == 200 - assert response.json == { + assert response.json() == { "pools": [ { "name": "default_pool", @@ -167,7 +171,7 @@ def test_success(self): } def test_failure(self): - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools") assert response.status_code == 401 assert response.headers["Content-Type"] == "application/problem+json" @@ -179,7 +183,8 @@ class TestSessionWithBasicAuthFallback(BaseTestAuth): def with_basic_auth_backend(self, minimal_app_for_api): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + flask_app = minimal_app_for_api.app + old_auth = getattr(flask_app, "api_auth") try: with conf_vars( @@ -187,29 +192,29 @@ def with_basic_auth_backend(self, minimal_app_for_api): ( "api", "auth_backends", - ): "airflow.api.auth.backend.session,airflow.api.auth.backend.basic_auth" + ): "airflow.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth" } ): - init_api_experimental_auth(minimal_app_for_api) + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_basic_auth_fallback(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() # request uses session - admin_user = client_with_login(self.app, username="test", password="test") + admin_user = client_with_login(self.connexion_app, username="test", password="test") response = admin_user.get("/api/v1/pools") assert response.status_code == 200 # request uses basic auth - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 200 # request without session or basic auth header - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools") assert response.status_code == 401 diff --git a/tests/api_connexion/test_cors.py b/tests/api_connexion/test_cors.py index 4dc4950df9946..fb60eebb44e7b 100644 --- a/tests/api_connexion/test_cors.py +++ b/tests/api_connexion/test_cors.py @@ -28,10 +28,12 @@ class BaseTestAuth: @pytest.fixture(autouse=True) - def set_attrs(self, minimal_app_for_api): - self.app = minimal_app_for_api + def set_attrs(self, minimal_app_for_api, minimal_app_for_api_cors_allow_all): + self.connexion_app = minimal_app_for_api + self.connexion_app_cors_allow_all = minimal_app_for_api_cors_allow_all + self.flask_app = self.connexion_app.app - sm = self.app.appbuilder.sm + sm = self.flask_app.appbuilder.sm tester = sm.find_user(username="test") if not tester: role_admin = sm.find_role("Admin") @@ -50,20 +52,21 @@ class TestEmptyCors(BaseTestAuth): def with_basic_auth_backend(self, minimal_app_for_api): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + flask_app = minimal_app_for_api.app + old_auth = getattr(flask_app, "api_auth") try: with conf_vars({("api", "auth_backends"): "airflow.api.auth.backend.basic_auth"}): - init_api_experimental_auth(minimal_app_for_api) + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_empty_cors_headers(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 200 assert "Access-Control-Allow-Headers" not in response.headers @@ -76,29 +79,25 @@ class TestCorsOrigin(BaseTestAuth): def with_basic_auth_backend(self, minimal_app_for_api): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + flask_app = minimal_app_for_api.app + old_auth = getattr(flask_app, "api_auth") try: with conf_vars( { ("api", "auth_backends"): "airflow.api.auth.backend.basic_auth", - ("api", "access_control_allow_origins"): "http://apache.org http://example.com", } ): - init_api_experimental_auth(minimal_app_for_api) + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_cors_origin_reflection(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 200 - assert response.headers["Access-Control-Allow-Origin"] == "http://apache.org" - + with self.connexion_app.test_client() as test_client: response = test_client.get( "/api/v1/pools", headers={"Authorization": token, "Origin": "http://apache.org"} ) @@ -109,33 +108,35 @@ def test_cors_origin_reflection(self): "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"} ) assert response.status_code == 200 + assert response.headers["Access-Control-Allow-Origin"] == "http://example.com" class TestCorsWildcard(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): + def with_basic_auth_backend(self, minimal_app_for_api_cors_allow_all): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + self.connexion_app = minimal_app_for_api_cors_allow_all + flask_app = minimal_app_for_api_cors_allow_all.app + old_auth = getattr(flask_app, "api_auth") try: with conf_vars( { ("api", "auth_backends"): "airflow.api.auth.backend.basic_auth", - ("api", "access_control_allow_origins"): "*", } ): - init_api_experimental_auth(minimal_app_for_api) + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_cors_origin_reflection(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() - with self.app.test_client() as test_client: + with self.connexion_app_cors_allow_all.test_client() as test_client: response = test_client.get( "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"} ) diff --git a/tests/api_connexion/test_error_handling.py b/tests/api_connexion/test_error_handling.py index d89515d05b68f..59a056bed1e88 100644 --- a/tests/api_connexion/test_error_handling.py +++ b/tests/api_connexion/test_error_handling.py @@ -31,8 +31,8 @@ def test_incorrect_endpoint_should_return_json(minimal_app_for_api): # Then we have parsable JSON as output - assert "Not Found" == resp.json["title"] - assert 404 == resp.json["status"] + assert "Not Found" == resp.json()["title"] + assert 404 == resp.json()["status"] assert 404 == resp.status_code @@ -45,8 +45,7 @@ def test_incorrect_endpoint_should_return_html(minimal_app_for_api): # Then we do not have JSON as response, rather standard HTML - assert resp.json is None - assert resp.mimetype == "text/html" + assert resp.headers["content-type"].startswith("text/html") assert resp.status_code == 404 @@ -60,8 +59,8 @@ def test_incorrect_method_should_return_json(minimal_app_for_api): # Then we have parsable JSON as output - assert "Method Not Allowed" == resp.json["title"] - assert 405 == resp.json["status"] + assert "Method Not Allowed" == resp.json()["title"] + assert 405 == resp.json()["status"] assert 405 == resp.status_code @@ -74,6 +73,5 @@ def test_incorrect_method_should_return_html(minimal_app_for_api): # Then we do not have JSON as response, rather standard HTML - assert resp.json is None - assert resp.mimetype == "text/html" + assert resp.headers["content-type"].startswith("text/html") assert resp.status_code == 405 diff --git a/tests/api_connexion/test_security.py b/tests/api_connexion/test_security.py index e75eba53e40f4..d0fa1988caaba 100644 --- a/tests/api_connexion/test_security.py +++ b/tests/api_connexion/test_security.py @@ -20,35 +20,37 @@ from airflow.security import permissions from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + flask_app = minimal_app_for_api.app create_user( - app, # type:ignore + flask_app, # type:ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore ) - yield minimal_app_for_api + with conf_vars({("webserver", "expose_config"): "True"}): + yield minimal_app_for_api - delete_user(app, username="test") # type: ignore + delete_user(flask_app, username="test") # type: ignore class TestSession: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore def test_session_not_created_on_api_request(self): - self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test"}) - assert all(cookie.name != "session" for cookie in self.client.cookie_jar) + self.client.get("/api/v1/dags", headers={"REMOTE_USER": "test"}) + assert all(cookie.name != "session" for cookie in self.client.cookies) def test_session_not_created_on_health_endpoint_request(self): self.client.get("health") - assert all(cookie.name != "session" for cookie in self.client.cookie_jar) + assert all(cookie.name != "session" for cookie in self.client.cookies) diff --git a/tests/api_experimental/auth/backend/test_basic_auth.py b/tests/api_experimental/auth/backend/test_basic_auth.py index a7f7a2a1cd8a8..ce1167e9dc317 100644 --- a/tests/api_experimental/auth/backend/test_basic_auth.py +++ b/tests/api_experimental/auth/backend/test_basic_auth.py @@ -30,9 +30,9 @@ class TestBasicAuth: @pytest.fixture(autouse=True) def set_attrs(self, minimal_app_for_experimental_api): - self.app = minimal_app_for_experimental_api + self.connexion_app = minimal_app_for_experimental_api - self.appbuilder = self.app.appbuilder + self.appbuilder = self.connexion_app.app.appbuilder role_admin = self.appbuilder.sm.find_role("Admin") tester = self.appbuilder.sm.find_user(username="test") if not tester: @@ -49,7 +49,7 @@ def test_success(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: with pytest.warns(RemovedInAirflow3Warning, match=r"Use Pool.get_pools\(\) instead"): # Experimental client itself deprecated, no reason to change to actual methods # It should be removed in the same time: Airflow 3.0 @@ -72,7 +72,7 @@ def test_success(self): ], ) def test_malformed_headers(self, token): - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get("/api/experimental/pools", headers={"Authorization": token}) assert response.status_code == 401 assert response.headers["WWW-Authenticate"] == "Basic" @@ -87,14 +87,14 @@ def test_malformed_headers(self, token): ], ) def test_invalid_auth_header(self, token): - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get("/api/experimental/pools", headers={"Authorization": token}) assert response.status_code == 401 assert response.headers["WWW-Authenticate"] == "Basic" @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_experimental_api(self): - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get("/api/experimental/pools", headers={"Authorization": "Basic"}) assert response.status_code == 401 assert response.headers["WWW-Authenticate"] == "Basic" diff --git a/tests/api_experimental/conftest.py b/tests/api_experimental/conftest.py index f4cb79cbf8c74..e3e69abf7485a 100644 --- a/tests/api_experimental/conftest.py +++ b/tests/api_experimental/conftest.py @@ -38,6 +38,6 @@ def minimal_app_for_experimental_api(): def factory(): # Make sure we don't issue a warning in the test summary about deprecation with pytest.deprecated_call(): - return app.create_app(testing=True) # type:ignore + return app.create_connexion_app(testing=True) # type:ignore yield factory() diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py b/tests/api_internal/endpoints/test_rpc_api_endpoint.py index 4c312da3a708d..3144c5e094a11 100644 --- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py +++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py @@ -57,7 +57,7 @@ def minimal_app_for_internal_api() -> Flask: ) def factory() -> Flask: with conf_vars({("webserver", "run_internal_api"): "true"}): - return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore + return app.create_connexion_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore return factory() @@ -70,8 +70,8 @@ def equals(a, b) -> bool: class TestRpcApiEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator: - self.app = minimal_app_for_internal_api - self.client = self.app.test_client() # type:ignore + self.connexion_app = minimal_app_for_internal_api + self.client = self.connexion_app.test_client() # type:ignore mock_test_method.reset_mock() mock_test_method.side_effect = None with mock.patch( @@ -85,7 +85,7 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator: @pytest.mark.parametrize( "input_params, method_result, result_cmp_func, method_params", [ - ({}, None, lambda got, _: got == b"", {}), + ({}, None, lambda got, _: got == "", {}), ({}, "test_me", equals, {}), ( BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"}), @@ -123,9 +123,9 @@ def test_method(self, input_params, method_result, result_cmp_func, method_param ) assert response.status_code == 200 if method_result: - response_data = BaseSerialization.deserialize(json.loads(response.data), use_pydantic_models=True) + response_data = BaseSerialization.deserialize(json.loads(response.text), use_pydantic_models=True) else: - response_data = response.data + response_data = response.text assert result_cmp_func(response_data, method_result) @@ -139,7 +139,7 @@ def test_method_with_exception(self): "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) assert response.status_code == 500 - assert response.data, b"Error executing method: test_method." + assert response.text, b"Error executing method: test_method." mock_test_method.assert_called_once() def test_unknown_method(self): @@ -149,7 +149,7 @@ def test_unknown_method(self): "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) assert response.status_code == 400 - assert response.data.startswith(b"Unrecognized method: i-bet-it-does-not-exist.") + assert response.text.startswith("Unrecognized method: i-bet-it-does-not-exist.") mock_test_method.assert_not_called() def test_invalid_jsonrpc(self): @@ -159,5 +159,5 @@ def test_invalid_jsonrpc(self): "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) assert response.status_code == 400 - assert response.data.startswith(b"Expected jsonrpc 2.0 request.") + assert response.text.startswith("Expected jsonrpc 2.0 request.") mock_test_method.assert_not_called() diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index 7628924ad688f..d3a9b5bda0e8a 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -125,9 +125,6 @@ class TestBaseAuthManager: def test_get_cli_commands_return_empty_list(self, auth_manager): assert auth_manager.get_cli_commands() == [] - def test_get_api_endpoints_return_none(self, auth_manager): - assert auth_manager.get_api_endpoints() is None - def test_get_user_name(self, auth_manager): user = Mock() user.get_name.return_value = "test_username" diff --git a/tests/cli/commands/test_internal_api_command.py b/tests/cli/commands/test_internal_api_command.py index 9de857588a3fc..3d8eff83de7c5 100644 --- a/tests/cli/commands/test_internal_api_command.py +++ b/tests/cli/commands/test_internal_api_command.py @@ -152,7 +152,7 @@ def test_cli_internal_api_background(self, tmp_path): def test_cli_internal_api_debug(self, app): with mock.patch( - "airflow.cli.commands.internal_api_command.create_app", return_value=app + "airflow.cli.commands.internal_api_command.create_connexion_app", return_value=app ), mock.patch.object(app, "run") as app_run: args = self.parser.parse_args( [ @@ -163,8 +163,7 @@ def test_cli_internal_api_debug(self, app): internal_api_command.internal_api(args) app_run.assert_called_with( - debug=True, - use_reloader=False, + log_level="debug", port=9080, host="0.0.0.0", ) @@ -192,7 +191,7 @@ def test_cli_internal_api_args(self): "--workers", "4", "--worker-class", - "sync", + "uvicorn.workers.UvicornWorker", "--timeout", "120", "--bind", @@ -209,7 +208,7 @@ def test_cli_internal_api_args(self): "python:airflow.api_internal.gunicorn_config", "--access-logformat", "custom_log_format", - "airflow.cli.commands.internal_api_command:cached_app()", + "airflow.cli.commands.internal_api_command:cached_connexion_app()", "--preload", ], close_fds=True, diff --git a/tests/cli/commands/test_webserver_command.py b/tests/cli/commands/test_webserver_command.py index 07d95a9e5f75a..16415b2cec07e 100644 --- a/tests/cli/commands/test_webserver_command.py +++ b/tests/cli/commands/test_webserver_command.py @@ -312,7 +312,7 @@ def test_cli_webserver_shutdown_when_gunicorn_master_is_killed(self, _): assert ctx.value.code == 1 def test_cli_webserver_debug(self, app): - with mock.patch("airflow.www.app.create_app", return_value=app), mock.patch.object( + with mock.patch("airflow.www.app.create_connexion_app", return_value=app), mock.patch.object( app, "run" ) as app_run: args = self.parser.parse_args( @@ -324,11 +324,11 @@ def test_cli_webserver_debug(self, app): webserver_command.webserver(args) app_run.assert_called_with( - debug=True, - use_reloader=False, + log_level="debug", port=8080, host="0.0.0.0", - ssl_context=None, + ssl_certfile=None, + ssl_keyfile=None, ) def test_cli_webserver_args(self): @@ -352,7 +352,7 @@ def test_cli_webserver_args(self): "--workers", "4", "--worker-class", - "sync", + "uvicorn.workers.UvicornWorker", "--timeout", "120", "--bind", @@ -369,7 +369,7 @@ def test_cli_webserver_args(self): "-", "--access-logformat", "custom_log_format", - "airflow.www.app:cached_app()", + "airflow.www.app:cached_connexion_app()", "--preload", ], close_fds=True, diff --git a/tests/conftest.py b/tests/conftest.py index 9027391575e4c..9db21ec4c7070 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,7 +99,7 @@ os.environ["_AIRFLOW_RUN_DB_TESTS_ONLY"] = "true" AIRFLOW_TESTS_DIR = Path(os.path.dirname(os.path.realpath(__file__))).resolve() -AIRFLOW_SOURCES_ROOT_DIR = AIRFLOW_TESTS_DIR.parent.parent +AIRFLOW_SOURCES_ROOT_DIR = AIRFLOW_TESTS_DIR.parent os.environ["AIRFLOW__CORE__PLUGINS_FOLDER"] = os.fspath(AIRFLOW_TESTS_DIR / "plugins") os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = os.fspath(AIRFLOW_TESTS_DIR / "dags") @@ -742,7 +742,7 @@ def app(): with conf_vars({("fab", "auth_rate_limited"): "False"}): from airflow.www import app - yield app.create_app(testing=True) + yield app.create_connexion_app(testing=True) @pytest.fixture @@ -1151,20 +1151,21 @@ def _get(dag_id): @pytest.fixture def create_log_template(request): - from airflow import settings from airflow.models.tasklog import LogTemplate - session = settings.Session() - def _create_log_template(filename_template, elasticsearch_id=""): - log_template = LogTemplate(filename=filename_template, elasticsearch_id=elasticsearch_id) - session.add(log_template) - session.commit() + from airflow.utils.session import create_session - def _delete_log_template(): - session.delete(log_template) + with create_session() as session: + log_template = LogTemplate(filename=filename_template, elasticsearch_id=elasticsearch_id) + session.add(log_template) session.commit() + def _delete_log_template(): + with create_session() as session: + session.delete(log_template) + session.commit() + request.addfinalizer(_delete_log_template) return _create_log_template @@ -1277,6 +1278,16 @@ def initialize_providers_manager(): ProvidersManager().initialize_providers_configuration() +@pytest.fixture(autouse=True) +def create_swagger_ui_dir_if_missing(): + """ + The directory needs to exist to satisfy starlette attempting to register it as middleware + :return: + """ + swagger_ui_dir = AIRFLOW_SOURCES_ROOT_DIR / "airflow" / "www" / "static" / "dist" / "swagger-ui" + swagger_ui_dir.mkdir(exist_ok=True, parents=True) + + @pytest.fixture(autouse=True) def close_all_sqlalchemy_sessions(): from sqlalchemy.orm import close_all_sessions diff --git a/tests/integration/api_experimental/auth/backend/test_kerberos_auth.py b/tests/integration/api_experimental/auth/backend/test_kerberos_auth.py index 030a26ea77f9d..3365dea0f0572 100644 --- a/tests/integration/api_experimental/auth/backend/test_kerberos_auth.py +++ b/tests/integration/api_experimental/auth/backend/test_kerberos_auth.py @@ -43,7 +43,7 @@ def app_for_kerberos(): ("api", "enable_experimental_api"): "true", } ): - yield app.create_app(testing=True) + yield app.create_connexion_app(testing=True) @pytest.fixture(scope="module") @@ -57,16 +57,16 @@ def dagbag_to_db(): class TestApiKerberos: @pytest.fixture(autouse=True) def _set_attrs(self, app_for_kerberos, dagbag_to_db): - self.app = app_for_kerberos + self.connexion_app = app_for_kerberos def test_trigger_dag(self): - with self.app.test_client() as client: + with self.connexion_app.app.test_client() as client: url_template = "/api/experimental/dags/{}/dag_runs" url_path = url_template.format("example_bash_operator") response = client.post( url_path, data=json.dumps(dict(run_id="my_run" + datetime.now().isoformat())), - content_type="application/json", + headers={"Content-Type": "application/json"}, ) assert 401 == response.status_code @@ -89,21 +89,22 @@ class Request: CLIENT_AUTH.handle_response(response) assert "Authorization" in response.request.headers + headers = response.request.headers + headers.update({"Content-Type": "application/json"}) response2 = client.post( url_template.format("example_bash_operator"), data=json.dumps(dict(run_id="my_run" + datetime.now().isoformat())), - content_type="application/json", - headers=response.request.headers, + headers=headers, ) assert 200 == response2.status_code def test_unauthorized(self): - with self.app.test_client() as client: + with self.connexion_app.app.test_client() as client: url_template = "/api/experimental/dags/{}/dag_runs" response = client.post( url_template.format("example_bash_operator"), data=json.dumps(dict(run_id="my_run" + datetime.now().isoformat())), - content_type="application/json", + headers={"Content-Type": "application/json"}, ) assert 401 == response.status_code diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index b60a9de17999d..36df0c2b38fff 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -77,8 +77,8 @@ def wrapper(*args, **kwargs): class TestPluginsRBAC: @pytest.fixture(autouse=True) def _set_attrs(self, app): - self.app = app - self.appbuilder = app.appbuilder + self.connexion_app = app + self.appbuilder = app.app.appbuilder def test_flaskappbuilder_views(self): from tests.plugins.test_plugin import v_appbuilder_package @@ -137,12 +137,15 @@ def test_app_blueprints(self): from tests.plugins.test_plugin import bp # Blueprint should be present in the app - assert "test_plugin" in self.app.blueprints - assert self.app.blueprints["test_plugin"].name == bp.name + assert "test_plugin" in self.connexion_app.app.blueprints + assert self.connexion_app.app.blueprints["test_plugin"].name == bp.name def test_app_static_folder(self): # Blueprint static folder should be properly set - assert AIRFLOW_SOURCES_ROOT / "airflow" / "www" / "static" == Path(self.app.static_folder).resolve() + assert ( + AIRFLOW_SOURCES_ROOT / "airflow" / "www" / "static" + == Path(self.connexion_app.app.static_folder).resolve() + ) @pytest.mark.db_test @@ -155,7 +158,7 @@ class AirflowNoMenuViewsPlugin(AirflowPlugin): appbuilder_class_name = str(v_nomenu_appbuilder_package["view"].__class__.__name__) with mock_plugin_manager(plugins=[AirflowNoMenuViewsPlugin()]): - appbuilder = application.create_app(testing=True).appbuilder + appbuilder = application.create_connexion_app(testing=True).app.appbuilder plugin_views = [view for view in appbuilder.baseviews if view.blueprint.name == appbuilder_class_name] diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py index 03684e6336384..8bcfbece91ff7 100644 --- a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -23,7 +23,10 @@ from flask import Flask, session from flask_appbuilder.menu import MenuItem -from tests.test_utils.compat import AIRFLOW_V_2_8_PLUS +from tests.test_utils.compat import AIRFLOW_V_2_8_PLUS, AIRFLOW_V_2_10_PLUS + +pytestmark = pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Test requires Airflow 2.10+") + try: from airflow.auth.managers.models.resource_details import ( @@ -153,7 +156,7 @@ def client_admin(): "email": ["email"], } mock_init_saml_auth.return_value = auth - yield application.create_app(testing=True) + yield application.create_connexion_app(testing=True) class TestAwsAuthManager: @@ -165,7 +168,7 @@ def test_avp_facade(self, auth_manager): def test_get_user(self, mock_is_logged_in, auth_manager, app, test_user): mock_is_logged_in.return_value = True - with app.test_request_context(): + with app.app.test_request_context(): session["aws_user"] = test_user result = auth_manager.get_user() @@ -180,7 +183,7 @@ def test_get_user_return_none_when_not_logged_in(self, mock_is_logged_in, auth_m @pytest.mark.db_test def test_is_logged_in(self, auth_manager, app, test_user): - with app.test_request_context(): + with app.app.test_request_context(): session["aws_user"] = test_user result = auth_manager.is_logged_in() @@ -188,7 +191,7 @@ def test_is_logged_in(self, auth_manager, app, test_user): @pytest.mark.db_test def test_is_logged_in_return_false_when_no_user_in_session(self, auth_manager, app, test_user): - with app.test_request_context(): + with app.app.test_request_context(): result = auth_manager.is_logged_in() assert result is False diff --git a/tests/providers/amazon/aws/auth_manager/views/test_auth.py b/tests/providers/amazon/aws/auth_manager/views/test_auth.py index 7474d74727fd7..7a3f37966465a 100644 --- a/tests/providers/amazon/aws/auth_manager/views/test_auth.py +++ b/tests/providers/amazon/aws/auth_manager/views/test_auth.py @@ -23,12 +23,12 @@ from airflow.exceptions import AirflowException from airflow.www import app as application -from tests.test_utils.compat import AIRFLOW_V_2_8_PLUS +from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS from tests.test_utils.config import conf_vars pytest.importorskip("onelogin") -pytestmark = pytest.mark.skipif(not AIRFLOW_V_2_8_PLUS, reason="Test requires Airflow 2.8+") +pytestmark = pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Test requires Airflow 2.10+") SAML_METADATA_URL = "/saml/metadata" @@ -68,25 +68,25 @@ def aws_app(): ) as mock_is_policy_store_schema_up_to_date: mock_is_policy_store_schema_up_to_date.return_value = True mock_parser.parse_remote.return_value = SAML_METADATA_PARSED - return application.create_app(testing=True) + return application.create_connexion_app(testing=True) @pytest.mark.db_test class TestAwsAuthManagerAuthenticationViews: def test_login_redirect_to_identity_center(self, aws_app): - with aws_app.test_client() as client: + with aws_app.app.test_client() as client: response = client.get("/login") assert response.status_code == 302 assert response.location.startswith("https://portal.sso.us-east-1.amazonaws.com/saml/assertion/") def test_logout_redirect_to_identity_center(self, aws_app): - with aws_app.test_client() as client: + with aws_app.app.test_client() as client: response = client.get("/logout") assert response.status_code == 302 assert response.location.startswith("https://portal.sso.us-east-1.amazonaws.com/saml/logout/") def test_login_metadata_return_xml_file(self, aws_app): - with aws_app.test_client() as client: + with aws_app.app.test_client() as client: response = client.get("/login_metadata") assert response.status_code == 200 assert response.headers["Content-Type"] == "text/xml" @@ -120,8 +120,8 @@ def test_login_callback_set_user_in_session(self): "email": ["email"], } mock_init_saml_auth.return_value = auth - app = application.create_app(testing=True) - with app.test_client() as client: + connexion_app = application.create_connexion_app(testing=True) + with connexion_app.app.test_client() as client: response = client.get("/login_callback") assert response.status_code == 302 assert response.location == url_for("Airflow.index") @@ -152,12 +152,12 @@ def test_login_callback_raise_exception_if_errors(self): auth = Mock() auth.is_authenticated.return_value = False mock_init_saml_auth.return_value = auth - app = application.create_app(testing=True) - with app.test_client() as client: + connexion_app = application.create_connexion_app(testing=True) + with connexion_app.app.test_client() as client: with pytest.raises(AirflowException): client.get("/login_callback") def test_logout_callback_raise_not_implemented_error(self, aws_app): - with aws_app.test_client() as client: + with aws_app.app.test_client() as client: with pytest.raises(NotImplementedError): client.get("/logout_callback") diff --git a/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py b/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py index 1f64b3181576d..34ac2a4975596 100644 --- a/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py +++ b/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py @@ -33,7 +33,7 @@ @pytest.fixture def app(): - return application.create_app(testing=True) + return application.create_connexion_app(testing=True) @pytest.fixture @@ -70,7 +70,7 @@ def setup_method(self) -> None: mock_call.reset_mock() def test_requires_authentication_with_no_header(self, app): - with app.test_request_context() as mock_context: + with app.app.test_request_context() as mock_context: mock_context.request.authorization = None result = function_decorated() @@ -87,7 +87,7 @@ def test_requires_authentication_with_ldap( user = Mock() mock_sm.auth_user_ldap.return_value = user - with app.test_request_context() as mock_context: + with app.app.test_request_context() as mock_context: mock_context.request.authorization = mock_authorization function_decorated() @@ -106,7 +106,7 @@ def test_requires_authentication_with_db( user = Mock() mock_sm.auth_user_db.return_value = user - with app.test_request_context() as mock_context: + with app.app.test_request_context() as mock_context: mock_context.request.authorization = mock_authorization function_decorated() diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py index a91a434412d9f..2a3c4c06ef555 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py @@ -28,7 +28,6 @@ from airflow.security import permissions from tests.test_utils.api_connexion_utils import ( - assert_401, create_role, create_user, delete_role, @@ -40,9 +39,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api + connexion_app = minimal_app_for_auth_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -53,58 +52,55 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestRoleEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore def teardown_method(self): """ Delete all roles except these ones. Test and TestNoPermissions are deleted by delete_user above """ - session = self.app.appbuilder.get_session + session = self.flask_app.appbuilder.get_session existing_roles = set(EXISTING_ROLES) existing_roles.update(["Test", "TestNoPermissions"]) roles = session.query(Role).filter(~Role.name.in_(existing_roles)).all() for role in roles: - delete_role(self.app, role.name) + delete_role(self.flask_app, role.name) class TestGetRoleEndpoint(TestRoleEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/roles/Admin", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/roles/Admin", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["name"] == "Admin" + assert response.json()["name"] == "Admin" def test_should_respond_404(self): - response = self.client.get( - "/auth/fab/v1/roles/invalid-role", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/auth/fab/v1/roles/invalid-role", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "Role with name 'invalid-role' was not found", "status": 404, "title": "Role not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.get("/auth/fab/v1/roles/Admin") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/auth/fab/v1/roles/Admin", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/auth/fab/v1/roles/Admin", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @pytest.mark.parametrize( @@ -119,30 +115,26 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c class TestGetRolesEndpoint(TestRoleEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/roles", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/roles", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 existing_roles = set(EXISTING_ROLES) existing_roles.update(["Test", "TestNoPermissions"]) - assert response.json["total_entries"] == len(existing_roles) - roles = {role["name"] for role in response.json["roles"]} + assert response.json()["total_entries"] == len(existing_roles) + roles = {role["name"] for role in response.json()["roles"]} assert roles == existing_roles def test_should_raises_401_unauthenticated(self): response = self.client.get("/auth/fab/v1/roles") - assert_401(response) + assert response.status_code == 401 def test_should_raises_400_for_invalid_order_by(self): - response = self.client.get( - "/auth/fab/v1/roles?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/auth/fab/v1/roles?order_by=invalid", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_should_raise_403_forbidden(self): - response = self.client.get( - "/auth/fab/v1/roles", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/auth/fab/v1/roles", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @pytest.mark.parametrize( @@ -178,33 +170,31 @@ class TestGetRolesEndpointPaginationandFilter(TestRoleEndpoint): ], ) def test_can_handle_limit_and_offset(self, url, expected_roles): - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 existing_roles = set(EXISTING_ROLES) existing_roles.update(["Test", "TestNoPermissions"]) - assert response.json["total_entries"] == len(existing_roles) - roles = [role["name"] for role in response.json["roles"] if role] + assert response.json()["total_entries"] == len(existing_roles) + roles = [role["name"] for role in response.json()["roles"] if role] assert roles == expected_roles class TestGetPermissionsEndpoint(TestRoleEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/permissions", environ_overrides={"REMOTE_USER": "test"}) - actions = {i[0] for i in self.app.appbuilder.sm.get_all_permissions() if i} + response = self.client.get("/auth/fab/v1/permissions", headers={"REMOTE_USER": "test"}) + actions = {i[0] for i in self.flask_app.appbuilder.sm.get_all_permissions() if i} assert response.status_code == 200 - assert response.json["total_entries"] == len(actions) - returned_actions = {perm["name"] for perm in response.json["actions"]} + assert response.json()["total_entries"] == len(actions) + returned_actions = {perm["name"] for perm in response.json()["actions"]} assert actions == returned_actions def test_should_raises_401_unauthenticated(self): response = self.client.get("/auth/fab/v1/permissions") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/auth/fab/v1/permissions", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/auth/fab/v1/permissions", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @pytest.mark.parametrize( @@ -223,11 +213,9 @@ def test_post_should_respond_200(self): "name": "Test2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], } - response = self.client.post( - "/auth/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/auth/fab/v1/roles", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - role = self.app.appbuilder.sm.find_role("Test2") + role = self.flask_app.appbuilder.sm.find_role("Test2") assert role is not None @pytest.mark.parametrize( @@ -296,11 +284,9 @@ def test_post_should_respond_200(self): ], ) def test_post_should_respond_400_for_invalid_payload(self, payload, error_message): - response = self.client.post( - "/auth/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/auth/fab/v1/roles", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": error_message, "status": 400, "title": "Bad Request", @@ -312,11 +298,9 @@ def test_post_should_respond_409_already_exist(self): "name": "Test", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], } - response = self.client.post( - "/auth/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/auth/fab/v1/roles", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 409 - assert response.json == { + assert response.json() == { "detail": "Role with name 'Test' already exists; please update with the PATCH endpoint", "status": 409, "title": "Conflict", @@ -332,7 +316,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.post( @@ -341,7 +325,7 @@ def test_should_raise_403_forbidden(self): "name": "mytest2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -361,20 +345,16 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c class TestDeleteRole(TestRoleEndpoint): def test_delete_should_respond_204(self, session): - role = create_role(self.app, "mytestrole") - response = self.client.delete( - f"/auth/fab/v1/roles/{role.name}", environ_overrides={"REMOTE_USER": "test"} - ) + role = create_role(self.flask_app, "mytestrole") + response = self.client.delete(f"/auth/fab/v1/roles/{role.name}", headers={"REMOTE_USER": "test"}) assert response.status_code == 204 role_obj = session.query(Role).filter(Role.name == role.name).all() assert len(role_obj) == 0 def test_delete_should_respond_404(self): - response = self.client.delete( - "/auth/fab/v1/roles/invalidrolename", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.delete("/auth/fab/v1/roles/invalidrolename", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "Role with name 'invalidrolename' was not found", "status": 404, "title": "Role not found", @@ -384,11 +364,11 @@ def test_delete_should_respond_404(self): def test_should_raises_401_unauthenticated(self): response = self.client.delete("/auth/fab/v1/roles/test") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.delete( - "/auth/fab/v1/roles/test", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/auth/fab/v1/roles/test", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 @@ -398,7 +378,7 @@ def test_should_raise_403_forbidden(self): indirect=["set_auto_role_public"], ) def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") response = self.client.delete(f"/auth/fab/v1/roles/{role.name}") assert response.status_code == expected_status_code @@ -419,17 +399,17 @@ class TestPatchRole(TestRoleEndpoint): ], ) def test_patch_should_respond_200(self, payload, expected_name, expected_actions): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") response = self.client.patch( - f"/auth/fab/v1/roles/{role.name}", json=payload, environ_overrides={"REMOTE_USER": "test"} + f"/auth/fab/v1/roles/{role.name}", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["name"] == expected_name - assert response.json["actions"] == expected_actions + assert response.json()["name"] == expected_name + assert response.json()["actions"] == expected_actions def test_patch_should_update_correct_roles_permissions(self): - create_role(self.app, "role_to_change") - create_role(self.app, "already_exists") + create_role(self.flask_app, "role_to_change") + create_role(self.flask_app, "already_exists") response = self.client.patch( "/auth/fab/v1/roles/role_to_change", @@ -437,16 +417,16 @@ def test_patch_should_update_correct_roles_permissions(self): "name": "already_exists", "actions": [{"action": {"name": "can_delete"}, "resource": {"name": "XComs"}}], }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - updated_permissions = self.app.appbuilder.sm.find_role("role_to_change").permissions + updated_permissions = self.flask_app.appbuilder.sm.find_role("role_to_change").permissions assert len(updated_permissions) == 1 assert updated_permissions[0].resource.name == "XComs" assert updated_permissions[0].action.name == "can_delete" - assert len(self.app.appbuilder.sm.find_role("already_exists").permissions) == 0 + assert len(self.flask_app.appbuilder.sm.find_role("already_exists").permissions) == 0 @pytest.mark.parametrize( "update_mask, payload, expected_name, expected_actions", @@ -474,27 +454,27 @@ def test_patch_should_update_correct_roles_permissions(self): def test_patch_should_respond_200_with_update_mask( self, update_mask, payload, expected_name, expected_actions ): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") assert role.permissions == [] response = self.client.patch( f"/auth/fab/v1/roles/{role.name}{update_mask}", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["name"] == expected_name - assert response.json["actions"] == expected_actions + assert response.json()["name"] == expected_name + assert response.json()["actions"] == expected_actions def test_patch_should_respond_400_for_invalid_fields_in_update_mask(self): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") payload = {"name": "testme"} response = self.client.patch( f"/auth/fab/v1/roles/{role.name}?update_mask=invalid_name", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == "'invalid_name' in update_mask is unknown" + assert response.json()["detail"] == "'invalid_name' in update_mask is unknown" @pytest.mark.parametrize( "payload, expected_error", @@ -547,14 +527,14 @@ def test_patch_should_respond_400_for_invalid_fields_in_update_mask(self): ], ) def test_patch_should_respond_400_for_invalid_update(self, payload, expected_error): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") response = self.client.patch( f"/auth/fab/v1/roles/{role.name}", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == expected_error + assert response.json()["detail"] == expected_error def test_should_raises_401_unauthenticated(self): response = self.client.patch( @@ -565,7 +545,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.patch( @@ -574,7 +554,7 @@ def test_should_raise_403_forbidden(self): "name": "mytest2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -584,7 +564,7 @@ def test_should_raise_403_forbidden(self): indirect=["set_auto_role_public"], ) def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") response = self.client.patch( f"/auth/fab/v1/roles/{role.name}", json={"name": "mytest"}, diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py index e83d9fcf83736..0ae132684641e 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py @@ -30,20 +30,20 @@ with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.models import User -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_role, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_role, delete_user from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test -DEFAULT_TIME = "2020-06-11T18:00:00+00:00" +DEFAULT_TIME = "2020-06-11T18:00:00" @pytest.fixture(scope="module") def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api + connexion_app = minimal_app_for_auth_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -53,21 +53,22 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_role(app, name="TestNoPermissions") + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_role(connexion_app.app, name="TestNoPermissions") class TestUserEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - self.session = self.app.appbuilder.get_session + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore + self.session = self.flask_app.appbuilder.get_session def teardown_method(self) -> None: # Delete users that have our custom default time @@ -100,9 +101,9 @@ def test_should_respond_200(self): users = self._create_users(1) self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users/TEST_USER1", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "active": True, "changed_on": DEFAULT_TIME, "created_on": DEFAULT_TIME, @@ -128,9 +129,9 @@ def test_last_names_can_be_empty(self): ) self.session.add_all([prince]) self.session.commit() - response = self.client.get("/auth/fab/v1/users/prince", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users/prince", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "active": True, "changed_on": DEFAULT_TIME, "created_on": DEFAULT_TIME, @@ -156,9 +157,9 @@ def test_first_names_can_be_empty(self): ) self.session.add_all([liberace]) self.session.commit() - response = self.client.get("/auth/fab/v1/users/liberace", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users/liberace", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "active": True, "changed_on": DEFAULT_TIME, "created_on": DEFAULT_TIME, @@ -184,9 +185,9 @@ def test_both_first_and_last_names_can_be_empty(self): ) self.session.add_all([nameless]) self.session.commit() - response = self.client.get("/auth/fab/v1/users/nameless", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users/nameless", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "active": True, "changed_on": DEFAULT_TIME, "created_on": DEFAULT_TIME, @@ -201,44 +202,40 @@ def test_both_first_and_last_names_can_be_empty(self): } def test_should_respond_404(self): - response = self.client.get( - "/auth/fab/v1/users/invalid-user", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/auth/fab/v1/users/invalid-user", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "The User with username `invalid-user` was not found", "status": 404, "title": "User not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.get("/auth/fab/v1/users/TEST_USER1") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( - "/auth/fab/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/auth/fab/v1/users/TEST_USER1", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 class TestGetUsers(TestUserEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/users", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 2 - usernames = [user["username"] for user in response.json["users"] if user] + assert response.json()["total_entries"] == 2 + usernames = [user["username"] for user in response.json()["users"] if user] assert usernames == ["test", "test_no_permissions"] def test_should_raises_401_unauthenticated(self): response = self.client.get("/auth/fab/v1/users") - assert_401(response) + assert response.status_code def test_should_raise_403_forbidden(self): - response = self.client.get( - "/auth/fab/v1/users", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/auth/fab/v1/users", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @@ -289,10 +286,10 @@ def test_handle_limit_offset(self, url, expected_usernames): users = self._create_users(10) self.session.add_all(users) self.session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 12 - usernames = [user["username"] for user in response.json["users"] if user] + assert response.json()["total_entries"] == 12 + usernames = [user["username"] for user in response.json()["users"] if user] assert usernames == expected_usernames def test_should_respect_page_size_limit_default(self): @@ -300,33 +297,31 @@ def test_should_respect_page_size_limit_default(self): self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 # Explicitly add the 2 users on setUp - assert response.json["total_entries"] == 200 + len(["test", "test_no_permissions"]) - assert len(response.json["users"]) == 100 + assert response.json()["total_entries"] == 200 + len(["test", "test_no_permissions"]) + assert len(response.json()["users"]) == 100 def test_should_response_400_with_invalid_order_by(self): users = self._create_users(2) self.session.add_all(users) self.session.commit() - response = self.client.get( - "/auth/fab/v1/users?order_by=myname", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/auth/fab/v1/users?order_by=myname", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'myname' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_limit_of_zero_should_return_default(self): users = self._create_users(200) self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users?limit=0", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users?limit=0", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 # Explicit add the 2 users on setUp - assert response.json["total_entries"] == 200 + len(["test", "test_no_permissions"]) - assert len(response.json["users"]) == 100 + assert response.json()["total_entries"] == 200 + len(["test", "test_no_permissions"]) + assert len(response.json()["users"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self): @@ -334,9 +329,9 @@ def test_should_return_conf_max_if_req_max_above_conf(self): self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["users"]) == 150 + assert len(response.json()["users"]) == 150 EXAMPLE_USER_NAME = "example_user" @@ -349,6 +344,7 @@ def _delete_user(**filters): user = session.query(User).filter_by(**filters).first() if user is None: return + session.refresh(user) user.roles = [] session.delete(user) @@ -370,7 +366,7 @@ def autoclean_email(): @pytest.fixture def user_with_same_username(configured_app, autoclean_username): user = create_user( - configured_app, + configured_app.app, username=autoclean_username, email="another_user@example.com", role_name="TestNoPermissions", @@ -382,7 +378,7 @@ def user_with_same_username(configured_app, autoclean_username): @pytest.fixture def user_with_same_email(configured_app, autoclean_email): user = create_user( - configured_app, + configured_app.app, username="another_user", email=autoclean_email, role_name="TestNoPermissions", @@ -397,7 +393,7 @@ def user_different(configured_app): email = "another_user@example.com" _delete_user(username=username, email=email) - user = create_user(configured_app, username=username, email=email, role_name="TestNoPermissions") + user = create_user(configured_app.app, username=username, email=email, role_name="TestNoPermissions") assert user, "failed to create user 'another_user '" yield user _delete_user(username=username, email=email) @@ -416,7 +412,7 @@ def autoclean_user_payload(autoclean_username, autoclean_email): @pytest.fixture def autoclean_admin_user(configured_app, autoclean_user_payload): - security_manager = configured_app.appbuilder.sm + security_manager = configured_app.app.appbuilder.sm return security_manager.add_user( role=security_manager.find_role("Admin"), **autoclean_user_payload, @@ -425,27 +421,29 @@ def autoclean_admin_user(configured_app, autoclean_user_payload): class TestPostUser(TestUserEndpoint): def test_with_default_role(self, autoclean_username, autoclean_user_payload): + self.flask_app.config["AUTH_USER_REGISTRATION_ROLE"] = "Public" response = self.client.post( "/auth/fab/v1/users", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.json - security_manager = self.app.appbuilder.sm + security_manager = self.flask_app.appbuilder.sm user = security_manager.find_user(autoclean_username) assert user is not None assert user.roles == [security_manager.find_role("Public")] + self.flask_app.config["AUTH_USER_REGISTRATION_ROLE"] = None def test_with_custom_roles(self, autoclean_username, autoclean_user_payload): response = self.client.post( "/auth/fab/v1/users", json={"roles": [{"name": "User"}, {"name": "Viewer"}], **autoclean_user_payload}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.json - security_manager = self.app.appbuilder.sm + security_manager = self.flask_app.appbuilder.sm user = security_manager.find_user(autoclean_username) assert user is not None assert {r.name for r in user.roles} == {"User", "Viewer"} @@ -455,24 +453,24 @@ def test_with_existing_different_user(self, autoclean_user_payload): response = self.client.post( "/auth/fab/v1/users", json={"roles": [{"name": "User"}, {"name": "Viewer"}], **autoclean_user_payload}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() def test_unauthenticated(self, autoclean_user_payload): response = self.client.post( "/auth/fab/v1/users", json=autoclean_user_payload, ) - assert response.status_code == 401, response.json + assert response.status_code == 401, response.json() def test_forbidden(self, autoclean_user_payload): response = self.client.post( "/auth/fab/v1/users", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) - assert response.status_code == 403, response.json + assert response.status_code == 403, response.json() @pytest.mark.parametrize( "existing_user_fixture_name, error_detail_template", @@ -494,12 +492,12 @@ def test_already_exists( response = self.client.post( "/auth/fab/v1/users", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 409, response.json + assert response.status_code == 409, response.json() error_detail = error_detail_template.format(username=existing.username, email=existing.email) - assert response.json["detail"] == error_detail + assert response.json()["detail"] == error_detail @pytest.mark.parametrize( "payload_converter, error_message", @@ -530,10 +528,10 @@ def test_invalid_payload(self, autoclean_user_payload, payload_converter, error_ response = self.client.post( "/auth/fab/v1/users", json=payload_converter(autoclean_user_payload), - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 400, response.json - assert response.json == { + assert response.status_code == 400, response.json() + assert response.json() == { "detail": error_message, "status": 400, "title": "Bad Request", @@ -541,13 +539,13 @@ def test_invalid_payload(self, autoclean_user_payload, payload_converter, error_ } def test_internal_server_error(self, autoclean_user_payload): - with unittest.mock.patch.object(self.app.appbuilder.sm, "add_user", return_value=None): + with unittest.mock.patch.object(self.flask_app.appbuilder.sm, "add_user", return_value=None): response = self.client.post( "/auth/fab/v1/users", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.json == { + assert response.json() == { "detail": "Failed to add user `example_user`.", "status": 500, "title": "Internal Server Error", @@ -562,12 +560,12 @@ def test_change(self, autoclean_username, autoclean_user_payload): response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() # The first name is changed. - data = response.json + data = response.json() assert data["first_name"] == "Changed" assert data["last_name"] == "" @@ -578,12 +576,12 @@ def test_change_with_update_mask(self, autoclean_username, autoclean_user_payloa response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}?update_mask=last_name", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() # The first name is changed, but the last name isn't since we masked it. - data = response.json + data = response.json() assert data["first_name"] == "Tester" assert data["last_name"] == "McTesterson" @@ -608,11 +606,11 @@ def test_patch_already_exists( response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 409, response.json - assert response.json["detail"] == error_message + assert response.json()["detail"] == error_message @pytest.mark.parametrize( "field", @@ -629,10 +627,10 @@ def test_required_fields( response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400, response.json - assert response.json["detail"] == f"{{'{field}': ['Missing data for required field.']}}" + assert response.json()["detail"] == f"{{'{field}': ['Missing data for required field.']}}" @pytest.mark.usefixtures("autoclean_admin_user") def test_username_can_be_updated(self, autoclean_user_payload, autoclean_username): @@ -641,10 +639,10 @@ def test_username_can_be_updated(self, autoclean_user_payload, autoclean_usernam response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) _delete_user(username=testusername) - assert response.json["username"] == testusername + assert response.json()["username"] == testusername @pytest.mark.usefixtures("autoclean_admin_user") @unittest.mock.patch( @@ -661,10 +659,10 @@ def test_password_hashed( response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json - assert "password" not in response.json + assert response.status_code == 200, response.json() + assert "password" not in response.json() mock_generate_password_hash.assert_called_once_with("new-pass") @@ -680,10 +678,10 @@ def test_replace_roles(self, autoclean_username, autoclean_user_payload): response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}?update_mask=roles", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json - assert {d["name"] for d in response.json["roles"]} == {"User", "Viewer"} + assert response.status_code == 200, response.json() + assert {d["name"] for d in response.json()["roles"]} == {"User", "Viewer"} @pytest.mark.usefixtures("autoclean_admin_user") def test_unchanged(self, autoclean_username, autoclean_user_payload): @@ -691,12 +689,12 @@ def test_unchanged(self, autoclean_username, autoclean_user_payload): response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() expected = {k: v for k, v in autoclean_user_payload.items() if k != "password"} - assert {k: response.json[k] for k in expected} == expected + assert {k: response.json()[k] for k in expected} == expected @pytest.mark.usefixtures("autoclean_admin_user") def test_unauthenticated(self, autoclean_username, autoclean_user_payload): @@ -704,25 +702,25 @@ def test_unauthenticated(self, autoclean_username, autoclean_user_payload): f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, ) - assert response.status_code == 401, response.json + assert response.status_code == 401, response.json() @pytest.mark.usefixtures("autoclean_admin_user") def test_forbidden(self, autoclean_username, autoclean_user_payload): response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) - assert response.status_code == 403, response.json + assert response.status_code == 403, response.json() def test_not_found(self, autoclean_username, autoclean_user_payload): # This test does not populate autoclean_admin_user into the database. response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 404, response.json + assert response.status_code == 404, response.json() @pytest.mark.parametrize( "payload_converter, error_message", @@ -760,10 +758,10 @@ def test_invalid_payload( response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=payload_converter(autoclean_user_payload), - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 400, response.json - assert response.json == { + assert response.status_code == 400, response.json() + assert response.json() == { "detail": error_message, "status": 400, "title": "Bad Request", @@ -776,9 +774,9 @@ class TestDeleteUser(TestUserEndpoint): def test_delete(self, autoclean_username): response = self.client.delete( f"/auth/fab/v1/users/{autoclean_username}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 204, response.json # NO CONTENT. + assert response.status_code == 204, response.json() # NO CONTENT. assert self.session.query(count(User.id)).filter(User.username == autoclean_username).scalar() == 0 @pytest.mark.usefixtures("autoclean_admin_user") @@ -786,22 +784,22 @@ def test_unauthenticated(self, autoclean_username): response = self.client.delete( f"/auth/fab/v1/users/{autoclean_username}", ) - assert response.status_code == 401, response.json + assert response.status_code == 401, response.json() assert self.session.query(count(User.id)).filter(User.username == autoclean_username).scalar() == 1 @pytest.mark.usefixtures("autoclean_admin_user") def test_forbidden(self, autoclean_username): response = self.client.delete( f"/auth/fab/v1/users/{autoclean_username}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) - assert response.status_code == 403, response.json + assert response.status_code == 403, response.json() assert self.session.query(count(User.id)).filter(User.username == autoclean_username).scalar() == 1 def test_not_found(self, autoclean_username): # This test does not populate autoclean_admin_user into the database. response = self.client.delete( f"/auth/fab/v1/users/{autoclean_username}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 404, response.json + assert response.status_code == 404, response.json() diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py index 265407622e269..d4a1e117371f8 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py @@ -37,24 +37,25 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api + connexion_app = minimal_app_for_auth_api create_role( - app, + connexion_app.app, name="TestRole", permissions=[], ) - yield app + yield connexion_app - delete_role(app, "TestRole") # type:ignore + delete_role(connexion_app.app, "TestRole") # type:ignore class TestUserBase: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - self.role = self.app.appbuilder.sm.find_role("TestRole") - self.session = self.app.appbuilder.get_session + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore + self.role = self.flask_app.appbuilder.sm.find_role("TestRole") + self.session = self.flask_app.appbuilder.get_session def teardown_method(self): user = self.session.query(User).filter(User.email == TEST_EMAIL).first() diff --git a/tests/providers/fab/auth_manager/conftest.py b/tests/providers/fab/auth_manager/conftest.py index 6b4feb143f4b5..e39ab78448afd 100644 --- a/tests/providers/fab/auth_manager/conftest.py +++ b/tests/providers/fab/auth_manager/conftest.py @@ -29,14 +29,16 @@ def minimal_app_for_auth_api(): skip_all_except=[ "init_appbuilder", "init_api_experimental_auth", - "init_api_auth_provider", + "init_api_auth_manager", "init_api_error_handlers", ] ) def factory(): with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}): - _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore - _app.config["AUTH_ROLE_PUBLIC"] = None + _app = app.create_connexion_app( + testing=True, + config={"WTF_CSRF_ENABLED": False, "AUTH_ROLE_PUBLIC": None}, + ) # type:ignore return _app return factory() @@ -45,9 +47,9 @@ def factory(): @pytest.fixture def set_auto_role_public(request): app = request.getfixturevalue("minimal_app_for_auth_api") - auto_role_public = app.config["AUTH_ROLE_PUBLIC"] - app.config["AUTH_ROLE_PUBLIC"] = request.param + auto_role_public = app.app.config["AUTH_ROLE_PUBLIC"] + app.app.config["AUTH_ROLE_PUBLIC"] = request.param yield - app.config["AUTH_ROLE_PUBLIC"] = auto_role_public + app.app.config["AUTH_ROLE_PUBLIC"] = auto_role_public diff --git a/tests/providers/fab/auth_manager/decorators/test_auth.py b/tests/providers/fab/auth_manager/decorators/test_auth.py index 98f77a4f34271..26b3a54a1b951 100644 --- a/tests/providers/fab/auth_manager/decorators/test_auth.py +++ b/tests/providers/fab/auth_manager/decorators/test_auth.py @@ -34,7 +34,7 @@ @pytest.fixture(scope="module") def app(): - return application.create_app(testing=True) + return application.create_connexion_app(testing=True) @pytest.fixture @@ -59,7 +59,7 @@ def mock_auth_manager(mock_sm): @pytest.fixture def mock_app(mock_appbuilder): app = Mock() - app.appbuilder = mock_appbuilder + app.app.appbuilder = mock_appbuilder return app @@ -82,11 +82,11 @@ def setup_method(self) -> None: def test_requires_access_fab_sync_resource_permissions( self, mock_get_auth_manager, mock_sm, mock_appbuilder, mock_auth_manager, app ): - app.appbuilder = mock_appbuilder + app.app.appbuilder = mock_appbuilder mock_appbuilder.update_perms = True mock_get_auth_manager.return_value = mock_auth_manager - with app.test_request_context(): + with app.app.test_request_context(): @_requires_access_fab() def decorated_requires_access_fab(): @@ -102,7 +102,7 @@ def test_requires_access_fab_access_denied( mock_sm.check_authorization.return_value = False mock_get_auth_manager.return_value = mock_auth_manager - with app.test_request_context(): + with app.app.test_request_context(): @_requires_access_fab(permissions) def decorated_requires_access_fab(): @@ -123,7 +123,7 @@ def test_requires_access_fab_access_granted( mock_sm.check_authorization.return_value = True mock_get_auth_manager.return_value = mock_auth_manager - with app.test_request_context(): + with app.app.test_request_context(): @_requires_access_fab(permissions) def decorated_requires_access_fab(): @@ -137,8 +137,8 @@ def decorated_requires_access_fab(): @patch("airflow.providers.fab.auth_manager.decorators.auth._has_access") def test_has_access_fab_with_no_dags(self, mock_has_access, mock_sm, mock_appbuilder, app): - app.appbuilder = mock_appbuilder - with app.test_request_context(): + app.app.appbuilder = mock_appbuilder + with app.app.test_request_context(): decorated_has_access_fab() mock_sm.check_authorization.assert_called_once_with(permissions, None) @@ -149,8 +149,8 @@ def test_has_access_fab_with_no_dags(self, mock_has_access, mock_sm, mock_appbui def test_has_access_fab_with_multiple_dags_render_error( self, mock_has_access, mock_render_template, mock_sm, mock_appbuilder, app ): - app.appbuilder = mock_appbuilder - with app.test_request_context() as mock_context: + app.app.appbuilder = mock_appbuilder + with app.app.test_request_context() as mock_context: mock_context.request.args = {"dag_id": "dag1"} mock_context.request.form = {"dag_id": "dag2"} decorated_has_access_fab() diff --git a/tests/providers/fab/auth_manager/test_security.py b/tests/providers/fab/auth_manager/test_security.py index d3238ba0f1734..46b5794a1dbf8 100644 --- a/tests/providers/fab/auth_manager/test_security.py +++ b/tests/providers/fab/auth_manager/test_security.py @@ -171,17 +171,17 @@ def clear_db_before_test(): @pytest.fixture(scope="module") def app(): - _app = application.create_app(testing=True) - _app.config["WTF_CSRF_ENABLED"] = False + _app = application.create_connexion_app(testing=True) + _app.app.config["WTF_CSRF_ENABLED"] = False return _app @pytest.fixture(scope="module") def app_builder(app): - app_builder = app.appbuilder + app_builder = app.app.appbuilder app_builder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews") app_builder.add_view(SomeModelView, "SomeModelView", category="ModelViews") - return app.appbuilder + return app.app.appbuilder @pytest.fixture(scope="module") @@ -196,7 +196,7 @@ def session(app_builder): @pytest.fixture(scope="module") def db(app): - return SQLA(app) + return SQLA(app.app) @pytest.fixture @@ -208,7 +208,7 @@ def role(request, app, security_manager): security_manager.bulk_sync_roles(params["mock_roles"]) _role = security_manager.find_role(params["name"]) yield _role, params - delete_role(app, params["name"]) + delete_role(app.app, params["name"]) @pytest.fixture @@ -349,10 +349,10 @@ def test_verify_public_role_has_no_permissions(security_manager): def test_verify_default_anon_user_has_no_accessible_dag_ids( mock_is_logged_in, app, session, security_manager ): - with app.app_context(): + with app.app.app_context(): mock_is_logged_in.return_value = False user = AnonymousUser() - app.config["AUTH_ROLE_PUBLIC"] = "Public" + app.app.config["AUTH_ROLE_PUBLIC"] = "Public" assert security_manager.get_user_roles(user) == {security_manager.get_public_role()} with _create_dag_model_context("test_dag_id", session, security_manager): @@ -362,9 +362,9 @@ def test_verify_default_anon_user_has_no_accessible_dag_ids( def test_verify_default_anon_user_has_no_access_to_specific_dag(app, session, security_manager, has_dag_perm): - with app.app_context(): + with app.app.app_context(): user = AnonymousUser() - app.config["AUTH_ROLE_PUBLIC"] = "Public" + app.app.config["AUTH_ROLE_PUBLIC"] = "Public" assert security_manager.get_user_roles(user) == {security_manager.get_public_role()} dag_id = "test_dag_id" @@ -387,8 +387,8 @@ def test_verify_anon_user_with_admin_role_has_all_dag_access( mock_is_logged_in, app, security_manager, mock_dag_models ): test_dag_ids = mock_dag_models - with app.app_context(): - app.config["AUTH_ROLE_PUBLIC"] = "Admin" + with app.app.app_context(): + app.app.config["AUTH_ROLE_PUBLIC"] = "Admin" mock_is_logged_in.return_value = False user = AnonymousUser() @@ -402,9 +402,9 @@ def test_verify_anon_user_with_admin_role_has_all_dag_access( def test_verify_anon_user_with_admin_role_has_access_to_each_dag( app, session, security_manager, has_dag_perm ): - with app.app_context(): + with app.app.app_context(): user = AnonymousUser() - app.config["AUTH_ROLE_PUBLIC"] = "Admin" + app.app.config["AUTH_ROLE_PUBLIC"] = "Admin" # Call `.get_user_roles` bc `user` is a mock and the `user.roles` prop needs to be set. user.roles = security_manager.get_user_roles(user) @@ -462,9 +462,9 @@ def test_get_user_roles_for_anonymous_user(app, security_manager): (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS_MENU), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS), } - app.config["AUTH_ROLE_PUBLIC"] = "Viewer" + app.app.config["AUTH_ROLE_PUBLIC"] = "Viewer" - with app.app_context(): + with app.app.app_context(): user = AnonymousUser() perms_views = set() @@ -477,9 +477,9 @@ def test_get_current_user_permissions(app): action = "can_some_action" resource = "SomeBaseView" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="get_current_user_permissions", role_name="MyRole5", permissions=[ @@ -489,7 +489,7 @@ def test_get_current_user_permissions(app): assert user.perms == {(action, resource)} with create_user_scope( - app, + app.app, username="no_perms", ) as user: assert len(user.perms) == 0 @@ -502,9 +502,9 @@ def test_get_accessible_dag_ids(mock_is_logged_in, app, security_manager, sessio dag_id = "dag_id" username = "ElUser" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username=username, role_name=role_name, permissions=[ @@ -534,9 +534,9 @@ def test_dont_get_inaccessible_dag_ids_for_dag_resource_permission( role_name = "MyRole1" permission_action = [permissions.ACTION_CAN_EDIT] dag_id = "dag_id" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username=username, role_name=role_name, permissions=[ @@ -575,9 +575,9 @@ def test_sync_perm_for_dag_creates_permissions_for_specified_roles(app, security test_dag_id = "TEST_DAG" test_role = "limited-role" security_manager.bulk_sync_roles([{"role": test_role, "perms": []}]) - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="test_user", role_name=test_role, permissions=[], @@ -594,9 +594,9 @@ def test_sync_perm_for_dag_removes_existing_permissions_if_empty(app, security_m test_dag_id = "TEST_DAG" test_role = "limited-role" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="test_user", role_name=test_role, permissions=[], @@ -632,9 +632,9 @@ def test_sync_perm_for_dag_removes_permissions_from_other_roles(app, security_ma test_dag_id = "TEST_DAG" test_role = "limited-role" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="test_user", role_name=test_role, permissions=[], @@ -671,9 +671,9 @@ def test_sync_perm_for_dag_does_not_prune_roles_when_access_control_unset(app, s test_dag_id = "TEST_DAG" test_role = "limited-role" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="test_user", role_name=test_role, permissions=[], @@ -704,35 +704,35 @@ def test_sync_perm_for_dag_does_not_prune_roles_when_access_control_unset(app, s def test_has_all_dag_access(app, security_manager): for role_name in ["Admin", "Viewer", "Op", "User"]: - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="user", role_name=role_name, ) as user: assert _has_all_dags_access(user) - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="user", role_name="read_all", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)], ) as user: assert _has_all_dags_access(user) - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="user", role_name="edit_all", permissions=[(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)], ) as user: assert _has_all_dags_access(user) - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="user", role_name="nada", permissions=[], @@ -754,9 +754,9 @@ def test_access_control_with_non_existent_role(security_manager): def test_all_dag_access_doesnt_give_non_dag_access(app, security_manager): username = "dag_access_user" role_name = "dag_access_role" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username=username, role_name=role_name, permissions=[ @@ -778,7 +778,7 @@ def test_access_control_with_invalid_permission(app, security_manager): username = "LaUser" rolename = "team-a" with create_user_scope( - app, + app.app, username=username, role_name=rolename, ): @@ -800,9 +800,9 @@ def test_access_control_is_set_on_init( username = "access_control_is_set_on_init" role_name = "team-a" negated_role = "NOT-team-a" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username=username, role_name=role_name, permissions=[], @@ -818,7 +818,7 @@ def test_access_control_is_set_on_init( ) security_manager.bulk_sync_roles([{"role": negated_role, "perms": []}]) - set_user_single_role(app, user, role_name=negated_role) + set_user_single_role(app.app, user, role_name=negated_role) assert_user_does_not_have_dag_perms( perms=["PUT", "GET"], dag_id="access_control_test", @@ -834,14 +834,14 @@ def test_access_control_stale_perms_are_revoked( ): username = "access_control_stale_perms_are_revoked" role_name = "team-a" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username=username, role_name=role_name, permissions=[], ) as user: - set_user_single_role(app, user, role_name="team-a") + set_user_single_role(app.app, user, role_name="team-a") security_manager._sync_dag_view_permissions( "access_control_test", access_control={"team-a": READ_WRITE} ) @@ -985,7 +985,7 @@ def test_parent_dag_access_applies_to_subdag(app, security_manager, assert_user_ parent_dag_name = "parent_dag" subdag_name = parent_dag_name + ".subdag" subsubdag_name = parent_dag_name + ".subdag.subsubdag" - with app.app_context(): + with app.app.app_context(): mock_roles = [ { "role": role_name, @@ -996,7 +996,7 @@ def test_parent_dag_access_applies_to_subdag(app, security_manager, assert_user_ } ] with create_user_scope( - app, + app.app, username=username, role_name=role_name, ) as user: @@ -1026,7 +1026,7 @@ def test_permissions_work_for_dags_with_dot_in_dagname( role_name = "dag_permission_role" dag_id = "dag_id_1" dag_id_2 = "dag_id_1.with_dot" - with app.app_context(): + with app.app.app_context(): mock_roles = [ { "role": role_name, @@ -1037,7 +1037,7 @@ def test_permissions_work_for_dags_with_dot_in_dagname( } ] with create_user_scope( - app, + app.app, username=username, role_name=role_name, ) as user: @@ -1126,14 +1126,14 @@ def test_update_user_auth_stat_subsequent_unsuccessful_auth(mock_security_manage def test_users_can_be_found(app, security_manager, session, caplog): """Test that usernames are case insensitive""" - create_user(app, "Test") - create_user(app, "test") - create_user(app, "TEST") - create_user(app, "TeSt") + create_user(app.app, "Test") + create_user(app.app, "test") + create_user(app.app, "TEST") + create_user(app.app, "TeSt") assert security_manager.find_user("Test") users = security_manager.get_all_users() assert len(users) == 1 - delete_user(app, "Test") + delete_user(app.app, "Test") assert "Error adding new user to database" in caplog.text @@ -1183,7 +1183,7 @@ def test_dag_id_consistency( dag_id_json: str | None, fail: bool, ): - with app.test_request_context() as mock_context: + with app.app.test_request_context() as mock_context: from airflow.www.auth import has_access_dag mock_context.request.args = {"dag_id": dag_id_args} if dag_id_args else {} @@ -1194,7 +1194,7 @@ def test_dag_id_consistency( mock_context.request._parsed_content_type = ["application/json"] with create_user_scope( - app, + app.app, username="test-user", role_name="limited-role", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)], diff --git a/tests/providers/fab/auth_manager/views/test_permissions.py b/tests/providers/fab/auth_manager/views/test_permissions.py index 0b1073df287fa..80cb48016aa92 100644 --- a/tests/providers/fab/auth_manager/views/test_permissions.py +++ b/tests/providers/fab/auth_manager/views/test_permissions.py @@ -32,13 +32,13 @@ @pytest.fixture(scope="module") def fab_app(): - return application.create_app(testing=True) + return application.create_connexion_app(testing=True) @pytest.fixture(scope="module") def user_permissions_reader(fab_app): yield create_user( - fab_app, + fab_app.app, username="user_permissions", role_name="role_permissions", permissions=[ @@ -49,12 +49,12 @@ def user_permissions_reader(fab_app): ], ) - delete_user(fab_app, "user_permissions") + delete_user(fab_app.app, "user_permissions") @pytest.fixture def client_permissions_reader(fab_app, user_permissions_reader): - fab_app.config["WTF_CSRF_ENABLED"] = False + fab_app.app.config["WTF_CSRF_ENABLED"] = False return client_with_login( fab_app, username="user_permissions", diff --git a/tests/providers/fab/auth_manager/views/test_roles_list.py b/tests/providers/fab/auth_manager/views/test_roles_list.py index 156f07df41209..b2efcc5b6850f 100644 --- a/tests/providers/fab/auth_manager/views/test_roles_list.py +++ b/tests/providers/fab/auth_manager/views/test_roles_list.py @@ -23,7 +23,7 @@ from airflow.www import app as application from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS -from tests.test_utils.www import client_with_login +from tests.test_utils.www import flask_client_with_login pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.9.0+ only"), @@ -32,13 +32,13 @@ @pytest.fixture(scope="module") def fab_app(): - return application.create_app(testing=True) + return application.create_connexion_app(testing=True) @pytest.fixture(scope="module") def user_roles_reader(fab_app): yield create_user( - fab_app, + fab_app.app, username="user_roles", role_name="role_roles", permissions=[ @@ -47,13 +47,13 @@ def user_roles_reader(fab_app): ], ) - delete_user(fab_app, "user_roles") + delete_user(fab_app.app, "user_roles") @pytest.fixture def client_roles_reader(fab_app, user_roles_reader): - fab_app.config["WTF_CSRF_ENABLED"] = False - return client_with_login( + fab_app.app.config["WTF_CSRF_ENABLED"] = False + return flask_client_with_login( fab_app, username="user_roles_reader", password="user_roles_reader", diff --git a/tests/providers/fab/auth_manager/views/test_user.py b/tests/providers/fab/auth_manager/views/test_user.py index 6660ab926d886..11b40c66b6960 100644 --- a/tests/providers/fab/auth_manager/views/test_user.py +++ b/tests/providers/fab/auth_manager/views/test_user.py @@ -23,7 +23,7 @@ from airflow.www import app as application from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS -from tests.test_utils.www import client_with_login +from tests.test_utils.www import flask_client_with_login pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.9.0+ only"), @@ -32,13 +32,13 @@ @pytest.fixture(scope="module") def fab_app(): - return application.create_app(testing=True) + return application.create_connexion_app(testing=True) @pytest.fixture(scope="module") def user_user_reader(fab_app): yield create_user( - fab_app, + fab_app.app, username="user_user", role_name="role_user", permissions=[ @@ -47,13 +47,13 @@ def user_user_reader(fab_app): ], ) - delete_user(fab_app, "user_user") + delete_user(fab_app.app, "user_user") @pytest.fixture def client_user_reader(fab_app, user_user_reader): - fab_app.config["WTF_CSRF_ENABLED"] = False - return client_with_login( + fab_app.app.config["WTF_CSRF_ENABLED"] = False + return flask_client_with_login( fab_app, username="user_user_reader", password="user_user_reader", diff --git a/tests/providers/fab/auth_manager/views/test_user_edit.py b/tests/providers/fab/auth_manager/views/test_user_edit.py index 65937b6f83d33..d19412f1d3899 100644 --- a/tests/providers/fab/auth_manager/views/test_user_edit.py +++ b/tests/providers/fab/auth_manager/views/test_user_edit.py @@ -23,7 +23,7 @@ from airflow.www import app as application from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS -from tests.test_utils.www import client_with_login +from tests.test_utils.www import flask_client_with_login pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.9.0+ only"), @@ -32,13 +32,13 @@ @pytest.fixture(scope="module") def fab_app(): - return application.create_app(testing=True) + return application.create_connexion_app(testing=True) @pytest.fixture(scope="module") def user_user_reader(fab_app): yield create_user( - fab_app, + fab_app.app, username="user_user", role_name="role_user", permissions=[ @@ -47,13 +47,13 @@ def user_user_reader(fab_app): ], ) - delete_user(fab_app, "user_user") + delete_user(fab_app.app, "user_user") @pytest.fixture def client_user_reader(fab_app, user_user_reader): - fab_app.config["WTF_CSRF_ENABLED"] = False - return client_with_login( + fab_app.app.config["WTF_CSRF_ENABLED"] = False + return flask_client_with_login( fab_app, username="user_user_reader", password="user_user_reader", diff --git a/tests/providers/fab/auth_manager/views/test_user_stats.py b/tests/providers/fab/auth_manager/views/test_user_stats.py index 8cb260fcf1ec4..ab79afdd687d4 100644 --- a/tests/providers/fab/auth_manager/views/test_user_stats.py +++ b/tests/providers/fab/auth_manager/views/test_user_stats.py @@ -23,7 +23,7 @@ from airflow.www import app as application from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS -from tests.test_utils.www import client_with_login +from tests.test_utils.www import flask_client_with_login pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.9.0+ only"), @@ -32,13 +32,13 @@ @pytest.fixture(scope="module") def fab_app(): - return application.create_app(testing=True) + return application.create_connexion_app(testing=True) @pytest.fixture(scope="module") def user_user_stats_reader(fab_app): yield create_user( - fab_app, + fab_app.app, username="user_user_stats", role_name="role_user_stats", permissions=[ @@ -47,13 +47,13 @@ def user_user_stats_reader(fab_app): ], ) - delete_user(fab_app, "user_user_stats") + delete_user(fab_app.app, "user_user_stats") @pytest.fixture def client_user_stats_reader(fab_app, user_user_stats_reader): - fab_app.config["WTF_CSRF_ENABLED"] = False - return client_with_login( + fab_app.app.config["WTF_CSRF_ENABLED"] = False + return flask_client_with_login( fab_app, username="user_user_stats_reader", password="user_user_stats_reader", @@ -63,5 +63,5 @@ def client_user_stats_reader(fab_app, user_user_stats_reader): @pytest.mark.db_test class TestUserStats: def test_user_stats(self, client_user_stats_reader): - resp = client_user_stats_reader.get("/userstatschartview/chart", follow_redirects=True) + resp = client_user_stats_reader.get("/userstatschartview/chart/", follow_redirects=True) assert resp.status_code == 200 diff --git a/tests/providers/google/common/auth_backend/test_google_openid.py b/tests/providers/google/common/auth_backend/test_google_openid.py index d11613b5cf9f3..61637f532a85b 100644 --- a/tests/providers/google/common/auth_backend/test_google_openid.py +++ b/tests/providers/google/common/auth_backend/test_google_openid.py @@ -22,24 +22,28 @@ from flask_login import current_user from google.auth.exceptions import GoogleAuthError -from airflow.www.app import create_app +from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_pools +pytestmark = pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="This test is for Airflow 2.10+") + @pytest.fixture(scope="module") def google_openid_app(): + from airflow.www.app import create_connexion_app + confs = { ("api", "auth_backends"): "airflow.providers.google.common.auth_backend.google_openid", ("api", "enable_experimental_api"): "true", } with conf_vars(confs): - return create_app(testing=True) + return create_connexion_app(testing=True) @pytest.fixture(scope="module") def admin_user(google_openid_app): - appbuilder = google_openid_app.appbuilder + appbuilder = google_openid_app.app.appbuilder role_admin = appbuilder.sm.find_role("Admin") tester = appbuilder.sm.find_user(username="test") if not tester: @@ -58,7 +62,7 @@ def admin_user(google_openid_app): class TestGoogleOpenID: @pytest.fixture(autouse=True) def _set_attrs(self, google_openid_app, admin_user) -> None: - self.app = google_openid_app + self.connexion_app = google_openid_app self.admin_user = admin_user @mock.patch("google.oauth2.id_token.verify_token") @@ -70,7 +74,7 @@ def test_success(self, mock_verify_token): "email": "test@fab.org", } - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get( "/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"} ) @@ -88,7 +92,7 @@ def test_malformed_headers(self, mock_verify_token, auth_header): "email": "test@fab.org", } - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get("/api/experimental/pools", headers={"Authorization": auth_header}) assert 403 == response.status_code @@ -102,7 +106,7 @@ def test_invalid_iss_in_jwt_token(self, mock_verify_token): "email": "test@fab.org", } - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get( "/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"} ) @@ -118,7 +122,7 @@ def test_user_not_exists(self, mock_verify_token): "email": "invalid@fab.org", } - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get( "/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"} ) @@ -128,7 +132,7 @@ def test_user_not_exists(self, mock_verify_token): @conf_vars({("api", "auth_backends"): "airflow.providers.google.common.auth_backend.google_openid"}) def test_missing_id_token(self): - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get("/api/experimental/pools") assert 403 == response.status_code @@ -139,7 +143,7 @@ def test_missing_id_token(self): def test_invalid_id_token(self, mock_verify_token): mock_verify_token.side_effect = GoogleAuthError("Invalid token") - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get( "/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"} ) diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 28db24305faa6..fd14ed727980c 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -1094,8 +1094,8 @@ def test_external_task_sensor_extra_link( assert ti.task.external_task_id == expected_external_task_id assert ti.task.external_task_ids == [expected_external_task_id] - app.config["SERVER_NAME"] = "" - with app.app_context(): + app.app.config["SERVER_NAME"] = "" + with app.app.app_context(): url = ti.task.get_extra_links(ti, "External DAG") assert f"/dags/{expected_external_dag_id}/grid" in url diff --git a/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py index 44c0bcecc3b49..87072f8a776d2 100644 --- a/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py +++ b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py @@ -24,9 +24,12 @@ from airflow.www import app as application from tests.system.providers.amazon.aws.utils import set_env_id +from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.www import check_content_in_response +pytestmark = pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Test requires Airflow 2.10+") + pytest.importorskip("onelogin") SAML_METADATA_URL = "/saml/metadata" @@ -147,7 +150,7 @@ def client_no_permissions(base_app): "email": ["email"], } base_app.return_value = auth - return application.create_app(testing=True) + return application.create_connexion_app(testing=True) @pytest.fixture @@ -160,7 +163,7 @@ def client_admin_permissions(base_app): "groups": ["Admin"], } base_app.return_value = auth - return application.create_app(testing=True) + return application.create_connexion_app(testing=True) @pytest.mark.system("amazon") diff --git a/tests/test_utils/api_connexion_utils.py b/tests/test_utils/api_connexion_utils.py index af746b2d55468..bb3af262960ea 100644 --- a/tests/test_utils/api_connexion_utils.py +++ b/tests/test_utils/api_connexion_utils.py @@ -124,7 +124,7 @@ def delete_users(app): def assert_401(response): assert response.status_code == 401, f"Current code: {response.status_code}" - assert response.json == { + assert response.json() == { "detail": None, "status": 401, "title": "Unauthorized", diff --git a/tests/test_utils/decorators.py b/tests/test_utils/decorators.py index 5b028c694a8c6..49bf0dc7a63de 100644 --- a/tests/test_utils/decorators.py +++ b/tests/test_utils/decorators.py @@ -19,8 +19,6 @@ import functools from unittest.mock import patch -from airflow.www.app import purge_cached_app - def dont_initialize_flask_app_submodules(_func=None, *, skip_all_except=None): if not skip_all_except: @@ -40,7 +38,7 @@ def no_op(*args, **kwargs): "init_api_connexion", "init_api_internal", "init_api_experimental", - "init_api_auth_provider", + "init_api_auth_manager", "init_api_error_handlers", "init_jinja_globals", "init_xframe_protection", @@ -55,10 +53,12 @@ def func(*args, **kwargs): if method not in skip_all_except: patcher = patch(f"airflow.www.app.{method}", no_op) patcher.start() - purge_cached_app() + from airflow.www.app import purge_cached_connexion_app + + purge_cached_connexion_app() result = f(*args, **kwargs) patch.stopall() - purge_cached_app() + purge_cached_connexion_app() return result diff --git a/tests/test_utils/mock_cors_middeleware.py b/tests/test_utils/mock_cors_middeleware.py new file mode 100644 index 0000000000000..211f46a44639b --- /dev/null +++ b/tests/test_utils/mock_cors_middeleware.py @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import connexion + +from airflow.configuration import conf + + +def init_mock_cors_middleware(connexion_app: connexion.FlaskApp, allow_origins: list): + from starlette.middleware.cors import CORSMiddleware + + connexion_app.add_middleware( + CORSMiddleware, + connexion.middleware.MiddlewarePosition.BEFORE_ROUTING, + allow_origins=allow_origins, + allow_credentials=True, + allow_methods=conf.get("api", "access_control_allow_methods"), + allow_headers=conf.get("api", "access_control_allow_headers"), + ) diff --git a/tests/test_utils/remote_user_api_auth_backend.py b/tests/test_utils/remote_user_api_auth_backend.py index b7714e5192e6a..5be8a2bf9da0a 100644 --- a/tests/test_utils/remote_user_api_auth_backend.py +++ b/tests/test_utils/remote_user_api_auth_backend.py @@ -62,7 +62,7 @@ def requires_authentication(function: T): @wraps(function) def decorated(*args, **kwargs): - user_id = request.remote_user + user_id = request.headers.get("REMOTE-USER") if not user_id: log.debug("Missing REMOTE_USER.") return Response("Forbidden", 403) diff --git a/tests/test_utils/www.py b/tests/test_utils/www.py index 0a19c312fba4e..d8ff0f1abaf65 100644 --- a/tests/test_utils/www.py +++ b/tests/test_utils/www.py @@ -23,19 +23,29 @@ from airflow.models import Log -def client_with_login(app, expected_response_code=302, **kwargs): +def client_with_login(app, expected_path=b"/home", **kwargs): patch_path = "airflow.providers.fab.auth_manager.security_manager.override.check_password_hash" with mock.patch(patch_path) as check_password_hash: check_password_hash.return_value = True client = app.test_client() resp = client.post("/login/", data=kwargs) + assert resp.url.raw_path == expected_path + return client + + +def flask_client_with_login(app, expected_response_code=302, **kwargs): + patch_path = "airflow.providers.fab.auth_manager.security_manager.override.check_password_hash" + with mock.patch(patch_path) as check_password_hash: + check_password_hash.return_value = True + client = app.app.test_client() + resp = client.post("/login/", data=kwargs) assert resp.status_code == expected_response_code return client def client_without_login(app): # Anonymous users can only view if AUTH_ROLE_PUBLIC is set to non-Public - app.config["AUTH_ROLE_PUBLIC"] = "Viewer" + app.app.config["AUTH_ROLE_PUBLIC"] = "Viewer" client = app.test_client() return client @@ -48,7 +58,7 @@ def client_without_login_as_admin(app): def check_content_in_response(text, resp, resp_code=200): - resp_html = resp.data.decode("utf-8") + resp_html = resp.text assert resp_code == resp.status_code if isinstance(text, list): for line in text: @@ -58,7 +68,7 @@ def check_content_in_response(text, resp, resp_code=200): def check_content_not_in_response(text, resp, resp_code=200): - resp_html = resp.data.decode("utf-8") + resp_html = resp.text assert resp_code == resp.status_code if isinstance(text, list): for line in text: diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 27ef5a76db5b7..356ec5178adbc 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -171,8 +171,8 @@ def test_build_airflow_url_with_query(self): """ Test query generated with dag_id and params """ - query = {"dag_id": "test_dag", "param": "key/to.encode"} - expected_url = "/dags/test_dag/graph?param=key%2Fto.encode" + query = {"dag_id": "test_dag", "param": "key to.encode"} + expected_url = "/dags/test_dag/graph?param=key+to.encode" from airflow.www.app import cached_app diff --git a/tests/www/api/experimental/conftest.py b/tests/www/api/experimental/conftest.py index 59c6e13357c85..32a0400502f6c 100644 --- a/tests/www/api/experimental/conftest.py +++ b/tests/www/api/experimental/conftest.py @@ -39,11 +39,11 @@ def experiemental_api_app(): ] ) def factory(): - app = application.create_app(testing=True) - app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" - app.config["SECRET_KEY"] = "secret_key" - app.config["CSRF_ENABLED"] = False - app.config["WTF_CSRF_ENABLED"] = False + app = application.create_connexion_app(testing=True) + app.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" + app.app.config["SECRET_KEY"] = "secret_key" + app.app.config["CSRF_ENABLED"] = False + app.app.config["WTF_CSRF_ENABLED"] = False return app return factory() diff --git a/tests/www/api/experimental/test_dag_runs_endpoint.py b/tests/www/api/experimental/test_dag_runs_endpoint.py index 9f4bbf30bc41e..246e58b0df8fb 100644 --- a/tests/www/api/experimental/test_dag_runs_endpoint.py +++ b/tests/www/api/experimental/test_dag_runs_endpoint.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -import json - import pytest from airflow.api.common.trigger_dag import trigger_dag @@ -59,7 +57,7 @@ def test_get_dag_runs_success(self): response = self.app.get(url_template.format(dag_id)) assert 200 == response.status_code - data = json.loads(response.data.decode("utf-8")) + data = response.json() assert isinstance(data, list) assert len(data) == 1 @@ -74,7 +72,7 @@ def test_get_dag_runs_success_with_state_parameter(self): response = self.app.get(url_template.format(dag_id)) assert 200 == response.status_code - data = json.loads(response.data.decode("utf-8")) + data = response.json() assert isinstance(data, list) assert len(data) == 1 @@ -89,7 +87,7 @@ def test_get_dag_runs_success_with_capital_state_parameter(self): response = self.app.get(url_template.format(dag_id)) assert 200 == response.status_code - data = json.loads(response.data.decode("utf-8")) + data = response.json() assert isinstance(data, list) assert len(data) == 1 @@ -102,8 +100,8 @@ def test_get_dag_runs_success_with_state_no_result(self): # Create DagRun trigger_dag(dag_id=dag_id, run_id="test_get_dag_runs_success") - with pytest.raises(ValueError): - self.app.get(url_template.format(dag_id)) + resp = self.app.get(url_template.format(dag_id)) + assert 500 == resp.status_code def test_get_dag_runs_invalid_dag_id(self): url_template = "/api/experimental/dags/{}/dag_runs" @@ -111,7 +109,7 @@ def test_get_dag_runs_invalid_dag_id(self): response = self.app.get(url_template.format(dag_id)) assert 400 == response.status_code - data = json.loads(response.data.decode("utf-8")) + data = response.json() assert not isinstance(data, list) @@ -121,7 +119,7 @@ def test_get_dag_runs_no_runs(self): response = self.app.get(url_template.format(dag_id)) assert 200 == response.status_code - data = json.loads(response.data.decode("utf-8")) + data = response.json() assert isinstance(data, list) assert len(data) == 0 diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py index d78bc8fb37232..c7ac0abe5e0c7 100644 --- a/tests/www/api/experimental/test_endpoints.py +++ b/tests/www/api/experimental/test_endpoints.py @@ -53,7 +53,7 @@ class TestBase: @pytest.fixture(autouse=True) def _setup_attrs_base(self, experiemental_api_app, configured_session): self.app = experiemental_api_app - self.appbuilder = self.app.appbuilder + self.appbuilder = self.app.app.appbuilder self.client = self.app.test_client() self.session = configured_session @@ -92,7 +92,7 @@ def test_info(self): url = "/api/experimental/info" resp_raw = self.client.get(url) - resp = json.loads(resp_raw.data.decode("utf-8")) + resp = resp_raw.json() assert version == resp["version"] self.assert_deprecated(resp_raw) @@ -103,16 +103,16 @@ def test_task_info(self): response = self.client.get(url_template.format("example_bash_operator", "runme_0")) self.assert_deprecated(response) - assert '"email"' in response.data.decode("utf-8") - assert "error" not in response.data.decode("utf-8") + assert '"email"' in response.text + assert "error" not in response.json() assert 200 == response.status_code response = self.client.get(url_template.format("example_bash_operator", "does-not-exist")) - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() assert 404 == response.status_code response = self.client.get(url_template.format("does-not-exist", "does-not-exist")) - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() assert 404 == response.status_code def test_get_dag_code(self): @@ -120,7 +120,7 @@ def test_get_dag_code(self): response = self.client.get(url_template.format("example_bash_operator")) self.assert_deprecated(response) - assert "BashOperator(" in response.data.decode("utf-8") + assert "BashOperator(" in response.text assert 200 == response.status_code response = self.client.get(url_template.format("xyz")) @@ -133,22 +133,22 @@ def test_dag_paused(self): response = self.client.get(pause_url_template.format("example_bash_operator", "true")) self.assert_deprecated(response) - assert "ok" in response.data.decode("utf-8") + assert "ok" == response.json()["response"] assert 200 == response.status_code paused_response = self.client.get(paused_url) assert 200 == paused_response.status_code - assert {"is_paused": True} == paused_response.json + assert {"is_paused": True} == paused_response.json() response = self.client.get(pause_url_template.format("example_bash_operator", "false")) - assert "ok" in response.data.decode("utf-8") + assert "ok" in response.text assert 200 == response.status_code paused_response = self.client.get(paused_url) assert 200 == paused_response.status_code - assert {"is_paused": False} == paused_response.json + assert {"is_paused": False} == paused_response.json() def test_trigger_dag(self): url_template = "/api/experimental/dags/{}/dag_runs" @@ -156,7 +156,8 @@ def test_trigger_dag(self): # Test error for nonexistent dag response = self.client.post( - url_template.format("does_not_exist_dag"), data=json.dumps({}), content_type="application/json" + url_template.format("does_not_exist_dag"), + data=json.dumps({}), ) assert 404 == response.status_code @@ -164,7 +165,6 @@ def test_trigger_dag(self): response = self.client.post( url_template.format("example_bash_operator"), data=json.dumps({"conf": "This is a string not a dict"}), - content_type="application/json", ) assert 400 == response.status_code @@ -172,16 +172,15 @@ def test_trigger_dag(self): response = self.client.post( url_template.format("example_bash_operator"), data=json.dumps({"run_id": run_id, "conf": {"param": "value"}}), - content_type="application/json", ) self.assert_deprecated(response) assert 200 == response.status_code - response_execution_date = parse_datetime(json.loads(response.data.decode("utf-8"))["execution_date"]) + response_execution_date = parse_datetime(response.json()["execution_date"]) assert 0 == response_execution_date.microsecond # Check execution_date is correct - response = json.loads(response.data.decode("utf-8")) + response = response.json() dagbag = DagBag() dag = dagbag.get_dag("example_bash_operator") dag_run = dag.get_dagrun(response_execution_date) @@ -199,11 +198,10 @@ def test_trigger_dag_for_date(self): response = self.client.post( url_template.format(dag_id), data=json.dumps({"execution_date": datetime_string}), - content_type="application/json", ) self.assert_deprecated(response) assert 200 == response.status_code - assert datetime_string == json.loads(response.data.decode("utf-8"))["execution_date"] + assert datetime_string == response.json()["execution_date"] dagbag = DagBag() dag = dagbag.get_dag(dag_id) @@ -214,10 +212,9 @@ def test_trigger_dag_for_date(self): response = self.client.post( url_template.format(dag_id), data=json.dumps({"execution_date": datetime_string, "replace_microseconds": "true"}), - content_type="application/json", ) assert 200 == response.status_code - response_execution_date = parse_datetime(json.loads(response.data.decode("utf-8"))["execution_date"]) + response_execution_date = parse_datetime(response.json()["execution_date"]) assert 0 == response_execution_date.microsecond dagbag = DagBag() @@ -229,7 +226,6 @@ def test_trigger_dag_for_date(self): response = self.client.post( url_template.format("does_not_exist_dag"), data=json.dumps({"execution_date": datetime_string}), - content_type="application/json", ) assert 404 == response.status_code @@ -237,7 +233,6 @@ def test_trigger_dag_for_date(self): response = self.client.post( url_template.format(dag_id), data=json.dumps({"execution_date": "not_a_datetime"}), - content_type="application/json", ) assert 400 == response.status_code @@ -256,30 +251,30 @@ def test_task_instance_info(self): response = self.client.get(url_template.format(dag_id, datetime_string, task_id)) self.assert_deprecated(response) assert 200 == response.status_code - assert "state" in response.data.decode("utf-8") - assert "error" not in response.data.decode("utf-8") + assert "state" in response.json() + assert "error" not in response.json() # Test error for nonexistent dag response = self.client.get( url_template.format("does_not_exist_dag", datetime_string, task_id), ) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for nonexistent task response = self.client.get(url_template.format(dag_id, datetime_string, "does_not_exist_task")) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for nonexistent dag run (wrong execution_date) response = self.client.get(url_template.format(dag_id, wrong_datetime_string, task_id)) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for bad datetime format response = self.client.get(url_template.format(dag_id, "not_a_datetime", task_id)) assert 400 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() def test_dagrun_status(self): url_template = "/api/experimental/dags/{}/dag_runs/{}" @@ -295,25 +290,25 @@ def test_dagrun_status(self): response = self.client.get(url_template.format(dag_id, datetime_string)) self.assert_deprecated(response) assert 200 == response.status_code - assert "state" in response.data.decode("utf-8") - assert "error" not in response.data.decode("utf-8") + assert "state" in response.json() + assert "error" not in response.json() # Test error for nonexistent dag response = self.client.get( url_template.format("does_not_exist_dag", datetime_string), ) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for nonexistent dag run (wrong execution_date) response = self.client.get(url_template.format(dag_id, wrong_datetime_string)) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for bad datetime format response = self.client.get(url_template.format(dag_id, "not_a_datetime")) assert 400 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() class TestLineageApiExperimental(TestBase): @@ -354,25 +349,25 @@ def test_lineage_info(self): response = self.client.get(url_template.format(dag_id, datetime_string)) self.assert_deprecated(response) assert 200 == response.status_code - assert "task_ids" in response.data.decode("utf-8") - assert "error" not in response.data.decode("utf-8") + assert "task_ids" in response.json() + assert "error" not in response.json() # Test error for nonexistent dag response = self.client.get( url_template.format("does_not_exist_dag", datetime_string), ) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for nonexistent dag run (wrong execution_date) response = self.client.get(url_template.format(dag_id, wrong_datetime_string)) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for bad datetime format response = self.client.get(url_template.format(dag_id, "not_a_datetime")) assert 400 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() class TestPoolApiExperimental(TestBase): @@ -399,7 +394,7 @@ def _setup_attrs(self, _setup_attrs_base): def _get_pool_count(self): response = self.client.get("/api/experimental/pools") assert response.status_code == 200 - return len(json.loads(response.data.decode("utf-8"))) + return len(response.json()) def test_get_pool(self): response = self.client.get( @@ -407,18 +402,18 @@ def test_get_pool(self): ) self.assert_deprecated(response) assert response.status_code == 200 - assert json.loads(response.data.decode("utf-8")) == self.pool.to_json() + assert response.json() == self.pool.to_json() def test_get_pool_non_existing(self): response = self.client.get("/api/experimental/pools/foo") assert response.status_code == 404 - assert json.loads(response.data.decode("utf-8"))["error"] == "Pool 'foo' doesn't exist" + assert response.json()["error"] == "Pool 'foo' doesn't exist" def test_get_pools(self): response = self.client.get("/api/experimental/pools") self.assert_deprecated(response) assert response.status_code == 200 - pools = json.loads(response.data.decode("utf-8")) + pools = response.json() assert len(pools) == self.TOTAL_POOL_COUNT for i, pool in enumerate(sorted(pools, key=lambda p: p["pool"])): assert pool == self.pools[i].to_json() @@ -433,11 +428,10 @@ def test_create_pool(self): "description": "", } ), - content_type="application/json", ) self.assert_deprecated(response) assert response.status_code == 200 - pool = json.loads(response.data.decode("utf-8")) + pool = response.json() assert pool["pool"] == "foo" assert pool["slots"] == 1 assert pool["description"] == "" @@ -455,10 +449,9 @@ def test_create_pool_with_bad_name(self): "description": "", } ), - content_type="application/json", ) assert response.status_code == 400 - assert json.loads(response.data.decode("utf-8"))["error"] == "Pool name shouldn't be empty" + assert response.json()["error"] == "Pool name shouldn't be empty" assert self._get_pool_count() == self.TOTAL_POOL_COUNT def test_delete_pool(self): @@ -467,7 +460,7 @@ def test_delete_pool(self): ) self.assert_deprecated(response) assert response.status_code == 200 - assert json.loads(response.data.decode("utf-8")) == self.pool.to_json() + assert response.json() == self.pool.to_json() assert self._get_pool_count() == self.TOTAL_POOL_COUNT - 1 def test_delete_pool_non_existing(self): @@ -475,7 +468,7 @@ def test_delete_pool_non_existing(self): "/api/experimental/pools/foo", ) assert response.status_code == 404 - assert json.loads(response.data.decode("utf-8"))["error"] == "Pool 'foo' doesn't exist" + assert response.json()["error"] == "Pool 'foo' doesn't exist" def test_delete_default_pool(self): clear_db_pools() @@ -483,4 +476,4 @@ def test_delete_default_pool(self): "/api/experimental/pools/default_pool", ) assert response.status_code == 400 - assert json.loads(response.data.decode("utf-8"))["error"] == "default_pool cannot be deleted" + assert response.json()["error"] == "default_pool cannot be deleted" diff --git a/tests/www/test_app.py b/tests/www/test_app.py index 1e7bd67c9ae04..58ce32acfea7c 100644 --- a/tests/www/test_app.py +++ b/tests/www/test_app.py @@ -54,8 +54,8 @@ def setup_class(cls) -> None: ) @dont_initialize_flask_app_submodules def test_should_respect_proxy_fix(self): - app = application.cached_app(testing=True) - app.url_map.add(Rule("/debug", endpoint="debug")) + flask_app = application.cached_app(testing=True) + flask_app.url_map.add(Rule("/debug", endpoint="debug")) def debug_view(): from flask import request @@ -68,7 +68,7 @@ def debug_view(): return Response("success") - app.view_functions["debug"] = debug_view + flask_app.view_functions["debug"] = debug_view new_environ = { "PATH_INFO": "/debug", @@ -82,7 +82,7 @@ def debug_view(): } environ = create_environ(environ_overrides=new_environ) - response = Response.from_app(app, environ) + response = Response.from_app(flask_app, environ) assert b"success" == response.get_data() assert response.status_code == 200 @@ -224,16 +224,16 @@ def debug_view(): ) @dont_initialize_flask_app_submodules def test_should_set_permanent_session_timeout(self): - app = application.cached_app(testing=True) - assert app.config["PERMANENT_SESSION_LIFETIME"] == timedelta(minutes=3600) + flask_app = application.cached_app(testing=True) + assert flask_app.config["PERMANENT_SESSION_LIFETIME"] == timedelta(minutes=3600) @conf_vars({("webserver", "cookie_samesite"): ""}) @dont_initialize_flask_app_submodules def test_correct_default_is_set_for_cookie_samesite(self): """An empty 'cookie_samesite' should be corrected to 'Lax' with a deprecation warning.""" with pytest.deprecated_call(): - app = application.cached_app(testing=True) - assert app.config["SESSION_COOKIE_SAMESITE"] == "Lax" + flask_app = application.cached_app(testing=True) + assert flask_app.config["SESSION_COOKIE_SAMESITE"] == "Lax" @pytest.mark.parametrize( "hash_method, result", @@ -282,5 +282,5 @@ def test_app_can_json_serialize_k8s_pod(): k8s = pytest.importorskip("kubernetes.client.models") pod = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")])) - app = application.cached_app(testing=True) - assert app.json.dumps(pod) == '{"spec": {"containers": [{"name": "base"}]}}' + flask_app = application.cached_app(testing=True) + assert flask_app.json.dumps(pod) == '{"spec": {"containers": [{"name": "base"}]}}' diff --git a/tests/www/test_auth.py b/tests/www/test_auth.py index f21973a8b6782..0c67aa40c15f2 100644 --- a/tests/www/test_auth.py +++ b/tests/www/test_auth.py @@ -101,7 +101,7 @@ def test_has_access_no_details_when_not_logged_in( auth_manager.get_url_login.return_value = "login_url" mock_get_auth_manager.return_value = auth_manager - with app.test_request_context(): + with app.app.test_request_context(): result = getattr(auth, decorator_name)("GET")(self.method_test)() mock_call.assert_not_called() @@ -171,7 +171,7 @@ def test_has_access_with_details_when_unauthorized( setattr(auth_manager, is_authorized_method_name, is_authorized_method) mock_get_auth_manager.return_value = auth_manager - with app.test_request_context(): + with app.app.test_request_context(): result = getattr(auth, decorator_name)("GET")(self.method_test)(None, items) mock_call.assert_not_called() @@ -215,7 +215,7 @@ def test_has_access_dag_entities_when_unauthorized(self, mock_get_auth_manager, mock_get_auth_manager.return_value = auth_manager items = [Mock(dag_id="dag_1"), Mock(dag_id="dag_2")] - with app.test_request_context(): + with app.app.test_request_context(): result = auth.has_access_dag_entities("GET", dag_access_entity)(self.method_test)(None, items) mock_call.assert_not_called() @@ -231,7 +231,7 @@ def test_has_access_dag_entities_when_logged_out(self, mock_get_auth_manager, ap mock_get_auth_manager.return_value = auth_manager items = [Mock(dag_id="dag_1"), Mock(dag_id="dag_2")] - with app.test_request_context(): + with app.app.test_request_context(): result = auth.has_access_dag_entities("GET", dag_access_entity)(self.method_test)(None, items) mock_call.assert_not_called() diff --git a/tests/www/test_security_manager.py b/tests/www/test_security_manager.py index ff66864188270..38b3e0625a0e7 100644 --- a/tests/www/test_security_manager.py +++ b/tests/www/test_security_manager.py @@ -35,12 +35,12 @@ @pytest.fixture def app(): - return application.create_app(testing=True) + return application.create_connexion_app(testing=True) @pytest.fixture def app_builder(app): - return app.appbuilder + return app.app.appbuilder @pytest.fixture diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py index a90d9246998d6..a5e4e1835d9d7 100644 --- a/tests/www/test_utils.py +++ b/tests/www/test_utils.py @@ -230,7 +230,7 @@ def test_task_instance_link(self): ) ) - assert "%3Ca%261%3E" in html + assert "%3Ca&1%3E" in html assert "%3Cb2%3E" in html assert "map_index" in html assert "" not in html @@ -249,7 +249,7 @@ def test_dag_link(self): with cached_app(testing=True).test_request_context(): html = str(utils.dag_link({"dag_id": "", "execution_date": datetime.now()})) - assert "%3Ca%261%3E" in html + assert "%3Ca&1%3E" in html assert "" not in html @pytest.mark.db_test @@ -272,7 +272,7 @@ def test_dag_run_link(self): utils.dag_run_link({"dag_id": "", "run_id": "", "execution_date": datetime.now()}) ) - assert "%3Ca%261%3E" in html + assert "%3Ca&1%3E" in html assert "%3Cb2%3E" in html assert "" not in html assert "" not in html diff --git a/tests/www/views/conftest.py b/tests/www/views/conftest.py index 821f541ef0c43..e31205126c0a0 100644 --- a/tests/www/views/conftest.py +++ b/tests/www/views/conftest.py @@ -26,11 +26,16 @@ from airflow import settings from airflow.models import DagBag -from airflow.www.app import create_app +from airflow.www.app import create_connexion_app from tests.test_utils.api_connexion_utils import delete_user from tests.test_utils.config import conf_vars from tests.test_utils.decorators import dont_initialize_flask_app_submodules -from tests.test_utils.www import client_with_login, client_without_login, client_without_login_as_admin +from tests.test_utils.www import ( + client_with_login, + client_without_login, + client_without_login_as_admin, + flask_client_with_login, +) @pytest.fixture(autouse=True, scope="module") @@ -52,6 +57,7 @@ def app(examples_dag_bag): @dont_initialize_flask_app_submodules( skip_all_except=[ "init_api_connexion", + "init_api_error_handlers", "init_appbuilder", "init_appbuilder_links", "init_appbuilder_views", @@ -64,14 +70,14 @@ def app(examples_dag_bag): ) def factory(): with conf_vars({("fab", "auth_rate_limited"): "False"}): - return create_app(testing=True) + return create_connexion_app(testing=True) app = factory() - app.config["WTF_CSRF_ENABLED"] = False - app.dag_bag = examples_dag_bag - app.jinja_env.undefined = jinja2.StrictUndefined + app.app.config["WTF_CSRF_ENABLED"] = False + app.app.dag_bag = examples_dag_bag + app.app.jinja_env.undefined = jinja2.StrictUndefined - security_manager = app.appbuilder.sm + security_manager = app.app.appbuilder.sm test_users = [ { @@ -107,7 +113,7 @@ def factory(): yield app for user_dict in test_users: - delete_user(app, user_dict["username"]) + delete_user(app.app, user_dict["username"]) @pytest.fixture @@ -115,6 +121,11 @@ def admin_client(app): return client_with_login(app, username="test_admin", password="test_admin") +@pytest.fixture +def flask_admin_client(app): + return flask_client_with_login(app, username="test_admin", password="test_admin") + + @pytest.fixture def viewer_client(app): return client_with_login(app, username="test_viewer", password="test_viewer") @@ -125,6 +136,11 @@ def user_client(app): return client_with_login(app, username="test_user", password="test_user") +@pytest.fixture +def flask_user_client(app): + return flask_client_with_login(app, username="test_user", password="test_user") + + @pytest.fixture def anonymous_client(app): return client_without_login(app) @@ -132,7 +148,12 @@ def anonymous_client(app): @pytest.fixture def anonymous_client_as_admin(app): - return client_without_login_as_admin(app) + return client_without_login_as_admin(app.app) + + +@pytest.fixture +def admin_flask_client(app): + return flask_client_with_login(app, username="test_admin", password="test_admin") class _TemplateWithContext(NamedTuple): @@ -198,11 +219,11 @@ def manager() -> Generator[list[_TemplateWithContext], None, None]: def record(sender, template, context, **extra): recorded.append(_TemplateWithContext(template, context)) - flask.template_rendered.connect(record, app) # type: ignore + flask.template_rendered.connect(record, app.app) # type: ignore try: yield recorded finally: - flask.template_rendered.disconnect(record, app) # type: ignore + flask.template_rendered.disconnect(record, app.app) # type: ignore assert recorded, "Failed to catch the templates" diff --git a/tests/www/views/test_anonymous_as_admin_role.py b/tests/www/views/test_anonymous_as_admin_role.py index b7603d1eae5bb..64ce1b1a42592 100644 --- a/tests/www/views/test_anonymous_as_admin_role.py +++ b/tests/www/views/test_anonymous_as_admin_role.py @@ -55,8 +55,9 @@ def factory(**values): def test_delete_pool_anonymous_user_no_role(anonymous_client, pool_factory): pool = pool_factory() resp = anonymous_client.post(f"pool/delete/{pool.id}") - assert 302 == resp.status_code - assert f"/login/?next={quote_plus(f'http://localhost/pool/delete/{pool.id}')}" == resp.headers["Location"] + expected_path = f"/login/?next={quote_plus(f'http://testserver/pool/delete/{pool.id}', safe='/:?')}" + assert expected_path.encode("utf-8") == resp.url.raw_path + assert 200 == resp.status_code def test_delete_pool_anonymous_user_as_admin(anonymous_client_as_admin, pool_factory): diff --git a/tests/www/views/test_session.py b/tests/www/views/test_session.py index 0ec219aaeb4b3..e8fe4c2bb56e3 100644 --- a/tests/www/views/test_session.py +++ b/tests/www/views/test_session.py @@ -29,7 +29,7 @@ def get_session_cookie(client): - return next((cookie for cookie in client.cookie_jar if cookie.name == "session"), None) + return next((cookie for cookie in client.cookies.jar if cookie.name == "session"), None) def test_session_cookie_created_on_login(user_client): @@ -40,13 +40,25 @@ def test_session_inaccessible_after_logout(user_client): session_cookie = get_session_cookie(user_client) assert session_cookie is not None + # correctly logs in + resp = user_client.get("/home") + assert resp.status_code == 200 + assert resp.url.raw_path == b"/home" + + # Same with cookies overwritten + user_client.get("/home", cookies={"session": session_cookie.value}) + assert resp.status_code == 200 + assert resp.url.raw_path == b"/home" + + # logs out resp = user_client.post("/logout/") - assert resp.status_code == 302 + assert resp.status_code == 200 + assert resp.url.raw_path == b"/login/?next=http://testserver/home" - # Try to access /home with the session cookie from earlier - user_client.set_cookie("session", session_cookie.value) - user_client.get("/home/") - assert resp.status_code == 302 + # Try to access /home with the session cookie from earlier call + user_client.get("/home", cookies={"session": session_cookie.value}) + assert resp.status_code == 200 + assert resp.url.raw_path == b"/login/?next=http://testserver/home" def test_invalid_session_backend_option(): @@ -64,7 +76,7 @@ def test_invalid_session_backend_option(): ) def poorly_configured_app_factory(): with conf_vars({("webserver", "session_backend"): "invalid_value_for_session_backend"}): - return app.create_app(testing=True) + return app.create_connexion_app(testing=True) expected_exc_regex = ( "^Unrecognized session backend specified in web_server_session_backend: " @@ -78,14 +90,16 @@ def test_session_id_rotates(app, user_client): old_session_cookie = get_session_cookie(user_client) assert old_session_cookie is not None - resp = user_client.post("/logout/") - assert resp.status_code == 302 + resp = user_client.post("/logout/", follow_redirects=True) + assert resp.status_code == 200 patch_path = "airflow.providers.fab.auth_manager.security_manager.override.check_password_hash" with mock.patch(patch_path) as check_password_hash: check_password_hash.return_value = True - resp = user_client.post("/login/", data={"username": "test_user", "password": "test_user"}) - assert resp.status_code == 302 + resp = user_client.post( + "/login/", data={"username": "test_user", "password": "test_user"}, follow_redirects=True + ) + assert resp.status_code == 200 new_session_cookie = get_session_cookie(user_client) assert new_session_cookie is not None @@ -93,17 +107,16 @@ def test_session_id_rotates(app, user_client): def test_check_active_user(app, user_client): - user = app.appbuilder.sm.find_user(username="test_user") + user = app.app.appbuilder.sm.find_user(username="test_user") user.active = False resp = user_client.get("/home") - assert resp.status_code == 302 - assert "/login/?next=http%3A%2F%2Flocalhost%2Fhome" in resp.headers.get("Location") + assert resp.url.raw_path == b"/home" -def test_check_deactivated_user_redirected_to_login(app, user_client): - with app.test_request_context(): - user = app.appbuilder.sm.find_user(username="test_user") +def test_check_deactivated_user_redirected_to_login(app, flask_user_client): + with app.app.test_request_context(): + user = app.app.appbuilder.sm.find_user(username="test_user") user.active = False - resp = user_client.get("/home", follow_redirects=True) + resp = flask_user_client.get("/home", follow_redirects=True) assert resp.status_code == 200 assert "/login" in resp.request.url diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py index 067f556bb7fee..114540d907c58 100644 --- a/tests/www/views/test_views.py +++ b/tests/www/views/test_views.py @@ -94,7 +94,7 @@ def test_redoc_should_render_template(capture_templates, admin_client): assert templates[0].name == "airflow/redoc.html" assert templates[0].local_context == { "config_test_connection": "Disabled", - "openapi_spec_url": "/api/v1/openapi.yaml", + "openapi_spec_url": "api/v1/openapi.yaml", "rest_api_enabled": True, "get_docs_url": get_docs_url, "excluded_events_raw": "", @@ -229,7 +229,7 @@ def test_task_dag_id_equals_filter(admin_client, url, content): @mock.patch("airflow.www.views.url_for") def test_get_safe_url(mock_url_for, app, test_url, expected_url): mock_url_for.return_value = "/home" - with app.test_request_context(base_url="http://localhost:8080"): + with app.app.test_request_context(base_url="http://localhost:8080"): assert get_safe_url(test_url) == expected_url @@ -237,7 +237,7 @@ def test_get_safe_url(mock_url_for, app, test_url, expected_url): def test_app(): from airflow.www import app - return app.create_app(testing=True) + return app.create_connexion_app(testing=True) def test_mark_task_instance_state(test_app): @@ -297,10 +297,10 @@ def get_task_instance(session, task): session.commit() - test_app.dag_bag = DagBag(dag_folder="/dev/null", include_examples=False) - test_app.dag_bag.bag_dag(dag=dag, root_dag=dag) + test_app.app.dag_bag = DagBag(dag_folder="/dev/null", include_examples=False) + test_app.app.dag_bag.bag_dag(dag=dag, root_dag=dag) - with test_app.test_request_context(): + with test_app.app.test_request_context(): view = Airflow() view._mark_task_instance_state( @@ -399,10 +399,10 @@ def get_task_instance(session, task): session.commit() - test_app.dag_bag = DagBag(dag_folder="/dev/null", include_examples=False) - test_app.dag_bag.bag_dag(dag=dag, root_dag=dag) + test_app.app.dag_bag = DagBag(dag_folder="/dev/null", include_examples=False) + test_app.app.dag_bag.bag_dag(dag=dag, root_dag=dag) - with test_app.test_request_context(): + with test_app.app.test_request_context(): view = Airflow() view._mark_task_group_state( @@ -486,7 +486,9 @@ def test_get_task_stats_from_query(): assert data == expected_data -INVALID_DATETIME_RESPONSE = re.compile(r"Invalid datetime: &#x?\d+;invalid&#x?\d+;") +# After upgrading to connexion v3, test client returns JSON response instead of HTML response. +# Returned JSON does not contain the previous pattern. +INVALID_DATETIME_RESPONSE = re.compile(r"Invalid datetime: 'invalid'") @pytest.mark.parametrize( @@ -525,9 +527,8 @@ def test_get_task_stats_from_query(): def test_invalid_dates(app, admin_client, url, content): """Test invalid date format doesn't crash page.""" resp = admin_client.get(url, follow_redirects=True) - assert resp.status_code == 400 - assert re.search(content, resp.get_data().decode()) + assert re.search(content, resp.json()["detail"]) @pytest.mark.parametrize("enabled, dags_count", [(False, 5), (True, 5)]) diff --git a/tests/www/views/test_views_acl.py b/tests/www/views/test_views_acl.py index 51bf56adac854..2742ac3400dfc 100644 --- a/tests/www/views/test_views_acl.py +++ b/tests/www/views/test_views_acl.py @@ -18,7 +18,6 @@ from __future__ import annotations import datetime -import json import urllib.parse import pytest @@ -32,7 +31,12 @@ from airflow.www.views import FILTER_STATUS_COOKIE from tests.test_utils.api_connexion_utils import create_user_scope from tests.test_utils.db import clear_db_runs -from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login +from tests.test_utils.www import ( + check_content_in_response, + check_content_not_in_response, + client_with_login, + flask_client_with_login, +) pytestmark = pytest.mark.db_test @@ -81,7 +85,7 @@ @pytest.fixture(scope="module") def acl_app(app): - security_manager = app.appbuilder.sm + security_manager = app.app.appbuilder.sm for username, (role_name, kwargs) in USER_DATA.items(): if not security_manager.find_user(username=username): role = security_manager.add_role(role_name) @@ -138,7 +142,7 @@ def reset_dagruns(): @pytest.fixture(autouse=True) def init_dagruns(acl_app, reset_dagruns): - acl_app.dag_bag.get_dag("example_bash_operator").create_dagrun( + acl_app.app.dag_bag.get_dag("example_bash_operator").create_dagrun( run_id=DEFAULT_RUN_ID, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -146,7 +150,7 @@ def init_dagruns(acl_app, reset_dagruns): start_date=timezone.utcnow(), state=State.RUNNING, ) - acl_app.dag_bag.get_dag("example_subdag_operator").create_dagrun( + acl_app.app.dag_bag.get_dag("example_subdag_operator").create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, start_date=timezone.utcnow(), @@ -159,7 +163,9 @@ def init_dagruns(acl_app, reset_dagruns): @pytest.fixture def dag_test_client(acl_app): - return client_with_login(acl_app, username="dag_test", password="dag_test") + return client_with_login( + acl_app, expected_path=b"/login/?next=/home", username="dag_test", password="dag_test" + ) @pytest.fixture @@ -179,7 +185,7 @@ def all_dag_user_client(acl_app): @pytest.fixture(scope="module") def user_edit_one_dag(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_edit_one_dag", role_name="role_edit_one_dag", permissions=[ @@ -192,8 +198,8 @@ def user_edit_one_dag(acl_app): @pytest.mark.usefixtures("user_edit_one_dag") def test_permission_exist(acl_app): - perms_views = acl_app.appbuilder.sm.get_resource_permissions( - acl_app.appbuilder.sm.get_resource("DAG:example_bash_operator"), + perms_views = acl_app.app.appbuilder.sm.get_resource_permissions( + acl_app.app.appbuilder.sm.get_resource("DAG:example_bash_operator"), ) assert len(perms_views) == 3 @@ -205,7 +211,7 @@ def test_permission_exist(acl_app): @pytest.mark.usefixtures("user_edit_one_dag") def test_role_permission_associate(acl_app): - test_role = acl_app.appbuilder.sm.find_role("role_edit_one_dag") + test_role = acl_app.app.appbuilder.sm.find_role("role_edit_one_dag") perms = {str(perm) for perm in test_role.permissions} assert "can edit on DAG:example_bash_operator" in perms assert "can read on DAG:example_bash_operator" in perms @@ -214,7 +220,7 @@ def test_role_permission_associate(acl_app): @pytest.fixture(scope="module") def user_all_dags(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags", role_name="role_all_dags", permissions=[ @@ -234,6 +240,15 @@ def client_all_dags(acl_app, user_all_dags): ) +@pytest.fixture +def flask_client_all_dags(acl_app, user_all_dags): + return flask_client_with_login( + acl_app, + username="user_all_dags", + password="user_all_dags", + ) + + def test_index_for_all_dag_user(client_all_dags): # The all dag user can access/view all dags. resp = client_all_dags.get("/", follow_redirects=True) @@ -265,7 +280,7 @@ def test_dag_autocomplete_success(client_all_dags): {"name": "tutorial_taskflow_api_virtualenv", "type": "dag", "dag_display_name": None}, ] - assert resp.json == expected + assert resp.json() == expected @pytest.mark.parametrize( @@ -282,13 +297,13 @@ def test_dag_autocomplete_empty(client_all_dags, query, expected): if query is not None: url = f"{url}?query={query}" resp = client_all_dags.get(url, follow_redirects=False) - assert resp.json == expected + assert resp.json() == expected def test_dag_autocomplete_dag_display_name(client_all_dags): url = "dagmodel/autocomplete?query=Sample" resp = client_all_dags.get(url, follow_redirects=False) - assert resp.json == [ + assert resp.json() == [ {"name": "example_display_name", "type": "dag", "dag_display_name": "Sample DAG with Display Name"} ] @@ -312,10 +327,11 @@ def setup_paused_dag(): ], ) @pytest.mark.usefixtures("setup_paused_dag") -def test_dag_autocomplete_status(client_all_dags, status, expected, unexpected): - with client_all_dags.session_transaction() as flask_session: +def test_dag_autocomplete_status(flask_client_all_dags, status, expected, unexpected): + with flask_client_all_dags.session_transaction() as flask_session: flask_session[FILTER_STATUS_COOKIE] = status - resp = client_all_dags.get( + + resp = flask_client_all_dags.get( "dagmodel/autocomplete?query=example_branch_", follow_redirects=False, ) @@ -326,7 +342,7 @@ def test_dag_autocomplete_status(client_all_dags, status, expected, unexpected): @pytest.fixture(scope="module") def user_all_dags_dagruns(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_dagruns", role_name="role_all_dags_dagruns", permissions=[ @@ -350,7 +366,7 @@ def client_all_dags_dagruns(acl_app, user_all_dags_dagruns): def test_dag_stats_success(client_all_dags_dagruns): resp = client_all_dags_dagruns.post("dag_stats", follow_redirects=True) check_content_in_response("example_bash_operator", resp) - assert set(next(iter(resp.json.items()))[1][0].keys()) == {"state", "count"} + assert set(next(iter(resp.json().items()))[1][0].keys()) == {"state", "count"} def test_task_stats_failure(dag_test_client): @@ -367,7 +383,7 @@ def test_dag_stats_success_for_all_dag_user(client_all_dags_dagruns): @pytest.fixture(scope="module") def user_all_dags_dagruns_tis(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_dagruns_tis", role_name="role_all_dags_dagruns_tis", permissions=[ @@ -420,7 +436,7 @@ def test_task_stats_success( assert resp.status_code == 200 for dag_id in unexpected_dag_ids: check_content_not_in_response(dag_id, resp) - stats = json.loads(resp.data.decode()) + stats = resp.json() for dag_id in dags_to_run: assert dag_id in stats @@ -428,7 +444,7 @@ def test_task_stats_success( @pytest.fixture(scope="module") def user_all_dags_codes(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_codes", role_name="role_all_dags_codes", permissions=[ @@ -484,7 +500,7 @@ def test_dag_details_success_for_all_dag_user(client_all_dags_dagruns, dag_id): @pytest.fixture(scope="module") def user_all_dags_tis(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_tis", role_name="role_all_dags_tis", permissions=[ @@ -509,7 +525,7 @@ def client_all_dags_tis(acl_app, user_all_dags_tis): @pytest.fixture(scope="module") def user_all_dags_tis_xcom(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_tis_xcom", role_name="role_all_dags_tis_xcom", permissions=[ @@ -534,7 +550,7 @@ def client_all_dags_tis_xcom(acl_app, user_all_dags_tis_xcom): @pytest.fixture(scope="module") def user_dags_tis_logs(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_dags_tis_logs", role_name="role_dags_tis_logs", permissions=[ @@ -683,7 +699,7 @@ def test_blocked_success_when_selecting_dags( assert resp.status_code == 200 for dag_id in unexpected_dag_ids: check_content_not_in_response(dag_id, resp) - blocked_dags = {blocked["dag_id"] for blocked in json.loads(resp.data.decode())} + blocked_dags = {blocked["dag_id"] for blocked in resp.json()} for dag_id in dags_to_block: assert dag_id in blocked_dags @@ -691,7 +707,7 @@ def test_blocked_success_when_selecting_dags( @pytest.fixture(scope="module") def user_all_dags_edit_tis(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_edit_tis", role_name="role_all_dags_edit_tis", permissions=[ @@ -735,7 +751,7 @@ def test_paused_post_success(dag_test_client): @pytest.fixture(scope="module") def user_only_dags_tis(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_only_dags_tis", role_name="role_only_dags_tis", permissions=[ @@ -767,7 +783,7 @@ def test_success_fail_for_read_only_task_instance_access(client_only_dags_tis): past="false", ) resp = client_only_dags_tis.post("success", data=form) - check_content_not_in_response("Wait a minute", resp, resp_code=302) + check_content_not_in_response("Wait a minute", resp, resp_code=200) GET_LOGS_WITH_METADATA_URL = ( @@ -798,7 +814,7 @@ def test_get_logs_with_metadata_failure(dag_faker_client): @pytest.fixture(scope="module") def user_no_roles(acl_app): - with create_user_scope(acl_app, username="no_roles_user", role_name="no_roles_user_role") as user: + with create_user_scope(acl_app.app, username="no_roles_user", role_name="no_roles_user_role") as user: user.roles = [] yield user @@ -815,7 +831,7 @@ def client_no_roles(acl_app, user_no_roles): @pytest.fixture(scope="module") def user_no_permissions(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="no_permissions_user", role_name="no_permissions_role", ) as user: @@ -853,7 +869,7 @@ def test_no_roles_permissions(request, client, url, status_code, expected_conten @pytest.fixture(scope="module") def user_dag_level_access_with_ti_edit(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_dag_level_access_with_ti_edit", role_name="role_dag_level_access_with_ti_edit", permissions=[ @@ -895,7 +911,7 @@ def test_success_edit_ti_with_dag_level_access_only(client_dag_level_access_with @pytest.fixture(scope="module") def user_ti_edit_without_dag_level_access(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_ti_edit_without_dag_level_access", role_name="role_ti_edit_without_dag_level_access", permissions=[ diff --git a/tests/www/views/test_views_base.py b/tests/www/views/test_views_base.py index a125ca2d72835..889e81b4c4b7b 100644 --- a/tests/www/views/test_views_base.py +++ b/tests/www/views/test_views_base.py @@ -18,7 +18,6 @@ from __future__ import annotations import datetime -import json import pytest @@ -36,8 +35,9 @@ def test_index_redirect(admin_client): resp = admin_client.get("/") - assert resp.status_code == 302 - assert "/home" in resp.headers.get("Location") + # Starlette TestCliente used by connexion v3 responds after following the redirect + # therefore, the status code is 200 + assert resp.url.raw_path == b"/home" resp = admin_client.get("/", follow_redirects=True) check_content_in_response("DAGs", resp) @@ -57,7 +57,7 @@ def test_doc_urls(admin_client, monkeypatch): resp = admin_client.get("/", follow_redirects=True) check_content_in_response("!!DOCS_URL!!", resp) - check_content_in_response("/api/v1/ui", resp) + check_content_in_response("/swagger", resp) @pytest.fixture @@ -122,7 +122,7 @@ def test_health(request, admin_client, heartbeat): # Load the corresponding fixture by name. scheduler_status, last_scheduler_heartbeat = request.getfixturevalue(heartbeat) resp = admin_client.get("health", follow_redirects=True) - resp_json = json.loads(resp.data.decode("utf-8")) + resp_json = resp.json() assert "healthy" == resp_json["metadatabase"]["status"] assert scheduler_status == resp_json["scheduler"]["status"] assert last_scheduler_heartbeat == resp_json["scheduler"]["latest_scheduler_heartbeat"] @@ -150,8 +150,8 @@ def test_roles_read_unauthorized(viewer_client): @pytest.fixture(scope="module") def delete_role_if_exists(app): def func(role_name): - if app.appbuilder.sm.find_role(role_name): - app.appbuilder.sm.delete_role(role_name) + if app.app.appbuilder.sm.find_role(role_name): + app.app.appbuilder.sm.delete_role(role_name) return func @@ -167,32 +167,32 @@ def non_exist_role_name(delete_role_if_exists): @pytest.fixture def exist_role_name(app, delete_role_if_exists): role_name = "test_roles_create_role_new" - app.appbuilder.sm.add_role(role_name) + app.app.appbuilder.sm.add_role(role_name) yield role_name delete_role_if_exists(role_name) @pytest.fixture def exist_role(app, exist_role_name): - return app.appbuilder.sm.find_role(exist_role_name) + return app.app.appbuilder.sm.find_role(exist_role_name) def test_roles_create(app, admin_client, non_exist_role_name): admin_client.post("roles/add", data={"name": non_exist_role_name}, follow_redirects=True) - assert app.appbuilder.sm.find_role(non_exist_role_name) is not None + assert app.app.appbuilder.sm.find_role(non_exist_role_name) is not None def test_roles_create_unauthorized(app, viewer_client, non_exist_role_name): resp = viewer_client.post("roles/add", data={"name": non_exist_role_name}, follow_redirects=True) check_content_in_response("Access is Denied", resp) - assert app.appbuilder.sm.find_role(non_exist_role_name) is None + assert app.app.appbuilder.sm.find_role(non_exist_role_name) is None def test_roles_edit(app, admin_client, non_exist_role_name, exist_role): admin_client.post( f"roles/edit/{exist_role.id}", data={"name": non_exist_role_name}, follow_redirects=True ) - updated_role = app.appbuilder.sm.find_role(non_exist_role_name) + updated_role = app.app.appbuilder.sm.find_role(non_exist_role_name) assert exist_role.id == updated_role.id @@ -201,19 +201,19 @@ def test_roles_edit_unauthorized(app, viewer_client, non_exist_role_name, exist_ f"roles/edit/{exist_role.id}", data={"name": non_exist_role_name}, follow_redirects=True ) check_content_in_response("Access is Denied", resp) - assert app.appbuilder.sm.find_role(exist_role_name) - assert app.appbuilder.sm.find_role(non_exist_role_name) is None + assert app.app.appbuilder.sm.find_role(exist_role_name) + assert app.app.appbuilder.sm.find_role(non_exist_role_name) is None def test_roles_delete(app, admin_client, exist_role_name, exist_role): admin_client.post(f"roles/delete/{exist_role.id}", follow_redirects=True) - assert app.appbuilder.sm.find_role(exist_role_name) is None + assert app.app.appbuilder.sm.find_role(exist_role_name) is None def test_roles_delete_unauthorized(app, viewer_client, exist_role, exist_role_name): resp = viewer_client.post(f"roles/delete/{exist_role.id}", follow_redirects=True) check_content_in_response("Access is Denied", resp) - assert app.appbuilder.sm.find_role(exist_role_name) + assert app.app.appbuilder.sm.find_role(exist_role_name) @pytest.mark.parametrize( @@ -253,7 +253,7 @@ def test_views_get(request, url, client, content): def _check_task_stats_json(resp): - return set(next(iter(resp.json.items()))[1][0]) == {"state", "count"} + return set(next(iter(resp.json().items()))[1][0]) == {"state", "count"} @pytest.mark.parametrize( @@ -281,7 +281,7 @@ def test_views_post(admin_client, url, check_response): ids=["my-viewer", "pk-admin", "pk-viewer"], ) def test_resetmypasswordview_edit(app, request, url, client, content, username): - user = app.appbuilder.sm.find_user(username) + user = app.app.appbuilder.sm.find_user(username) resp = request.getfixturevalue(client).post( url.format(user.id), data={"password": "blah", "conf_password": "blah"}, follow_redirects=True ) @@ -321,13 +321,13 @@ def test_views_post_access_denied(viewer_client, url): @pytest.fixture def non_exist_username(app): username = "fake_username" - user = app.appbuilder.sm.find_user(username) + user = app.app.appbuilder.sm.find_user(username) if user is not None: - app.appbuilder.sm.del_register_user(user) + app.app.appbuilder.sm.del_register_user(user) yield username - user = app.appbuilder.sm.find_user(username) + user = app.app.appbuilder.sm.find_user(username) if user is not None: - app.appbuilder.sm.del_register_user(user) + app.app.appbuilder.sm.del_register_user(user) def test_create_user(app, admin_client, non_exist_username): @@ -345,13 +345,13 @@ def test_create_user(app, admin_client, non_exist_username): follow_redirects=True, ) check_content_in_response("Added Row", resp) - assert app.appbuilder.sm.find_user(non_exist_username) + assert app.app.appbuilder.sm.find_user(non_exist_username) @pytest.fixture def exist_username(app, exist_role): username = "test_edit_user_user" - app.appbuilder.sm.add_user( + app.app.appbuilder.sm.add_user( username, "first_name", "last_name", @@ -360,12 +360,12 @@ def exist_username(app, exist_role): password="password", ) yield username - if app.appbuilder.sm.find_user(username): - app.appbuilder.sm.del_register_user(username) + if app.app.appbuilder.sm.find_user(username): + app.app.appbuilder.sm.del_register_user(username) def test_edit_user(app, admin_client, exist_username): - user = app.appbuilder.sm.find_user(exist_username) + user = app.app.appbuilder.sm.find_user(exist_username) resp = admin_client.post( f"users/edit/{user.id}", data={"first_name": "new_first_name"}, @@ -375,7 +375,7 @@ def test_edit_user(app, admin_client, exist_username): def test_delete_user(app, admin_client, exist_username): - user = app.appbuilder.sm.find_user(exist_username) + user = app.app.appbuilder.sm.find_user(exist_username) resp = admin_client.post( f"users/delete/{user.id}", follow_redirects=True, @@ -419,5 +419,5 @@ def test_page_instance_name_with_markup(admin_client): @conf_vars(instance_name_with_markup_conf) def test_page_instance_name_with_markup_title(): - appbuilder = application.create_app(testing=True).appbuilder + appbuilder = application.create_connexion_app(testing=True).app.appbuilder assert appbuilder.app_name == "Bold Site Title Test" diff --git a/tests/www/views/test_views_blocked.py b/tests/www/views/test_views_blocked.py index c3e8cd4e88cf1..d0b44c77b6eb1 100644 --- a/tests/www/views/test_views_blocked.py +++ b/tests/www/views/test_views_blocked.py @@ -81,7 +81,7 @@ def test_blocked_subdag_success(admin_client, running_subdag): """ resp = admin_client.post("/blocked", data={"dag_ids": [running_subdag.dag_id]}) assert resp.status_code == 200 - assert resp.json == [ + assert resp.json() == [ { "dag_id": running_subdag.dag_id, "active_dag_run": 1, diff --git a/tests/www/views/test_views_cluster_activity.py b/tests/www/views/test_views_cluster_activity.py index 011b7aa071c99..354f4c607d599 100644 --- a/tests/www/views/test_views_cluster_activity.py +++ b/tests/www/views/test_views_cluster_activity.py @@ -96,7 +96,9 @@ def make_dag_runs(dag_maker, session, time_machine): time_machine.move_to("2023-07-02T00:00:00+00:00", tick=False) + session.commit() session.flush() + session.close() @pytest.mark.usefixtures("freeze_time_for_dagruns", "make_dag_runs") @@ -106,7 +108,7 @@ def test_historical_metrics_data(admin_client, session, time_machine): follow_redirects=True, ) assert resp.status_code == 200 - assert resp.json == { + assert resp.json() == { "dag_run_states": {"failed": 1, "queued": 0, "running": 1, "success": 1}, "dag_run_types": {"backfill": 0, "dataset_triggered": 1, "manual": 0, "scheduled": 2}, "task_instance_states": { @@ -135,7 +137,7 @@ def test_historical_metrics_data_date_filters(admin_client, session): follow_redirects=True, ) assert resp.status_code == 200 - assert resp.json == { + assert resp.json() == { "dag_run_states": {"failed": 1, "queued": 0, "running": 0, "success": 0}, "dag_run_types": {"backfill": 0, "dataset_triggered": 1, "manual": 0, "scheduled": 0}, "task_instance_states": { diff --git a/tests/www/views/test_views_connection.py b/tests/www/views/test_views_connection.py index a209cdfc2be8a..507d2d1e5afaf 100644 --- a/tests/www/views/test_views_connection.py +++ b/tests/www/views/test_views_connection.py @@ -424,7 +424,7 @@ def test_connection_form_widgets_testable_types(mock_pm_hooks, admin_client): assert ["first"] == ConnectionFormWidget().testable_connection_types -def test_process_form_invalid_extra_removed(admin_client): +def test_process_form_invalid_extra_removed(flask_admin_client): """ Test that when an invalid json `extra` is passed in the form, it is removed and _not_ saved over the existing extras. @@ -437,7 +437,7 @@ def test_process_form_invalid_extra_removed(admin_client): session.add(conn) data = {**conn_details, "extra": "Invalid"} - resp = admin_client.post("/connection/edit/1", data=data, follow_redirects=True) + resp = flask_admin_client.post("/connection/edit/1", data=data, follow_redirects=True) assert resp.status_code == 200 with create_session() as session: diff --git a/tests/www/views/test_views_custom_user_views.py b/tests/www/views/test_views_custom_user_views.py index ae6d0132827c2..75a7449ce5b85 100644 --- a/tests/www/views/test_views_custom_user_views.py +++ b/tests/www/views/test_views_custom_user_views.py @@ -28,7 +28,11 @@ from airflow.security import permissions from airflow.www import app as application from tests.test_utils.api_connexion_utils import create_user, delete_role -from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login +from tests.test_utils.www import ( + check_content_in_response, + check_content_not_in_response, + client_with_login, +) pytestmark = pytest.mark.db_test @@ -67,23 +71,24 @@ def setup_method(self): # an exception because app context teardown is removed and if even single request is run via app # it cannot be re-intialized again by passing it as constructor to SQLA # This makes the tests slightly slower (but they work with Flask 2.1 and 2.2 - self.app = application.create_app(testing=True) - self.appbuilder = self.app.appbuilder - self.app.config["WTF_CSRF_ENABLED"] = False + self.connexion_app = application.create_connexion_app(testing=True) + self.flask_app = self.connexion_app.app + self.appbuilder = self.flask_app.appbuilder + self.flask_app.config["WTF_CSRF_ENABLED"] = False self.security_manager = self.appbuilder.sm self.delete_roles() - self.db = SQLA(self.app) + self.db = SQLA(self.flask_app) - self.client = self.app.test_client() # type:ignore + self.client = self.connexion_app.test_client() # type:ignore def delete_roles(self): for role_name in ["role_edit_one_dag"]: - delete_role(self.app, role_name) + delete_role(self.flask_app, role_name) @pytest.mark.parametrize("url, _, expected_text", PERMISSIONS_TESTS_PARAMS) def test_user_model_view_with_access(self, url, expected_text, _): user_without_access = create_user( - self.app, + self.flask_app, username="no_access", role_name="role_no_access", permissions=[ @@ -91,7 +96,7 @@ def test_user_model_view_with_access(self, url, expected_text, _): ], ) client = client_with_login( - self.app, + self.connexion_app, username="no_access", password="no_access", ) @@ -101,14 +106,14 @@ def test_user_model_view_with_access(self, url, expected_text, _): @pytest.mark.parametrize("url, permission, expected_text", PERMISSIONS_TESTS_PARAMS) def test_user_model_view_without_access(self, url, permission, expected_text): user_with_access = create_user( - self.app, + self.flask_app, username="has_access", role_name="role_has_access", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), permission], ) client = client_with_login( - self.app, + self.connexion_app, username="has_access", password="has_access", ) @@ -117,22 +122,23 @@ def test_user_model_view_without_access(self, url, permission, expected_text): def test_user_model_view_without_delete_access(self): user_to_delete = create_user( - self.app, + self.flask_app, username="user_to_delete", role_name="user_to_delete", ) create_user( - self.app, + self.flask_app, username="no_access", role_name="role_no_access", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER), ], ) client = client_with_login( - self.app, + self.connexion_app, username="no_access", password="no_access", ) @@ -140,27 +146,29 @@ def test_user_model_view_without_delete_access(self): response = client.post(f"/users/delete/{user_to_delete.id}", follow_redirects=True) check_content_not_in_response("Deleted Row", response) - assert bool(self.security_manager.get_user_by_id(user_to_delete.id)) is True + response = client.get(f"/users/show/{user_to_delete.id}", follow_redirects=True) + assert response.status_code == 200 def test_user_model_view_with_delete_access(self): user_to_delete = create_user( - self.app, + self.flask_app, username="user_to_delete", role_name="user_to_delete", ) create_user( - self.app, + self.flask_app, username="has_access", role_name="role_has_access", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_USER), ], ) client = client_with_login( - self.app, + self.connexion_app, username="has_access", password="has_access", ) @@ -168,7 +176,8 @@ def test_user_model_view_with_delete_access(self): response = client.post(f"/users/delete/{user_to_delete.id}", follow_redirects=True) check_content_in_response("Deleted Row", response) check_content_not_in_response(user_to_delete.username, response) - assert bool(self.security_manager.get_user_by_id(user_to_delete.id)) is False + response = client.get(f"/users/show/{user_to_delete.id}", follow_redirects=True) + assert response.status_code == 404 # type: ignore[attr-defined] @@ -184,11 +193,12 @@ def setup_method(self): # an exception because app context teardown is removed and if even single request is run via app # it cannot be re-intialized again by passing it as constructor to SQLA # This makes the tests slightly slower (but they work with Flask 2.1 and 2.2 - self.app = application.create_app(testing=True) - self.appbuilder = self.app.appbuilder - self.app.config["WTF_CSRF_ENABLED"] = False + self.connexion_app = application.create_connexion_app(testing=True) + self.flask_app = self.connexion_app.app + self.appbuilder = self.flask_app.appbuilder + self.flask_app.config["WTF_CSRF_ENABLED"] = False self.security_manager = self.appbuilder.sm - self.interface = self.app.session_interface + self.interface = self.flask_app.session_interface self.model = self.interface.sql_session_model self.serializer = self.interface.serializer self.db = self.interface.db @@ -196,12 +206,12 @@ def setup_method(self): self.db.session.commit() self.db.session.flush() self.user_1 = create_user( - self.app, + self.flask_app, username="user_to_delete_1", role_name="user_to_delete", ) self.user_2 = create_user( - self.app, + self.flask_app, username="user_to_delete_2", role_name="user_to_delete", ) @@ -277,7 +287,7 @@ def test_refuse_delete(self, _mock_has_context, flash_mock): "airflow.providers.fab.auth_manager.security_manager.override.has_request_context", return_value=True ) def test_warn_securecookie(self, _mock_has_context, flash_mock): - self.app.session_interface = SecureCookieSessionInterface() + self.flask_app.session_interface = SecureCookieSessionInterface() self.security_manager.reset_password(self.user_1.id, "new_password") assert flash_mock.called assert ( @@ -309,7 +319,7 @@ def test_refuse_delete_cli(self, log_mock): @mock.patch("airflow.providers.fab.auth_manager.security_manager.override.log") def test_warn_securecookie_cli(self, log_mock): - self.app.session_interface = SecureCookieSessionInterface() + self.flask_app.session_interface = SecureCookieSessionInterface() self.security_manager.reset_password(self.user_1.id, "new_password") assert log_mock.warning.called assert ( diff --git a/tests/www/views/test_views_dagrun.py b/tests/www/views/test_views_dagrun.py index b7e048e0eaf21..705b6ff3d7ec6 100644 --- a/tests/www/views/test_views_dagrun.py +++ b/tests/www/views/test_views_dagrun.py @@ -25,22 +25,25 @@ from airflow.utils.session import create_session from airflow.www.views import DagRunModelView from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user -from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login +from tests.test_utils.www import ( + check_content_in_response, + check_content_not_in_response, + flask_client_with_login, +) from tests.www.views.test_views_tasks import _get_appbuilder_pk_string pytestmark = pytest.mark.db_test @pytest.fixture(scope="module") -def client_dr_without_dag_edit(app): +def flask_client_dr_without_dag_run_create(app): create_user( - app, - username="all_dr_permissions_except_dag_edit", - role_name="all_dr_permissions_except_dag_edit", + app.app, + username="all_dr_permissions_except_dag_run_create", + role_name="all_dr_permissions_except_dag_run_create", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), @@ -48,25 +51,26 @@ def client_dr_without_dag_edit(app): ], ) - yield client_with_login( + yield flask_client_with_login( app, - username="all_dr_permissions_except_dag_edit", - password="all_dr_permissions_except_dag_edit", + username="all_dr_permissions_except_dag_run_create", + password="all_dr_permissions_except_dag_run_create", ) - delete_user(app, username="all_dr_permissions_except_dag_edit") # type: ignore - delete_roles(app) + delete_user(app.app, username="all_dr_permissions_except_dag_run_create") # type: ignore + delete_roles(app.app) @pytest.fixture(scope="module") -def client_dr_without_dag_run_create(app): +def flask_client_dr_without_dag_edit(app): create_user( - app, - username="all_dr_permissions_except_dag_run_create", - role_name="all_dr_permissions_except_dag_run_create", + app.app, + username="all_dr_permissions_except_dag_edit", + role_name="all_dr_permissions_except_dag_edit", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), @@ -74,14 +78,14 @@ def client_dr_without_dag_run_create(app): ], ) - yield client_with_login( + yield flask_client_with_login( app, - username="all_dr_permissions_except_dag_run_create", - password="all_dr_permissions_except_dag_run_create", + username="all_dr_permissions_except_dag_edit", + password="all_dr_permissions_except_dag_edit", ) - delete_user(app, username="all_dr_permissions_except_dag_run_create") # type: ignore - delete_roles(app) + delete_user(app.app, username="all_dr_permissions_except_dag_edit") # type: ignore + delete_roles(app.app) @pytest.fixture(scope="module", autouse=True) @@ -103,14 +107,16 @@ def reset_dagrun(): session.query(TaskInstance).delete() -def test_get_dagrun_can_view_dags_without_edit_perms(session, running_dag_run, client_dr_without_dag_edit): +def test_get_dagrun_can_view_dags_without_edit_perms( + session, running_dag_run, flask_client_dr_without_dag_edit +): """Test that a user without dag_edit but with dag_read permission can view the records""" assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1 - resp = client_dr_without_dag_edit.get("/dagrun/list/", follow_redirects=True) + resp = flask_client_dr_without_dag_edit.get("/dagrun/list/", follow_redirects=True) check_content_in_response(running_dag_run.dag_id, resp) -def test_create_dagrun_permission_denied(session, client_dr_without_dag_run_create): +def test_create_dagrun_permission_denied(session, flask_client_dr_without_dag_run_create): data = { "state": "running", "dag_id": "example_bash_operator", @@ -119,7 +125,7 @@ def test_create_dagrun_permission_denied(session, client_dr_without_dag_run_crea "conf": '{"include": "me"}', } - resp = client_dr_without_dag_run_create.post("/dagrun/add", data=data, follow_redirects=True) + resp = flask_client_dr_without_dag_run_create.post("/dagrun/add", data=data, follow_redirects=True) check_content_in_response("Access is Denied", resp) @@ -169,18 +175,18 @@ def completed_dag_run_with_missing_task(session): return dag, dr -def test_delete_dagrun(session, admin_client, running_dag_run): +def test_delete_dagrun(session, flask_admin_client, running_dag_run): composite_key = _get_appbuilder_pk_string(DagRunModelView, running_dag_run) assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1 - admin_client.post(f"/dagrun/delete/{composite_key}", follow_redirects=True) + flask_admin_client.post(f"/dagrun/delete/{composite_key}", follow_redirects=True) assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 0 -def test_delete_dagrun_permission_denied(session, running_dag_run, client_dr_without_dag_edit): +def test_delete_dagrun_permission_denied(session, running_dag_run, flask_client_dr_without_dag_edit): composite_key = _get_appbuilder_pk_string(DagRunModelView, running_dag_run) assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1 - resp = client_dr_without_dag_edit.post(f"/dagrun/delete/{composite_key}", follow_redirects=True) + resp = flask_client_dr_without_dag_edit.post(f"/dagrun/delete/{composite_key}", follow_redirects=True) check_content_in_response("Access is Denied", resp) assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1 @@ -218,13 +224,13 @@ def test_delete_dagrun_permission_denied(session, running_dag_run, client_dr_wit ) def test_set_dag_runs_action( session, - admin_client, + flask_admin_client, running_dag_run, action, expected_ti_states, expected_message, ): - resp = admin_client.post( + resp = flask_admin_client.post( "/dagrun/action_post", data={"action": action, "rowid": [running_dag_run.id]}, follow_redirects=True, @@ -244,8 +250,8 @@ def test_set_dag_runs_action( ], ids=["clear", "success", "failed", "running", "queued"], ) -def test_set_dag_runs_action_fails(admin_client, action, expected_message): - resp = admin_client.post( +def test_set_dag_runs_action_fails(flask_admin_client, action, expected_message): + resp = flask_admin_client.post( "/dagrun/action_post", data={"action": action, "rowid": ["0"]}, follow_redirects=True, @@ -253,9 +259,9 @@ def test_set_dag_runs_action_fails(admin_client, action, expected_message): check_content_in_response(expected_message, resp) -def test_muldelete_dag_runs_action(session, admin_client, running_dag_run): +def test_muldelete_dag_runs_action(session, flask_admin_client, running_dag_run): dag_run_id = running_dag_run.id - resp = admin_client.post( + resp = flask_admin_client.post( "/dagrun/action_post", data={"action": "muldelete", "rowid": [dag_run_id]}, follow_redirects=True, @@ -270,9 +276,9 @@ def test_muldelete_dag_runs_action(session, admin_client, running_dag_run): ["clear", "set_success", "set_failed", "set_running"], ids=["clear", "success", "failed", "running"], ) -def test_set_dag_runs_action_permission_denied(client_dr_without_dag_edit, running_dag_run, action): +def test_set_dag_runs_action_permission_denied(flask_client_dr_without_dag_edit, running_dag_run, action): running_dag_id = running_dag_run.id - resp = client_dr_without_dag_edit.post( + resp = flask_client_dr_without_dag_edit.post( "/dagrun/action_post", data={"action": action, "rowid": [str(running_dag_id)]}, follow_redirects=True, @@ -280,9 +286,9 @@ def test_set_dag_runs_action_permission_denied(client_dr_without_dag_edit, runni check_content_in_response("Access is Denied", resp) -def test_dag_runs_queue_new_tasks_action(session, admin_client, completed_dag_run_with_missing_task): +def test_dag_runs_queue_new_tasks_action(session, flask_admin_client, completed_dag_run_with_missing_task): dag, dag_run = completed_dag_run_with_missing_task - resp = admin_client.post( + resp = flask_admin_client.post( "/dagrun_queued", data={"dag_id": dag.dag_id, "dag_run_id": dag_run.run_id, "confirmed": False}, ) diff --git a/tests/www/views/test_views_dataset.py b/tests/www/views/test_views_dataset.py index d67ed80f385e5..01771bf0a97a2 100644 --- a/tests/www/views/test_views_dataset.py +++ b/tests/www/views/test_views_dataset.py @@ -55,7 +55,7 @@ def test_should_respond_200(self, admin_client, session): response = admin_client.get("/object/datasets_summary") assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "datasets": [ { @@ -89,7 +89,7 @@ def test_order_by_raises_400_for_invalid_attr(self, admin_client, session): assert response.status_code == 400 msg = "Ordering with 'fake' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_order_by_raises_400_for_invalid_datetimes(self, admin_client, session): datasets = [ @@ -139,15 +139,15 @@ def test_filter_by_datetimes(self, admin_client, session): response = admin_client.get(f"/object/datasets_summary?updated_after={cutoff}") assert response.status_code == 200 - assert response.json["total_entries"] == 2 - assert [json_dict["id"] for json_dict in response.json["datasets"]] == [2, 3] + assert response.json()["total_entries"] == 2 + assert [json_dict["id"] for json_dict in response.json()["datasets"]] == [2, 3] cutoff = today.add(days=-1).add(minutes=5).to_iso8601_string() response = admin_client.get(f"/object/datasets_summary?updated_before={cutoff}") assert response.status_code == 200 - assert response.json["total_entries"] == 2 - assert [json_dict["id"] for json_dict in response.json["datasets"]] == [1, 2] + assert response.json()["total_entries"] == 2 + assert [json_dict["id"] for json_dict in response.json()["datasets"]] == [1, 2] @pytest.mark.parametrize( "order_by, ordered_dataset_ids", @@ -188,8 +188,8 @@ def test_order_by(self, admin_client, session, order_by, ordered_dataset_ids): response = admin_client.get(f"/object/datasets_summary?order_by={order_by}") assert response.status_code == 200 - assert ordered_dataset_ids == [json_dict["id"] for json_dict in response.json["datasets"]] - assert response.json["total_entries"] == len(ordered_dataset_ids) + assert ordered_dataset_ids == [json_dict["id"] for json_dict in response.json()["datasets"]] + assert response.json()["total_entries"] == len(ordered_dataset_ids) def test_search_uri_pattern(self, admin_client, session): datasets = [ @@ -207,7 +207,7 @@ def test_search_uri_pattern(self, admin_client, session): response = admin_client.get(f"/object/datasets_summary?uri_pattern={uri_pattern}") assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "datasets": [ { @@ -224,7 +224,7 @@ def test_search_uri_pattern(self, admin_client, session): response = admin_client.get(f"/object/datasets_summary?uri_pattern={uri_pattern}") assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "datasets": [ { @@ -289,7 +289,7 @@ def test_correct_counts_update(self, admin_client, session, dag_maker, app, monk ): EmptyOperator(task_id="task1", outlets=[datasets[4]]) - m.setattr(app, "dag_bag", dag_maker.dagbag) + m.setattr(app.app, "dag_bag", dag_maker.dagbag) ds1_id = session.query(DatasetModel.id).filter_by(uri=datasets[0].uri).scalar() ds2_id = session.query(DatasetModel.id).filter_by(uri=datasets[1].uri).scalar() @@ -342,7 +342,7 @@ def test_correct_counts_update(self, admin_client, session, dag_maker, app, monk response = admin_client.get("/object/datasets_summary") assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "datasets": [ { @@ -408,7 +408,7 @@ def test_limit_and_offset(self, admin_client, session, url, expected_dataset_uri response = admin_client.get(url) assert response.status_code == 200 - dataset_uris = [dataset["uri"] for dataset in response.json["datasets"]] + dataset_uris = [dataset["uri"] for dataset in response.json()["datasets"]] assert dataset_uris == expected_dataset_uris def test_should_respect_page_size_limit_default(self, admin_client, session): @@ -425,7 +425,7 @@ def test_should_respect_page_size_limit_default(self, admin_client, session): response = admin_client.get("/object/datasets_summary") assert response.status_code == 200 - assert len(response.json["datasets"]) == 25 + assert len(response.json()["datasets"]) == 25 def test_should_return_max_if_req_above(self, admin_client, session): datasets = [ @@ -441,15 +441,19 @@ def test_should_return_max_if_req_above(self, admin_client, session): response = admin_client.get("/object/datasets_summary?limit=180") assert response.status_code == 200 - assert len(response.json["datasets"]) == 50 + assert len(response.json()["datasets"]) == 50 class TestGetDatasetNextRunSummary(TestDatasetEndpoint): - def test_next_run_dataset_summary(self, dag_maker, admin_client): - with dag_maker(dag_id="upstream", schedule=[Dataset(uri="s3://bucket/key/1")], serialized=True): + def test_next_run_dataset_summary(self, dag_maker, admin_client, session): + with dag_maker( + dag_id="upstream", schedule=[Dataset(uri="s3://bucket/key/1")], serialized=True, session=session + ): EmptyOperator(task_id="task1") + session.commit() + session.close() response = admin_client.post("/next_run_datasets_summary", data={"dag_ids": ["upstream"]}) assert response.status_code == 200 - assert response.json == {"upstream": {"ready": 0, "total": 1, "uri": "s3://bucket/key/1"}} + assert response.json() == {"upstream": {"ready": 0, "total": 1, "uri": "s3://bucket/key/1"}} diff --git a/tests/www/views/test_views_extra_links.py b/tests/www/views/test_views_extra_links.py index d5b70caba586d..50692038f199a 100644 --- a/tests/www/views/test_views_extra_links.py +++ b/tests/www/views/test_views_extra_links.py @@ -79,13 +79,17 @@ def dag(): @pytest.fixture(scope="module") def create_dag_run(dag): def _create_dag_run(*, execution_date, session): - return dag.create_dagrun( - state=DagRunState.RUNNING, - execution_date=execution_date, - data_interval=(execution_date, execution_date), - run_type=DagRunType.MANUAL, - session=session, - ) + try: + return dag.create_dagrun( + state=DagRunState.RUNNING, + execution_date=execution_date, + data_interval=(execution_date, execution_date), + run_type=DagRunType.MANUAL, + session=session, + ) + finally: + session.commit() + session.close() return _create_dag_run @@ -97,7 +101,7 @@ def dag_run(create_dag_run, session): @pytest.fixture(scope="module", autouse=True) def patched_app(app, dag): - with mock.patch.object(app, "dag_bag") as mock_dag_bag: + with mock.patch.object(app.app, "dag_bag") as mock_dag_bag: mock_dag_bag.get_dag.return_value = dag yield @@ -140,7 +144,7 @@ def test_extra_links_works(dag_run, task_1, viewer_client, session): ) assert response.status_code == 200 - assert json.loads(response.data.decode()) == { + assert json.loads(response.text) == { "url": "http://www.example.com/some_dummy_task/foo-bar/manual__2017-01-01T00:00:00+00:00", "error": None, } @@ -154,7 +158,7 @@ def test_global_extra_links_works(dag_run, task_1, viewer_client, session): ) assert response.status_code == 200 - assert json.loads(response.data.decode()) == { + assert json.loads(response.text) == { "url": "https://github.com/apache/airflow", "error": None, } @@ -168,10 +172,7 @@ def test_operator_extra_link_override_global_extra_link(dag_run, task_1, viewer_ ) assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org", "error": None} + assert json.loads(response.text) == {"url": "https://airflow.apache.org", "error": None} def test_extra_links_error_raised(dag_run, task_1, viewer_client): @@ -182,10 +183,7 @@ def test_extra_links_error_raised(dag_run, task_1, viewer_client): ) assert 404 == response.status_code - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": None, "error": "This is an error"} + assert json.loads(response.text) == {"url": None, "error": "This is an error"} def test_extra_links_no_response(dag_run, task_1, viewer_client): @@ -196,10 +194,7 @@ def test_extra_links_no_response(dag_run, task_1, viewer_client): ) assert response.status_code == 404 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": None, "error": "No URL found for no_response"} + assert json.loads(response.text) == {"url": None, "error": "No URL found for no_response"} def test_operator_extra_link_override_plugin(dag_run, task_2, viewer_client): @@ -217,10 +212,8 @@ def test_operator_extra_link_override_plugin(dag_run, task_2, viewer_client): ) assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} + + assert json.loads(response.text) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} def test_operator_extra_link_multiple_operators(dag_run, task_2, task_3, viewer_client): @@ -239,10 +232,8 @@ def test_operator_extra_link_multiple_operators(dag_run, task_2, task_3, viewer_ ) assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} + + assert json.loads(response.text) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} response = viewer_client.get( f"{ENDPOINT}?dag_id={task_3.dag_id}&task_id={task_3.task_id}" @@ -251,10 +242,7 @@ def test_operator_extra_link_multiple_operators(dag_run, task_2, task_3, viewer_ ) assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} + assert json.loads(response.text) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} # Also check that the other Operator Link defined for this operator exists response = viewer_client.get( @@ -264,7 +252,4 @@ def test_operator_extra_link_multiple_operators(dag_run, task_2, task_3, viewer_ ) assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://www.google.com", "error": None} + assert json.loads(response.text) == {"url": "https://www.google.com", "error": None} diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 3d13dea4d1248..42264f39552ff 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -81,7 +81,7 @@ def mapped_task_group(arg1): with TaskGroup(group_id="group"): MockOperator.partial(task_id="mapped").expand(arg1=["a", "b", "c", "d"]) - m.setattr(app, "dag_bag", dag_maker.dagbag) + m.setattr(app.app, "dag_bag", dag_maker.dagbag) yield dag_maker @@ -96,14 +96,15 @@ def dag_with_runs(dag_without_runs): run_type=DagRunType.SCHEDULED, execution_date=date + timedelta(days=1), ) - return run_1, run_2 -def test_no_runs(admin_client, dag_without_runs): +def test_no_runs(admin_client, dag_without_runs, session): + session.commit() + session.close() resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}", follow_redirects=True) - assert resp.status_code == 200, resp.json - assert resp.json == { + assert resp.status_code == 200, resp.json() + assert resp.json() == { "dag_runs": [], "groups": { "children": [ @@ -163,7 +164,9 @@ def test_no_runs(admin_client, dag_without_runs): } -def test_grid_data_filtered_on_run_type_and_run_state(admin_client, dag_with_runs): +def test_grid_data_filtered_on_run_type_and_run_state(admin_client, dag_with_runs, session): + session.commit() + session.close() for uri_params, expected_run_types, expected_run_states in [ ("run_state=success&run_state=queued", ["scheduled"], ["success"]), ("run_state=running&run_state=failed", ["scheduled"], ["running"]), @@ -177,9 +180,9 @@ def test_grid_data_filtered_on_run_type_and_run_state(admin_client, dag_with_run ), ]: resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}&{uri_params}", follow_redirects=True) - assert resp.status_code == 200, resp.json - actual_run_types = list(map(lambda x: x["run_type"], resp.json["dag_runs"])) - actual_run_states = list(map(lambda x: x["state"], resp.json["dag_runs"])) + assert resp.status_code == 200, resp.json() + actual_run_types = list(map(lambda x: x["run_type"], resp.json()["dag_runs"])) + actual_run_states = list(map(lambda x: x["state"], resp.json()["dag_runs"])) assert actual_run_types == expected_run_types assert actual_run_states == expected_run_states @@ -199,7 +202,6 @@ def test_one_run(admin_client, dag_with_runs: list[DagRun], session): - One TI not yet finished """ run1, run2 = dag_with_runs - for ti in run1.task_instances: ti.state = TaskInstanceState.SUCCESS for ti in sorted(run2.task_instances, key=lambda ti: (ti.task_id, ti.map_index)): @@ -214,14 +216,14 @@ def test_one_run(admin_client, dag_with_runs: list[DagRun], session): ti.state = TaskInstanceState.RUNNING ti.start_date = pendulum.DateTime(2021, 7, 1, 2, 3, 4, tzinfo=pendulum.UTC) ti.end_date = None - + session.commit() session.flush() - + session.close() resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}", follow_redirects=True) - assert resp.status_code == 200, resp.json + assert resp.status_code == 200, resp.json() - assert resp.json == { + assert resp.json() == { "dag_runs": [ { "conf": None, @@ -429,7 +431,9 @@ def test_has_outlet_dataset_flag(admin_client, dag_maker, session, app, monkeypa EmptyOperator(task_id="task3", outlets=[Dataset("foo"), lineagefile]) EmptyOperator(task_id="task4", outlets=[Dataset("foo")]) - m.setattr(app, "dag_bag", dag_maker.dagbag) + m.setattr(app.app, "dag_bag", dag_maker.dagbag) + session.commit() + session.close() resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}", follow_redirects=True) def _expected_task_details(task_id, has_outlet_datasets): @@ -444,8 +448,8 @@ def _expected_task_details(task_id, has_outlet_datasets): "trigger_rule": "all_success", } - assert resp.status_code == 200, resp.json - assert resp.json == { + assert resp.status_code == 200, resp.json() + assert resp.json() == { "dag_runs": [], "groups": { "children": [ @@ -470,7 +474,7 @@ def test_next_run_datasets(admin_client, dag_maker, session, app, monkeypatch): with dag_maker(dag_id=DAG_ID, schedule=datasets, serialized=True, session=session): EmptyOperator(task_id="task1") - m.setattr(app, "dag_bag", dag_maker.dagbag) + m.setattr(app.app, "dag_bag", dag_maker.dagbag) ds1_id = session.query(DatasetModel.id).filter_by(uri=datasets[0].uri).scalar() ds2_id = session.query(DatasetModel.id).filter_by(uri=datasets[1].uri).scalar() @@ -500,8 +504,8 @@ def test_next_run_datasets(admin_client, dag_maker, session, app, monkeypatch): resp = admin_client.get(f"/object/next_run_datasets/{DAG_ID}", follow_redirects=True) - assert resp.status_code == 200, resp.json - assert resp.json == { + assert resp.status_code == 200, resp.json() + assert resp.json() == { "dataset_expression": {"all": ["s3://bucket/key/1", "s3://bucket/key/2"]}, "events": [ {"id": ds1_id, "uri": "s3://bucket/key/1", "lastUpdate": "2022-08-02T02:00:00+00:00"}, @@ -512,5 +516,5 @@ def test_next_run_datasets(admin_client, dag_maker, session, app, monkeypatch): def test_next_run_datasets_404(admin_client): resp = admin_client.get("/object/next_run_datasets/missingdag", follow_redirects=True) - assert resp.status_code == 404, resp.json - assert resp.json == {"error": "can't find dag missingdag"} + assert resp.status_code == 404, resp.json() + assert resp.json() == {"error": "can't find dag missingdag"} diff --git a/tests/www/views/test_views_home.py b/tests/www/views/test_views_home.py index cffe0844b012e..e73aa5a5e8c2f 100644 --- a/tests/www/views/test_views_home.py +++ b/tests/www/views/test_views_home.py @@ -85,13 +85,15 @@ def call_kwargs(): update_stmt = update(DagModel).where(DagModel.dag_id == "filter_test_1").values(is_active=False) session.execute(update_stmt) + session.commit() + session.close() admin_client.get("home", follow_redirects=True) assert call_kwargs()["status_count_all"] == 3 -def test_home_status_filter_cookie(admin_client): - with admin_client: +def test_home_status_filter_cookie(admin_flask_client): + with admin_flask_client as admin_client: admin_client.get("home", follow_redirects=True) assert "all" == flask.session[FILTER_STATUS_COOKIE] @@ -118,7 +120,7 @@ def test_home_status_filter_cookie(admin_client): def user_no_importerror(app): """Create User that cannot access Import Errors""" return create_user( - app, + app.app, username="user_no_importerrors", role_name="role_no_importerrors", permissions=[ @@ -142,7 +144,7 @@ def client_no_importerror(app, user_no_importerror): def user_single_dag(app): """Create User that can only access the first DAG from TEST_FILTER_DAG_IDS""" return create_user( - app, + app.app, username="user_single_dag", role_name="role_single_dag", permissions=[ @@ -167,7 +169,7 @@ def client_single_dag(app, user_single_dag): def user_single_dag_edit(app): """Create User that can edit DAG resource only a single DAG""" return create_user( - app, + app.app, username="user_single_dag_edit", role_name="role_single_dag", permissions=[ @@ -278,8 +280,8 @@ def broken_dags_after_working(tmp_path): _process_file(path, session) -def test_home_filter_tags(working_dags, admin_client): - with admin_client: +def test_home_filter_tags(working_dags, admin_flask_client): + with admin_flask_client as admin_client: admin_client.get("home?tags=example&tags=data", follow_redirects=True) assert "example,data" == flask.session[FILTER_TAGS_COOKIE] @@ -451,7 +453,7 @@ def test_dashboard_flash_messages_type(user_client): ) def test_sorting_home_view(url, lower_key, greater_key, user_client, working_dags): resp = user_client.get(url, follow_redirects=True) - resp_html = resp.data.decode("utf-8") + resp_html = resp.text lower_index = resp_html.find(lower_key) greater_index = resp_html.find(greater_key) assert lower_index < greater_index diff --git a/tests/www/views/test_views_log.py b/tests/www/views/test_views_log.py index 2607317c5fccc..80b698b35e275 100644 --- a/tests/www/views/test_views_log.py +++ b/tests/www/views/test_views_log.py @@ -39,11 +39,11 @@ from airflow.utils.session import create_session from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import DagRunType -from airflow.www.app import create_app +from airflow.www.app import create_connexion_app from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs from tests.test_utils.decorators import dont_initialize_flask_app_submodules -from tests.test_utils.www import client_with_login +from tests.test_utils.www import client_with_login, flask_client_with_login pytestmark = pytest.mark.db_test @@ -83,10 +83,10 @@ def log_app(backup_modules, log_path): } ) def factory(): - app = create_app(testing=True) - app.config["WTF_CSRF_ENABLED"] = False + app = create_connexion_app(testing=True) + app.app.config["WTF_CSRF_ENABLED"] = False settings.configure_orm() - security_manager = app.appbuilder.sm + security_manager = app.app.appbuilder.sm if not security_manager.find_user(username="test"): security_manager.add_user( username="test", @@ -142,7 +142,7 @@ def dags(log_app, create_dummy_dag, session): bag.bag_dag(dag=dag, root_dag=dag) bag.bag_dag(dag=dag_removed, root_dag=dag_removed) bag.sync_to_db(session=session) - log_app.dag_bag = bag + log_app.app.dag_bag = bag yield dag, dag_removed @@ -174,6 +174,9 @@ def tis(dags, session): (ti_removed_dag,) = dagrun_removed.task_instances ti_removed_dag.try_number = 1 + session.commit() + session.close() + yield ti, ti_removed_dag clear_db_runs() @@ -198,6 +201,11 @@ def create_expected_log_file(try_number): shutil.rmtree(sub_path) +@pytest.fixture +def flask_log_admin_client(log_app): + return flask_client_with_login(log_app, username="test", password="test") + + @pytest.fixture def log_admin_client(log_app): return client_with_login(log_app, username="test", password="test") @@ -233,12 +241,12 @@ def test_get_file_task_log(log_admin_client, tis, state, try_number, num_logs): response = log_admin_client.get( ENDPOINT, - data={"username": "test", "password": "test"}, + params={"username": "test", "password": "test"}, follow_redirects=True, ) assert response.status_code == 200 - data = response.data.decode() + data = response.text assert "Log by attempts" in data for num in range(1, num_logs + 1): assert f"log-group-{num}" in data @@ -271,7 +279,7 @@ def test_get_logs_with_metadata_as_download_file(log_admin_client, create_expect in content_disposition ) assert 200 == response.status_code - content = response.data.decode("utf-8") + content = response.text assert "Log for testing." in content assert "localhost\n" in content @@ -314,7 +322,7 @@ def test_get_logs_for_changed_filename_format_db( # Should find the log under corresponding db entry. assert 200 == response.status_code - assert "Log for testing." in response.data.decode("utf-8") + assert "Log for testing." in response.text content_disposition = response.headers["Content-Disposition"] expected_filename = ( f"{dag_run_with_log_filename.dag_id}/{dag_run_with_log_filename.run_id}/{TASK_ID}/{try_number}.log" @@ -348,7 +356,7 @@ def test_get_logs_with_metadata_as_download_large_file(_, log_admin_client): ) response = log_admin_client.get(url) - data = response.data.decode() + data = response.text assert "1st line" in data assert "2nd line" in data assert "3rd line" in data @@ -368,12 +376,12 @@ def test_get_logs_with_metadata(log_admin_client, metadata, create_expected_log_ try_number, metadata, ), - data={"username": "test", "password": "test"}, + params={"username": "test", "password": "test"}, follow_redirects=True, ) assert 200 == response.status_code - data = response.data.decode() + data = response.text assert '"message":' in data assert '"metadata":' in data assert "Log for testing." in data @@ -391,12 +399,12 @@ def test_get_logs_with_invalid_metadata(log_admin_client): 1, metadata, ), - data={"username": "test", "password": "test"}, + params={"username": "test", "password": "test"}, follow_redirects=True, ) assert response.status_code == 400 - assert response.json == {"error": "Invalid JSON metadata"} + assert response.json() == {"error": "Invalid JSON metadata"} @unittest.mock.patch( @@ -413,12 +421,12 @@ def test_get_logs_with_metadata_for_removed_dag(_, log_admin_client): 1, "{}", ), - data={"username": "test", "password": "test"}, + params={"username": "test", "password": "test"}, follow_redirects=True, ) assert 200 == response.status_code - data = response.data.decode() + data = response.text assert '"message":' in data assert '"metadata":' in data assert "airflow log line" in data @@ -440,7 +448,7 @@ def test_get_logs_response_with_ti_equal_to_none(log_admin_client): ) response = log_admin_client.get(url) - data = response.json + data = response.json() assert "message" in data assert "error" in data assert "*** Task instance did not exist in the DB\n" == data["message"] @@ -464,9 +472,9 @@ def test_get_logs_with_json_response_format(log_admin_client, create_expected_lo response = log_admin_client.get(url) assert 200 == response.status_code - assert "message" in response.json - assert "metadata" in response.json - assert "Log for testing." in response.json["message"][0][1] + assert "message" in response.json() + assert "metadata" in response.json() + assert "Log for testing." in response.json()["message"][0][1] def test_get_logs_invalid_execution_data_format(log_admin_client): @@ -485,7 +493,7 @@ def test_get_logs_invalid_execution_data_format(log_admin_client): ) response = log_admin_client.get(url) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "error": ( "Given execution date 'Tuesday February 27, 2024' could not be identified as a date. " "Example date format: 2015-11-16T14:34:15+00:00" @@ -512,7 +520,7 @@ def test_get_logs_for_handler_without_read_method(mock_reader, log_admin_client) response = log_admin_client.get(url) assert 200 == response.status_code - data = response.json + data = response.json() assert "message" in data assert "metadata" in data assert "Task log handler does not support read logs." in data["message"] @@ -530,8 +538,8 @@ def test_redirect_to_external_log_with_local_log_handler(log_admin_client, task_ try_number, ) response = log_admin_client.get(url) - assert 302 == response.status_code - assert "/home" == response.headers["Location"] + assert 200 == response.status_code + assert "/home" == response.url.path class _ExternalHandler(ExternalLoggingMixin): @@ -554,7 +562,7 @@ def supports_external_link(self) -> bool: new_callable=unittest.mock.PropertyMock, return_value=_ExternalHandler(), ) -def test_redirect_to_external_log_with_external_log_handler(_, log_admin_client): +def test_redirect_to_external_log_with_external_log_handler(_, flask_log_admin_client): url_template = "redirect_to_external_log?dag_id={}&task_id={}&execution_date={}&try_number={}" try_number = 1 url = url_template.format( @@ -563,6 +571,6 @@ def test_redirect_to_external_log_with_external_log_handler(_, log_admin_client) urllib.parse.quote_plus(DEFAULT_DATE.isoformat()), try_number, ) - response = log_admin_client.get(url) + response = flask_log_admin_client.get(url) assert 302 == response.status_code assert _ExternalHandler.EXTERNAL_URL == response.headers["Location"] diff --git a/tests/www/views/test_views_mount.py b/tests/www/views/test_views_mount.py index f0c052294b60a..87f72efec6acb 100644 --- a/tests/www/views/test_views_mount.py +++ b/tests/www/views/test_views_mount.py @@ -21,7 +21,7 @@ import werkzeug.test import werkzeug.wrappers -from airflow.www.app import create_app +from airflow.www.app import create_connexion_app from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -31,16 +31,16 @@ def app(): @conf_vars({("webserver", "base_url"): "http://localhost/test"}) def factory(): - return create_app(testing=True) + return create_connexion_app(testing=True) app = factory() - app.config["WTF_CSRF_ENABLED"] = False + app.app.config["WTF_CSRF_ENABLED"] = False return app @pytest.fixture def client(app): - return werkzeug.test.Client(app, werkzeug.wrappers.response.Response) + return werkzeug.test.Client(app.app, werkzeug.wrappers.response.Response) def test_mount(client): diff --git a/tests/www/views/test_views_paused.py b/tests/www/views/test_views_paused.py index 46b0a3aa03f1a..e54fe0ad253cc 100644 --- a/tests/www/views/test_views_paused.py +++ b/tests/www/views/test_views_paused.py @@ -34,17 +34,17 @@ def dags(create_dummy_dag): clear_db_dags() -def test_logging_pause_dag(admin_client, dags, session): +def test_logging_pause_dag(flask_admin_client, dags, session): dag, _ = dags # is_paused=false mean pause the dag - admin_client.post(f"/paused?is_paused=false&dag_id={dag.dag_id}", follow_redirects=True) + flask_admin_client.post(f"/paused?is_paused=false&dag_id={dag.dag_id}", follow_redirects=True) dag_query = session.query(Log).filter(Log.dag_id == dag.dag_id) assert '{"is_paused": true}' in dag_query.first().extra -def test_logging_unpause_dag(admin_client, dags, session): +def test_logging_unpause_dag(flask_admin_client, dags, session): _, paused_dag = dags # is_paused=true mean unpause the dag - admin_client.post(f"/paused?is_paused=true&dag_id={paused_dag.dag_id}", follow_redirects=True) + flask_admin_client.post(f"/paused?is_paused=true&dag_id={paused_dag.dag_id}", follow_redirects=True) dag_query = session.query(Log).filter(Log.dag_id == paused_dag.dag_id) assert '{"is_paused": false}' in dag_query.first().extra diff --git a/tests/www/views/test_views_pool.py b/tests/www/views/test_views_pool.py index 3fcacbbbf8bed..4b38c5f32ac9e 100644 --- a/tests/www/views/test_views_pool.py +++ b/tests/www/views/test_views_pool.py @@ -83,7 +83,7 @@ def test_list(app, admin_client, pool_factory): resp = admin_client.get("/pool/list/") # We should see this link - with app.test_request_context(): + with app.app.test_request_context(): description_tag = markupsafe.Markup("{description}").format( description="test-pool-description" ) diff --git a/tests/www/views/test_views_rate_limit.py b/tests/www/views/test_views_rate_limit.py index fa4502a275315..032b0ffda7f27 100644 --- a/tests/www/views/test_views_rate_limit.py +++ b/tests/www/views/test_views_rate_limit.py @@ -19,10 +19,10 @@ import pytest -from airflow.www.app import create_app +from airflow.www.app import create_connexion_app from tests.test_utils.config import conf_vars from tests.test_utils.decorators import dont_initialize_flask_app_submodules -from tests.test_utils.www import client_with_login +from tests.test_utils.www import client_with_login, flask_client_with_login pytestmark = pytest.mark.db_test @@ -44,23 +44,25 @@ def app_with_rate_limit_one(examples_dag_bag): ) def factory(): with conf_vars({("fab", "auth_rate_limited"): "True", ("fab", "auth_rate_limit"): "1 per 20 second"}): - return create_app(testing=True) + return create_connexion_app(testing=True) app = factory() - app.config["WTF_CSRF_ENABLED"] = False + app.app.config["WTF_CSRF_ENABLED"] = False return app def test_rate_limit_one(app_with_rate_limit_one): - client_with_login( + flask_client_with_login( app_with_rate_limit_one, expected_response_code=302, username="test_admin", password="test_admin" ) - client_with_login( - app_with_rate_limit_one, expected_response_code=429, username="test_admin", password="test_admin" - ) - client_with_login( - app_with_rate_limit_one, expected_response_code=429, username="test_admin", password="test_admin" - ) + from starlette.exceptions import HTTPException + + with pytest.raises(HTTPException) as ex: + flask_client_with_login(app_with_rate_limit_one, username="test_admin", password="test_admin") + assert ex.value.status_code == 429 + with pytest.raises(HTTPException) as ex: + flask_client_with_login(app_with_rate_limit_one, username="test_admin", password="test_admin") + assert ex.value.status_code == 429 def test_rate_limit_disabled(app): diff --git a/tests/www/views/test_views_rendered.py b/tests/www/views/test_views_rendered.py index 842f1010138d4..3d26cb9f9bf03 100644 --- a/tests/www/views/test_views_rendered.py +++ b/tests/www/views/test_views_rendered.py @@ -161,7 +161,7 @@ def _create_dag_run(*, execution_date, session): @pytest.fixture def patch_app(app, dag): - with mock.patch.object(app, "dag_bag") as mock_dag_bag: + with mock.patch.object(app.app, "dag_bag") as mock_dag_bag: mock_dag_bag.get_dag.return_value = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) yield app @@ -215,7 +215,7 @@ def test_user_defined_filter_and_macros_raise_error(admin_client, create_dag_run resp = admin_client.get(url, follow_redirects=True) assert resp.status_code == 200 - resp_html: str = resp.data.decode("utf-8") + resp_html: str = resp.text assert "echo Hello Apache Airflow" not in resp_html assert ( "Webserver does not have access to User-defined Macros or Filters when " @@ -323,7 +323,7 @@ def test_rendered_task_detail_env_secret(patch_app, admin_client, request, env, Variable.set("plain_var", "banana") Variable.set("secret_var", "monkey") - dag: DAG = patch_app.dag_bag.get_dag("testdag") + dag: DAG = patch_app.app.dag_bag.get_dag("testdag") task_secret: BashOperator = dag.get_task(task_id="task1") task_secret.env = env date = quote_plus(str(DEFAULT_DATE)) diff --git a/tests/www/views/test_views_robots.py b/tests/www/views/test_views_robots.py index 03d8547c04d4b..319fba3a7efcd 100644 --- a/tests/www/views/test_views_robots.py +++ b/tests/www/views/test_views_robots.py @@ -25,16 +25,16 @@ def test_robots(viewer_client): resp = viewer_client.get("/robots.txt", follow_redirects=True) - assert resp.data.decode("utf-8") == "User-agent: *\nDisallow: /\n" + assert resp.text == "User-agent: *\nDisallow: /\n" def test_deployment_warning_config(admin_client): warn_text = "webserver.warn_deployment_exposure" admin_client.get("/robots.txt", follow_redirects=True) resp = admin_client.get("", follow_redirects=True) - assert warn_text in resp.data.decode("utf-8") + assert warn_text in resp.text with conf_vars({("webserver", "warn_deployment_exposure"): "False"}): admin_client.get("/robots.txt", follow_redirects=True) resp = admin_client.get("/robots.txt", follow_redirects=True) - assert warn_text not in resp.data.decode("utf-8") + assert warn_text not in resp.text diff --git a/tests/www/views/test_views_task_norun.py b/tests/www/views/test_views_task_norun.py index a0709c4303d99..7001f141bb11d 100644 --- a/tests/www/views/test_views_task_norun.py +++ b/tests/www/views/test_views_task_norun.py @@ -41,7 +41,7 @@ def test_task_view_no_task_instance(admin_client): url = f"/task?task_id=runme_0&dag_id=example_bash_operator&execution_date={DEFAULT_VAL}" resp = admin_client.get(url, follow_redirects=True) assert resp.status_code == 200 - html = resp.data.decode("utf-8") + html = resp.text assert "
No Task Instance Available
" in html assert "
Task Instance Attributes
" not in html @@ -50,5 +50,5 @@ def test_rendered_templates_view_no_task_instance(admin_client): url = f"/rendered-templates?task_id=runme_0&dag_id=example_bash_operator&execution_date={DEFAULT_VAL}" resp = admin_client.get(url, follow_redirects=True) assert resp.status_code == 200 - html = resp.data.decode("utf-8") + html = resp.text assert "Rendered Template" in html diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index de86d9227bd64..ce17faaf8535c 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -18,11 +18,11 @@ from __future__ import annotations import html -import json import unittest.mock import urllib.parse from getpass import getuser +import httpx import pendulum import pytest import time_machine @@ -47,7 +47,12 @@ from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_runs, clear_db_xcom -from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login +from tests.test_utils.www import ( + check_content_in_response, + check_content_not_in_response, + client_with_login, + flask_client_with_login, +) pytestmark = pytest.mark.db_test @@ -68,7 +73,7 @@ def reset_dagruns(): @pytest.fixture(autouse=True) def init_dagruns(app, reset_dagruns): with time_machine.travel(DEFAULT_DATE, tick=False): - app.dag_bag.get_dag("example_bash_operator").create_dagrun( + app.app.dag_bag.get_dag("example_bash_operator").create_dagrun( run_id=DEFAULT_DAGRUN, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -83,7 +88,7 @@ def init_dagruns(app, reset_dagruns): dag_id="example_bash_operator", run_id=DEFAULT_DAGRUN, ) - app.dag_bag.get_dag("example_subdag_operator").create_dagrun( + app.app.dag_bag.get_dag("example_subdag_operator").create_dagrun( run_id=DEFAULT_DAGRUN, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -91,7 +96,7 @@ def init_dagruns(app, reset_dagruns): start_date=timezone.utcnow(), state=State.RUNNING, ) - app.dag_bag.get_dag("example_xcom").create_dagrun( + app.app.dag_bag.get_dag("example_xcom").create_dagrun( run_id=DEFAULT_DAGRUN, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -99,7 +104,7 @@ def init_dagruns(app, reset_dagruns): start_date=timezone.utcnow(), state=State.RUNNING, ) - app.dag_bag.get_dag("latest_only").create_dagrun( + app.app.dag_bag.get_dag("latest_only").create_dagrun( run_id=DEFAULT_DAGRUN, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -107,7 +112,7 @@ def init_dagruns(app, reset_dagruns): start_date=timezone.utcnow(), state=State.RUNNING, ) - app.dag_bag.get_dag("example_task_group").create_dagrun( + app.app.dag_bag.get_dag("example_task_group").create_dagrun( run_id=DEFAULT_DAGRUN, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -121,9 +126,9 @@ def init_dagruns(app, reset_dagruns): @pytest.fixture(scope="module") -def client_ti_without_dag_edit(app): +def flask_client_ti_without_dag_edit(app): create_user( - app, + app.app, username="all_ti_permissions_except_dag_edit", role_name="all_ti_permissions_except_dag_edit", permissions=[ @@ -138,14 +143,14 @@ def client_ti_without_dag_edit(app): ], ) - yield client_with_login( + yield flask_client_with_login( app, username="all_ti_permissions_except_dag_edit", password="all_ti_permissions_except_dag_edit", ) - delete_user(app, username="all_ti_permissions_except_dag_edit") # type: ignore - delete_roles(app) + delete_user(app.app, username="all_ti_permissions_except_dag_edit") # type: ignore + delete_roles(app.app) @pytest.mark.parametrize( @@ -367,7 +372,7 @@ def test_xcom_return_value_is_not_bytes(admin_client): def test_rendered_task_view(admin_client): url = f"task?task_id=runme_0&dag_id=example_bash_operator&execution_date={DEFAULT_VAL}" resp = admin_client.get(url, follow_redirects=True) - resp_html = resp.data.decode("utf-8") + resp_html = resp.text assert resp.status_code == 200 assert "_try_number" not in resp_html assert "try_number" in resp_html @@ -388,7 +393,7 @@ def test_rendered_k8s_without_k8s(admin_client): def test_tree_trigger_origin_tree_view(app, admin_client): - app.dag_bag.get_dag("test_tree_view").create_dagrun( + app.app.dag_bag.get_dag("test_tree_view").create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, data_interval=(DEFAULT_DATE, DEFAULT_DATE), @@ -399,12 +404,12 @@ def test_tree_trigger_origin_tree_view(app, admin_client): url = "tree?dag_id=test_tree_view" resp = admin_client.get(url, follow_redirects=True) params = {"origin": "/dags/test_tree_view/grid"} - href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params))}" + href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params, safe='/:?'))}" check_content_in_response(href, resp) def test_graph_trigger_origin_grid_view(app, admin_client): - app.dag_bag.get_dag("test_tree_view").create_dagrun( + app.app.dag_bag.get_dag("test_tree_view").create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, data_interval=(DEFAULT_DATE, DEFAULT_DATE), @@ -415,12 +420,12 @@ def test_graph_trigger_origin_grid_view(app, admin_client): url = "/dags/test_tree_view/graph" resp = admin_client.get(url, follow_redirects=True) params = {"origin": "/dags/test_tree_view/grid?tab=graph"} - href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params))}" + href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params, safe='/:?'))}" check_content_in_response(href, resp) def test_gantt_trigger_origin_grid_view(app, admin_client): - app.dag_bag.get_dag("test_tree_view").create_dagrun( + app.app.dag_bag.get_dag("test_tree_view").create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, data_interval=(DEFAULT_DATE, DEFAULT_DATE), @@ -431,7 +436,7 @@ def test_gantt_trigger_origin_grid_view(app, admin_client): url = "/dags/test_tree_view/gantt" resp = admin_client.get(url, follow_redirects=True) params = {"origin": "/dags/test_tree_view/grid?tab=gantt"} - href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params))}" + href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params, safe='/:?'))}" check_content_in_response(href, resp) @@ -439,16 +444,15 @@ def test_graph_view_without_dag_permission(app, one_dag_perm_user_client): url = "/dags/example_bash_operator/graph" resp = one_dag_perm_user_client.get(url, follow_redirects=True) assert resp.status_code == 200 - assert ( - resp.request.url - == "http://localhost/dags/example_bash_operator/grid?tab=graph&dag_run_id=TEST_DAGRUN" + assert resp.request.url == httpx.URL( + "http://testserver/dags/example_bash_operator/grid?tab=graph&dag_run_id=TEST_DAGRUN" ) check_content_in_response("example_bash_operator", resp) url = "/dags/example_xcom/graph" resp = one_dag_perm_user_client.get(url, follow_redirects=True) assert resp.status_code == 200 - assert resp.request.url == "http://localhost/home" + assert resp.request.url == httpx.URL("http://testserver/home") check_content_in_response("Access is Denied", resp) @@ -462,7 +466,7 @@ def test_last_dagruns_success_when_selecting_dags(admin_client): "last_dagruns", data={"dag_ids": ["example_subdag_operator"]}, follow_redirects=True ) assert resp.status_code == 200 - stats = json.loads(resp.data.decode("utf-8")) + stats = resp.text assert "example_bash_operator" not in stats assert "example_subdag_operator" in stats @@ -472,7 +476,7 @@ def test_last_dagruns_success_when_selecting_dags(admin_client): data={"dag_ids": ["example_subdag_operator", "example_bash_operator"]}, follow_redirects=True, ) - stats = json.loads(resp.data.decode("utf-8")) + stats = resp.text assert "example_bash_operator" in stats assert "example_subdag_operator" in stats check_content_not_in_response("example_xcom", resp) @@ -628,18 +632,20 @@ def new_dag_to_delete(): dag = DAG("new_dag_to_delete", is_paused_upon_creation=True) session = settings.Session() dag.sync_to_db(session=session) + session.commit() + session.close() return dag @pytest.fixture def per_dag_perm_user_client(app, new_dag_to_delete): - sm = app.appbuilder.sm + sm = app.app.appbuilder.sm perm = f"{permissions.RESOURCE_DAG_PREFIX}{new_dag_to_delete.dag_id}" sm.create_permission(permissions.ACTION_CAN_DELETE, perm) create_user( - app, + app.app, username="test_user_per_dag_perms", role_name="User with some perms", permissions=[ @@ -657,21 +663,21 @@ def per_dag_perm_user_client(app, new_dag_to_delete): password="test_user_per_dag_perms", ) - delete_user(app, username="test_user_per_dag_perms") # type: ignore - delete_roles(app) + delete_user(app.app, username="test_user_per_dag_perms") # type: ignore + delete_roles(app.app) @pytest.fixture def one_dag_perm_user_client(app): username = "test_user_one_dag_perm" dag_id = "example_bash_operator" - sm = app.appbuilder.sm + sm = app.app.appbuilder.sm perm = f"{permissions.RESOURCE_DAG_PREFIX}{dag_id}" sm.create_permission(permissions.ACTION_CAN_READ, perm) create_user( - app, + app.app, username=username, role_name="User with permission to access only one dag", permissions=[ @@ -691,8 +697,8 @@ def one_dag_perm_user_client(app): password=username, ) - delete_user(app, username=username) # type: ignore - delete_roles(app) + delete_user(app.app, username=username) # type: ignore + delete_roles(app.app) def test_delete_just_dag_per_dag_permissions(new_dag_to_delete, per_dag_perm_user_client): @@ -790,7 +796,7 @@ def _get_appbuilder_pk_string(model_view_cls, instance) -> str: return model_view_cls._serialize_pk_if_composite(model_view_cls, pk_value) -def test_task_instance_delete(session, admin_client, create_task_instance): +def test_task_instance_delete(session, flask_admin_client, create_task_instance): task_instance_to_delete = create_task_instance( task_id="test_task_instance_delete", execution_date=timezone.utcnow(), @@ -800,11 +806,13 @@ def test_task_instance_delete(session, admin_client, create_task_instance): task_id = task_instance_to_delete.task_id assert session.query(TaskInstance).filter(TaskInstance.task_id == task_id).count() == 1 - admin_client.post(f"/taskinstance/delete/{composite_key}", follow_redirects=True) + flask_admin_client.post(f"/taskinstance/delete/{composite_key}", follow_redirects=True) assert session.query(TaskInstance).filter(TaskInstance.task_id == task_id).count() == 0 -def test_task_instance_delete_permission_denied(session, client_ti_without_dag_edit, create_task_instance): +def test_task_instance_delete_permission_denied( + session, flask_client_ti_without_dag_edit, create_task_instance +): task_instance_to_delete = create_task_instance( task_id="test_task_instance_delete_permission_denied", execution_date=timezone.utcnow(), @@ -812,11 +820,14 @@ def test_task_instance_delete_permission_denied(session, client_ti_without_dag_e session=session, ) session.commit() + session.close() composite_key = _get_appbuilder_pk_string(TaskInstanceModelView, task_instance_to_delete) task_id = task_instance_to_delete.task_id assert session.query(TaskInstance).filter(TaskInstance.task_id == task_id).count() == 1 - resp = client_ti_without_dag_edit.post(f"/taskinstance/delete/{composite_key}", follow_redirects=True) + resp = flask_client_ti_without_dag_edit.post( + f"/taskinstance/delete/{composite_key}", follow_redirects=True + ) check_content_in_response("Access is Denied", resp) assert session.query(TaskInstance).filter(TaskInstance.task_id == task_id).count() == 1 @@ -1005,7 +1016,9 @@ def test_action_muldelete_task_instance(session, admin_client, task_search_tuple for task in tasks_to_delete ] session.bulk_save_objects(trs) + session.commit() session.flush() + session.close() # run the function to test resp = admin_client.post( @@ -1030,7 +1043,7 @@ def test_action_muldelete_task_instance(session, admin_client, task_search_tuple assert session.query(TaskReschedule).count() == 0 -def test_graph_view_doesnt_fail_on_recursion_error(app, dag_maker, admin_client): +def test_graph_view_doesnt_fail_on_recursion_error(app, dag_maker, flask_admin_client): """Test that the graph view doesn't fail on a recursion error.""" from airflow.models.baseoperator import chain @@ -1043,10 +1056,10 @@ def test_graph_view_doesnt_fail_on_recursion_error(app, dag_maker, admin_client) for i in range(1, 1000 + 1) ] chain(*tasks) - with unittest.mock.patch.object(app, "dag_bag") as mocked_dag_bag: + with unittest.mock.patch.object(app.app, "dag_bag") as mocked_dag_bag: mocked_dag_bag.get_dag.return_value = dag url = f"/dags/{dag.dag_id}/graph" - resp = admin_client.get(url, follow_redirects=True) + resp = flask_admin_client.get(url, follow_redirects=True) assert resp.status_code == 200 @@ -1057,7 +1070,7 @@ def test_task_instances(admin_client): follow_redirects=True, ) assert resp.status_code == 200 - assert resp.json == { + assert resp.json() == { "also_run_this": { "custom_operator_name": None, "dag_id": "example_bash_operator", diff --git a/tests/www/views/test_views_trigger_dag.py b/tests/www/views/test_views_trigger_dag.py index c53213c3e68ea..3f3cba44fc170 100644 --- a/tests/www/views/test_views_trigger_dag.py +++ b/tests/www/views/test_views_trigger_dag.py @@ -48,8 +48,8 @@ def initialize_one_dag(): def test_trigger_dag_button_normal_exist(admin_client): resp = admin_client.get("/", follow_redirects=True) - assert "/dags/example_bash_operator/trigger" in resp.data.decode("utf-8") - assert "return confirmDeleteDag(this, 'example_bash_operator')" in resp.data.decode("utf-8") + assert "/dags/example_bash_operator/trigger" in resp.text + assert "return confirmDeleteDag(this, 'example_bash_operator')" in resp.text # test trigger button with and without run_id @@ -174,10 +174,10 @@ def test_trigger_dag_form(admin_client): ("%2Fgraph%3Fdag_id%3Dexample_bash_operator", "http://localhost/graph?dag_id=example_bash_operator"), ], ) -def test_trigger_dag_form_origin_url(admin_client, test_origin, expected_origin): +def test_trigger_dag_form_origin_url(admin_flask_client, test_origin, expected_origin): test_dag_id = "example_bash_operator" - resp = admin_client.get(f"dags/{test_dag_id}/trigger?origin={test_origin}") + resp = admin_flask_client.get(f"dags/{test_dag_id}/trigger?origin={test_origin}") check_content_in_response(f'Cancel', resp) @@ -210,7 +210,7 @@ def test_trigger_dag_params_conf(admin_client, request_conf, expected_conf): check_content_in_response(str(expected_conf[key]), resp) -def test_trigger_dag_params_render(admin_client, dag_maker, session, app, monkeypatch): +def test_trigger_dag_params_render(admin_flask_client, dag_maker, session, app, monkeypatch): """ Test that textarea in Trigger DAG UI is pre-populated with param value set in DAG. @@ -236,8 +236,8 @@ def test_trigger_dag_params_render(admin_client, dag_maker, session, app, monkey with dag_maker(dag_id=DAG_ID, serialized=True, session=session, params={"accounts": param}): EmptyOperator(task_id="task1") - m.setattr(app, "dag_bag", dag_maker.dagbag) - resp = admin_client.get(f"dags/{DAG_ID}/trigger") + m.setattr(app.app, "dag_bag", dag_maker.dagbag) + resp = admin_flask_client.get(f"dags/{DAG_ID}/trigger") check_content_in_response( f'', @@ -246,7 +246,7 @@ def test_trigger_dag_params_render(admin_client, dag_maker, session, app, monkey @pytest.mark.parametrize("allow_html", [False, True]) -def test_trigger_dag_html_allow(admin_client, dag_maker, session, app, monkeypatch, allow_html): +def test_trigger_dag_html_allow(admin_flask_client, dag_maker, session, app, monkeypatch, allow_html): """ Test that HTML is escaped per default in description. """ @@ -277,8 +277,8 @@ def test_trigger_dag_html_allow(admin_client, dag_maker, session, app, monkeypat ): EmptyOperator(task_id="task1") - m.setattr(app, "dag_bag", dag_maker.dagbag) - resp = admin_client.get(f"dags/{DAG_ID}/trigger") + m.setattr(app.app, "dag_bag", dag_maker.dagbag) + resp = admin_flask_client.get(f"dags/{DAG_ID}/trigger") if expect_escape: check_content_in_response(escape(HTML_DESCRIPTION1), resp) @@ -309,7 +309,7 @@ def test_viewer_cant_trigger_dag(app): Test that the test_viewer user can't trigger DAGs. """ with create_test_client( - app, + app.app, user_name="test_user", role_name="test_role", permissions=[ @@ -324,7 +324,7 @@ def test_viewer_cant_trigger_dag(app): assert "Access is Denied" in response_data -def test_trigger_dag_params_array_value_none_render(admin_client, dag_maker, session, app, monkeypatch): +def test_trigger_dag_params_array_value_none_render(admin_flask_client, dag_maker, session, app, monkeypatch): """ Test that textarea in Trigger DAG UI is pre-populated with param value None and type ["null", "array"] set in DAG. @@ -341,8 +341,8 @@ def test_trigger_dag_params_array_value_none_render(admin_client, dag_maker, ses with dag_maker(dag_id=DAG_ID, serialized=True, session=session, params={"dag_param": param}): EmptyOperator(task_id="task1") - m.setattr(app, "dag_bag", dag_maker.dagbag) - resp = admin_client.get(f"dags/{DAG_ID}/trigger") + m.setattr(app.app, "dag_bag", dag_maker.dagbag) + resp = admin_flask_client.get(f"dags/{DAG_ID}/trigger") check_content_in_response( f'', diff --git a/tests/www/views/test_views_variable.py b/tests/www/views/test_views_variable.py index fcdad2bdb0bdd..494af8a519af7 100644 --- a/tests/www/views/test_views_variable.py +++ b/tests/www/views/test_views_variable.py @@ -52,7 +52,7 @@ def clear_variables(): def user_variable_reader(app): """Create User that can only read variables""" return create_user( - app, + app.app, username="user_variable_reader", role_name="role_variable_reader", permissions=[ @@ -103,7 +103,7 @@ def test_import_variables_no_file(admin_client): check_content_in_response("Missing file or syntax error.", resp) -def test_import_variables_failed(session, admin_client): +def test_import_variables_failed(session, admin_flask_client): content = '{"str_key": "str_value"}' with mock.patch("airflow.models.Variable.set") as set_mock: @@ -112,32 +112,32 @@ def test_import_variables_failed(session, admin_client): bytes_content = BytesIO(bytes(content, encoding="utf-8")) - resp = admin_client.post( + resp = admin_flask_client.post( "/variable/varimport", data={"file": (bytes_content, "test.json")}, follow_redirects=True ) check_content_in_response("1 variable(s) failed to be updated.", resp) -def test_import_variables_success(session, admin_client): +def test_import_variables_success(session, admin_flask_client): assert session.query(Variable).count() == 0 content = '{"str_key": "str_value", "int_key": 60, "list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}' bytes_content = BytesIO(bytes(content, encoding="utf-8")) - resp = admin_client.post( + resp = admin_flask_client.post( "/variable/varimport", data={"file": (bytes_content, "test.json")}, follow_redirects=True ) check_content_in_response("4 variable(s) successfully updated.", resp) _check_last_log(session, dag_id=None, event="variables.varimport", execution_date=None) -def test_import_variables_override_existing_variables_if_set(session, admin_client, caplog): +def test_import_variables_override_existing_variables_if_set(session, admin_flask_client, caplog): assert session.query(Variable).count() == 0 Variable.set("str_key", "str_value") content = '{"str_key": "str_value", "int_key": 60}' # str_key already exists bytes_content = BytesIO(bytes(content, encoding="utf-8")) - resp = admin_client.post( + resp = admin_flask_client.post( "/variable/varimport", data={"file": (bytes_content, "test.json"), "action_if_exist": "overwrite"}, follow_redirects=True, @@ -146,13 +146,13 @@ def test_import_variables_override_existing_variables_if_set(session, admin_clie _check_last_log(session, dag_id=None, event="variables.varimport", execution_date=None) -def test_import_variables_skips_update_if_set(session, admin_client, caplog): +def test_import_variables_skips_update_if_set(session, admin_flask_client, caplog): assert session.query(Variable).count() == 0 Variable.set("str_key", "str_value") content = '{"str_key": "str_value", "int_key": 60}' # str_key already exists bytes_content = BytesIO(bytes(content, encoding="utf-8")) - resp = admin_client.post( + resp = admin_flask_client.post( "/variable/varimport", data={"file": (bytes_content, "test.json"), "action_if_exists": "skip"}, follow_redirects=True, @@ -166,13 +166,13 @@ def test_import_variables_skips_update_if_set(session, admin_client, caplog): assert "Variable: str_key already exists, skipping." in caplog.text -def test_import_variables_fails_if_action_if_exists_is_fail(session, admin_client, caplog): +def test_import_variables_fails_if_action_if_exists_is_fail(session, admin_flask_client, caplog): assert session.query(Variable).count() == 0 Variable.set("str_key", "str_value") content = '{"str_key": "str_value", "int_key": 60}' # str_key already exists bytes_content = BytesIO(bytes(content, encoding="utf-8")) - admin_client.post( + admin_flask_client.post( "/variable/varimport", data={"file": (bytes_content, "test.json"), "action_if_exists": "fail"}, follow_redirects=True, @@ -244,7 +244,7 @@ def test_action_export(admin_client, variable): assert resp.status_code == 200 assert resp.headers["Content-Type"] == "application/json; charset=utf-8" assert resp.headers["Content-Disposition"] == "attachment; filename=variables.json" - assert resp.json == {"test_key": "text_val"} + assert resp.json() == {"test_key": "text_val"} def test_action_muldelete(session, admin_client, variable):