From 28944f57e120c0ecb05f16617f7ad835b64bd043 Mon Sep 17 00:00:00 2001 From: ofekisr <35701650+ofekisr@users.noreply.github.com> Date: Thu, 11 Nov 2021 15:34:23 +0200 Subject: [PATCH] refactor move chart_data_apis from ChartRestApi ChartDataRestApi in seperate module (#17400) --- superset/charts/api.py | 291 +-------------------- superset/charts/data/__init__.py | 16 ++ superset/charts/data/api.py | 365 +++++++++++++++++++++++++++ superset/initialization/__init__.py | 2 + superset/reports/commands/execute.py | 2 +- 5 files changed, 385 insertions(+), 291 deletions(-) create mode 100644 superset/charts/data/__init__.py create mode 100644 superset/charts/data/api.py diff --git a/superset/charts/api.py b/superset/charts/api.py index e94ad7c326e40..f44a9016153b7 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-lines import json import logging from datetime import datetime @@ -27,7 +26,7 @@ from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.hooks import before_request from flask_appbuilder.models.sqla.interface import SQLAInterface -from flask_babel import gettext as _, ngettext +from flask_babel import ngettext from marshmallow import ValidationError from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wsgi import FileWrapper @@ -35,13 +34,10 @@ from superset import is_feature_enabled, thumbnail_cache from superset.charts.commands.bulk_delete import BulkDeleteChartCommand from superset.charts.commands.create import CreateChartCommand -from superset.charts.commands.data import ChartDataCommand from superset.charts.commands.delete import DeleteChartCommand from superset.charts.commands.exceptions import ( ChartBulkDeleteFailedError, ChartCreateFailedError, - ChartDataCacheLoadError, - ChartDataQueryFailedError, ChartDeleteFailedError, ChartForbiddenError, ChartInvalidError, @@ -69,11 +65,9 @@ from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod -from superset.exceptions import QueryObjectValidationError from superset.extensions import event_logger, security_manager from superset.models.slice import Slice from superset.tasks.thumbnails import cache_chart_thumbnail -from superset.utils.async_query_manager import AsyncQueryTokenException from superset.utils.core import json_int_dttm_ser from superset.utils.screenshots import ChartScreenshot from superset.utils.urls import get_url_path @@ -105,9 +99,6 @@ def ensure_thumbnails_enabled(self) -> Optional[Response]: RouteMethod.IMPORT, RouteMethod.RELATED, "bulk_delete", # not using RouteMethod since locally defined - "data", - "get_data", - "data_from_cache", "viz_types", "favorite_status", "thumbnail", @@ -525,286 +516,6 @@ def send_chart_response( return self.response_400(message=f"Unsupported result_format: {result_format}") - def get_data_response( - self, - command: ChartDataCommand, - force_cached: bool = False, - form_data: Optional[Dict[str, Any]] = None, - ) -> Response: - try: - result = command.run(force_cached=force_cached) - except ChartDataCacheLoadError as exc: - return self.response_422(message=exc.message) - except ChartDataQueryFailedError as exc: - return self.response_400(message=exc.message) - - return self.send_chart_response(result, form_data) - - @expose("//data/", methods=["GET"]) - @protect() - @statsd_metrics - @event_logger.log_this_with_context( - action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", - log_to_statsd=False, - ) - def get_data(self, pk: int) -> Response: - """ - Takes a chart ID and uses the query context stored when the chart was saved - to return payload data response. - --- - get: - description: >- - Takes a chart ID and uses the query context stored when the chart was saved - to return payload data response. - parameters: - - in: path - schema: - type: integer - name: pk - description: The chart ID - - in: query - name: format - description: The format in which the data should be returned - schema: - type: string - - in: query - name: type - description: The type in which the data should be returned - schema: - type: string - responses: - 200: - description: Query result - content: - application/json: - schema: - $ref: "#/components/schemas/ChartDataResponseSchema" - 202: - description: Async job details - content: - application/json: - schema: - $ref: "#/components/schemas/ChartDataAsyncResponseSchema" - 400: - $ref: '#/components/responses/400' - 401: - $ref: '#/components/responses/401' - 500: - $ref: '#/components/responses/500' - """ - chart = self.datamodel.get(pk, self._base_filters) - if not chart: - return self.response_404() - - try: - json_body = json.loads(chart.query_context) - except (TypeError, json.decoder.JSONDecodeError): - json_body = None - - if json_body is None: - return self.response_400( - message=_( - "Chart has no query context saved. Please save the chart again." - ) - ) - - # override saved query context - json_body["result_format"] = request.args.get( - "format", ChartDataResultFormat.JSON - ) - json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL) - - try: - command = ChartDataCommand() - query_context = command.set_query_context(json_body) - command.validate() - except QueryObjectValidationError as error: - return self.response_400(message=error.message) - except ValidationError as error: - return self.response_400( - message=_( - "Request is incorrect: %(error)s", error=error.normalized_messages() - ) - ) - - # TODO: support CSV, SQL query and other non-JSON types - if ( - is_feature_enabled("GLOBAL_ASYNC_QUERIES") - and query_context.result_format == ChartDataResultFormat.JSON - and query_context.result_type == ChartDataResultType.FULL - ): - return self._run_async(command) - - try: - form_data = json.loads(chart.params) - except (TypeError, json.decoder.JSONDecodeError): - form_data = {} - - return self.get_data_response(command, form_data=form_data) - - @expose("/data", methods=["POST"]) - @protect() - @statsd_metrics - @event_logger.log_this_with_context( - action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", - log_to_statsd=False, - ) - def data(self) -> Response: - """ - Takes a query context constructed in the client and returns payload - data response for the given query. - --- - post: - description: >- - Takes a query context constructed in the client and returns payload data - response for the given query. - requestBody: - description: >- - A query context consists of a datasource from which to fetch data - and one or many query objects. - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/ChartDataQueryContextSchema" - responses: - 200: - description: Query result - content: - application/json: - schema: - $ref: "#/components/schemas/ChartDataResponseSchema" - 202: - description: Async job details - content: - application/json: - schema: - $ref: "#/components/schemas/ChartDataAsyncResponseSchema" - 400: - $ref: '#/components/responses/400' - 401: - $ref: '#/components/responses/401' - 500: - $ref: '#/components/responses/500' - """ - json_body = None - if request.is_json: - json_body = request.json - elif request.form.get("form_data"): - # CSV export submits regular form data - try: - json_body = json.loads(request.form["form_data"]) - except (TypeError, json.JSONDecodeError): - pass - - if json_body is None: - return self.response_400(message=_("Request is not JSON")) - - try: - command = ChartDataCommand() - query_context = command.set_query_context(json_body) - command.validate() - except QueryObjectValidationError as error: - return self.response_400(message=error.message) - except ValidationError as error: - return self.response_400( - message=_( - "Request is incorrect: %(error)s", error=error.normalized_messages() - ) - ) - - # TODO: support CSV, SQL query and other non-JSON types - if ( - is_feature_enabled("GLOBAL_ASYNC_QUERIES") - and query_context.result_format == ChartDataResultFormat.JSON - and query_context.result_type == ChartDataResultType.FULL - ): - return self._run_async(command) - - return self.get_data_response(command) - - def _run_async(self, command: ChartDataCommand) -> Response: - """ - Execute command as an async query. - """ - # First, look for the chart query results in the cache. - try: - result = command.run(force_cached=True) - except ChartDataCacheLoadError: - result = None # type: ignore - - already_cached_result = result is not None - - # If the chart query has already been cached, return it immediately. - if already_cached_result: - return self.send_chart_response(result) - - # Otherwise, kick off a background job to run the chart query. - # Clients will either poll or be notified of query completion, - # at which point they will call the /data/ endpoint - # to retrieve the results. - try: - command.validate_async_request(request) - except AsyncQueryTokenException: - return self.response_401() - - result = command.run_async(g.user.get_id()) - return self.response(202, **result) - - @expose("/data/", methods=["GET"]) - @protect() - @statsd_metrics - @event_logger.log_this_with_context( - action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" - f".data_from_cache", - log_to_statsd=False, - ) - def data_from_cache(self, cache_key: str) -> Response: - """ - Takes a query context cache key and returns payload - data response for the given query. - --- - get: - description: >- - Takes a query context cache key and returns payload data - response for the given query. - parameters: - - in: path - schema: - type: string - name: cache_key - responses: - 200: - description: Query result - content: - application/json: - schema: - $ref: "#/components/schemas/ChartDataResponseSchema" - 400: - $ref: '#/components/responses/400' - 401: - $ref: '#/components/responses/401' - 404: - $ref: '#/components/responses/404' - 422: - $ref: '#/components/responses/422' - 500: - $ref: '#/components/responses/500' - """ - command = ChartDataCommand() - try: - cached_data = command.load_query_context_from_cache(cache_key) - command.set_query_context(cached_data) - command.validate() - except ChartDataCacheLoadError: - return self.response_404() - except ValidationError as error: - return self.response_400( - message=_("Request is incorrect: %(error)s", error=error.messages) - ) - - return self.get_data_response(command, True) - @expose("//cache_screenshot/", methods=["GET"]) @protect() @rison(screenshot_query_schema) diff --git a/superset/charts/data/__init__.py b/superset/charts/data/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/superset/charts/data/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py new file mode 100644 index 0000000000000..c68760edbd5a2 --- /dev/null +++ b/superset/charts/data/api.py @@ -0,0 +1,365 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, Optional, TYPE_CHECKING + +import simplejson +from flask import g, make_response, request +from flask_appbuilder.api import expose, protect +from flask_babel import gettext as _ +from marshmallow import ValidationError + +from superset import is_feature_enabled, security_manager +from superset.charts.api import ChartRestApi +from superset.charts.commands.data import ChartDataCommand +from superset.charts.commands.exceptions import ( + ChartDataCacheLoadError, + ChartDataQueryFailedError, +) +from superset.charts.post_processing import apply_post_process +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType +from superset.exceptions import QueryObjectValidationError +from superset.extensions import event_logger +from superset.utils.async_query_manager import AsyncQueryTokenException +from superset.utils.core import json_int_dttm_ser +from superset.views.base import CsvResponse, generate_download_headers +from superset.views.base_api import statsd_metrics + +if TYPE_CHECKING: + from flask import Response + +logger = logging.getLogger(__name__) + + +class ChartDataRestApi(ChartRestApi): + include_route_methods = {"get_data", "data", "data_from_cache"} + + @expose("//data/", methods=["GET"]) + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", + log_to_statsd=False, + ) + def get_data(self, pk: int) -> Response: + """ + Takes a chart ID and uses the query context stored when the chart was saved + to return payload data response. + --- + get: + description: >- + Takes a chart ID and uses the query context stored when the chart was saved + to return payload data response. + parameters: + - in: path + schema: + type: integer + name: pk + description: The chart ID + - in: query + name: format + description: The format in which the data should be returned + schema: + type: string + - in: query + name: type + description: The type in which the data should be returned + schema: + type: string + responses: + 200: + description: Query result + content: + application/json: + schema: + $ref: "#/components/schemas/ChartDataResponseSchema" + 202: + description: Async job details + content: + application/json: + schema: + $ref: "#/components/schemas/ChartDataAsyncResponseSchema" + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 500: + $ref: '#/components/responses/500' + """ + chart = self.datamodel.get(pk, self._base_filters) + if not chart: + return self.response_404() + + try: + json_body = json.loads(chart.query_context) + except (TypeError, json.decoder.JSONDecodeError): + json_body = None + + if json_body is None: + return self.response_400( + message=_( + "Chart has no query context saved. Please save the chart again." + ) + ) + + # override saved query context + json_body["result_format"] = request.args.get( + "format", ChartDataResultFormat.JSON + ) + json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL) + + try: + command = ChartDataCommand() + query_context = command.set_query_context(json_body) + command.validate() + except QueryObjectValidationError as error: + return self.response_400(message=error.message) + except ValidationError as error: + return self.response_400( + message=_( + "Request is incorrect: %(error)s", error=error.normalized_messages() + ) + ) + + # TODO: support CSV, SQL query and other non-JSON types + if ( + is_feature_enabled("GLOBAL_ASYNC_QUERIES") + and query_context.result_format == ChartDataResultFormat.JSON + and query_context.result_type == ChartDataResultType.FULL + ): + return self._run_async(command) + + try: + form_data = json.loads(chart.params) + except (TypeError, json.decoder.JSONDecodeError): + form_data = {} + + return self.get_data_response(command, form_data=form_data) + + @expose("/data", methods=["POST"]) + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", + log_to_statsd=False, + ) + def data(self) -> Response: + """ + Takes a query context constructed in the client and returns payload + data response for the given query. + --- + post: + description: >- + Takes a query context constructed in the client and returns payload data + response for the given query. + requestBody: + description: >- + A query context consists of a datasource from which to fetch data + and one or many query objects. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ChartDataQueryContextSchema" + responses: + 200: + description: Query result + content: + application/json: + schema: + $ref: "#/components/schemas/ChartDataResponseSchema" + 202: + description: Async job details + content: + application/json: + schema: + $ref: "#/components/schemas/ChartDataAsyncResponseSchema" + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 500: + $ref: '#/components/responses/500' + """ + json_body = None + if request.is_json: + json_body = request.json + elif request.form.get("form_data"): + # CSV export submits regular form data + try: + json_body = json.loads(request.form["form_data"]) + except (TypeError, json.JSONDecodeError): + pass + + if json_body is None: + return self.response_400(message=_("Request is not JSON")) + + try: + command = ChartDataCommand() + query_context = command.set_query_context(json_body) + command.validate() + except QueryObjectValidationError as error: + return self.response_400(message=error.message) + except ValidationError as error: + return self.response_400( + message=_( + "Request is incorrect: %(error)s", error=error.normalized_messages() + ) + ) + + # TODO: support CSV, SQL query and other non-JSON types + if ( + is_feature_enabled("GLOBAL_ASYNC_QUERIES") + and query_context.result_format == ChartDataResultFormat.JSON + and query_context.result_type == ChartDataResultType.FULL + ): + return self._run_async(command) + + return self.get_data_response(command) + + @expose("/data/", methods=["GET"]) + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".data_from_cache", + log_to_statsd=False, + ) + def data_from_cache(self, cache_key: str) -> Response: + """ + Takes a query context cache key and returns payload + data response for the given query. + --- + get: + description: >- + Takes a query context cache key and returns payload data + response for the given query. + parameters: + - in: path + schema: + type: string + name: cache_key + responses: + 200: + description: Query result + content: + application/json: + schema: + $ref: "#/components/schemas/ChartDataResponseSchema" + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + command = ChartDataCommand() + try: + cached_data = command.load_query_context_from_cache(cache_key) + command.set_query_context(cached_data) + command.validate() + except ChartDataCacheLoadError: + return self.response_404() + except ValidationError as error: + return self.response_400( + message=_("Request is incorrect: %(error)s", error=error.messages) + ) + + return self.get_data_response(command, True) + + def _run_async(self, command: ChartDataCommand) -> Response: + """ + Execute command as an async query. + """ + # First, look for the chart query results in the cache. + try: + result = command.run(force_cached=True) + except ChartDataCacheLoadError: + result = None # type: ignore + + already_cached_result = result is not None + + # If the chart query has already been cached, return it immediately. + if already_cached_result: + return self.send_chart_response(result) + + # Otherwise, kick off a background job to run the chart query. + # Clients will either poll or be notified of query completion, + # at which point they will call the /data/ endpoint + # to retrieve the results. + try: + command.validate_async_request(request) + except AsyncQueryTokenException: + return self.response_401() + + result = command.run_async(g.user.get_id()) + return self.response(202, **result) + + def send_chart_response( + self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None, + ) -> Response: + result_type = result["query_context"].result_type + result_format = result["query_context"].result_format + + # Post-process the data so it matches the data presented in the chart. + # This is needed for sending reports based on text charts that do the + # post-processing of data, eg, the pivot table. + if result_type == ChartDataResultType.POST_PROCESSED: + result = apply_post_process(result, form_data) + + if result_format == ChartDataResultFormat.CSV: + # Verify user has permission to export CSV file + if not security_manager.can_access("can_csv", "Superset"): + return self.response_403() + + # return the first result + data = result["queries"][0]["data"] + return CsvResponse(data, headers=generate_download_headers("csv")) + + if result_format == ChartDataResultFormat.JSON: + response_data = simplejson.dumps( + {"result": result["queries"]}, + default=json_int_dttm_ser, + ignore_nan=True, + ) + resp = make_response(response_data, 200) + resp.headers["Content-Type"] = "application/json; charset=utf-8" + return resp + + return self.response_400(message=f"Unsupported result_format: {result_format}") + + def get_data_response( + self, + command: ChartDataCommand, + force_cached: bool = False, + form_data: Optional[Dict[str, Any]] = None, + ) -> Response: + try: + result = command.run(force_cached=force_cached) + except ChartDataCacheLoadError as exc: + return self.response_422(message=exc.message) + except ChartDataQueryFailedError as exc: + return self.response_400(message=exc.message) + + return self.send_chart_response(result, form_data) diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index d2ecf7f7674c3..56e8b1bec0872 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -117,6 +117,7 @@ def init_views(self) -> None: from superset.async_events.api import AsyncEventsRestApi from superset.cachekeys.api import CacheRestApi from superset.charts.api import ChartRestApi + from superset.charts.data.api import ChartDataRestApi from superset.connectors.druid.views import ( Druid, DruidClusterModelView, @@ -199,6 +200,7 @@ def init_views(self) -> None: appbuilder.add_api(AsyncEventsRestApi) appbuilder.add_api(CacheRestApi) appbuilder.add_api(ChartRestApi) + appbuilder.add_api(ChartDataRestApi) appbuilder.add_api(CssTemplateRestApi) appbuilder.add_api(DashboardRestApi) appbuilder.add_api(DatabaseRestApi) diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index 985cdbf9cf4dd..297316b8a7512 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -151,7 +151,7 @@ def _get_url( ChartDataResultFormat.JSON, }: return get_url_path( - "ChartRestApi.get_data", + "ChartDataRestApi.get_data", pk=self._report_schedule.chart_id, format=result_format.value, type=ChartDataResultType.POST_PROCESSED.value,