diff --git a/docker/.env b/docker/.env index 57575da76edc1..c8b2413ae6a2c 100644 --- a/docker/.env +++ b/docker/.env @@ -59,7 +59,6 @@ MAPBOX_API_KEY='' # Make sure you set this to a unique secure random value on production SUPERSET_SECRET_KEY=TEST_NON_DEV_SECRET - ENABLE_PLAYWRIGHT=false PUPPETEER_SKIP_CHROMIUM_DOWNLOAD=true BUILD_SUPERSET_FRONTEND_IN_DOCKER=true diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 89e47a9dcb881..31611570ad811 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -1550,6 +1550,7 @@ class ImportV1ChartSchema(Schema): dataset_uuid = fields.UUID(required=True) is_managed_externally = fields.Boolean(allow_none=True, dump_default=False) external_url = fields.String(allow_none=True) + tags = fields.List(fields.String(), allow_none=True) class ChartCacheWarmUpRequestSchema(Schema): diff --git a/superset/commands/chart/export.py b/superset/commands/chart/export.py index a84dfcca147db..50b06cb30de18 100644 --- a/superset/commands/chart/export.py +++ b/superset/commands/chart/export.py @@ -26,10 +26,13 @@ from superset.daos.chart import ChartDAO from superset.commands.dataset.export import ExportDatasetsCommand from superset.commands.export.models import ExportModelsCommand +from superset.commands.tag.export import ExportTagsCommand from superset.models.slice import Slice +from superset.tags.models import TagType from superset.utils.dict_import_export import EXPORT_VERSION from superset.utils.file import get_filename from superset.utils import json +from superset.extensions import feature_flag_manager logger = logging.getLogger(__name__) @@ -71,9 +74,23 @@ def _file_content(model: Slice) -> str: if model.table: payload["dataset_uuid"] = str(model.table.uuid) + # Fetch tags from the database if TAGGING_SYSTEM is enabled + if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): + tags = getattr(model, "tags", []) + payload["tags"] = [tag.name for tag in tags if tag.type == TagType.custom] file_content = yaml.safe_dump(payload, sort_keys=False) return file_content + _include_tags: bool = True # Default to True + + @classmethod + def disable_tag_export(cls) -> None: + cls._include_tags = False + + @classmethod + def enable_tag_export(cls) -> None: + cls._include_tags = True + @staticmethod def _export( model: Slice, export_related: bool = True @@ -85,3 +102,12 @@ def _export( if model.table and export_related: yield from ExportDatasetsCommand([model.table.id]).run() + + # Check if the calling class is ExportDashboardCommands + if ( + export_related + and ExportChartsCommand._include_tags + and feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM") + ): + chart_id = model.id + yield from ExportTagsCommand().export(chart_ids=[chart_id]) diff --git a/superset/commands/chart/importers/v1/__init__.py b/superset/commands/chart/importers/v1/__init__.py index dc5a7079669ae..2a475216e41ad 100644 --- a/superset/commands/chart/importers/v1/__init__.py +++ b/superset/commands/chart/importers/v1/__init__.py @@ -14,23 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session # noqa: F401 +from superset import db from superset.charts.schemas import ImportV1ChartSchema from superset.commands.chart.exceptions import ChartImportError from superset.commands.chart.importers.v1.utils import import_chart from superset.commands.database.importers.v1.utils import import_database from superset.commands.dataset.importers.v1.utils import import_dataset from superset.commands.importers.v1 import ImportModelsCommand +from superset.commands.importers.v1.utils import import_tag from superset.commands.utils import update_chart_config_dataset from superset.connectors.sqla.models import SqlaTable from superset.daos.chart import ChartDAO from superset.databases.schemas import ImportV1DatabaseSchema from superset.datasets.schemas import ImportV1DatasetSchema +from superset.extensions import feature_flag_manager class ImportChartsCommand(ImportModelsCommand): @@ -47,7 +51,13 @@ class ImportChartsCommand(ImportModelsCommand): import_error = ChartImportError @staticmethod - def _import(configs: dict[str, Any], overwrite: bool = False) -> None: + def _import( + configs: dict[str, Any], + overwrite: bool = False, + contents: dict[str, Any] | None = None, + ) -> None: + if contents is None: + contents = {} # discover datasets associated with charts dataset_uuids: set[str] = set() for file_name, config in configs.items(): @@ -93,4 +103,12 @@ def _import(configs: dict[str, Any], overwrite: bool = False) -> None: "datasource_name": dataset.table_name, } config = update_chart_config_dataset(config, dataset_dict) - import_chart(config, overwrite=overwrite) + chart = import_chart(config, overwrite=overwrite) + + # Handle tags using import_tag function + if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): + if "tags" in config: + new_tag_names = config["tags"] + import_tag( + new_tag_names, contents, chart.id, "chart", db.session + ) diff --git a/superset/commands/dashboard/export.py b/superset/commands/dashboard/export.py index 93cc490ad73de..27dbbfe79dd72 100644 --- a/superset/commands/dashboard/export.py +++ b/superset/commands/dashboard/export.py @@ -25,6 +25,7 @@ import yaml from superset.commands.chart.export import ExportChartsCommand +from superset.commands.tag.export import ExportTagsCommand from superset.commands.dashboard.exceptions import DashboardNotFoundError from superset.commands.dashboard.importers.v1.utils import find_chart_uuids from superset.daos.dashboard import DashboardDAO @@ -33,9 +34,11 @@ from superset.daos.dataset import DatasetDAO from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.tags.models import TagType from superset.utils.dict_import_export import EXPORT_VERSION from superset.utils.file import get_filename from superset.utils import json +from superset.extensions import feature_flag_manager # Import the feature flag manager logger = logging.getLogger(__name__) @@ -159,6 +162,11 @@ def _file_content(model: Dashboard) -> str: payload["version"] = EXPORT_VERSION + # Check if the TAGGING_SYSTEM feature is enabled + if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): + tags = model.tags if hasattr(model, "tags") else [] + payload["tags"] = [tag.name for tag in tags if tag.type == TagType.custom] + file_content = yaml.safe_dump(payload, sort_keys=False) return file_content @@ -173,7 +181,14 @@ def _export( if export_related: chart_ids = [chart.id for chart in model.slices] + dashboard_ids = model.id + ExportChartsCommand.disable_tag_export() yield from ExportChartsCommand(chart_ids).run() + ExportChartsCommand.enable_tag_export() + if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): + yield from ExportTagsCommand.export( + dashboard_ids=dashboard_ids, chart_ids=chart_ids + ) payload = model.export_to_dict( recursive=False, diff --git a/superset/commands/dashboard/importers/v1/__init__.py b/superset/commands/dashboard/importers/v1/__init__.py index 18cbb7da8407c..9021ff913f4ec 100644 --- a/superset/commands/dashboard/importers/v1/__init__.py +++ b/superset/commands/dashboard/importers/v1/__init__.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import Any from marshmallow import Schema @@ -34,11 +36,13 @@ from superset.commands.database.importers.v1.utils import import_database from superset.commands.dataset.importers.v1.utils import import_dataset from superset.commands.importers.v1 import ImportModelsCommand +from superset.commands.importers.v1.utils import import_tag from superset.commands.utils import update_chart_config_dataset from superset.daos.dashboard import DashboardDAO from superset.dashboards.schemas import ImportV1DashboardSchema from superset.databases.schemas import ImportV1DatabaseSchema from superset.datasets.schemas import ImportV1DatasetSchema +from superset.extensions import feature_flag_manager from superset.migrations.shared.native_filters import migrate_dashboard from superset.models.dashboard import Dashboard, dashboard_slices @@ -58,9 +62,15 @@ class ImportDashboardsCommand(ImportModelsCommand): import_error = DashboardImportError # TODO (betodealmeida): refactor to use code from other commands - # pylint: disable=too-many-branches, too-many-locals + # pylint: disable=too-many-branches, too-many-locals, too-many-statements @staticmethod - def _import(configs: dict[str, Any], overwrite: bool = False) -> None: + def _import( + configs: dict[str, Any], + overwrite: bool = False, + contents: dict[str, Any] | None = None, + ) -> None: + if contents is None: + contents = {} # discover charts and datasets associated with dashboards chart_uuids: set[str] = set() dataset_uuids: set[str] = set() @@ -120,6 +130,14 @@ def _import(configs: dict[str, Any], overwrite: bool = False) -> None: charts.append(chart) chart_ids[str(chart.uuid)] = chart.id + # Handle tags using import_tag function + if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): + if "tags" in config: + new_tag_names = config["tags"] + import_tag( + new_tag_names, contents, chart.id, "chart", db.session + ) + # store the existing relationship between dashboards and charts existing_relationships = db.session.execute( select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id]) @@ -140,6 +158,18 @@ def _import(configs: dict[str, Any], overwrite: bool = False) -> None: if (dashboard.id, chart_id) not in existing_relationships: dashboard_chart_ids.append((dashboard.id, chart_id)) + # Handle tags using import_tag function + if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): + if "tags" in config: + new_tag_names = config["tags"] + import_tag( + new_tag_names, + contents, + dashboard.id, + "dashboard", + db.session, + ) + # set ref in the dashboard_slices table values = [ {"dashboard_id": dashboard_id, "slice_id": chart_id} diff --git a/superset/commands/database/importers/v1/__init__.py b/superset/commands/database/importers/v1/__init__.py index c8684bc5eb428..4269cac6e50b0 100644 --- a/superset/commands/database/importers/v1/__init__.py +++ b/superset/commands/database/importers/v1/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any +from typing import Any, Optional from marshmallow import Schema from sqlalchemy.orm import Session # noqa: F401 @@ -42,7 +42,11 @@ class ImportDatabasesCommand(ImportModelsCommand): import_error = DatabaseImportError @staticmethod - def _import(configs: dict[str, Any], overwrite: bool = False) -> None: + def _import( + configs: dict[str, Any], + overwrite: bool = False, + contents: Optional[dict[str, Any]] = None, + ) -> None: # first import databases database_ids: dict[str, int] = {} for file_name, config in configs.items(): diff --git a/superset/commands/dataset/importers/v1/__init__.py b/superset/commands/dataset/importers/v1/__init__.py index c7ecba122725d..5cc562e8a4079 100644 --- a/superset/commands/dataset/importers/v1/__init__.py +++ b/superset/commands/dataset/importers/v1/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any +from typing import Any, Optional from marshmallow import Schema from sqlalchemy.orm import Session # noqa: F401 @@ -42,7 +42,13 @@ class ImportDatasetsCommand(ImportModelsCommand): import_error = DatasetImportError @staticmethod - def _import(configs: dict[str, Any], overwrite: bool = False) -> None: + def _import( + configs: dict[str, Any], + overwrite: bool = False, + contents: Optional[dict[str, Any]] = None, + ) -> None: + if contents is None: + contents = {} # discover databases associated with datasets database_uuids: set[str] = set() for file_name, config in configs.items(): diff --git a/superset/commands/export/assets.py b/superset/commands/export/assets.py index ff76dab03dae5..2acd19d89204b 100644 --- a/superset/commands/export/assets.py +++ b/superset/commands/export/assets.py @@ -53,6 +53,7 @@ def run(self) -> Iterator[tuple[str, Callable[[], str]]]: ExportDashboardsCommand, ExportSavedQueriesCommand, ] + for command in commands: ids = [model.id for model in command.dao.find_all()] for file_name, file_content in command(ids, export_related=False).run(): diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py index f90708acf51f1..951851a213523 100644 --- a/superset/commands/importers/v1/__init__.py +++ b/superset/commands/importers/v1/__init__.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Optional +from __future__ import annotations + +from typing import Any from marshmallow import Schema, validate # noqa: F401 from marshmallow.exceptions import ValidationError @@ -61,7 +63,11 @@ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self._configs: dict[str, Any] = {} @staticmethod - def _import(configs: dict[str, Any], overwrite: bool = False) -> None: + def _import( + configs: dict[str, Any], + overwrite: bool = False, + contents: dict[str, Any] | None = None, + ) -> None: raise NotImplementedError("Subclasses MUST implement _import") @classmethod @@ -73,7 +79,7 @@ def run(self) -> None: self.validate() try: - self._import(self._configs, self.overwrite) + self._import(self._configs, self.overwrite, self.contents) except CommandException: raise except Exception as ex: @@ -84,7 +90,7 @@ def validate(self) -> None: # noqa: F811 # verify that the metadata file is present and valid try: - metadata: Optional[dict[str, str]] = load_metadata(self.contents) + metadata: dict[str, str] | None = load_metadata(self.contents) except ValidationError as exc: exceptions.append(exc) metadata = None diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py index bcf6b5062fb9b..5b4f46c1a65c0 100644 --- a/superset/commands/importers/v1/examples.py +++ b/superset/commands/importers/v1/examples.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any +from typing import Any, Optional from marshmallow import Schema from sqlalchemy.exc import MultipleResultsFound @@ -90,6 +90,7 @@ def _get_uuids(cls) -> set[str]: def _import( # pylint: disable=too-many-locals, too-many-branches configs: dict[str, Any], overwrite: bool = False, + contents: Optional[dict[str, Any]] = None, force_data: bool = False, ) -> None: # import databases @@ -129,7 +130,7 @@ def _import( # pylint: disable=too-many-locals, too-many-branches dataset = import_dataset( config, overwrite=overwrite, - force_data=force_data, + force_data=bool(force_data), ignore_permissions=True, ) except MultipleResultsFound: diff --git a/superset/commands/importers/v1/utils.py b/superset/commands/importers/v1/utils.py index 51ab99271c82c..9a4a6cad07541 100644 --- a/superset/commands/importers/v1/utils.py +++ b/superset/commands/importers/v1/utils.py @@ -21,11 +21,15 @@ import yaml from marshmallow import fields, Schema, validate from marshmallow.exceptions import ValidationError +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session from superset import db from superset.commands.importers.exceptions import IncorrectVersionError from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.extensions import feature_flag_manager from superset.models.core import Database +from superset.tags.models import Tag, TaggedObject from superset.utils.core import check_is_safe_zip METADATA_FILE_NAME = "metadata.yaml" @@ -214,3 +218,80 @@ def get_contents_from_bundle(bundle: ZipFile) -> dict[str, str]: for file_name in bundle.namelist() if is_valid_config(file_name) } + + +# pylint: disable=consider-using-transaction +def import_tag( + new_tag_names: list[str], + contents: dict[str, Any], + object_id: int, + object_type: str, + db_session: Session, +) -> list[int]: + """Handles the import logic for tags for charts and dashboards""" + + if not feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): + return [] + + tag_descriptions = {} + new_tag_ids = [] + if "tags.yaml" in contents: + try: + tags_config = yaml.safe_load(contents["tags.yaml"]) + except yaml.YAMLError as err: + logger.error("Error parsing tags.yaml: %s", err) + tags_config = {} + + for tag_info in tags_config.get("tags", []): + tag_name = tag_info.get("tag_name") + description = tag_info.get("description", None) + if tag_name: + tag_descriptions[tag_name] = description + existing_tags = ( + db_session.query(TaggedObject) + .filter_by(object_id=object_id, object_type=object_type) + .all() + ) + + for tag_name in new_tag_names: + try: + tag = db_session.query(Tag).filter_by(name=tag_name).first() + if tag is None: + # If tag does not exist, create it with the provided description + description = tag_descriptions.get(tag_name, None) + tag = Tag(name=tag_name, description=description, type="custom") + db_session.add(tag) + db_session.commit() + + # Ensure the association with the object + tagged_object = ( + db_session.query(TaggedObject) + .filter_by(object_id=object_id, object_type=object_type, tag_id=tag.id) + .first() + ) + if not tagged_object: + new_tagged_object = TaggedObject( + tag_id=tag.id, object_id=object_id, object_type=object_type + ) + db_session.add(new_tagged_object) + + new_tag_ids.append(tag.id) + + except SQLAlchemyError as err: # Catching specific database exceptions + logger.error( + "Error processing tag '%s' for %s ID %d: %s", + tag_name, + object_type, + object_id, + err, # Used lazy logging + ) + continue # Continue to the next tag if there's an error + + # Remove old tags not in the new config + for tag in existing_tags: + if tag.tag_id not in new_tag_ids: + db_session.delete(tag) + + db_session.commit() + + return new_tag_ids diff --git a/superset/commands/query/importers/v1/__init__.py b/superset/commands/query/importers/v1/__init__.py index 3dc25d93a194d..1f7290d6f7c6b 100644 --- a/superset/commands/query/importers/v1/__init__.py +++ b/superset/commands/query/importers/v1/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any +from typing import Any, Optional from marshmallow import Schema from sqlalchemy.orm import Session # noqa: F401 @@ -43,7 +43,11 @@ class ImportSavedQueriesCommand(ImportModelsCommand): import_error = SavedQueryImportError @staticmethod - def _import(configs: dict[str, Any], overwrite: bool = False) -> None: + def _import( + configs: dict[str, Any], + overwrite: bool = False, + contents: Optional[dict[str, Any]] = None, + ) -> None: # discover databases associated with saved queries database_uuids: set[str] = set() for file_name, config in configs.items(): diff --git a/superset/commands/tag/export.py b/superset/commands/tag/export.py new file mode 100644 index 0000000000000..3e2f041939123 --- /dev/null +++ b/superset/commands/tag/export.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file + + +from typing import Any, Callable, List, Optional, Union +from collections.abc import Iterator + +import yaml +from superset.daos.chart import ChartDAO +from superset.daos.dashboard import DashboardDAO +from superset.extensions import feature_flag_manager +from superset.tags.models import TagType +from superset.commands.tag.exceptions import TagNotFoundError + + +# pylint: disable=too-few-public-methods +class ExportTagsCommand: + not_found = TagNotFoundError + + @staticmethod + def _file_name() -> str: + # Use the model to determine the filename + return "tags.yaml" + + @staticmethod + def _merge_tags( + dashboard_tags: List[dict[str, Any]], chart_tags: List[dict[str, Any]] + ) -> List[dict[str, Any]]: + # Create a dictionary to prevent duplicates based on tag name + tags_dict = {tag["tag_name"]: tag for tag in dashboard_tags} + + # Add chart tags, preserving unique tag names + for tag in chart_tags: + if tag["tag_name"] not in tags_dict: + tags_dict[tag["tag_name"]] = tag + + # Return merged tags as a list + return list(tags_dict.values()) + + @staticmethod + def _file_content( + dashboard_ids: Optional[Union[int, List[Union[int, str]]]] = None, + chart_ids: Optional[Union[int, List[Union[int, str]]]] = None, + ) -> str: + payload: dict[str, list[dict[str, Any]]] = {"tags": []} + + dashboard_tags = [] + chart_tags = [] + + # Fetch dashboard tags if provided + if dashboard_ids: + # Ensure dashboard_ids is a list + if isinstance(dashboard_ids, int): + dashboard_ids = [ + dashboard_ids + ] # Convert single int to list for consistency + + dashboards = [ + dashboard + for dashboard in ( + DashboardDAO.find_by_id(dashboard_id) + for dashboard_id in dashboard_ids + ) + if dashboard is not None + ] + + for dashboard in dashboards: + tags = dashboard.tags if hasattr(dashboard, "tags") else [] + filtered_tags = [ + {"tag_name": tag.name, "description": tag.description} + for tag in tags + if tag.type == TagType.custom + ] + dashboard_tags.extend(filtered_tags) + + # Fetch chart tags if provided + if chart_ids: + # Ensure chart_ids is a list + if isinstance(chart_ids, int): + chart_ids = [chart_ids] # Convert single int to list for consistency + + charts = [ + chart + for chart in (ChartDAO.find_by_id(chart_id) for chart_id in chart_ids) + if chart is not None + ] + + for chart in charts: + tags = chart.tags if hasattr(chart, "tags") else [] + filtered_tags = [ + {"tag_name": tag.name, "description": tag.description} + for tag in tags + if "type:" not in tag.name and "owner:" not in tag.name + ] + chart_tags.extend(filtered_tags) + + # Merge the tags from both dashboards and charts + merged_tags = ExportTagsCommand._merge_tags(dashboard_tags, chart_tags) + payload["tags"].extend(merged_tags) + + # Convert to YAML format + file_content = yaml.safe_dump(payload, sort_keys=False) + return file_content + + @staticmethod + def export( + dashboard_ids: Optional[Union[int, List[Union[int, str]]]] = None, + chart_ids: Optional[Union[int, List[Union[int, str]]]] = None, + ) -> Iterator[tuple[str, Callable[[], str]]]: + if not feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): + yield from iter([]) + + yield ( + ExportTagsCommand._file_name(), + lambda: ExportTagsCommand._file_content(dashboard_ids, chart_ids), + ) diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py index d855e22b87725..714bacabe8fcf 100644 --- a/superset/dashboards/schemas.py +++ b/superset/dashboards/schemas.py @@ -454,6 +454,7 @@ class ImportV1DashboardSchema(Schema): certified_by = fields.String(allow_none=True) certification_details = fields.String(allow_none=True) published = fields.Boolean(allow_none=True) + tags = fields.List(fields.String(), allow_none=True) class EmbeddedDashboardConfigSchema(Schema): diff --git a/tests/unit_tests/charts/commands/importers/v1/import_test.py b/tests/unit_tests/charts/commands/importers/v1/import_test.py index 8284c8565d04b..7ca0fbfb44cb6 100644 --- a/tests/unit_tests/charts/commands/importers/v1/import_test.py +++ b/tests/unit_tests/charts/commands/importers/v1/import_test.py @@ -18,8 +18,10 @@ import copy from collections.abc import Generator +from unittest.mock import patch import pytest +import yaml from flask_appbuilder.security.sqla.models import Role, User from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session @@ -27,8 +29,11 @@ from superset import security_manager from superset.commands.chart.importers.v1.utils import import_chart from superset.commands.exceptions import ImportFailedError +from superset.commands.importers.v1.utils import import_tag from superset.connectors.sqla.models import Database, SqlaTable +from superset.extensions import feature_flag_manager from superset.models.slice import Slice +from superset.tags.models import TaggedObject from superset.utils.core import override_user from tests.integration_tests.fixtures.importexport import chart_config @@ -231,3 +236,43 @@ def test_import_existing_chart_with_permission( # Assert that the can write to chart was checked mock_can_access.assert_called_once_with("can_write", "Chart") mock_can_access_chart.assert_called_once_with(slice) + + +def test_import_tag_logic_for_charts(session_with_schema: Session): + contents = { + "tags.yaml": yaml.dump( + {"tags": [{"tag_name": "tag_1", "description": "Description for tag_1"}]} + ) + } + + object_id = 1 + object_type = "chart" + + with patch.object(feature_flag_manager, "is_feature_enabled", return_value=True): + new_tag_ids = import_tag( + ["tag_1"], contents, object_id, object_type, session_with_schema + ) + assert len(new_tag_ids) > 0 + assert ( + session_with_schema.query(TaggedObject) + .filter_by(object_id=object_id, object_type=object_type) + .count() + > 0 + ) + + session_with_schema.query(TaggedObject).filter_by( + object_id=object_id, object_type=object_type + ).delete() + session_with_schema.commit() + + with patch.object(feature_flag_manager, "is_feature_enabled", return_value=False): + new_tag_ids_disabled = import_tag( + ["tag_1"], contents, object_id, object_type, session_with_schema + ) + assert len(new_tag_ids_disabled) == 0 + associated_tags = ( + session_with_schema.query(TaggedObject) + .filter_by(object_id=object_id, object_type=object_type) + .all() + ) + assert len(associated_tags) == 0 diff --git a/tests/unit_tests/commands/export_test.py b/tests/unit_tests/commands/export_test.py index 4ce3545738bd0..74981dcda82e0 100644 --- a/tests/unit_tests/commands/export_test.py +++ b/tests/unit_tests/commands/export_test.py @@ -16,9 +16,15 @@ # under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel +from unittest.mock import patch + +import pytest +import yaml from freezegun import freeze_time from pytest_mock import MockerFixture +from superset.extensions import feature_flag_manager + def test_export_assets_command(mocker: MockerFixture) -> None: """ @@ -80,7 +86,6 @@ def test_export_assets_command(mocker: MockerFixture) -> None: with freeze_time("2022-01-01T00:00:00Z"): command = ExportAssetsCommand() output = [(file[0], file[1]()) for file in list(command.run())] - assert output == [ ( "metadata.yaml", @@ -92,3 +97,61 @@ def test_export_assets_command(mocker: MockerFixture) -> None: ("dashboards/sales.yaml", ""), ("queries/example/metric.yaml", ""), ] + + +@pytest.fixture +def mock_export_tags_command_charts_dashboards(mocker): + ExportTagsCommand = mocker.patch("superset.commands.tag.export.ExportTagsCommand") + + def _mock_export(dashboard_ids=None, chart_ids=None): + if not feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): + return iter([]) + return [ + ( + "tags.yaml", + lambda: yaml.dump( + { + "tags": [ + { + "tag_name": "tag_1", + "description": "Description for tag_1", + } + ] + }, + sort_keys=False, + ), + ), + ("charts/pie.yaml", lambda: "tag:\n- tag_1"), + ] + + ExportTagsCommand.return_value._export.side_effect = _mock_export + return ExportTagsCommand + + +def test_export_tags_with_charts_dashboards( + mock_export_tags_command_charts_dashboards, mocker +): + with patch.object(feature_flag_manager, "is_feature_enabled", return_value=True): + command = mock_export_tags_command_charts_dashboards() + result = list(command._export(chart_ids=[1])) + + file_name, file_content_func = result[0] + file_content = file_content_func() + assert file_name == "tags.yaml" + payload = yaml.safe_load(file_content) + assert payload["tags"] == [ + {"tag_name": "tag_1", "description": "Description for tag_1"} + ] + + file_name, file_content_func = result[1] + file_content = file_content_func() + assert file_name == "charts/pie.yaml" + assert file_content == "tag:\n- tag_1" + + with patch.object(feature_flag_manager, "is_feature_enabled", return_value=False): + command = mock_export_tags_command_charts_dashboards() + result = list(command._export(chart_ids=[1])) + assert not any(file_name == "tags.yaml" for file_name, _ in result) + assert all( + file_content_func() != "tag:\n- tag_1" for _, file_content_func in result + ) diff --git a/tests/unit_tests/dashboards/commands/importers/v1/import_test.py b/tests/unit_tests/dashboards/commands/importers/v1/import_test.py index c311f1b3906c0..d365ae72b28b5 100644 --- a/tests/unit_tests/dashboards/commands/importers/v1/import_test.py +++ b/tests/unit_tests/dashboards/commands/importers/v1/import_test.py @@ -18,8 +18,10 @@ import copy from collections.abc import Generator +from unittest.mock import patch import pytest +import yaml from flask_appbuilder.security.sqla.models import Role, User from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session @@ -27,7 +29,10 @@ from superset import security_manager from superset.commands.dashboard.importers.v1.utils import import_dashboard from superset.commands.exceptions import ImportFailedError +from superset.commands.importers.v1.utils import import_tag +from superset.extensions import feature_flag_manager from superset.models.dashboard import Dashboard +from superset.tags.models import TaggedObject from superset.utils.core import override_user from tests.integration_tests.fixtures.importexport import dashboard_config @@ -189,3 +194,43 @@ def test_import_existing_dashboard_with_permission( # Assert that the can write to dashboard was checked mock_can_access.assert_called_once_with("can_write", "Dashboard") mock_can_access_dashboard.assert_called_once_with(dashboard) + + +def test_import_tag_logic_for_dashboards(session_with_schema: Session): + contents = { + "tags.yaml": yaml.dump( + {"tags": [{"tag_name": "tag_1", "description": "Description for tag_1"}]} + ) + } + + object_id = 1 + object_type = "dashboards" + + with patch.object(feature_flag_manager, "is_feature_enabled", return_value=True): + new_tag_ids = import_tag( + ["tag_1"], contents, object_id, object_type, session_with_schema + ) + assert len(new_tag_ids) > 0 + assert ( + session_with_schema.query(TaggedObject) + .filter_by(object_id=object_id, object_type=object_type) + .count() + > 0 + ) + + session_with_schema.query(TaggedObject).filter_by( + object_id=object_id, object_type=object_type + ).delete() + session_with_schema.commit() + + with patch.object(feature_flag_manager, "is_feature_enabled", return_value=False): + new_tag_ids_disabled = import_tag( + ["tag_1"], contents, object_id, object_type, session_with_schema + ) + assert len(new_tag_ids_disabled) == 0 + associated_tags = ( + session_with_schema.query(TaggedObject) + .filter_by(object_id=object_id, object_type=object_type) + .all() + ) + assert len(associated_tags) == 0