diff --git a/sqle/api/controller/v1/rule.go b/sqle/api/controller/v1/rule.go index 848343c44..f3581877c 100644 --- a/sqle/api/controller/v1/rule.go +++ b/sqle/api/controller/v1/rule.go @@ -643,10 +643,12 @@ func getRuleTemplateTips(c echo.Context, projectId string, filterDBType string) ruleTemplateTipsRes := make([]RuleTemplateTipResV1, 0, len(ruleTemplates)) for _, roleTemplate := range ruleTemplates { + isDefaultRuleTemplate := roleTemplate.Name == fmt.Sprintf("default_%s", roleTemplate.DBType) ruleTemplateTipRes := RuleTemplateTipResV1{ - ID: roleTemplate.GetIDStr(), - Name: roleTemplate.Name, - DBType: roleTemplate.DBType, + ID: roleTemplate.GetIDStr(), + Name: roleTemplate.Name, + DBType: roleTemplate.DBType, + IsDefaultRuleTemplate: isDefaultRuleTemplate, } ruleTemplateTipsRes = append(ruleTemplateTipsRes, ruleTemplateTipRes) } diff --git a/sqle/api/controller/v1/sql_audit_record.go b/sqle/api/controller/v1/sql_audit_record.go index 58185bc9e..c4ded45c9 100644 --- a/sqle/api/controller/v1/sql_audit_record.go +++ b/sqle/api/controller/v1/sql_audit_record.go @@ -96,6 +96,19 @@ func CreateSQLAuditRecord(c echo.Context) error { } s := model.GetStorage() + + var ruleTemplateID uint + if req.RuleTemplateName != nil { + ruleTemplate, exist, err := s.GetGlobalAndProjectRuleTemplateByNameAndProjectId(*req.RuleTemplateName, projectUid) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + if !exist { + return controller.JSONBaseErrorReq(c, errors.New(errors.DataNotExist, fmt.Errorf("rule template %v not exist", *req.RuleTemplateName))) + } + ruleTemplateID = ruleTemplate.ID + } + sqls := getSQLFromFileResp{} user, err := controller.GetCurrentUser(c, dms.GetUser) if err != nil { @@ -125,6 +138,7 @@ func CreateSQLAuditRecord(c echo.Context) error { return controller.JSONBaseErrorReq(c, err) } } + task.RuleTemplateID = ruleTemplateID // if task instance is not nil, gorm will update instance when save task. task.Instance = nil diff --git a/sqle/model/task.go b/sqle/model/task.go index 6abc28fc1..38490d1a3 100644 --- a/sqle/model/task.go +++ b/sqle/model/task.go @@ -57,17 +57,26 @@ type Task struct { Status string `json:"status" gorm:"default:\"initialized\";type:varchar(255)"` GroupId uint `json:"group_id" gorm:"column:group_id"` CreateUserId uint64 + RuleTemplateID uint `json:"rule_template_id" gorm:"column:rule_template_id"` ExecStartAt *time.Time ExecEndAt *time.Time ExecMode string `json:"exec_mode" gorm:"default:'sqls';type:varchar(255)" example:"sqls"` EnableBackup bool `gorm:"column:enable_backup"` FileOrderMethod string `json:"file_order_method" gorm:"column:file_order_method;type:varchar(255)"` Instance *Instance `json:"-" gorm:"-"` + RuleTemplate *RuleTemplate `json:"-" gorm:"foreignkey:RuleTemplateID"` ExecuteSQLs []*ExecuteSQL `json:"-" gorm:"foreignkey:TaskId"` RollbackSQLs []*RollbackSQL `json:"-" gorm:"foreignkey:TaskId"` AuditFiles []*AuditFile `json:"-" gorm:"foreignkey:TaskId"` } +func (t *Task) RuleTemplateName() string { + if t.RuleTemplate != nil { + return t.RuleTemplate.Name + } + return "" +} + func (t *Task) InstanceName() string { if t.Instance != nil { return t.Instance.Name @@ -445,7 +454,7 @@ func (s *Storage) GetTasksByIds(taskIds []uint) (tasks []*Task, foundAllIds bool func (s *Storage) GetTaskDetailById(taskId string) (*Task, bool, error) { task := &Task{} err := s.db.Where("id = ?", taskId). - Preload("ExecuteSQLs").Preload("RollbackSQLs").First(task).Error + Preload("RuleTemplate").Preload("ExecuteSQLs").Preload("RollbackSQLs").First(task).Error if err == gorm.ErrRecordNotFound { return nil, false, nil } diff --git a/sqle/server/sqled.go b/sqle/server/sqled.go index b7f24aa8e..080c5d74d 100644 --- a/sqle/server/sqled.go +++ b/sqle/server/sqled.go @@ -114,7 +114,7 @@ func (s *Sqled) addTask(projectId string, taskId string, typ int) (*action, erro action.task = task // plugin will be closed by drvMgr in Sqled.do(). - rules, customRules, err = st.GetAllRulesByTmpNameAndProjectIdInstanceDBType("", "", task.Instance, task.DBType) + rules, customRules, err = st.GetAllRulesByTmpNameAndProjectIdInstanceDBType(task.RuleTemplateName(), projectId, task.Instance, task.DBType) if err != nil { goto Error }