diff --git a/internal/connectors/db/connector.go b/internal/connectors/db/connector.go index 2b96b78a..459d9640 100644 --- a/internal/connectors/db/connector.go +++ b/internal/connectors/db/connector.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "path" + "reflect" "time" "ydbcp/internal/config" @@ -14,6 +16,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3" "github.com/ydb-platform/ydb-go-sdk/v3/balancers" "github.com/ydb-platform/ydb-go-sdk/v3/table" + "github.com/ydb-platform/ydb-go-sdk/v3/table/options" "github.com/ydb-platform/ydb-go-sdk/v3/table/result" table_types "github.com/ydb-platform/ydb-go-sdk/v3/table/types" "go.uber.org/zap" @@ -33,6 +36,11 @@ var ( ), table.CommitTx(), ) + writeTxNoCommit = table.TxControl( + table.BeginTx( + table.WithSerializableReadWrite(), + ), + ) ) type DBConnector interface { @@ -49,7 +57,7 @@ type DBConnector interface { []*types.BackupSchedule, error, ) ActiveOperations(context.Context) ([]types.Operation, error) - UpdateOperation(context.Context, types.Operation) error + UpdateOperation(context.Context, types.Operation, types.OperationState) error CreateOperation(context.Context, types.Operation) (string, error) CreateBackup(context.Context, types.Backup) (string, error) UpdateBackup(context context.Context, id string, backupState string) error @@ -211,16 +219,80 @@ func (d *YdbConnector) ExecuteUpsert(ctx context.Context, queryBuilder queries.W } err = d.GetTableClient().Do( ctx, func(ctx context.Context, s table.Session) (err error) { - _, _, err = s.Execute( + var ( + txControl *table.TransactionControl + statsOption options.ExecuteDataQueryOption + ) + + needToCheckQueryStats := len(queryFormat.ExpectedUpdateStats) > 0 + if needToCheckQueryStats { + txControl = writeTxNoCommit + statsOption = options.WithCollectStatsModeBasic() + } else { + txControl = writeTx + statsOption = options.WithCollectStatsModeNone() + } + + tx, res, err := s.Execute( ctx, - writeTx, + txControl, queryFormat.QueryText, queryFormat.QueryParams, + statsOption, ) if err != nil { return err } - return nil + + defer func(res result.Result) { + err = res.Close() + if err != nil { + xlog.Error(ctx, "Error closing transaction result", zap.Error(err)) + } + }(res) // result must be closed + + if !needToCheckQueryStats { + return nil + } + + if stats := res.Stats(); stats != nil { + updateStats := make(map[string]uint64) + + for { + phaseStats, ok := stats.NextPhase() + if !ok { + break + } + + for { + tableStats, ook := phaseStats.NextTableAccess() + if !ook { + break + } + + _, tableName := path.Split(tableStats.Name) + // We can't receive modifications stats before commit. + // This query only contains modifications (upserts or updates), + // so we can be sure that all reads operations are related to updates. + // We can use this assumption to calculate updated rows count. + updateStats[tableName] += tableStats.Reads.Rows + } + } + + if !reflect.DeepEqual(updateStats, queryFormat.ExpectedUpdateStats) { + xlog.Error(ctx, "Expected updated rows count does not match actual", + zap.Any("expected", queryFormat.ExpectedUpdateStats), + zap.Any("actual", updateStats), + ) + return tx.Rollback(ctx) + } + } else { + xlog.Error(ctx, "Empty stats for upsert query") + return tx.Rollback(ctx) + } + + _, err = tx.CommitTx(ctx) + return err }, ) if err != nil { @@ -303,7 +375,7 @@ func (d *YdbConnector) ActiveOperations(ctx context.Context) ( } func (d *YdbConnector) UpdateOperation( - ctx context.Context, operation types.Operation, + ctx context.Context, operation types.Operation, prevState types.OperationState, ) error { if operation.GetAudit() != nil && operation.GetAudit().CompletedAt != nil { operation.SetUpdatedAt(operation.GetAudit().CompletedAt) @@ -311,7 +383,7 @@ func (d *YdbConnector) UpdateOperation( operation.SetUpdatedAt(timestamppb.Now()) } - return d.ExecuteUpsert(ctx, queries.NewWriteTableQuery().WithUpdateOperation(operation)) + return d.ExecuteUpsert(ctx, queries.NewWriteTableQuery().WithUpdateOperation(operation, prevState)) } func (d *YdbConnector) CreateOperation( diff --git a/internal/connectors/db/mock.go b/internal/connectors/db/mock.go index 8103840c..4ebfa131 100644 --- a/internal/connectors/db/mock.go +++ b/internal/connectors/db/mock.go @@ -147,7 +147,7 @@ func (c *MockDBConnector) ActiveOperations(_ context.Context) ( } func (c *MockDBConnector) UpdateOperation( - _ context.Context, op types.Operation, + _ context.Context, op types.Operation, prevState types.OperationState, ) error { c.guard.Lock() defer c.guard.Unlock() @@ -155,7 +155,12 @@ func (c *MockDBConnector) UpdateOperation( if _, exist := c.operations[op.GetID()]; !exist { return fmt.Errorf("update nonexistent operation %s", types.OperationToString(op)) } - c.operations[op.GetID()] = op.Copy() + if c.operations[op.GetID()].GetState() == prevState { + c.operations[op.GetID()] = op.Copy() + } else { + return fmt.Errorf("operation state was changed %s", types.OperationToString(op)) + } + return nil } diff --git a/internal/connectors/db/yql/queries/read.go b/internal/connectors/db/yql/queries/read.go index 9856bad1..abb04b88 100644 --- a/internal/connectors/db/yql/queries/read.go +++ b/internal/connectors/db/yql/queries/read.go @@ -27,8 +27,9 @@ type QueryFilter struct { } type FormatQueryResult struct { - QueryText string - QueryParams *table.QueryParameters + QueryText string + QueryParams *table.QueryParameters + ExpectedUpdateStats map[string]uint64 } type ReadTableQuery interface { diff --git a/internal/connectors/db/yql/queries/write.go b/internal/connectors/db/yql/queries/write.go index c6d09822..b6d08a17 100644 --- a/internal/connectors/db/yql/queries/write.go +++ b/internal/connectors/db/yql/queries/write.go @@ -21,7 +21,7 @@ type WriteTableQuery interface { WithCreateOperation(operation types.Operation) WriteTableQuery WithCreateBackupSchedule(schedule types.BackupSchedule) WriteTableQuery WithUpdateBackup(backup types.Backup) WriteTableQuery - WithUpdateOperation(operation types.Operation) WriteTableQuery + WithUpdateOperation(operation types.Operation, prevState types.OperationState) WriteTableQuery WithUpdateBackupSchedule(schedule types.BackupSchedule) WriteTableQuery } @@ -35,6 +35,8 @@ type WriteSingleTableQueryImpl struct { upsertFields []string tableQueryParams []table.ParameterOption updateParam *table.ParameterOption + filterFields []string + filterParams []table.ParameterOption } type WriteTableQueryImplOption func(*WriteTableQueryImpl) @@ -192,7 +194,7 @@ func BuildCreateOperationQuery(operation types.Operation, index int) WriteSingle return d } -func BuildUpdateOperationQuery(operation types.Operation, index int) WriteSingleTableQueryImpl { +func BuildUpdateOperationQuery(operation types.Operation, index int, prevState types.OperationState) WriteSingleTableQueryImpl { d := WriteSingleTableQueryImpl{ index: index, tableName: "Operations", @@ -217,6 +219,16 @@ func BuildUpdateOperationQuery(operation types.Operation, index int) WriteSingle table_types.TimestampValueFromTime(operation.GetUpdatedAt().AsTime()), ) } + + d.filterFields = append(d.filterFields, "status") + d.filterParams = append( + d.filterParams, + table.ValueParam( + fmt.Sprintf("%s_%d", "$prev_status", d.index), + table_types.StringValueFromString(prevState.String()), + ), + ) + return d } @@ -381,9 +393,9 @@ func (d *WriteTableQueryImpl) WithUpdateBackup(backup types.Backup) WriteTableQu return d } -func (d *WriteTableQueryImpl) WithUpdateOperation(operation types.Operation) WriteTableQuery { +func (d *WriteTableQueryImpl) WithUpdateOperation(operation types.Operation, prevState types.OperationState) WriteTableQuery { index := len(d.tableQueries) - d.tableQueries = append(d.tableQueries, BuildUpdateOperationQuery(operation, index)) + d.tableQueries = append(d.tableQueries, BuildUpdateOperationQuery(operation, index, prevState)) return d } @@ -439,25 +451,36 @@ func ProcessUpdateQuery( for i := range t.upsertFields { updates = append(updates, fmt.Sprintf("%s = %s", t.upsertFields[i], paramNames[i])) } + + filters := make([]string, 0) + filters = append(filters, keyParam) + for i := range t.filterFields { + filters = append(filters, fmt.Sprintf("%s = %s", t.filterFields[i], t.filterParams[i].Name())) + } + *queryStrings = append( *queryStrings, fmt.Sprintf( - "UPDATE %s SET %s WHERE %s", t.tableName, strings.Join(updates, ", "), keyParam, + `UPDATE %s SET %s WHERE %s`, + t.tableName, strings.Join(updates, ", "), strings.Join(filters, " AND "), ), ) *allParams = append(*allParams, *t.updateParam) *allParams = append(*allParams, t.tableQueryParams...) + *allParams = append(*allParams, t.filterParams...) return nil } func (d *WriteTableQueryImpl) FormatQuery(ctx context.Context) (*FormatQueryResult, error) { queryStrings := make([]string, 0) allParams := make([]table.ParameterOption, 0) + expectedUpdateStats := make(map[string]uint64) for _, t := range d.tableQueries { var err error if t.updateParam == nil { err = ProcessUpsertQuery(&queryStrings, &allParams, &t) } else { err = ProcessUpdateQuery(&queryStrings, &allParams, &t) + expectedUpdateStats[t.tableName]++ } if err != nil { return nil, err @@ -466,7 +489,8 @@ func (d *WriteTableQueryImpl) FormatQuery(ctx context.Context) (*FormatQueryResu res := strings.Join(queryStrings, ";\n") xlog.Debug(ctx, "write query", zap.String("yql", res)) return &FormatQueryResult{ - QueryText: res, - QueryParams: table.NewQueryParameters(allParams...), + QueryText: res, + QueryParams: table.NewQueryParameters(allParams...), + ExpectedUpdateStats: expectedUpdateStats, }, nil } diff --git a/internal/connectors/db/yql/queries/write_mock.go b/internal/connectors/db/yql/queries/write_mock.go index ae1eff79..7e98e50e 100644 --- a/internal/connectors/db/yql/queries/write_mock.go +++ b/internal/connectors/db/yql/queries/write_mock.go @@ -42,7 +42,7 @@ func (w *WriteTableQueryMock) WithUpdateBackup(backup types.Backup) WriteTableQu return w } -func (w *WriteTableQueryMock) WithUpdateOperation(operation types.Operation) WriteTableQuery { +func (w *WriteTableQueryMock) WithUpdateOperation(operation types.Operation, _ types.OperationState) WriteTableQuery { w.Operation = operation return w } diff --git a/internal/handlers/delete_backup.go b/internal/handlers/delete_backup.go index a4193cbb..110fca77 100644 --- a/internal/handlers/delete_backup.go +++ b/internal/handlers/delete_backup.go @@ -54,13 +54,14 @@ func DBOperationHandler( Status: types.BackupStateUnknown, } + prevState := operation.GetState() if deadlineExceeded(dbOp.Audit.CreatedAt, config) { backupToWrite.Status = types.BackupStateError operation.SetState(types.OperationStateError) operation.SetMessage("Operation deadline exceeded") operation.GetAudit().CompletedAt = timestamppb.Now() return db.ExecuteUpsert( - ctx, queryBulderFactory().WithUpdateOperation(operation).WithUpdateBackup(backupToWrite), + ctx, queryBulderFactory().WithUpdateOperation(operation, prevState).WithUpdateBackup(backupToWrite), ) } @@ -110,7 +111,7 @@ func DBOperationHandler( backupToWrite.Status = types.BackupStateDeleting operation.SetState(types.OperationStateRunning) err := db.ExecuteUpsert( - ctx, queryBulderFactory().WithUpdateOperation(operation).WithUpdateBackup(backupToWrite), + ctx, queryBulderFactory().WithUpdateOperation(operation, prevState).WithUpdateBackup(backupToWrite), ) if err != nil { return fmt.Errorf("can't update operation: %v", err) @@ -133,6 +134,6 @@ func DBOperationHandler( } return db.ExecuteUpsert( - ctx, queryBulderFactory().WithUpdateOperation(operation).WithUpdateBackup(backupToWrite), + ctx, queryBulderFactory().WithUpdateOperation(operation, prevState).WithUpdateBackup(backupToWrite), ) } diff --git a/internal/handlers/restore_backup.go b/internal/handlers/restore_backup.go index 9a8e10c5..ae20cea5 100644 --- a/internal/handlers/restore_backup.go +++ b/internal/handlers/restore_backup.go @@ -60,11 +60,13 @@ func RBOperationHandler( if err != nil { return err } + + prevState := operation.GetState() if ydbOpResponse.shouldAbortHandler { operation.SetState(ydbOpResponse.opState) operation.SetMessage(ydbOpResponse.opMessage) operation.GetAudit().CompletedAt = timestamppb.Now() - return db.UpdateOperation(ctx, operation) + return db.UpdateOperation(ctx, operation, prevState) } if ydbOpResponse.opResponse == nil { @@ -81,7 +83,7 @@ func RBOperationHandler( operation.SetMessage("Operation deadline exceeded") } - return db.UpdateOperation(ctx, operation) + return db.UpdateOperation(ctx, operation, prevState) } if opResponse.GetOperation().Status == Ydb.StatusIds_SUCCESS { operation.SetState(types.OperationStateDone) @@ -105,7 +107,7 @@ func RBOperationHandler( return err } - return db.UpdateOperation(ctx, operation) + return db.UpdateOperation(ctx, operation, prevState) } case types.OperationStateCancelling: { @@ -116,7 +118,7 @@ func RBOperationHandler( operation.GetAudit().CompletedAt = timestamppb.Now() } - return db.UpdateOperation(ctx, operation) + return db.UpdateOperation(ctx, operation, prevState) } if opResponse.GetOperation().Status == Ydb.StatusIds_SUCCESS { operation.SetState(types.OperationStateDone) @@ -160,5 +162,5 @@ func RBOperationHandler( } operation.GetAudit().CompletedAt = timestamppb.Now() - return db.UpdateOperation(ctx, operation) + return db.UpdateOperation(ctx, operation, prevState) } diff --git a/internal/handlers/take_backup.go b/internal/handlers/take_backup.go index 3d0513f7..4350df30 100644 --- a/internal/handlers/take_backup.go +++ b/internal/handlers/take_backup.go @@ -60,6 +60,7 @@ func TBOperationHandler( return err } + prevState := operation.GetState() now := timestamppb.Now() backupToWrite := types.Backup{ ID: tb.BackupID, @@ -75,7 +76,7 @@ func TBOperationHandler( backupToWrite.Message = operation.GetMessage() backupToWrite.AuditInfo.CompletedAt = now return db.ExecuteUpsert( - ctx, queryBuilderFactory().WithUpdateOperation(operation).WithUpdateBackup(backupToWrite), + ctx, queryBuilderFactory().WithUpdateOperation(operation, prevState).WithUpdateBackup(backupToWrite), ) } if ydbOpResponse.opResponse == nil { @@ -120,7 +121,7 @@ func TBOperationHandler( operation.SetState(types.OperationStateStartCancelling) operation.SetMessage("Operation deadline exceeded") } - return db.UpdateOperation(ctx, operation) + return db.UpdateOperation(ctx, operation, prevState) } else if opResponse.GetOperation().Status == Ydb.StatusIds_SUCCESS { size, err := getBackupSize(tb.BackupID) if err != nil { @@ -156,7 +157,7 @@ func TBOperationHandler( backupToWrite.Message = operation.GetMessage() backupToWrite.AuditInfo.CompletedAt = operation.GetAudit().CompletedAt return db.ExecuteUpsert( - ctx, queryBuilderFactory().WithUpdateOperation(operation).WithUpdateBackup(backupToWrite), + ctx, queryBuilderFactory().WithUpdateOperation(operation, prevState).WithUpdateBackup(backupToWrite), ) } case types.OperationStateCancelling: @@ -170,11 +171,11 @@ func TBOperationHandler( operation.GetAudit().CompletedAt = now backupToWrite.Message = operation.GetMessage() return db.ExecuteUpsert( - ctx, queryBuilderFactory().WithUpdateOperation(operation).WithUpdateBackup(backupToWrite), + ctx, queryBuilderFactory().WithUpdateOperation(operation, prevState).WithUpdateBackup(backupToWrite), ) } - return db.UpdateOperation(ctx, operation) + return db.UpdateOperation(ctx, operation, prevState) } if opResponse.GetOperation().Status == Ydb.StatusIds_SUCCESS { size, err := getBackupSize(tb.BackupID) @@ -221,6 +222,6 @@ func TBOperationHandler( backupToWrite.AuditInfo.CompletedAt = now operation.GetAudit().CompletedAt = now return db.ExecuteUpsert( - ctx, queryBuilderFactory().WithUpdateOperation(operation).WithUpdateBackup(backupToWrite), + ctx, queryBuilderFactory().WithUpdateOperation(operation, prevState).WithUpdateBackup(backupToWrite), ) } diff --git a/internal/processor/processor_test.go b/internal/processor/processor_test.go index beae9ccd..0ad25df3 100644 --- a/internal/processor/processor_test.go +++ b/internal/processor/processor_test.go @@ -43,9 +43,10 @@ func TestProcessor(t *testing.T) { ctx, "TB handler called for operation", zap.String("operation", types.OperationToString(op)), ) + prevState := op.GetState() op.SetState(types.OperationStateDone) op.SetMessage("Success") - db.UpdateOperation(ctx, op) + db.UpdateOperation(ctx, op, prevState) handlerCalled <- struct{}{} return nil }, diff --git a/internal/server/services/operation/operation_service.go b/internal/server/services/operation/operation_service.go index c97d3eaf..0574a4f0 100644 --- a/internal/server/services/operation/operation_service.go +++ b/internal/server/services/operation/operation_service.go @@ -151,15 +151,16 @@ func (s *OperationService) CancelOperation( } ctx = xlog.With(ctx, zap.String("SubjectID", subject)) - if operation.GetState() != types.OperationStatePending { - xlog.Error(ctx, "can't cancel operation with state != pending") + if operation.GetState() != types.OperationStatePending && operation.GetState() != types.OperationStateRunning { + xlog.Error(ctx, "can't cancel operation with state", zap.String("OperationState", operation.GetState().String())) return operation.Proto(), nil } + prevState := operation.GetState() operation.SetState(types.OperationStateStartCancelling) operation.SetMessage("Operation was cancelled via OperationService") - err = s.driver.UpdateOperation(ctx, operation) + err = s.driver.UpdateOperation(ctx, operation, prevState) if err != nil { xlog.Error(ctx, "error updating operation", zap.Error(err)) return nil, status.Error(codes.Internal, "error updating operation")