diff --git a/dbm-ui/backend/ticket/builders/common/constants.py b/dbm-ui/backend/ticket/builders/common/constants.py index 045263b37b..92c59e1a6b 100644 --- a/dbm-ui/backend/ticket/builders/common/constants.py +++ b/dbm-ui/backend/ticket/builders/common/constants.py @@ -149,3 +149,12 @@ class MongoDBScriptImportMode(str, StructuredEnum): FILE = EnumField("file", _("文件上传")) MANUAL = EnumField("manual", _("手动输入")) + + +class FlashbackBuildType(str, StructuredEnum): + """ + 闪回方式 + """ + + TABLE_FLASHBACK = EnumField("TABLE_FLASHBACK", _("库表闪回")) + RECORD_FLASHBACK = EnumField("RECORD_FLASHBACK", _("记录级闪回")) diff --git a/dbm-ui/backend/ticket/builders/mysql/mysql_flashback.py b/dbm-ui/backend/ticket/builders/mysql/mysql_flashback.py index 3be7be477f..f6a484f730 100644 --- a/dbm-ui/backend/ticket/builders/mysql/mysql_flashback.py +++ b/dbm-ui/backend/ticket/builders/mysql/mysql_flashback.py @@ -8,16 +8,24 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ +import csv import datetime +from collections import defaultdict +from io import StringIO +from typing import Dict from django.utils import timezone from django.utils.translation import gettext_lazy as _ from rest_framework import serializers +from backend.components import DRSApi +from backend.db_meta.api.cluster.base.handler import ClusterHandler +from backend.db_meta.models import Cluster from backend.db_services.mysql.remote_service.handlers import RemoteServiceHandler +from backend.db_services.partition.constants import QUERY_DATABASE_FIELD_TYPE from backend.flow.engine.controller.mysql import MySQLController from backend.ticket import builders -from backend.ticket.builders.common.constants import MYSQL_BINLOG_ROLLBACK +from backend.ticket.builders.common.constants import MYSQL_BINLOG_ROLLBACK, FlashbackBuildType from backend.ticket.builders.common.field import DBTimezoneField from backend.ticket.builders.mysql.base import ( BaseMySQLHATicketFlowBuilder, @@ -41,9 +49,12 @@ class FlashbackSerializer(serializers.Serializer): help_text=_("flashback工具地址"), default=MYSQL_BINLOG_ROLLBACK, required=False ) recored_file = serializers.CharField(help_text=_("记录文件"), required=False, default="") + rows_filter = serializers.CharField(help_text=_("待闪回记录"), required=False, default="") + direct_write_back = serializers.BooleanField(help_text=_("是否覆盖原始数据"), required=False, default=False) infos = serializers.ListSerializer(help_text=_("flashback信息"), child=FlashbackSerializer(), allow_empty=False) force = serializers.BooleanField(help_text=_("是否强制执行"), required=False, default=False) + flashback_type = serializers.ChoiceField(help_text=_("闪回方式"), choices=FlashbackBuildType.get_choices()) def validate_flash_time(self, attrs): # 校验start time和end time的合法性 @@ -56,6 +67,77 @@ def validate_flash_time(self, attrs): _("flash的起止时间{}--{}不合法,请保证开始时间小于结束时间,并且二者不大于当前时间").format(start_time, end_time) ) + def validate_rows_filter(self, attrs): + if attrs["flashback_type"] not in FlashbackBuildType.get_values(): + raise serializers.ValidationError(_("不支持的闪回方式{}").format(attrs["flashback_type"])) + # 校验待闪回记录信息 + if attrs["flashback_type"] != FlashbackBuildType.RECORD_FLASHBACK: + for info in attrs["infos"]: + if info["rows_filter"]: + raise serializers.ValidationError(_("库表闪回不支持rows_filter参数")) + return attrs + + for info in attrs["infos"]: + if not info["rows_filter"]: + raise serializers.ValidationError(_("记录级闪回缺少rows_filter参数")) + info["rows_filter"] = info["rows_filter"].replace(" ", "") + try: + # 使用 StringIO 将字符串转换为类似文件的对象 + csv_file = StringIO(info["rows_filter"]) + csv_reader = csv.reader(csv_file) + # 获取头部并计算列数 + headers = next(csv_reader, None) + if headers is None: + raise serializers.ValidationError(_("CSV file is empty")) + # 校验字段名 + self.validate_table_fields(info, headers) + expected_column_count = len(headers) + # 验证每一行的数据长度是否与头部长度一致 + for row_number, row in enumerate(csv_reader, start=2): + if len(row) != expected_column_count: + raise serializers.ValidationError( + _("字段个数 {} 与数据列数不匹配 {}.").format(row_number, expected_column_count) + ) + + except csv.Error: + raise serializers.ValidationError(_("输入内容不符合csv格式")) + + def validate_table_fields(self, info, input_fild_names): + bk_biz_id = self.context["bk_biz_id"] + cluster_id = info["cluster_id"] + + # 获取集群的DRS查询地址,格式化库表过滤条件 + cluster = Cluster.objects.get(id=cluster_id) + address = ClusterHandler.get_exact_handler(bk_biz_id=bk_biz_id, cluster_id=cluster_id).get_remote_address() + + table_sts = "(" + " or ".join([f"table_name = '{table}'" for table in info["tables"]]) + ")" + db_sts = "(" + " or ".join([f"table_schema like '{db}'" for db in info["databases"]]) + ")" + fields_type_sql = QUERY_DATABASE_FIELD_TYPE.format(table_sts=table_sts, db_sts=db_sts) + + # 查询涉及的所有库表索引信息和字段类型信息 + rpc_results = DRSApi.rpc( + {"bk_cloud_id": cluster.bk_cloud_id, "addresses": [address], "cmds": [fields_type_sql]} + ) + if rpc_results[0]["cmd_results"] is None: + raise serializers.ValidationError(_("字段信息查询错误:{}").format(rpc_results[0]["error_msg"])) + + db_table_fields: Dict[str, Dict[str, list]] = defaultdict(lambda: defaultdict(list)) + for table_data in rpc_results[0]["cmd_results"][0]["table_data"]: + db_table_fields[table_data["table_schema"]][table_data["table_name"]].append(table_data["column_name"]) + + for db_name, table_info in db_table_fields.items(): + for table_name, fild_names in table_info.items(): + no_file_name = set(input_fild_names).difference(set(fild_names)) + if no_file_name: + raise serializers.ValidationError( + _("数据库【{}】表【{}】中不存在字段{}".format(db_name, table_name, no_file_name)) + ) + + def check_flashback_database_result(self, attrs): + for info in attrs["infos"]: + if info.get("message"): + raise serializers.ValidationError(_(info["message"])) + def validate(self, attrs): # 校验闪回的时间 self.validate_flash_time(attrs) @@ -63,6 +145,10 @@ def validate(self, attrs): super(MySQLFlashbackDetailSerializer, self).validate_cluster_can_access(attrs) # 校验库表是否存在 RemoteServiceHandler(bk_biz_id=self.context["bk_biz_id"]).check_flashback_database(attrs["infos"]) + # 库表校验结果判断 + self.check_flashback_database_result(attrs) + # 校验待闪回记录格式与字段是否存在 + self.validate_rows_filter(attrs) return attrs diff --git a/dbm-ui/backend/ticket/builders/tendbcluster/tendb_flashback.py b/dbm-ui/backend/ticket/builders/tendbcluster/tendb_flashback.py index 7c9df925c5..fb820e8251 100644 --- a/dbm-ui/backend/ticket/builders/tendbcluster/tendb_flashback.py +++ b/dbm-ui/backend/ticket/builders/tendbcluster/tendb_flashback.py @@ -31,6 +31,10 @@ def validate(self, attrs): super(TendbBaseOperateDetailSerializer, self).validate_cluster_can_access(attrs) # 校验flash的库表选择器 RemoteServiceHandler(bk_biz_id=self.context["bk_biz_id"]).check_flashback_database(attrs["infos"]) + # 库表校验结果判断 + self.check_flashback_database_result(attrs) + # 校验待闪回记录格式与字段是否存在 + self.validate_rows_filter(attrs) return attrs