From 1901239f949b4a3247ddafcc6f7c73f6624ebcda Mon Sep 17 00:00:00 2001 From: xhe Date: Mon, 4 Sep 2023 16:15:43 +0800 Subject: [PATCH] disktask: replace failure nodes with alive ones (#45935) ref pingcap/tidb#46258 --- Makefile | 23 +- disttask/framework/BUILD.bazel | 3 +- disttask/framework/dispatcher/BUILD.bazel | 1 + disttask/framework/dispatcher/dispatcher.go | 106 +++++++++- .../framework/dispatcher/dispatcher_test.go | 28 +-- disttask/framework/framework_ha_test.go | 199 ++++++++++++++++++ disttask/framework/framework_test.go | 38 +++- disttask/framework/mock/scheduler_mock.go | 39 ++-- disttask/framework/scheduler/BUILD.bazel | 1 + disttask/framework/scheduler/interface.go | 3 +- disttask/framework/scheduler/manager.go | 24 ++- disttask/framework/scheduler/manager_test.go | 6 - disttask/framework/scheduler/scheduler.go | 104 +++++++-- .../framework/scheduler/scheduler_test.go | 5 +- disttask/framework/storage/table_test.go | 27 +++ disttask/framework/storage/task_table.go | 48 +++++ disttask/importinto/dispatcher_test.go | 5 +- domain/infosync/mock_info.go | 20 ++ 18 files changed, 579 insertions(+), 101 deletions(-) create mode 100644 disttask/framework/framework_ha_test.go diff --git a/Makefile b/Makefile index 638a13aa52ac2..6ec54318f3a33 100644 --- a/Makefile +++ b/Makefile @@ -266,6 +266,9 @@ tools/bin/vfsgendev: tools/bin/gotestsum: GOBIN=$(shell pwd)/tools/bin $(GO) install gotest.tools/gotestsum@v1.8.1 +tools/bin/mockgen: + GOBIN=$(shell pwd)/tools/bin $(GO) install go.uber.org/mock/mockgen@v0.2.0 + # Usage: # # $ make vectorized-bench VB_FILE=Time VB_FUNC=builtinCurrentDateSig @@ -370,18 +373,18 @@ br_compatibility_test_prepare: br_compatibility_test: @cd br && tests/run_compatible.sh run -mock_s3iface: - @mockgen -package mock github.com/aws/aws-sdk-go/service/s3/s3iface S3API > br/pkg/mock/s3iface.go +mock_s3iface: tools/bin/mockgen + tools/bin/mockgen -package mock github.com/aws/aws-sdk-go/service/s3/s3iface S3API > br/pkg/mock/s3iface.go # mock interface for lightning and IMPORT INTO -mock_lightning: - @mockgen -package mock github.com/pingcap/tidb/br/pkg/lightning/backend Backend,EngineWriter,TargetInfoGetter,ChunkFlushStatus > br/pkg/mock/backend.go - @mockgen -package mock github.com/pingcap/tidb/br/pkg/lightning/backend/encode Encoder,EncodingBuilder,Rows,Row > br/pkg/mock/encode.go - @mockgen -package mocklocal github.com/pingcap/tidb/br/pkg/lightning/backend/local DiskUsage,TiKVModeSwitcher > br/pkg/mock/mocklocal/local.go - @mockgen -package mock github.com/pingcap/tidb/br/pkg/utils TaskRegister > br/pkg/mock/task_register.go - -gen_mock: - @mockgen -package mock github.com/pingcap/tidb/disttask/framework/scheduler TaskTable,SubtaskExecutor,Pool,Scheduler,InternalScheduler > disttask/framework/mock/scheduler_mock.go +mock_lightning: tools/bin/mockgen + tools/bin/mockgen -package mock github.com/pingcap/tidb/br/pkg/lightning/backend Backend,EngineWriter,TargetInfoGetter,ChunkFlushStatus > br/pkg/mock/backend.go + tools/bin/mockgen -package mock github.com/pingcap/tidb/br/pkg/lightning/backend/encode Encoder,EncodingBuilder,Rows,Row > br/pkg/mock/encode.go + tools/bin/mockgen -package mocklocal github.com/pingcap/tidb/br/pkg/lightning/backend/local DiskUsage,TiKVModeSwitcher > br/pkg/mock/mocklocal/local.go + tools/bin/mockgen -package mock github.com/pingcap/tidb/br/pkg/utils TaskRegister > br/pkg/mock/task_register.go + +gen_mock: tools/bin/mockgen + tools/bin/mockgen -package mock github.com/pingcap/tidb/disttask/framework/scheduler TaskTable,SubtaskExecutor,Pool,Scheduler,InternalScheduler > disttask/framework/mock/scheduler_mock.go # There is no FreeBSD environment for GitHub actions. So cross-compile on Linux # but that doesn't work with CGO_ENABLED=1, so disable cgo. The reason to have diff --git a/disttask/framework/BUILD.bazel b/disttask/framework/BUILD.bazel index b5a4ef73513df..a49591496b169 100644 --- a/disttask/framework/BUILD.bazel +++ b/disttask/framework/BUILD.bazel @@ -5,12 +5,13 @@ go_test( timeout = "short", srcs = [ "framework_err_handling_test.go", + "framework_ha_test.go", "framework_rollback_test.go", "framework_test.go", ], flaky = True, race = "on", - shard_count = 14, + shard_count = 22, deps = [ "//disttask/framework/dispatcher", "//disttask/framework/proto", diff --git a/disttask/framework/dispatcher/BUILD.bazel b/disttask/framework/dispatcher/BUILD.bazel index d64b387e65ad7..4ea32cf32dea4 100644 --- a/disttask/framework/dispatcher/BUILD.bazel +++ b/disttask/framework/dispatcher/BUILD.bazel @@ -19,6 +19,7 @@ go_library( "//sessionctx/variable", "//util", "//util/disttask", + "//util/intest", "//util/logutil", "//util/syncutil", "@com_github_pingcap_errors//:errors", diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index 14ae3911b3281..94844985a8e2b 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -17,6 +17,7 @@ package dispatcher import ( "context" "fmt" + "math/rand" "time" "github.com/pingcap/errors" @@ -27,6 +28,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" disttaskutil "github.com/pingcap/tidb/util/disttask" + "github.com/pingcap/tidb/util/intest" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) @@ -36,6 +38,8 @@ const ( DefaultSubtaskConcurrency = 16 // MaxSubtaskConcurrency is the maximum concurrency for handling subtask. MaxSubtaskConcurrency = 256 + // DefaultLiveNodesCheckInterval is the tick interval of fetching all server infos from etcd. + DefaultLiveNodesCheckInterval = 2 ) var ( @@ -65,6 +69,17 @@ type dispatcher struct { logCtx context.Context serverID string impl Dispatcher + + // for HA + // liveNodes will fetch and store all live nodes every liveNodeInterval ticks. + liveNodes []*infosync.ServerInfo + liveNodeFetchInterval int + // liveNodeFetchTick is the tick variable. + liveNodeFetchTick int + // taskNodes stores the id of current scheduler nodes. + taskNodes []string + // rand is for generating random selection of nodes. + rand *rand.Rand } // MockOwnerChange mock owner change in tests. @@ -74,12 +89,17 @@ func newDispatcher(ctx context.Context, taskMgr *storage.TaskManager, serverID s logPrefix := fmt.Sprintf("task_id: %d, task_type: %s, server_id: %s", task.ID, task.Type, serverID) impl := GetTaskDispatcher(task.Type) dsp := &dispatcher{ - ctx: ctx, - taskMgr: taskMgr, - task: task, - logCtx: logutil.WithKeyValue(context.Background(), "dispatcher", logPrefix), - serverID: serverID, - impl: impl, + ctx: ctx, + taskMgr: taskMgr, + task: task, + logCtx: logutil.WithKeyValue(context.Background(), "dispatcher", logPrefix), + serverID: serverID, + impl: impl, + liveNodes: nil, + liveNodeFetchInterval: DefaultLiveNodesCheckInterval, + liveNodeFetchTick: 0, + taskNodes: nil, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), } if dsp.impl == nil { logutil.BgLogger().Warn("gen dispatcher impl failed, this type impl doesn't register") @@ -215,12 +235,70 @@ func (d *dispatcher) onRunning() error { logutil.Logger(d.logCtx).Info("previous stage finished, generate dist plan", zap.Int64("stage", d.task.Step)) return d.onNextStage() } + // Check if any node are down. + if err := d.replaceDeadNodesIfAny(); err != nil { + return err + } // Wait all subtasks in this stage finished. d.impl.OnTick(d.ctx, d.task) logutil.Logger(d.logCtx).Debug("on running state, this task keeps current state", zap.String("state", d.task.State)) return nil } +func (d *dispatcher) replaceDeadNodesIfAny() error { + if len(d.taskNodes) == 0 { + return errors.Errorf("len(d.taskNodes) == 0, onNextStage is not invoked before onRunning") + } + d.liveNodeFetchTick++ + if d.liveNodeFetchTick == d.liveNodeFetchInterval { + d.liveNodeFetchTick = 0 + serverInfos, err := GenerateSchedulerNodes(d.ctx) + if err != nil { + return err + } + eligibleServerInfos, err := d.impl.GetEligibleInstances(d.ctx, d.task) + if err != nil { + return err + } + newInfos := serverInfos[:0] + for _, m := range serverInfos { + found := false + for _, n := range eligibleServerInfos { + if m.ID == n.ID { + found = true + break + } + } + if found { + newInfos = append(newInfos, m) + } + } + d.liveNodes = newInfos + } + if len(d.liveNodes) > 0 { + replaceNodes := make(map[string]string) + for _, nodeID := range d.taskNodes { + if ok := disttaskutil.MatchServerInfo(d.liveNodes, nodeID); !ok { + n := d.liveNodes[d.rand.Int()%len(d.liveNodes)] //nolint:gosec + replaceNodes[nodeID] = disttaskutil.GenerateExecID(n.IP, n.Port) + } + } + if err := d.taskMgr.UpdateFailedSchedulerIDs(d.task.ID, replaceNodes); err != nil { + return err + } + // replace local cache. + for k, v := range replaceNodes { + for m, n := range d.taskNodes { + if n == k { + d.taskNodes[m] = v + break + } + } + } + } + return nil +} + func (d *dispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) { prevState := d.task.State d.task.State = taskState @@ -331,6 +409,10 @@ func (d *dispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error { if len(serverNodes) == 0 { return errors.New("no available TiDB node to dispatch subtasks") } + d.taskNodes = make([]string, len(serverNodes)) + for i := range serverNodes { + d.taskNodes[i] = disttaskutil.GenerateExecID(serverNodes[i].IP, serverNodes[i].Port) + } subTasks := make([]*proto.Subtask, 0, len(metas)) for i, meta := range metas { // we assign the subtask to the instance in a round-robin way. @@ -353,8 +435,14 @@ func (d *dispatcher) handlePlanErr(err error) error { } // GenerateSchedulerNodes generate a eligible TiDB nodes. -func GenerateSchedulerNodes(ctx context.Context) ([]*infosync.ServerInfo, error) { - serverInfos, err := infosync.GetAllServerInfo(ctx) +func GenerateSchedulerNodes(ctx context.Context) (serverNodes []*infosync.ServerInfo, err error) { + var serverInfos map[string]*infosync.ServerInfo + _, etcd := ctx.Value("etcd").(bool) + if intest.InTest && !etcd { + serverInfos = infosync.MockGlobalServerInfoManagerEntry.GetAllServerInfo() + } else { + serverInfos, err = infosync.GetAllServerInfo(ctx) + } if err != nil { return nil, err } @@ -362,7 +450,7 @@ func GenerateSchedulerNodes(ctx context.Context) ([]*infosync.ServerInfo, error) return nil, errors.New("not found instance") } - serverNodes := make([]*infosync.ServerInfo, 0, len(serverInfos)) + serverNodes = make([]*infosync.ServerInfo, 0, len(serverInfos)) for _, serverInfo := range serverInfos { serverNodes = append(serverNodes, serverInfo) } diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go index 819228e3b6397..0e8fff03d53db 100644 --- a/disttask/framework/dispatcher/dispatcher_test.go +++ b/disttask/framework/dispatcher/dispatcher_test.go @@ -106,7 +106,7 @@ func (*numberExampleDispatcher) IsRetryableErr(error) bool { } func MockDispatcherManager(t *testing.T, pool *pools.ResourcePool) (*dispatcher.Manager, *storage.TaskManager) { - ctx := context.Background() + ctx := context.WithValue(context.Background(), "etcd", true) mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), pool) storage.SetTaskManager(mgr) dsp, err := dispatcher.NewManager(util.WithInternalSourceType(ctx, "dispatcher"), mgr, "host:port") @@ -220,31 +220,19 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) { // 3s cnt := 60 checkGetRunningTaskCnt := func(expected int) { - var retCnt int - for i := 0; i < cnt; i++ { - retCnt = dsp.GetRunningTaskCnt() - if retCnt == expected { - break - } - time.Sleep(time.Millisecond * 50) - } - require.Equal(t, retCnt, expected) + require.Eventually(t, func() bool { + return dsp.GetRunningTaskCnt() == expected + }, time.Second, 50*time.Millisecond) } checkTaskRunningCnt := func() []*proto.Task { - var retCnt int var tasks []*proto.Task - var err error - for i := 0; i < cnt; i++ { + require.Eventually(t, func() bool { + var err error tasks, err = mgr.GetGlobalTasksInStates(proto.TaskStateRunning) require.NoError(t, err) - retCnt = len(tasks) - if retCnt == taskCnt { - break - } - time.Sleep(time.Millisecond * 50) - } - require.Equal(t, retCnt, taskCnt) + return len(tasks) == taskCnt + }, time.Second, 50*time.Millisecond) return tasks } diff --git a/disttask/framework/framework_ha_test.go b/disttask/framework/framework_ha_test.go new file mode 100644 index 0000000000000..fb655613203f8 --- /dev/null +++ b/disttask/framework/framework_ha_test.go @@ -0,0 +1,199 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package framework_test + +import ( + "context" + "sync" + "testing" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/disttask/framework/dispatcher" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/scheduler" + "github.com/pingcap/tidb/domain/infosync" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" +) + +type haTestFlowHandle struct{} + +var _ dispatcher.Dispatcher = (*haTestFlowHandle)(nil) + +func (*haTestFlowHandle) OnTick(_ context.Context, _ *proto.Task) { +} + +func (*haTestFlowHandle) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { + if gTask.State == proto.TaskStatePending { + gTask.Step = proto.StepOne + return [][]byte{ + []byte("task1"), + []byte("task2"), + []byte("task3"), + []byte("task4"), + []byte("task5"), + []byte("task6"), + []byte("task7"), + []byte("task8"), + []byte("task9"), + []byte("task10"), + }, nil + } + if gTask.Step == proto.StepOne { + gTask.Step = proto.StepTwo + return [][]byte{ + []byte("task11"), + []byte("task12"), + []byte("task13"), + []byte("task14"), + []byte("task15"), + }, nil + } + return nil, nil +} + +func (*haTestFlowHandle) OnErrStage(ctx context.Context, h dispatcher.TaskHandle, gTask *proto.Task, receiveErr []error) (subtaskMeta []byte, err error) { + return nil, nil +} + +func (*haTestFlowHandle) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { + return generateSchedulerNodes4Test() +} + +func (*haTestFlowHandle) IsRetryableErr(error) bool { + return true +} + +func RegisterHATaskMeta(m *sync.Map) { + dispatcher.ClearTaskDispatcher() + dispatcher.RegisterTaskDispatcher(proto.TaskTypeExample, &haTestFlowHandle{}) + scheduler.ClearSchedulers() + scheduler.RegisterTaskType(proto.TaskTypeExample) + scheduler.RegisterSchedulerConstructor(proto.TaskTypeExample, proto.StepOne, func(_ context.Context, _ int64, _ []byte, _ int64) (scheduler.Scheduler, error) { + return &testScheduler{}, nil + }) + scheduler.RegisterSchedulerConstructor(proto.TaskTypeExample, proto.StepTwo, func(_ context.Context, _ int64, _ []byte, _ int64) (scheduler.Scheduler, error) { + return &testScheduler{}, nil + }) + scheduler.RegisterSubtaskExectorConstructor(proto.TaskTypeExample, proto.StepOne, func(_ proto.MinimalTask, _ int64) (scheduler.SubtaskExecutor, error) { + return &testSubtaskExecutor{m: m}, nil + }) + scheduler.RegisterSubtaskExectorConstructor(proto.TaskTypeExample, proto.StepTwo, func(_ proto.MinimalTask, _ int64) (scheduler.SubtaskExecutor, error) { + return &testSubtaskExecutor1{m: m}, nil + }) +} + +func TestHABasic(t *testing.T) { + defer dispatcher.ClearTaskDispatcher() + defer scheduler.ClearSchedulers() + var m sync.Map + RegisterHATaskMeta(&m) + distContext := testkit.NewDistExecutionContext(t, 4) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager", "4*return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown", "return(\":4000\")")) + DispatchTaskAndCheckSuccess("😊", t, &m) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler")) + distContext.Close() +} + +func TestHAManyNodes(t *testing.T) { + defer dispatcher.ClearTaskDispatcher() + defer scheduler.ClearSchedulers() + var m sync.Map + + RegisterHATaskMeta(&m) + distContext := testkit.NewDistExecutionContext(t, 30) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager", "30*return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown", "return(\":4000\")")) + DispatchTaskAndCheckSuccess("😊", t, &m) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler")) + distContext.Close() +} + +func TestHAFailInDifferentStage(t *testing.T) { + defer dispatcher.ClearTaskDispatcher() + defer scheduler.ClearSchedulers() + var m sync.Map + + RegisterHATaskMeta(&m) + distContext := testkit.NewDistExecutionContext(t, 6) + // stage1 : server num from 6 to 3. + // stage2 : server num from 3 to 2. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager", "6*return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown", "return(\":4000\")")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown2", "return()")) + + DispatchTaskAndCheckSuccess("😊", t, &m) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown2")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler")) + distContext.Close() +} + +func TestHAFailInDifferentStageManyNodes(t *testing.T) { + defer dispatcher.ClearTaskDispatcher() + defer scheduler.ClearSchedulers() + var m sync.Map + + RegisterHATaskMeta(&m) + distContext := testkit.NewDistExecutionContext(t, 30) + // stage1 : server num from 30 to 27. + // stage2 : server num from 27 to 26. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager", "30*return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown", "return(\":4000\")")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown2", "return()")) + + DispatchTaskAndCheckSuccess("😊", t, &m) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown2")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler")) + distContext.Close() +} + +func TestHAReplacedButRunning(t *testing.T) { + defer dispatcher.ClearTaskDispatcher() + defer scheduler.ClearSchedulers() + var m sync.Map + + RegisterHATaskMeta(&m) + distContext := testkit.NewDistExecutionContext(t, 4) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBPartitionThenResume", "10*return(true)")) + DispatchTaskAndCheckSuccess("😊", t, &m) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBPartitionThenResume")) + distContext.Close() +} + +func TestHAReplacedButRunningManyNodes(t *testing.T) { + defer dispatcher.ClearTaskDispatcher() + defer scheduler.ClearSchedulers() + var m sync.Map + + RegisterHATaskMeta(&m) + distContext := testkit.NewDistExecutionContext(t, 30) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBPartitionThenResume", "30*return(true)")) + DispatchTaskAndCheckSuccess("😊", t, &m) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBPartitionThenResume")) + distContext.Close() +} diff --git a/disttask/framework/framework_test.go b/disttask/framework/framework_test.go index 0b977f9d185e8..ab53fd5beb4a4 100644 --- a/disttask/framework/framework_test.go +++ b/disttask/framework/framework_test.go @@ -156,7 +156,7 @@ func DispatchTask(taskKey string, t *testing.T) *proto.Task { var task *proto.Task for { - if time.Since(start) > 2*time.Minute { + if time.Since(start) > 10*time.Minute { require.FailNow(t, "timeout") } @@ -358,3 +358,39 @@ func TestFrameworkCancelThenSubmitSubTask(t *testing.T) { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/cancelBeforeUpdate")) distContext.Close() } + +func TestSchedulerDownBasic(t *testing.T) { + defer dispatcher.ClearTaskDispatcher() + defer scheduler.ClearSchedulers() + var m sync.Map + RegisterTaskMeta(&m, &testDispatcher{}) + + distContext := testkit.NewDistExecutionContext(t, 4) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager", "4*return(\":4000\")")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown", "return(\":4000\")")) + DispatchTaskAndCheckSuccess("😊", t, &m) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager")) + + distContext.Close() +} + +func TestSchedulerDownManyNodes(t *testing.T) { + defer dispatcher.ClearTaskDispatcher() + defer scheduler.ClearSchedulers() + var m sync.Map + RegisterTaskMeta(&m, &testDispatcher{}) + + distContext := testkit.NewDistExecutionContext(t, 30) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager", "30*return(\":4000\")")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown", "return(\":4000\")")) + DispatchTaskAndCheckSuccess("😊", t, &m) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager")) + + distContext.Close() +} diff --git a/disttask/framework/mock/scheduler_mock.go b/disttask/framework/mock/scheduler_mock.go index 3510f7ffd785b..43f2212ffbed0 100644 --- a/disttask/framework/mock/scheduler_mock.go +++ b/disttask/framework/mock/scheduler_mock.go @@ -123,6 +123,21 @@ func (mr *MockTaskTableMockRecorder) HasSubtasksInStates(arg0, arg1, arg2 interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasSubtasksInStates", reflect.TypeOf((*MockTaskTable)(nil).HasSubtasksInStates), varargs...) } +// IsSchedulerCanceled mocks base method. +func (m *MockTaskTable) IsSchedulerCanceled(arg0 int64, arg1 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsSchedulerCanceled", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsSchedulerCanceled indicates an expected call of IsSchedulerCanceled. +func (mr *MockTaskTableMockRecorder) IsSchedulerCanceled(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsSchedulerCanceled", reflect.TypeOf((*MockTaskTable)(nil).IsSchedulerCanceled), arg0, arg1) +} + // StartSubtask mocks base method. func (m *MockTaskTable) StartSubtask(arg0 int64) error { m.ctrl.T.Helper() @@ -410,27 +425,3 @@ func (mr *MockInternalSchedulerMockRecorder) Run(arg0, arg1 interface{}) *gomock mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockInternalScheduler)(nil).Run), arg0, arg1) } - -// Start mocks base method. -func (m *MockInternalScheduler) Start() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Start") -} - -// Start indicates an expected call of Start. -func (mr *MockInternalSchedulerMockRecorder) Start() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockInternalScheduler)(nil).Start)) -} - -// Stop mocks base method. -func (m *MockInternalScheduler) Stop() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Stop") -} - -// Stop indicates an expected call of Stop. -func (mr *MockInternalSchedulerMockRecorder) Stop() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockInternalScheduler)(nil).Stop)) -} diff --git a/disttask/framework/scheduler/BUILD.bazel b/disttask/framework/scheduler/BUILD.bazel index d807c96616c58..273ee4e320f9f 100644 --- a/disttask/framework/scheduler/BUILD.bazel +++ b/disttask/framework/scheduler/BUILD.bazel @@ -13,6 +13,7 @@ go_library( deps = [ "//disttask/framework/proto", "//disttask/framework/storage", + "//domain/infosync", "//resourcemanager/pool/spool", "//resourcemanager/util", "//util/logutil", diff --git a/disttask/framework/scheduler/interface.go b/disttask/framework/scheduler/interface.go index 1b836b1fd84c1..d1063a4b88a05 100644 --- a/disttask/framework/scheduler/interface.go +++ b/disttask/framework/scheduler/interface.go @@ -31,6 +31,7 @@ type TaskTable interface { FinishSubtask(id int64, meta []byte) error HasSubtasksInStates(instanceID string, taskID int64, step int64, states ...interface{}) (bool, error) UpdateErrorToSubtask(tidbID string, err error) error + IsSchedulerCanceled(taskID int64, execID string) (bool, error) } // Pool defines the interface of a pool. @@ -42,8 +43,6 @@ type Pool interface { // InternalScheduler defines the interface of an internal scheduler. type InternalScheduler interface { - Start() - Stop() Run(context.Context, *proto.Task) error Rollback(context.Context, *proto.Task) error } diff --git a/disttask/framework/scheduler/manager.go b/disttask/framework/scheduler/manager.go index df07a317ffd00..34d30afff7b52 100644 --- a/disttask/framework/scheduler/manager.go +++ b/disttask/framework/scheduler/manager.go @@ -17,10 +17,13 @@ package scheduler import ( "context" "sync" + "sync/atomic" "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/resourcemanager/pool/spool" "github.com/pingcap/tidb/resourcemanager/util" "github.com/pingcap/tidb/util/logutil" @@ -253,6 +256,14 @@ func (m *Manager) filterAlreadyHandlingTasks(tasks []*proto.Task) []*proto.Task return tasks[:i] } +// TestContext only used in tests. +type TestContext struct { + TestSyncSubtaskRun chan struct{} + mockDown atomic.Bool +} + +var testContexts sync.Map + // onRunnableTask handles a runnable task. func (m *Manager) onRunnableTask(ctx context.Context, taskID int64, taskType string) { logutil.Logger(m.logCtx).Info("onRunnableTask", zap.Any("task_id", taskID), zap.Any("type", taskType)) @@ -262,14 +273,23 @@ func (m *Manager) onRunnableTask(ctx context.Context, taskID int64, taskType str } // runCtx only used in scheduler.Run, cancel in m.fetchAndFastCancelTasks. scheduler := m.newScheduler(ctx, m.id, taskID, m.taskTable, m.subtaskExecutorPools[taskType]) - scheduler.Start() - defer scheduler.Stop() for { select { case <-ctx.Done(): return case <-time.After(checkTime): } + failpoint.Inject("mockStopManager", func() { + testContexts.Store(m.id, &TestContext{make(chan struct{}), atomic.Bool{}}) + go func() { + v, ok := testContexts.Load(m.id) + if ok { + <-v.(*TestContext).TestSyncSubtaskRun + m.Stop() + _ = infosync.MockGlobalServerInfoManagerEntry.DeleteByID(m.id) + } + }() + }) task, err := m.taskTable.GetGlobalTaskByID(taskID) if err != nil { m.onError(err) diff --git a/disttask/framework/scheduler/manager_test.go b/disttask/framework/scheduler/manager_test.go index a9a58dbebd9af..51c4b955c45e7 100644 --- a/disttask/framework/scheduler/manager_test.go +++ b/disttask/framework/scheduler/manager_test.go @@ -130,7 +130,6 @@ func TestOnRunnableTasks(t *testing.T) { mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil) mockPool.EXPECT().Run(gomock.Any()).DoAndReturn(runFn) - mockInternalScheduler.EXPECT().Start() mockTaskTable.EXPECT().GetGlobalTaskByID(taskID).Return(task, nil) mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil) @@ -151,7 +150,6 @@ func TestOnRunnableTasks(t *testing.T) { task3 := &proto.Task{ID: taskID, State: proto.TaskStateReverted, Step: proto.StepTwo} mockTaskTable.EXPECT().GetGlobalTaskByID(taskID).Return(task3, nil) - mockInternalScheduler.EXPECT().Stop() m.onRunnableTasks(context.Background(), []*proto.Task{task}) @@ -191,7 +189,6 @@ func TestManager(t *testing.T) { Return(true, nil) wg, runFn := getPoolRunFn() mockPool.EXPECT().Run(gomock.Any()).DoAndReturn(runFn) - mockInternalScheduler.EXPECT().Start() mockTaskTable.EXPECT().GetGlobalTaskByID(taskID1).Return(task1, nil).AnyTimes() mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID1, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). @@ -200,13 +197,11 @@ func TestManager(t *testing.T) { mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID1, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). Return(false, nil).AnyTimes() - mockInternalScheduler.EXPECT().Stop() // task2 mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID2, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). Return(true, nil) mockPool.EXPECT().Run(gomock.Any()).DoAndReturn(runFn) - mockInternalScheduler.EXPECT().Start() mockTaskTable.EXPECT().GetGlobalTaskByID(taskID2).Return(task2, nil).AnyTimes() mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID2, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). @@ -215,7 +210,6 @@ func TestManager(t *testing.T) { mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID2, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). Return(false, nil).AnyTimes() - mockInternalScheduler.EXPECT().Stop() // once for scheduler pool, once for subtask pool mockPool.EXPECT().ReleaseAndWait().Do(func() { wg.Wait() diff --git a/disttask/framework/scheduler/scheduler.go b/disttask/framework/scheduler/scheduler.go index 88bccc110ac50..8a50ad1aab23d 100644 --- a/disttask/framework/scheduler/scheduler.go +++ b/disttask/framework/scheduler/scheduler.go @@ -24,23 +24,26 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/tidb/disttask/framework/proto" "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) +const ( + // DefaultCheckSubtaskCanceledInterval is the default check interval for cancel cancelled subtasks. + DefaultCheckSubtaskCanceledInterval = 2 * time.Second +) + // TestSyncChan is used to sync the test. var TestSyncChan = make(chan struct{}) // InternalSchedulerImpl is the implementation of InternalScheduler. type InternalSchedulerImpl struct { - ctx context.Context - cancel context.CancelFunc // id, it's the same as server id now, i.e. host:port. id string taskID int64 taskTable TaskTable pool Pool - wg sync.WaitGroup logCtx context.Context mu struct { @@ -54,7 +57,7 @@ type InternalSchedulerImpl struct { } // NewInternalScheduler creates a new InternalScheduler. -func NewInternalScheduler(ctx context.Context, id string, taskID int64, taskTable TaskTable, pool Pool) InternalScheduler { +func NewInternalScheduler(_ context.Context, id string, taskID int64, taskTable TaskTable, pool Pool) InternalScheduler { logPrefix := fmt.Sprintf("id: %s, task_id: %d", id, taskID) schedulerImpl := &InternalSchedulerImpl{ id: id, @@ -63,24 +66,35 @@ func NewInternalScheduler(ctx context.Context, id string, taskID int64, taskTabl pool: pool, logCtx: logutil.WithKeyValue(context.Background(), "scheduler", logPrefix), } - schedulerImpl.ctx, schedulerImpl.cancel = context.WithCancel(ctx) - return schedulerImpl } -// Start starts the scheduler. -func (*InternalSchedulerImpl) Start() { - // s.wg.Add(1) - // go func() { - // defer s.wg.Done() - // s.heartbeat() - // }() -} - -// Stop stops the scheduler. -func (s *InternalSchedulerImpl) Stop() { - s.cancel() - s.wg.Wait() +func (s *InternalSchedulerImpl) startCancelCheck(ctx context.Context, wg *sync.WaitGroup, cancelFn context.CancelFunc) { + wg.Add(1) + go func() { + defer wg.Done() + ticker := time.NewTicker(DefaultCheckSubtaskCanceledInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + logutil.Logger(s.logCtx).Info("scheduler exits", zap.Error(ctx.Err())) + return + case <-ticker.C: + canceled, err := s.taskTable.IsSchedulerCanceled(s.taskID, s.id) + logutil.Logger(s.logCtx).Info("scheduler before canceled") + if err != nil { + continue + } + if canceled { + logutil.Logger(s.logCtx).Info("scheduler canceled") + if cancelFn != nil { + cancelFn() + } + } + } + } + }() } // Run runs the scheduler task. @@ -100,7 +114,6 @@ func (s *InternalSchedulerImpl) run(ctx context.Context, task *proto.Task) error runCtx, runCancel := context.WithCancel(ctx) defer runCancel() s.registerCancelFunc(runCancel) - s.resetError() logutil.Logger(s.logCtx).Info("scheduler run a step", zap.Any("step", task.Step), zap.Any("concurrency", task.Concurrency)) scheduler, err := createScheduler(ctx, task) @@ -116,11 +129,18 @@ func (s *InternalSchedulerImpl) run(ctx context.Context, task *proto.Task) error s.onError(err) return s.getError() } + + var wg sync.WaitGroup + cancelCtx, checkCancel := context.WithCancel(ctx) + s.startCancelCheck(cancelCtx, &wg, runCancel) + defer func() { err := scheduler.CleanupSubtaskExecEnv(runCtx) if err != nil { logutil.Logger(s.logCtx).Error("cleanup subtask exec env failed", zap.Error(err)) } + checkCancel() + wg.Wait() }() minimalTaskCh := make(chan func(), task.Concurrency) @@ -150,6 +170,15 @@ func (s *InternalSchedulerImpl) run(ctx context.Context, task *proto.Task) error if err := s.getError(); err != nil { break } + failpoint.Inject("mockCleanScheduler", func() { + v, ok := testContexts.Load(s.id) + if ok { + if v.(*TestContext).mockDown.Load() { + failpoint.Break() + } + } + }) + s.runSubtask(runCtx, scheduler, subtask, minimalTaskCh) } return s.getError() @@ -176,6 +205,36 @@ func (s *InternalSchedulerImpl) runSubtask(ctx context.Context, scheduler Schedu zap.Int64("subtask_id", subtask.ID), zap.Int64("subtask_step", subtask.Step)) + failpoint.Inject("mockTiDBDown", func(val failpoint.Value) { + if s.id == val.(string) || s.id == ":4001" || s.id == ":4002" { + v, ok := testContexts.Load(s.id) + if ok { + v.(*TestContext).TestSyncSubtaskRun <- struct{}{} + v.(*TestContext).mockDown.Store(true) + time.Sleep(2 * time.Second) + failpoint.Return() + } + } + }) + failpoint.Inject("mockTiDBDown2", func() { + if s.id == ":4003" && subtask.Step == proto.StepTwo { + v, ok := testContexts.Load(s.id) + if ok { + v.(*TestContext).TestSyncSubtaskRun <- struct{}{} + v.(*TestContext).mockDown.Store(true) + time.Sleep(2 * time.Second) + return + } + } + }) + + failpoint.Inject("mockTiDBPartitionThenResume", func(val failpoint.Value) { + if val.(bool) && (s.id == ":4000" || s.id == ":4001" || s.id == ":4002") { + _ = infosync.MockGlobalServerInfoManagerEntry.DeleteByID(s.id) + time.Sleep(20 * time.Second) + } + }) + var minimalTaskWg sync.WaitGroup for _, minimalTask := range minimalTasks { minimalTaskWg.Add(1) @@ -223,7 +282,6 @@ func (s *InternalSchedulerImpl) onSubtaskFinished(ctx context.Context, scheduler } func (s *InternalSchedulerImpl) runMinimalTask(minimalTaskCtx context.Context, minimalTask proto.MinimalTask, tp string, step int64) { - logutil.Logger(s.logCtx).Info("scheduler run a minimalTask", zap.Any("step", step), zap.Stringer("minimal_task", minimalTask)) select { case <-minimalTaskCtx.Done(): s.onError(minimalTaskCtx.Err()) @@ -233,12 +291,13 @@ func (s *InternalSchedulerImpl) runMinimalTask(minimalTaskCtx context.Context, m if s.getError() != nil { return } - + logutil.Logger(s.logCtx).Info("scheduler run a minimalTask", zap.Any("step", step), zap.Stringer("minimal_task", minimalTask)) executor, err := createSubtaskExecutor(minimalTask, tp, step) if err != nil { s.onError(err) return } + failpoint.Inject("MockExecutorRunErr", func(val failpoint.Value) { if val.(bool) { s.onError(errors.New("MockExecutorRunErr")) @@ -260,6 +319,7 @@ func (s *InternalSchedulerImpl) runMinimalTask(minimalTaskCtx context.Context, m if err = executor.Run(minimalTaskCtx); err != nil { s.onError(err) } + logutil.Logger(s.logCtx).Info("minimal task done", zap.Stringer("minimal_task", minimalTask)) } // Rollback rollbacks the scheduler task. diff --git a/disttask/framework/scheduler/scheduler_test.go b/disttask/framework/scheduler/scheduler_test.go index 3326b4c1541aa..587946251b7d3 100644 --- a/disttask/framework/scheduler/scheduler_test.go +++ b/disttask/framework/scheduler/scheduler_test.go @@ -70,6 +70,9 @@ func TestSchedulerRun(t *testing.T) { mockScheduler := mock.NewMockScheduler(ctrl) mockSubtaskExecutor := mock.NewMockSubtaskExecutor(ctrl) + // check cancel loop will call it once. + mockSubtaskTable.EXPECT().IsSchedulerCanceled(gomock.Any(), gomock.Any()).Return(false, nil).Times(1) + // 1. no scheduler constructor schedulerRegisterErr := errors.Errorf("constructor of scheduler for key %s not found", getKey(tp, proto.StepOne)) scheduler := NewInternalScheduler(ctx, "id", 1, mockSubtaskTable, mockPool) @@ -365,8 +368,6 @@ func TestScheduler(t *testing.T) { }) scheduler := NewInternalScheduler(ctx, "id", 1, mockSubtaskTable, mockPool) - scheduler.Start() - defer scheduler.Stop() poolWg, runWithConcurrencyFn := getRunWithConcurrencyFn() diff --git a/disttask/framework/storage/table_test.go b/disttask/framework/storage/table_test.go index de5c4dba0e0ba..a72ffb0a8235f 100644 --- a/disttask/framework/storage/table_test.go +++ b/disttask/framework/storage/table_test.go @@ -256,6 +256,33 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.Equal(t, subtask2.StartTime, subtask.StartTime) require.Greater(t, subtask2.UpdateTime, subtask.UpdateTime) + + // test UpdateFailedSchedulerIDs and IsSchedulerCanceled + canceled, err := sm.IsSchedulerCanceled(4, "for_test999") + require.NoError(t, err) + require.True(t, canceled) + canceled, err = sm.IsSchedulerCanceled(4, "for_test1") + require.NoError(t, err) + require.False(t, canceled) + canceled, err = sm.IsSchedulerCanceled(4, "for_test2") + require.NoError(t, err) + require.True(t, canceled) + + require.NoError(t, sm.UpdateSubtaskStateAndError(4, proto.TaskStateRunning, nil)) + require.NoError(t, sm.UpdateFailedSchedulerIDs(4, map[string]string{ + "for_test1": "for_test999", + "for_test2": "for_test999", + })) + + canceled, err = sm.IsSchedulerCanceled(4, "for_test1") + require.NoError(t, err) + require.True(t, canceled) + canceled, err = sm.IsSchedulerCanceled(4, "for_test2") + require.NoError(t, err) + require.True(t, canceled) + canceled, err = sm.IsSchedulerCanceled(4, "for_test999") + require.NoError(t, err) + require.False(t, canceled) } func TestBothGlobalAndSubTaskTable(t *testing.T) { diff --git a/disttask/framework/storage/task_table.go b/disttask/framework/storage/task_table.go index 1171eaf5581dc..9238db6d972b4 100644 --- a/disttask/framework/storage/task_table.go +++ b/disttask/framework/storage/task_table.go @@ -511,6 +511,54 @@ func (stm *TaskManager) GetSchedulerIDsByTaskID(taskID int64) ([]string, error) return instanceIDs, nil } +// IsSchedulerCanceled checks if subtask 'execID' of task 'taskID' has been canceled somehow. +func (stm *TaskManager) IsSchedulerCanceled(taskID int64, execID string) (bool, error) { + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_background_subtask where task_key = %? and exec_id = %?", taskID, execID) + if err != nil { + return false, err + } + return len(rs) == 0, nil +} + +// UpdateFailedSchedulerIDs replace failed scheduler nodes with alive nodes. +func (stm *TaskManager) UpdateFailedSchedulerIDs(taskID int64, replaceNodes map[string]string) error { + // skip + if len(replaceNodes) == 0 { + return nil + } + + sql := new(strings.Builder) + if err := sqlexec.FormatSQL(sql, "update mysql.tidb_background_subtask set state = %? ,exec_id = (case ", proto.TaskStatePending); err != nil { + return err + } + for k, v := range replaceNodes { + if err := sqlexec.FormatSQL(sql, "when exec_id = %? then %? ", k, v); err != nil { + return err + } + } + if err := sqlexec.FormatSQL(sql, " end) where task_key = %? and state != \"succeed\" and exec_id in (", taskID); err != nil { + return err + } + i := 0 + for k := range replaceNodes { + if i != 0 { + if err := sqlexec.FormatSQL(sql, ","); err != nil { + return err + } + } + if err := sqlexec.FormatSQL(sql, "%?", k); err != nil { + return err + } + i++ + } + if err := sqlexec.FormatSQL(sql, ")"); err != nil { + return err + } + + _, err := stm.executeSQLWithNewSession(stm.ctx, sql.String()) + return err +} + // UpdateGlobalTaskAndAddSubTasks update the global task and add new subtasks func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, prevState string) (bool, error) { retryable := true diff --git a/disttask/importinto/dispatcher_test.go b/disttask/importinto/dispatcher_test.go index 809348bf18652..721025926d3dc 100644 --- a/disttask/importinto/dispatcher_test.go +++ b/disttask/importinto/dispatcher_test.go @@ -62,8 +62,9 @@ func (s *importIntoSuite) TestDispatcherGetEligibleInstances() { dsp := importDispatcher{} gTask := &proto.Task{Meta: []byte("{}")} + ctx := context.WithValue(context.Background(), "etcd", true) s.enableFailPoint("github.com/pingcap/tidb/domain/infosync/mockGetAllServerInfo", mockedAllServerInfos) - eligibleInstances, err := dsp.GetEligibleInstances(context.Background(), gTask) + eligibleInstances, err := dsp.GetEligibleInstances(ctx, gTask) s.NoError(err) // order of slice is not stable, change to map resultMap := map[string]*infosync.ServerInfo{} @@ -73,7 +74,7 @@ func (s *importIntoSuite) TestDispatcherGetEligibleInstances() { s.Equal(serverInfoMap, resultMap) gTask.Meta = []byte(`{"EligibleInstances":[{"ip": "1.1.1.1", "listening_port": 4000}]}`) - eligibleInstances, err = dsp.GetEligibleInstances(context.Background(), gTask) + eligibleInstances, err = dsp.GetEligibleInstances(ctx, gTask) s.NoError(err) s.Equal([]*infosync.ServerInfo{{IP: "1.1.1.1", Port: 4000}}, eligibleInstances) } diff --git a/domain/infosync/mock_info.go b/domain/infosync/mock_info.go index 0532cc0420104..452267ab103cf 100644 --- a/domain/infosync/mock_info.go +++ b/domain/infosync/mock_info.go @@ -15,6 +15,7 @@ package infosync import ( + "fmt" "sync" "time" @@ -55,6 +56,25 @@ func (m *MockGlobalServerInfoManager) Delete(idx int) error { return nil } +// DeleteByID delete ServerInfo by host:port. +func (m *MockGlobalServerInfoManager) DeleteByID(id string) error { + m.mu.Lock() + defer m.mu.Unlock() + idx := -1 + for i := 0; i < len(m.infos); i++ { + name := fmt.Sprintf("%s:%d", m.infos[i].IP, m.infos[i].Port) + if name == id { + idx = i + break + } + } + if idx == -1 { + return nil + } + m.infos = append(m.infos[:idx], m.infos[idx+1:]...) + return nil +} + // GetAllServerInfo return all serverInfo in a map. func (m *MockGlobalServerInfoManager) GetAllServerInfo() map[string]*ServerInfo { m.mu.Lock()