From 3f2129b8952ad6c9ff391bae5eaea960326c3059 Mon Sep 17 00:00:00 2001 From: ofekisr <35701650+ofekisr@users.noreply.github.com> Date: Wed, 17 Nov 2021 17:13:40 +0200 Subject: [PATCH] refactor: chartDataCommand - remove the responsibly of creating query context from command (#17461) --- superset/charts/data/api.py | 26 ++++++++++++++++++++------ superset/charts/data/commands.py | 14 +++----------- superset/tasks/async_queries.py | 21 ++++++++++++++++++--- 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 534101bae6be1..cf7d95acd08f2 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -38,6 +38,7 @@ ) from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader from superset.charts.post_processing import apply_post_process +from superset.charts.schemas import ChartDataQueryContextSchema from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.exceptions import QueryObjectValidationError from superset.extensions import event_logger @@ -49,6 +50,8 @@ if TYPE_CHECKING: from flask import Response + from superset.common.query_context import QueryContext + logger = logging.getLogger(__name__) @@ -130,8 +133,8 @@ def get_data(self, pk: int) -> Response: json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL) try: - command = ChartDataCommand() - query_context = command.set_query_context(json_body) + query_context = self._create_query_context_from_form(json_body) + command = ChartDataCommand(query_context) command.validate() except QueryObjectValidationError as error: return self.response_400(message=error.message) @@ -216,8 +219,8 @@ def data(self) -> Response: return self.response_400(message=_("Request is not JSON")) try: - command = ChartDataCommand() - query_context = command.set_query_context(json_body) + query_context = self._create_query_context_from_form(json_body) + command = ChartDataCommand(query_context) command.validate() except QueryObjectValidationError as error: return self.response_400(message=error.message) @@ -278,10 +281,10 @@ def data_from_cache(self, cache_key: str) -> Response: 500: $ref: '#/components/responses/500' """ - command = ChartDataCommand() try: cached_data = self._load_query_context_form_from_cache(cache_key) - command.set_query_context(cached_data) + query_context = self._create_query_context_from_form(cached_data) + command = ChartDataCommand(query_context) command.validate() except ChartDataCacheLoadError: return self.response_404() @@ -374,3 +377,14 @@ def _get_data_response( # pylint: disable=invalid-name, no-self-use def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]: return QueryContextCacheLoader.load(cache_key) + + # pylint: disable=no-self-use + def _create_query_context_from_form( + self, form_data: Dict[str, Any] + ) -> QueryContext: + try: + return ChartDataQueryContextSchema().load(form_data) + except KeyError as ex: + raise ValidationError("Request is incorrect") from ex + except ValidationError as error: + raise error diff --git a/superset/charts/data/commands.py b/superset/charts/data/commands.py index d434f79a17101..3fc02e260f350 100644 --- a/superset/charts/data/commands.py +++ b/superset/charts/data/commands.py @@ -18,13 +18,11 @@ from typing import Any, Dict, Optional from flask import Request -from marshmallow import ValidationError from superset.charts.commands.exceptions import ( ChartDataCacheLoadError, ChartDataQueryFailedError, ) -from superset.charts.schemas import ChartDataQueryContextSchema from superset.commands.base import BaseCommand from superset.common.query_context import QueryContext from superset.exceptions import CacheLoadError @@ -37,6 +35,9 @@ class ChartDataCommand(BaseCommand): _query_context: QueryContext + def __init__(self, query_context: QueryContext): + self._query_context = query_context + def run(self, **kwargs: Any) -> Dict[str, Any]: # caching is handled in query_context.get_df_payload # (also evals `force` property) @@ -63,15 +64,6 @@ def run(self, **kwargs: Any) -> Dict[str, Any]: return return_value - def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext: - try: - self._query_context = ChartDataQueryContextSchema().load(form_data) - except KeyError as ex: - raise ValidationError("Request is incorrect") from ex - except ValidationError as error: - raise error - return self._query_context - def validate(self) -> None: self._query_context.raise_for_access() diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index c50dbb9a94436..e916028b12a86 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -14,14 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import copy import logging -from typing import Any, cast, Dict, Optional +from typing import Any, cast, Dict, Optional, TYPE_CHECKING from celery.exceptions import SoftTimeLimitExceeded from flask import current_app, g +from marshmallow import ValidationError +from superset.charts.schemas import ChartDataQueryContextSchema from superset.exceptions import SupersetVizException from superset.extensions import ( async_query_manager, @@ -32,6 +35,9 @@ from superset.utils.cache import generate_cache_key, set_and_log_cache from superset.views.utils import get_datasource_info, get_viz +if TYPE_CHECKING: + from superset.common.query_context import QueryContext + logger = logging.getLogger(__name__) query_timeout = current_app.config[ "SQLLAB_ASYNC_TIME_LIMIT_SEC" @@ -50,6 +56,15 @@ def set_form_data(form_data: Dict[str, Any]) -> None: g.form_data = form_data +def _create_query_context_from_form(form_data: Dict[str, Any]) -> QueryContext: + try: + return ChartDataQueryContextSchema().load(form_data) + except KeyError as ex: + raise ValidationError("Request is incorrect") from ex + except ValidationError as error: + raise error + + @celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout) def load_chart_data_into_cache( job_metadata: Dict[str, Any], form_data: Dict[str, Any], @@ -60,8 +75,8 @@ def load_chart_data_into_cache( try: ensure_user_is_set(job_metadata.get("user_id")) set_form_data(form_data) - command = ChartDataCommand() - command.set_query_context(form_data) + query_context = _create_query_context_from_form(form_data) + command = ChartDataCommand(query_context) result = command.run(cache=True) cache_key = result["cache_key"] result_url = f"/api/v1/chart/data/{cache_key}"