From fba2fb521fd3f06fd92bec04a51860335e4bfb28 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Mon, 9 Dec 2024 22:34:06 +0200 Subject: [PATCH 01/16] introduce nclprotocol transport layer --- pkg/compute/watchers/ncl_message_creator.go | 4 +- pkg/config/types/compute.go | 1 + pkg/models/messages/constants.go | 8 + pkg/models/node_info.go | 6 + pkg/node/compute.go | 97 ++-- pkg/node/constants.go | 8 - pkg/node/ncl.go | 58 -- pkg/node/requester.go | 97 ++-- pkg/orchestrator/watchers/event_logger.go | 4 +- .../watchers/ncl_message_creator.go | 39 +- pkg/transport/bprotocol/compute/transport.go | 6 + pkg/transport/bprotocol/errors.go | 8 + .../bprotocol/orchestrator/server.go | 10 + pkg/transport/forwarder/forwarder.go | 96 ---- pkg/transport/forwarder/forwarder_e2e_test.go | 289 ---------- pkg/transport/forwarder/forwarder_test.go | 155 ------ pkg/transport/mocks.go | 51 -- pkg/transport/nclprotocol/compute/config.go | 126 +++++ .../nclprotocol/compute/controlplane.go | 240 +++++++++ .../nclprotocol/compute/dataplane.go | 183 +++++++ pkg/transport/nclprotocol/compute/errors.go | 3 + .../nclprotocol/compute/health_tracker.go | 78 +++ pkg/transport/nclprotocol/compute/manager.go | 498 ++++++++++++++++++ .../{ => nclprotocol}/dispatcher/config.go | 4 +- .../dispatcher/config_test.go | 0 .../{ => nclprotocol}/dispatcher/constants.go | 0 .../dispatcher/dispatcher.go | 15 +- .../dispatcher/dispatcher_e2e_test.go | 2 +- .../dispatcher/dispatcher_test.go | 6 +- .../{ => nclprotocol}/dispatcher/errors.go | 0 .../{ => nclprotocol}/dispatcher/handler.go | 6 +- .../dispatcher/handler_test.go | 6 +- .../{ => nclprotocol}/dispatcher/recovery.go | 0 .../dispatcher/recovery_test.go | 0 .../{ => nclprotocol}/dispatcher/state.go | 2 +- .../dispatcher/state_test.go | 0 .../{ => nclprotocol}/dispatcher/utils.go | 0 pkg/transport/nclprotocol/mocks.go | 141 +++++ .../nclprotocol/orchestrator/config.go | 113 ++++ .../nclprotocol/orchestrator/dataplane.go | 263 +++++++++ .../nclprotocol/orchestrator/manager.go | 308 +++++++++++ pkg/transport/nclprotocol/registry.go | 42 ++ pkg/transport/nclprotocol/subjects.go | 29 + pkg/transport/nclprotocol/tracker.go | 52 ++ pkg/transport/nclprotocol/types.go | 76 +++ pkg/transport/types.go | 28 - 46 files changed, 2318 insertions(+), 840 deletions(-) delete mode 100644 pkg/node/ncl.go create mode 100644 pkg/transport/bprotocol/errors.go delete mode 100644 pkg/transport/forwarder/forwarder.go delete mode 100644 pkg/transport/forwarder/forwarder_e2e_test.go delete mode 100644 pkg/transport/forwarder/forwarder_test.go delete mode 100644 pkg/transport/mocks.go create mode 100644 pkg/transport/nclprotocol/compute/config.go create mode 100644 pkg/transport/nclprotocol/compute/controlplane.go create mode 100644 pkg/transport/nclprotocol/compute/dataplane.go create mode 100644 pkg/transport/nclprotocol/compute/errors.go create mode 100644 pkg/transport/nclprotocol/compute/health_tracker.go create mode 100644 pkg/transport/nclprotocol/compute/manager.go rename pkg/transport/{ => nclprotocol}/dispatcher/config.go (97%) rename pkg/transport/{ => nclprotocol}/dispatcher/config_test.go (100%) rename pkg/transport/{ => nclprotocol}/dispatcher/constants.go (100%) rename pkg/transport/{ => nclprotocol}/dispatcher/dispatcher.go (95%) rename pkg/transport/{ => nclprotocol}/dispatcher/dispatcher_e2e_test.go (99%) rename pkg/transport/{ => nclprotocol}/dispatcher/dispatcher_test.go (96%) rename pkg/transport/{ => nclprotocol}/dispatcher/errors.go (100%) rename pkg/transport/{ => nclprotocol}/dispatcher/handler.go (90%) rename pkg/transport/{ => nclprotocol}/dispatcher/handler_test.go (95%) rename pkg/transport/{ => nclprotocol}/dispatcher/recovery.go (100%) rename pkg/transport/{ => nclprotocol}/dispatcher/recovery_test.go (100%) rename pkg/transport/{ => nclprotocol}/dispatcher/state.go (98%) rename pkg/transport/{ => nclprotocol}/dispatcher/state_test.go (100%) rename pkg/transport/{ => nclprotocol}/dispatcher/utils.go (100%) create mode 100644 pkg/transport/nclprotocol/mocks.go create mode 100644 pkg/transport/nclprotocol/orchestrator/config.go create mode 100644 pkg/transport/nclprotocol/orchestrator/dataplane.go create mode 100644 pkg/transport/nclprotocol/orchestrator/manager.go create mode 100644 pkg/transport/nclprotocol/registry.go create mode 100644 pkg/transport/nclprotocol/subjects.go create mode 100644 pkg/transport/nclprotocol/tracker.go create mode 100644 pkg/transport/nclprotocol/types.go delete mode 100644 pkg/transport/types.go diff --git a/pkg/compute/watchers/ncl_message_creator.go b/pkg/compute/watchers/ncl_message_creator.go index 5da3c7ef91..09f5ae9653 100644 --- a/pkg/compute/watchers/ncl_message_creator.go +++ b/pkg/compute/watchers/ncl_message_creator.go @@ -8,7 +8,7 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" "github.com/bacalhau-project/bacalhau/pkg/models" "github.com/bacalhau-project/bacalhau/pkg/models/messages" - "github.com/bacalhau-project/bacalhau/pkg/transport" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" ) type NCLMessageCreator struct { @@ -67,4 +67,4 @@ func (d *NCLMessageCreator) CreateMessage(event watcher.Event) (*envelope.Messag } // compile-time check that NCLMessageCreator implements dispatcher.MessageCreator -var _ transport.MessageCreator = &NCLMessageCreator{} +var _ nclprotocol.MessageCreator = &NCLMessageCreator{} diff --git a/pkg/config/types/compute.go b/pkg/config/types/compute.go index 089643086c..074a956818 100644 --- a/pkg/config/types/compute.go +++ b/pkg/config/types/compute.go @@ -32,6 +32,7 @@ type Heartbeat struct { // InfoUpdateInterval specifies the time between updates of non-resource information to the orchestrator. InfoUpdateInterval Duration `yaml:"InfoUpdateInterval,omitempty" json:"InfoUpdateInterval,omitempty"` // ResourceUpdateInterval specifies the time between updates of resource information to the orchestrator. + // Deprecated: only used by legacy transport, will be removed in the future. ResourceUpdateInterval Duration `yaml:"ResourceUpdateInterval,omitempty" json:"ResourceUpdateInterval,omitempty"` // Interval specifies the time between heartbeat signals sent to the orchestrator. Interval Duration `yaml:"Interval,omitempty" json:"Interval,omitempty"` diff --git a/pkg/models/messages/constants.go b/pkg/models/messages/constants.go index c931b665fe..f3dfe70313 100644 --- a/pkg/models/messages/constants.go +++ b/pkg/models/messages/constants.go @@ -9,4 +9,12 @@ const ( BidResultMessageType = "BidResult" RunResultMessageType = "RunResult" ComputeErrorMessageType = "ComputeError" + + HandshakeRequestMessageType = "transport.HandshakeRequest" + HeartbeatRequestMessageType = "transport.HeartbeatRequest" + NodeInfoUpdateRequestMessageType = "transport.UpdateNodeInfoRequest" + + HandshakeResponseType = "transport.HandshakeResponse" + HeartbeatResponseType = "transport.HeartbeatResponse" + NodeInfoUpdateResponseType = "transport.UpdateNodeInfoResponse" ) diff --git a/pkg/models/node_info.go b/pkg/models/node_info.go index eadd2dbd3f..c8a4d34924 100644 --- a/pkg/models/node_info.go +++ b/pkg/models/node_info.go @@ -107,6 +107,12 @@ func (n NodeInfo) IsComputeNode() bool { return n.NodeType == NodeTypeCompute } +// HasNodeInfoChanged returns true if the node info has changed compared to the previous call +// TODO: implement this function +func HasNodeInfoChanged(prev, current NodeInfo) bool { + return false +} + // ComputeNodeInfo contains metadata about the current state and abilities of a compute node. Compute Nodes share // this state with Requester nodes by including it in the NodeInfo they share across the network. type ComputeNodeInfo struct { diff --git a/pkg/node/compute.go b/pkg/node/compute.go index 6842d30ec4..4e8b4b7ac5 100644 --- a/pkg/node/compute.go +++ b/pkg/node/compute.go @@ -20,7 +20,6 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/compute/watchers" "github.com/bacalhau-project/bacalhau/pkg/executor" executor_util "github.com/bacalhau-project/bacalhau/pkg/executor/util" - "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" "github.com/bacalhau-project/bacalhau/pkg/models" "github.com/bacalhau-project/bacalhau/pkg/nats" @@ -29,14 +28,14 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/publisher" "github.com/bacalhau-project/bacalhau/pkg/storage" bprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/bprotocol/compute" - "github.com/bacalhau-project/bacalhau/pkg/transport/dispatcher" + transportcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/dispatcher" ) type Compute struct { // Visible for testing ID string LocalEndpoint compute.Endpoint - LogstreamServer logstream.Server Capacity capacity.Tracker ExecutionStore store.ExecutionStore Executors executor.ExecProvider @@ -139,12 +138,6 @@ func NewComputeNode( }, }) - // logging server - logserver := logstream.NewServer(logstream.ServerParams{ - ExecutionStore: executionStore, - Executors: executors, - }) - bidder := NewBidder(cfg, allocatedResources, publishers, @@ -204,54 +197,46 @@ func NewComputeNode( err = nil } - // compute -> orchestrator ncl publisher - natsConn, err := clientFactory.CreateClient(ctx) - if err != nil { - return nil, err - } - messageRegistry := MustCreateMessageRegistry() - nclPublisher, err := ncl.NewOrderedPublisher(natsConn, ncl.OrderedPublisherConfig{ - Name: cfg.NodeID, - Destination: computeOutSubject(cfg.NodeID), - MessageRegistry: messageRegistry, + // connection manager + connectionManager, err := transportcompute.NewConnectionManager(transportcompute.Config{ + NodeID: cfg.NodeID, + ClientFactory: clientFactory, + NodeInfoProvider: nodeInfoProvider, + HeartbeatInterval: cfg.BacalhauConfig.Compute.Heartbeat.Interval.AsTimeDuration(), + NodeInfoUpdateInterval: cfg.BacalhauConfig.Compute.Heartbeat.InfoUpdateInterval.AsTimeDuration(), + DataPlaneMessageHandler: compute.NewMessageHandler(executionStore), + DataPlaneMessageCreator: watchers.NewNCLMessageCreator(), + EventStore: executionStore.GetEventStore(), + Checkpointer: executionStore, + DispatcherConfig: dispatcher.DefaultConfig(), + LogStreamServer: logstream.NewServer(logstream.ServerParams{ + ExecutionStore: executionStore, + Executors: executors, + }), }) if err != nil { return nil, err } - // orchestrator -> compute ncl subscriber - nclSubscriber, err := ncl.NewSubscriber(natsConn, ncl.SubscriberConfig{ - Name: cfg.NodeID, - MessageRegistry: messageRegistry, - MessageHandler: compute.NewMessageHandler(executionStore), - }) - if err != nil { - return nil, err - } - if err = nclSubscriber.Subscribe(ctx, computeInSubscription(cfg.NodeID)); err != nil { - return nil, err + if err = connectionManager.Start(ctx); err != nil { + return nil, fmt.Errorf("failed to start connection manager: %w", err) } - watcherRegistry, nclDispatcher, err := setupComputeWatchers( - ctx, executionStore, nclPublisher, bufferRunner, bidder) + watcherRegistry, err := setupComputeWatchers( + ctx, executionStore, bufferRunner, bidder) if err != nil { return nil, err } // A single Cleanup function to make sure the order of closing dependencies is correct cleanupFunc := func(ctx context.Context) { - if err = nclSubscriber.Close(ctx); err != nil { - log.Error().Err(err).Msg("failed to close ncl subscriber") - } - if nclDispatcher != nil { - if err = nclDispatcher.Stop(ctx); err != nil { - log.Error().Err(err).Msg("failed to stop dispatcher") - } - } if err = watcherRegistry.Stop(ctx); err != nil { log.Error().Err(err).Msg("failed to stop watcher registry") } legacyConnectionManager.Stop(ctx) + if err = connectionManager.Close(ctx); err != nil { + log.Error().Err(err).Msg("failed to stop connection manager") + } if err = executionStore.Close(ctx); err != nil { log.Error().Err(err).Msg("failed to close execution store") } @@ -263,7 +248,6 @@ func NewComputeNode( return &Compute{ ID: cfg.NodeID, LocalEndpoint: baseEndpoint, - LogstreamServer: logserver, Capacity: runningCapacityTracker, ExecutionStore: executionStore, Executors: executors, @@ -352,19 +336,19 @@ func NewBidder( func setupComputeWatchers( ctx context.Context, executionStore store.ExecutionStore, - nclPublisher ncl.OrderedPublisher, bufferRunner *compute.ExecutorBuffer, bidder compute.Bidder, -) (watcher.Manager, *dispatcher.Dispatcher, error) { +) (watcher.Manager, error) { watcherRegistry := watcher.NewManager(executionStore.GetEventStore()) // Set up execution logger watcher _, err := watcherRegistry.Create(ctx, computeExecutionLoggerWatcherID, watcher.WithHandler(watchers.NewExecutionLogger(log.Logger)), + watcher.WithEphemeral(), watcher.WithAutoStart(), watcher.WithInitialEventIterator(watcher.LatestIterator())) if err != nil { - return nil, nil, fmt.Errorf("failed to setup execution logger watcher: %w", err) + return nil, fmt.Errorf("failed to setup execution logger watcher: %w", err) } // Set up execution handler watcher @@ -378,29 +362,8 @@ func setupComputeWatchers( watcher.WithMaxRetries(3), watcher.WithInitialEventIterator(watcher.LatestIterator())) if err != nil { - return nil, nil, fmt.Errorf("failed to setup execution handler watcher: %w", err) - } - - // setup ncl dispatcher - nclDispatcherWatcher, err := watcherRegistry.Create(ctx, computeNCLDispatcherWatcherID, - watcher.WithFilter(watcher.EventFilter{ - ObjectTypes: []string{compute.EventObjectExecutionUpsert}, - }), - watcher.WithRetryStrategy(watcher.RetryStrategyBlock), - watcher.WithInitialEventIterator(watcher.LatestIterator())) - if err != nil { - return nil, nil, fmt.Errorf("failed to setup ncl dispatcher watcher: %w", err) - } - - nclDispatcher, err := dispatcher.New( - nclPublisher, nclDispatcherWatcher, watchers.NewNCLMessageCreator(), dispatcher.DefaultConfig()) - if err != nil { - return nil, nil, fmt.Errorf("failed to create dispatcher: %w", err) - } - - if err = nclDispatcher.Start(ctx); err != nil { - return nil, nil, fmt.Errorf("failed to start dispatcher: %w", err) + return nil, fmt.Errorf("failed to setup execution handler watcher: %w", err) } - return watcherRegistry, nclDispatcher, nil + return watcherRegistry, nil } diff --git a/pkg/node/constants.go b/pkg/node/constants.go index 6bdbd6c8e6..6a6371b4a5 100644 --- a/pkg/node/constants.go +++ b/pkg/node/constants.go @@ -5,18 +5,10 @@ const ( // and handles them locally by triggering the executor or bidder for example. computeExecutionHandlerWatcherID = "execution-handler" - // computeNCLDispatcherWatcherID is the ID of the watcher that listens for execution events - // and forwards them to the NCL dispatcher. - computeNCLDispatcherWatcherID = "compute-ncl-dispatcher" - // computeExecutionLoggerWatcherID is the ID of the watcher that listens for execution events // and logs them. computeExecutionLoggerWatcherID = "compute-logger" - // orchestratorNCLDispatcherWatcherID is the ID of the watcher that listens for execution events - // and forwards them to the NCL dispatcher. - orchestratorNCLDispatcherWatcherID = "orchestrator-ncl-dispatcher" - // orchestratorExecutionCancellerWatcherID is the ID of the watcher that listens for execution events // and cancels them the execution's observed state orchestratorExecutionCancellerWatcherID = "execution-canceller" diff --git a/pkg/node/ncl.go b/pkg/node/ncl.go deleted file mode 100644 index d76fafeefb..0000000000 --- a/pkg/node/ncl.go +++ /dev/null @@ -1,58 +0,0 @@ -package node - -import ( - "errors" - "fmt" - - "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" - "github.com/bacalhau-project/bacalhau/pkg/models/messages" -) - -// CreateMessageSerDeRegistry creates a new payload registry. -func CreateMessageSerDeRegistry() (*envelope.Registry, error) { - reg := envelope.NewRegistry() - err := errors.Join( - reg.Register(messages.AskForBidMessageType, messages.AskForBidRequest{}), - reg.Register(messages.BidAcceptedMessageType, messages.BidAcceptedRequest{}), - reg.Register(messages.BidRejectedMessageType, messages.BidRejectedRequest{}), - reg.Register(messages.CancelExecutionMessageType, messages.CancelExecutionRequest{}), - reg.Register(messages.BidResultMessageType, messages.BidResult{}), - reg.Register(messages.RunResultMessageType, messages.RunResult{}), - reg.Register(messages.ComputeErrorMessageType, messages.ComputeError{}), - ) - return reg, err -} - -// MustCreateMessageRegistry creates a new payload registry. -func MustCreateMessageRegistry() *envelope.Registry { - reg, err := CreateMessageSerDeRegistry() - if err != nil { - panic(err) - } - return reg -} - -// orchestratorSubjectSub returns the subject to subscribe to for orchestrator messages. -// it subscribes to outgoing messages from all compute nodes. -func orchestratorInSubscription() string { - return "bacalhau.global.compute.*.out.msgs" -} - -// orchestratorOutSubject returns the subject to publish orchestrator messages to. -// it publishes to the incoming subject of a specific compute node. -func orchestratorOutSubject(computeNodeID string) string { - return fmt.Sprintf("bacalhau.global.compute.%s.in.msgs", computeNodeID) -} - -// computeInSubscription returns the subject to subscribe to for compute messages. -// it subscribes to incoming messages directed to its own node. -func computeInSubscription(nodeID string) string { - return fmt.Sprintf("bacalhau.global.compute.%s.in.msgs", nodeID) -} - -// computeOutSubject returns the subject to publish compute messages to. -// it publishes to the outgoing subject of a specific compute node, which the -// orchestrator subscribes to. -func computeOutSubject(nodeID string) string { - return fmt.Sprintf("bacalhau.global.compute.%s.out.msgs", nodeID) -} diff --git a/pkg/node/requester.go b/pkg/node/requester.go index c6146759ad..cdea80e0a4 100644 --- a/pkg/node/requester.go +++ b/pkg/node/requester.go @@ -10,12 +10,11 @@ import ( "go.opentelemetry.io/otel/attribute" "github.com/bacalhau-project/bacalhau/pkg/bacerrors" - "github.com/bacalhau-project/bacalhau/pkg/compute" "github.com/bacalhau-project/bacalhau/pkg/jobstore" boltjobstore "github.com/bacalhau-project/bacalhau/pkg/jobstore/boltdb" - "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" "github.com/bacalhau-project/bacalhau/pkg/models" + natsutil "github.com/bacalhau-project/bacalhau/pkg/nats" "github.com/bacalhau-project/bacalhau/pkg/nats/proxy" nats_transport "github.com/bacalhau-project/bacalhau/pkg/nats/transport" "github.com/bacalhau-project/bacalhau/pkg/node/metrics" @@ -38,7 +37,8 @@ import ( s3helper "github.com/bacalhau-project/bacalhau/pkg/s3" "github.com/bacalhau-project/bacalhau/pkg/system" bprotocolorchestrator "github.com/bacalhau-project/bacalhau/pkg/transport/bprotocol/orchestrator" - "github.com/bacalhau-project/bacalhau/pkg/transport/forwarder" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" + transportorchestrator "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/orchestrator" ) var ( @@ -263,32 +263,29 @@ func NewRequesterNode( return nil, pkgerrors.Wrap(err, "failed to start connection manager") } - // nclPublisher - messageSerDeRegistry := MustCreateMessageRegistry() - nclPublisher, err := ncl.NewOrderedPublisher(natsConn, ncl.OrderedPublisherConfig{ - Name: cfg.NodeID, - MessageRegistry: messageSerDeRegistry, + // connection manager + connectionManager, err := transportorchestrator.NewComputeManager(transportorchestrator.Config{ + NodeID: cfg.NodeID, + ClientFactory: natsutil.ClientFactoryFunc(transportLayer.CreateClient), + NodeManager: nodesManager, + HeartbeatTimeout: cfg.BacalhauConfig.Orchestrator.NodeManager.DisconnectTimeout.AsTimeDuration(), + DataPlaneMessageHandler: orchestrator.NewMessageHandler(jobStore), + DataPlaneMessageCreatorFactory: watchers.NewNCLMessageCreatorFactory(watchers.NCLMessageCreatorFactoryParams{ + ProtocolRouter: protocolRouter, + SubjectFn: nclprotocol.NatsSubjectComputeInMsgs, + }), + EventStore: jobStore.GetEventStore(), }) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create connection manager: %w", err) } - watcherRegistry, nclForwarder, err := setupOrchestratorWatchers( - ctx, jobStore, nclPublisher, evalBroker, protocolRouter) - if err != nil { - return nil, err + if err = connectionManager.Start(ctx); err != nil { + return nil, fmt.Errorf("failed to start connection manager: %w", err) } - // ncl subscriber - nclSubscriber, err := ncl.NewSubscriber(natsConn, ncl.SubscriberConfig{ - Name: cfg.NodeID, - MessageRegistry: messageSerDeRegistry, - MessageHandler: orchestrator.NewMessageHandler(jobStore), - }) + watcherRegistry, err := setupOrchestratorWatchers(ctx, jobStore, evalBroker) if err != nil { - return nil, pkgerrors.Wrap(err, "failed to create ncl subscriber") - } - if err = nclSubscriber.Subscribe(ctx, orchestratorInSubscription()); err != nil { return nil, err } @@ -298,14 +295,9 @@ func NewRequesterNode( // stop the legacy connection manager legacyConnectionManager.Stop(ctx) - // close the ncl subscriber - cleanupErr = nclSubscriber.Close(ctx) - if cleanupErr != nil { - logDebugIfContextCancelled(ctx, cleanupErr, "failed to cleanly shutdown ncl subscriber") - } - - if cleanupErr = nclForwarder.Stop(ctx); cleanupErr != nil { - logDebugIfContextCancelled(ctx, cleanupErr, "failed to cleanly shutdown ncl forwarder") + // stop the connection manager + if cleanupErr = connectionManager.Stop(ctx); cleanupErr != nil { + logDebugIfContextCancelled(ctx, cleanupErr, "failed to cleanly shutdown connection manager") } if cleanupErr = watcherRegistry.Stop(ctx); cleanupErr != nil { @@ -410,10 +402,8 @@ func createNodeManager(ctx context.Context, cfg NodeConfig, natsConn *nats.Conn) func setupOrchestratorWatchers( ctx context.Context, jobStore jobstore.Store, - nclPublisher ncl.OrderedPublisher, evalBroker orchestrator.EvaluationBroker, - protocolRouter *watchers.ProtocolRouter, -) (watcher.Manager, *forwarder.Forwarder, error) { +) (watcher.Manager, error) { watcherRegistry := watcher.NewManager(jobStore.GetEventStore()) // Start watching for evaluation events using latest iterator @@ -427,63 +417,40 @@ func setupOrchestratorWatchers( }), ) if err != nil { - return nil, nil, fmt.Errorf("failed to start evaluation watcher: %w", err) + return nil, fmt.Errorf("failed to start evaluation watcher: %w", err) } // Set up execution logger watcher _, err = watcherRegistry.Create(ctx, orchestratorExecutionLoggerWatcherID, watcher.WithHandler(watchers.NewExecutionLogger(log.Logger)), + watcher.WithEphemeral(), watcher.WithAutoStart(), + watcher.WithInitialEventIterator(watcher.LatestIterator()), + watcher.WithRetryStrategy(watcher.RetryStrategySkip), watcher.WithFilter(watcher.EventFilter{ ObjectTypes: []string{jobstore.EventObjectExecutionUpsert}, }), - watcher.WithInitialEventIterator(watcher.LatestIterator()), ) if err != nil { - return nil, nil, fmt.Errorf("failed to setup orchestrator logger watcher: %w", err) + return nil, fmt.Errorf("failed to setup orchestrator logger watcher: %w", err) } // Set up execution canceller watcher _, err = watcherRegistry.Create(ctx, orchestratorExecutionCancellerWatcherID, watcher.WithHandler(watchers.NewExecutionCanceller(jobStore)), watcher.WithAutoStart(), - watcher.WithFilter(watcher.EventFilter{ - ObjectTypes: []string{jobstore.EventObjectExecutionUpsert}, - }), watcher.WithInitialEventIterator(watcher.LatestIterator()), watcher.WithRetryStrategy(watcher.RetryStrategySkip), watcher.WithMaxRetries(3), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to setup orchestrator canceller watcher: %w", err) - } - - // setup ncl dispatcher - nclDispatcherWatcher, err := watcherRegistry.Create(ctx, orchestratorNCLDispatcherWatcherID, watcher.WithFilter(watcher.EventFilter{ - ObjectTypes: []string{compute.EventObjectExecutionUpsert}, + ObjectTypes: []string{jobstore.EventObjectExecutionUpsert}, }), - watcher.WithRetryStrategy(watcher.RetryStrategyBlock), - watcher.WithInitialEventIterator(watcher.LatestIterator())) - if err != nil { - return nil, nil, fmt.Errorf("failed to setup ncl dispatcher watcher: %w", err) - } - - nclMessageCreator := watchers.NewNCLMessageCreator(watchers.NCLMessageCreatorParams{ - ProtocolRouter: protocolRouter, - SubjectFn: orchestratorOutSubject, - }) - - nclForwarder, err := forwarder.New(nclPublisher, nclDispatcherWatcher, nclMessageCreator) + ) if err != nil { - return nil, nil, fmt.Errorf("failed to create forwarder: %w", err) - } - - if err = nclForwarder.Start(ctx); err != nil { - return nil, nil, fmt.Errorf("failed to start forwarder: %w", err) + return nil, fmt.Errorf("failed to setup orchestrator canceller watcher: %w", err) } - return watcherRegistry, nclForwarder, nil + return watcherRegistry, nil } func (r *Requester) cleanup(ctx context.Context) { diff --git a/pkg/orchestrator/watchers/event_logger.go b/pkg/orchestrator/watchers/event_logger.go index b61f9bb1a5..9db1979c84 100644 --- a/pkg/orchestrator/watchers/event_logger.go +++ b/pkg/orchestrator/watchers/event_logger.go @@ -41,7 +41,9 @@ func (e *ExecutionLogger) HandleEvent(ctx context.Context, event watcher.Event) } // Create base log event with common fields - logEvent := e.logger.Debug().Str("execution_id", upsert.Current.ID) + logEvent := e.logger.Debug(). + Str("execution_id", upsert.Current.ID). + Uint64("sequence_number", event.SeqNum) // Add state transition information if this is an update if upsert.Previous != nil { diff --git a/pkg/orchestrator/watchers/ncl_message_creator.go b/pkg/orchestrator/watchers/ncl_message_creator.go index 5551c56f6e..f7628a161e 100644 --- a/pkg/orchestrator/watchers/ncl_message_creator.go +++ b/pkg/orchestrator/watchers/ncl_message_creator.go @@ -11,15 +11,43 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" "github.com/bacalhau-project/bacalhau/pkg/models" "github.com/bacalhau-project/bacalhau/pkg/models/messages" - "github.com/bacalhau-project/bacalhau/pkg/transport" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" ) +type NCLMessageCreatorFactory struct { + protocolRouter *ProtocolRouter + subjectFn func(nodeID string) string +} + +type NCLMessageCreatorFactoryParams struct { + ProtocolRouter *ProtocolRouter + SubjectFn func(nodeID string) string +} + +// NewNCLMessageCreatorFactory creates a new NCL protocol dispatcher factory +func NewNCLMessageCreatorFactory(params NCLMessageCreatorFactoryParams) *NCLMessageCreatorFactory { + return &NCLMessageCreatorFactory{ + protocolRouter: params.ProtocolRouter, + subjectFn: params.SubjectFn, + } +} + +func (f *NCLMessageCreatorFactory) CreateMessageCreator(ctx context.Context, nodeID string) nclprotocol.MessageCreator { + return NewNCLMessageCreator(NCLMessageCreatorParams{ + NodeID: nodeID, + ProtocolRouter: f.protocolRouter, + SubjectFn: f.subjectFn, + }) +} + type NCLMessageCreator struct { + nodeID string protocolRouter *ProtocolRouter subjectFn func(nodeID string) string } type NCLMessageCreatorParams struct { + NodeID string ProtocolRouter *ProtocolRouter SubjectFn func(nodeID string) string } @@ -27,6 +55,7 @@ type NCLMessageCreatorParams struct { // NewNCLMessageCreator creates a new NCL protocol dispatcher func NewNCLMessageCreator(params NCLMessageCreatorParams) *NCLMessageCreator { return &NCLMessageCreator{ + nodeID: params.NodeID, protocolRouter: params.ProtocolRouter, subjectFn: params.SubjectFn, } @@ -47,6 +76,12 @@ func (d *NCLMessageCreator) CreateMessage(event watcher.Event) (*envelope.Messag return nil, bacerrors.New("upsert.Current is nil"). WithComponent(nclDispatcherErrComponent) } + + // Filter events not meant for the node this dispatcher is handling + if upsert.Current.NodeID != d.nodeID { + return nil, nil + } + execution := upsert.Current preferredProtocol, err := d.protocolRouter.PreferredProtocol(context.Background(), execution) if err != nil { @@ -129,4 +164,4 @@ func (d *NCLMessageCreator) createCancelMessage(upsert models.ExecutionUpsert) * } // compile-time check that NCLMessageCreator implements dispatcher.MessageCreator -var _ transport.MessageCreator = &NCLMessageCreator{} +var _ nclprotocol.MessageCreator = &NCLMessageCreator{} diff --git a/pkg/transport/bprotocol/compute/transport.go b/pkg/transport/bprotocol/compute/transport.go index ac94df4234..c2f4366396 100644 --- a/pkg/transport/bprotocol/compute/transport.go +++ b/pkg/transport/bprotocol/compute/transport.go @@ -9,6 +9,7 @@ import ( "fmt" "github.com/nats-io/nats.go" + "github.com/rs/zerolog/log" "github.com/bacalhau-project/bacalhau/pkg/compute" "github.com/bacalhau-project/bacalhau/pkg/compute/watchers" @@ -144,6 +145,11 @@ func (cm *ConnectionManager) Start(ctx context.Context) error { HeartbeatConfig: cm.config.HeartbeatConfig, }) if err = managementClient.RegisterNode(ctx); err != nil { + if errors.As(err, &bprotocol.ErrUpgradeAvailable) { + log.Info().Msg("Disabling bprotocol management client due to upgrade available") + cm.Stop(ctx) + return nil + } return fmt.Errorf("failed to register node with requester: %s", err) } diff --git a/pkg/transport/bprotocol/errors.go b/pkg/transport/bprotocol/errors.go new file mode 100644 index 0000000000..1e164a0b43 --- /dev/null +++ b/pkg/transport/bprotocol/errors.go @@ -0,0 +1,8 @@ +package bprotocol + +import ( + "fmt" +) + +// ErrUpgradeAvailable indicates that the orchestrator supports the NCLv1 protocol +var ErrUpgradeAvailable = fmt.Errorf("node supports NCLv1 protocol - legacy protocol disabled") diff --git a/pkg/transport/bprotocol/orchestrator/server.go b/pkg/transport/bprotocol/orchestrator/server.go index 41b6578e1a..5dc1c45ef6 100644 --- a/pkg/transport/bprotocol/orchestrator/server.go +++ b/pkg/transport/bprotocol/orchestrator/server.go @@ -52,6 +52,16 @@ func (h *Server) HandleMessage(ctx context.Context, message *envelope.Message) e // Register handles compute node registration requests func (h *Server) Register(ctx context.Context, request legacy.RegisterRequest) (*legacy.RegisterResponse, error) { + // Check if the node supports NCLv1 protocol + for _, protocol := range request.Info.SupportedProtocols { + if protocol == models.ProtocolNCLV1 { + return &legacy.RegisterResponse{ + Accepted: false, + Reason: bprotocol.ErrUpgradeAvailable.Error(), + }, nil + } + } + resp, err := h.nodeManager.Handshake(ctx, messages.HandshakeRequest{ NodeInfo: request.Info, StartTime: time.Now(), diff --git a/pkg/transport/forwarder/forwarder.go b/pkg/transport/forwarder/forwarder.go deleted file mode 100644 index 59f83fb43a..0000000000 --- a/pkg/transport/forwarder/forwarder.go +++ /dev/null @@ -1,96 +0,0 @@ -package forwarder - -import ( - "context" - "fmt" - "sync" - - "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" - "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" - "github.com/bacalhau-project/bacalhau/pkg/transport" -) - -// Forwarder forwards events from a watcher to a destination in order. -// Unlike Dispatcher, it provides no delivery guarantees or recovery mechanisms. -type Forwarder struct { - watcher watcher.Watcher - creator transport.MessageCreator - publisher ncl.OrderedPublisher - - running bool - mu sync.RWMutex -} - -func New( - publisher ncl.OrderedPublisher, watcher watcher.Watcher, creator transport.MessageCreator) (*Forwarder, error) { - if publisher == nil { - return nil, fmt.Errorf("publisher cannot be nil") - } - if watcher == nil { - return nil, fmt.Errorf("watcher cannot be nil") - } - if creator == nil { - return nil, fmt.Errorf("message creator cannot be nil") - } - - f := &Forwarder{ - watcher: watcher, - creator: creator, - publisher: publisher, - } - - if err := watcher.SetHandler(f); err != nil { - return nil, fmt.Errorf("failed to set handler: %w", err) - } - - return f, nil -} - -func (f *Forwarder) Start(ctx context.Context) error { - f.mu.Lock() - if f.running { - f.mu.Unlock() - return fmt.Errorf("forwarder already running") - } - f.running = true - f.mu.Unlock() - - return f.watcher.Start(ctx) -} - -func (f *Forwarder) Stop(ctx context.Context) error { - f.mu.Lock() - defer f.mu.Unlock() - if !f.running { - return nil - } - f.running = false - f.watcher.Stop(ctx) - return nil -} - -func (f *Forwarder) HandleEvent(ctx context.Context, event watcher.Event) error { - message, err := f.creator.CreateMessage(event) - if err != nil { - return fmt.Errorf("creaddte message failed: %w", err) - } - if message == nil { - return nil - } - - // Add sequence number for ordering - message.WithMetadataValue(transport.KeySeqNum, fmt.Sprint(event.SeqNum)) - message.WithMetadataValue(ncl.KeyMessageID, transport.GenerateMsgID(event)) - - // Publish request - request := ncl.NewPublishRequest(message) - if message.Metadata.Has(ncl.KeySubject) { - request = request.WithSubject(message.Metadata.Get(ncl.KeySubject)) - } - - if err = f.publisher.Publish(ctx, request); err != nil { - return err - } - - return nil -} diff --git a/pkg/transport/forwarder/forwarder_e2e_test.go b/pkg/transport/forwarder/forwarder_e2e_test.go deleted file mode 100644 index e2a48a99e8..0000000000 --- a/pkg/transport/forwarder/forwarder_e2e_test.go +++ /dev/null @@ -1,289 +0,0 @@ -//go:build unit || !integration - -package forwarder_test - -import ( - "context" - "fmt" - "reflect" - "testing" - "time" - - "github.com/nats-io/nats-server/v2/server" - "github.com/nats-io/nats.go" - "github.com/stretchr/testify/suite" - - "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" - "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" - "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" - "github.com/bacalhau-project/bacalhau/pkg/lib/watcher/boltdb" - watchertest "github.com/bacalhau-project/bacalhau/pkg/lib/watcher/test" - "github.com/bacalhau-project/bacalhau/pkg/logger" - testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils" - "github.com/bacalhau-project/bacalhau/pkg/transport" - "github.com/bacalhau-project/bacalhau/pkg/transport/forwarder" -) - -type ForwarderE2ETestSuite struct { - suite.Suite - ctx context.Context - cancel context.CancelFunc - natsServer *server.Server - nc *nats.Conn - store watcher.EventStore - registry *envelope.Registry - subscriber ncl.Subscriber - received []*envelope.Message - cleanupFuncs []func() -} - -func (s *ForwarderE2ETestSuite) SetupTest() { - logger.ConfigureTestLogging(s.T()) - s.ctx, s.cancel = context.WithCancel(context.Background()) - - // Start NATS server - s.natsServer, s.nc = testutils.StartNats(s.T()) - - // Create boltdb store and watcher - eventObjectSerializer := watcher.NewJSONSerializer() - s.Require().NoError(eventObjectSerializer.RegisterType("test", reflect.TypeOf(""))) - store, err := boltdb.NewEventStore( - watchertest.CreateBoltDB(s.T()), - boltdb.WithEventSerializer(eventObjectSerializer), - ) - s.Require().NoError(err) - s.store = store - - // Create registry - s.registry = envelope.NewRegistry() - s.Require().NoError(s.registry.Register("test", "string")) - - // Create subscriber - s.received = make([]*envelope.Message, 0) - var msgHandler ncl.MessageHandlerFunc = func(_ context.Context, msg *envelope.Message) error { - payload, _ := msg.GetPayload("string") - s.T().Logf("Received message: %s", payload) - s.received = append(s.received, msg) - return nil - } - subscriber, err := ncl.NewSubscriber(s.nc, ncl.SubscriberConfig{ - Name: "test-subscriber", - MessageRegistry: s.registry, - MessageSerializer: envelope.NewSerializer(), - MessageHandler: msgHandler, - }) - s.Require().NoError(err) - s.subscriber = subscriber - s.Require().NoError(s.subscriber.Subscribe(s.ctx, "test")) -} - -func (s *ForwarderE2ETestSuite) TearDownTest() { - for i := len(s.cleanupFuncs) - 1; i >= 0; i-- { - s.cleanupFuncs[i]() - } - s.cleanupFuncs = nil - - if s.subscriber != nil { - s.Require().NoError(s.subscriber.Close(context.Background())) - } - if s.nc != nil { - s.nc.Close() - } - if s.natsServer != nil && s.natsServer.Running() { - s.natsServer.Shutdown() - } - s.cancel() -} - -func (s *ForwarderE2ETestSuite) startForwarder() *forwarder.Forwarder { - w, err := watcher.New(s.ctx, "test-watcher", s.store) - s.Require().NoError(err) - - publisher, err := ncl.NewOrderedPublisher(s.nc, ncl.OrderedPublisherConfig{ - Name: "test-publisher", - Destination: "test", - MessageRegistry: s.registry, - MessageSerializer: envelope.NewSerializer(), - AckMode: ncl.NoAck, - }) - s.Require().NoError(err) - - f, err := forwarder.New(publisher, w, &testMessageCreator{}) - s.Require().NoError(err) - - s.Require().NoError(f.Start(s.ctx)) - - s.cleanupFuncs = append(s.cleanupFuncs, func() { - s.Require().NoError(f.Stop(s.ctx)) - s.Require().NoError(publisher.Close(s.ctx)) - }) - - return f -} - -func (s *ForwarderE2ETestSuite) TestEventFlow() { - // Create forwarder - s.startForwarder() - - // Store some events - s.storeEvents(5) - - // Wait for processing - s.Eventually(func() bool { - return len(s.received) == 5 - }, time.Second, 10*time.Millisecond) - s.Require().Equal(5, len(s.received)) - - // Verify messages were published in order - for i, msg := range s.received { - s.verifyMsg(msg, i+1) - } -} - -func (s *ForwarderE2ETestSuite) TestReconnection() { - // Create forwarder - s.startForwarder() - - // Store an event - s.storeEvent(1) - - // Wait for event to be published - s.Eventually(func() bool { - return len(s.received) == 1 - }, time.Second, 10*time.Millisecond) - - // Verify first message - s.Require().Lenf(s.received, 1, "received: %v", s.received) - s.verifyMsg(s.received[0], 1) - - // Stop NATS server - s.natsServer.Shutdown() - s.natsServer.WaitForShutdown() - - // wait for client to be disconnected - s.Eventually(func() bool { - return !s.nc.IsConnected() - }, time.Second, 10*time.Millisecond) - - // Store another event - should not be lost - s.storeEvent(2) - - // Restart NATS server - s.natsServer, _ = testutils.RestartNatsServer(s.T(), s.natsServer) - s.Eventually(func() bool { - return s.nc.IsConnected() - }, 5*time.Second, 10*time.Millisecond) - - // Store another event after reconnection - s.storeEvent(3) - - // Wait for new event - s.Eventually(func() bool { - return len(s.received) >= 3 - }, time.Second, 10*time.Millisecond) - - // Should've received all 3 events - s.Require().Lenf(s.received, 3, "received: %v", s.received) - s.verifyMsg(s.received[0], 1) - s.verifyMsg(s.received[1], 2) - s.verifyMsg(s.received[2], 3) -} - -func (s *ForwarderE2ETestSuite) TestNoResponders() { - // Create forwarder - s.startForwarder() - - // Stop subscriber - s.Require().NoError(s.subscriber.Close(s.ctx)) - - // Store events - s.storeEvent(1) - s.storeEvent(2) - - // sleep and verify no messages were received - time.Sleep(100 * time.Millisecond) - s.Require().Empty(s.received) - - // Restart subscriber - s.Require().NoError(s.subscriber.Subscribe(s.ctx, "test")) - - // Store more events - s.storeEvent(3) - s.storeEvent(4) - - // Wait for event to be published - s.Eventually(func() bool { - return len(s.received) >= 2 - }, time.Second, 10*time.Millisecond) - - // Verify the messages - s.Require().Lenf(s.received, 2, "received: %v", s.received) - s.verifyMsg(s.received[0], 3) - s.verifyMsg(s.received[1], 4) -} - -func (s *ForwarderE2ETestSuite) TestRestart() { - // Create forwarder - f := s.startForwarder() - - // Store some events - s.storeEvents(3) - - // Wait for events - s.Eventually(func() bool { - return len(s.received) == 3 - }, time.Second, 10*time.Millisecond) - - // Stop forwarder - s.Require().NoError(f.Stop(s.ctx)) - s.received = s.received[:0] - - // Store more events while stopped - s.storeEvent(4) - s.storeEvent(5) - - // Start new forwarder - should process all events from beginning - s.startForwarder() - - // Should receive all events since forwarder doesn't checkpoint - s.Eventually(func() bool { - return len(s.received) == 5 - }, time.Second, 10*time.Millisecond) - - for i, msg := range s.received { - s.verifyMsg(msg, i+1) - } -} - -func (s *ForwarderE2ETestSuite) storeEvent(index int) { - err := s.store.StoreEvent(s.ctx, watcher.StoreEventRequest{ - Operation: watcher.OperationCreate, - ObjectType: "test", - Object: fmt.Sprintf("event-%d", index), - }) - s.Require().NoError(err) -} - -func (s *ForwarderE2ETestSuite) storeEvents(count int) { - for i := 1; i <= count; i++ { - s.storeEvent(i) - } -} - -func (s *ForwarderE2ETestSuite) verifyMsg(msg *envelope.Message, i int) { - payload, ok := msg.GetPayload("") - s.Require().True(ok, "payload missing or not a string") - s.Contains(payload, fmt.Sprintf("event-%d", i)) - s.Require().Equal(fmt.Sprintf("%d", i), msg.Metadata.Get(transport.KeySeqNum)) -} - -// Helper implementation -type testMessageCreator struct{} - -func (c *testMessageCreator) CreateMessage(event watcher.Event) (*envelope.Message, error) { - return envelope.NewMessage(event.Object), nil -} - -func TestForwarderE2ETestSuite(t *testing.T) { - suite.Run(t, new(ForwarderE2ETestSuite)) -} diff --git a/pkg/transport/forwarder/forwarder_test.go b/pkg/transport/forwarder/forwarder_test.go deleted file mode 100644 index 2f3e79e3a4..0000000000 --- a/pkg/transport/forwarder/forwarder_test.go +++ /dev/null @@ -1,155 +0,0 @@ -package forwarder - -import ( - "context" - "fmt" - "testing" - - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" - - "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" - "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" - "github.com/bacalhau-project/bacalhau/pkg/transport" -) - -type ForwarderUnitTestSuite struct { - suite.Suite - ctrl *gomock.Controller - ctx context.Context - publisher *ncl.MockOrderedPublisher - watcher *watcher.MockWatcher - creator *transport.MockMessageCreator -} - -func (s *ForwarderUnitTestSuite) SetupTest() { - s.ctrl = gomock.NewController(s.T()) - s.ctx = context.Background() - s.publisher = ncl.NewMockOrderedPublisher(s.ctrl) - s.watcher = watcher.NewMockWatcher(s.ctrl) - s.creator = transport.NewMockMessageCreator(s.ctrl) -} - -func (s *ForwarderUnitTestSuite) TearDownTest() { - s.ctrl.Finish() -} - -func (s *ForwarderUnitTestSuite) TestNewForwarder() { - tests := []struct { - name string - setup func() (*Forwarder, error) - expectError string - }{ - { - name: "nil publisher", - setup: func() (*Forwarder, error) { - return New(nil, s.watcher, s.creator) - }, - expectError: "publisher cannot be nil", - }, - { - name: "nil watcher", - setup: func() (*Forwarder, error) { - return New(s.publisher, nil, s.creator) - }, - expectError: "watcher cannot be nil", - }, - { - name: "nil message creator", - setup: func() (*Forwarder, error) { - return New(s.publisher, s.watcher, nil) - }, - expectError: "message creator cannot be nil", - }, - { - name: "handler setup failure", - setup: func() (*Forwarder, error) { - s.watcher.EXPECT().SetHandler(gomock.Any()).Return(fmt.Errorf("handler error")) - return New(s.publisher, s.watcher, s.creator) - }, - expectError: "failed to set handler", - }, - { - name: "success", - setup: func() (*Forwarder, error) { - s.watcher.EXPECT().SetHandler(gomock.Any()).Return(nil) - return New(s.publisher, s.watcher, s.creator) - }, - }, - } - - for _, tc := range tests { - s.Run(tc.name, func() { - f, err := tc.setup() - if tc.expectError != "" { - s.Error(err) - s.ErrorContains(err, tc.expectError) - s.Nil(f) - } else { - s.NoError(err) - s.NotNil(f) - } - }) - } -} - -func (s *ForwarderUnitTestSuite) TestStartupFailure() { - s.watcher.EXPECT().SetHandler(gomock.Any()).Return(nil) - - f, err := New(s.publisher, s.watcher, s.creator) - s.Require().NoError(err) - - startErr := fmt.Errorf("start failed") - s.watcher.EXPECT().Start(gomock.Any()).Return(startErr) - - err = f.Start(s.ctx) - s.Error(err) - s.ErrorIs(err, startErr) -} - -func (s *ForwarderUnitTestSuite) TestDoubleStart() { - s.watcher.EXPECT().SetHandler(gomock.Any()).Return(nil) - s.watcher.EXPECT().Start(gomock.Any()).Return(nil) - - f, err := New(s.publisher, s.watcher, s.creator) - s.Require().NoError(err) - - err = f.Start(s.ctx) - s.NoError(err) - - err = f.Start(s.ctx) - s.Error(err) - s.Contains(err.Error(), "already running") -} - -func (s *ForwarderUnitTestSuite) TestStopNonStarted() { - s.watcher.EXPECT().SetHandler(gomock.Any()).Return(nil) - - f, err := New(s.publisher, s.watcher, s.creator) - s.Require().NoError(err) - - err = f.Stop(s.ctx) - s.NoError(err) -} - -func (s *ForwarderUnitTestSuite) TestDoubleStop() { - s.watcher.EXPECT().SetHandler(gomock.Any()).Return(nil) - s.watcher.EXPECT().Start(gomock.Any()).Return(nil) - - f, err := New(s.publisher, s.watcher, s.creator) - s.Require().NoError(err) - - err = f.Start(s.ctx) - s.NoError(err) - - s.watcher.EXPECT().Stop(gomock.Any()) - err = f.Stop(s.ctx) - s.NoError(err) - - err = f.Stop(s.ctx) - s.NoError(err) -} - -func TestForwarderUnitTestSuite(t *testing.T) { - suite.Run(t, new(ForwarderUnitTestSuite)) -} diff --git a/pkg/transport/mocks.go b/pkg/transport/mocks.go deleted file mode 100644 index cf17d7bdea..0000000000 --- a/pkg/transport/mocks.go +++ /dev/null @@ -1,51 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: types.go - -// Package transport is a generated GoMock package. -package transport - -import ( - reflect "reflect" - - envelope "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" - watcher "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" - gomock "go.uber.org/mock/gomock" -) - -// MockMessageCreator is a mock of MessageCreator interface. -type MockMessageCreator struct { - ctrl *gomock.Controller - recorder *MockMessageCreatorMockRecorder -} - -// MockMessageCreatorMockRecorder is the mock recorder for MockMessageCreator. -type MockMessageCreatorMockRecorder struct { - mock *MockMessageCreator -} - -// NewMockMessageCreator creates a new mock instance. -func NewMockMessageCreator(ctrl *gomock.Controller) *MockMessageCreator { - mock := &MockMessageCreator{ctrl: ctrl} - mock.recorder = &MockMessageCreatorMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMessageCreator) EXPECT() *MockMessageCreatorMockRecorder { - return m.recorder -} - -// CreateMessage mocks base method. -func (m *MockMessageCreator) CreateMessage(event watcher.Event) (*envelope.Message, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateMessage", event) - ret0, _ := ret[0].(*envelope.Message) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateMessage indicates an expected call of CreateMessage. -func (mr *MockMessageCreatorMockRecorder) CreateMessage(event interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateMessage", reflect.TypeOf((*MockMessageCreator)(nil).CreateMessage), event) -} diff --git a/pkg/transport/nclprotocol/compute/config.go b/pkg/transport/nclprotocol/compute/config.go new file mode 100644 index 0000000000..a9b4f9ccbf --- /dev/null +++ b/pkg/transport/nclprotocol/compute/config.go @@ -0,0 +1,126 @@ +package compute + +import ( + "errors" + "time" + + "github.com/benbjohnson/clock" + + "github.com/bacalhau-project/bacalhau/pkg/compute/logstream" + "github.com/bacalhau-project/bacalhau/pkg/lib/backoff" + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/lib/validate" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" + "github.com/bacalhau-project/bacalhau/pkg/models" + "github.com/bacalhau-project/bacalhau/pkg/nats" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/dispatcher" +) + +type Config struct { + NodeID string + ClientFactory nats.ClientFactory + NodeInfoProvider models.NodeInfoProvider + + MessageSerializer envelope.MessageSerializer + MessageRegistry *envelope.Registry + + // Control plane config + ReconnectInterval time.Duration + HeartbeatInterval time.Duration + HeartbeatMissFactor int + NodeInfoUpdateInterval time.Duration + RequestTimeout time.Duration + ReconnectBackoff backoff.Backoff + + // Data plane config + DataPlaneMessageHandler ncl.MessageHandler // Handles incoming messages + DataPlaneMessageCreator nclprotocol.MessageCreator // Creates messages for sending + EventStore watcher.EventStore + DispatcherConfig dispatcher.Config + LogStreamServer logstream.Server + + // Checkpoint config + Checkpointer nclprotocol.Checkpointer + CheckpointInterval time.Duration + + Clock clock.Clock +} + +// Validate checks if the config is valid +func (c *Config) Validate() error { + return errors.Join( + validate.NotBlank(c.NodeID, "nodeID cannot be blank"), + validate.NotNil(c.ClientFactory, "client factory cannot be nil"), + validate.NotNil(c.MessageSerializer, "message serializer cannot be nil"), + validate.NotNil(c.MessageRegistry, "message registry cannot be nil"), + validate.NotNil(c.NodeInfoProvider, "node info provider cannot be nil"), + validate.NotNil(c.DataPlaneMessageHandler, "data plane message handler cannot be nil"), + validate.NotNil(c.DataPlaneMessageCreator, "data plane message creator cannot be nil"), + + // validations for timing configs + validate.IsGreaterThanZero(c.HeartbeatInterval, "heartbeat interval must be positive"), + validate.IsGreaterThanZero(c.HeartbeatMissFactor, "heartbeat miss factor must be positive"), + validate.IsGreaterThanZero(c.NodeInfoUpdateInterval, "node info update interval must be positive"), + validate.IsGreaterThanZero(c.RequestTimeout, "request timeout must be positive"), + validate.IsGreaterThanZero(c.ReconnectInterval, "reconnect interval must be positive"), + validate.IsGreaterThanZero(c.CheckpointInterval, "checkpoint interval must be positive"), + + // validations for data plane components + validate.NotNil(c.EventStore, "event store cannot be nil"), + validate.NotNil(c.ReconnectBackoff, "backoff cannot be nil"), + validate.NotNil(c.Checkpointer, "checkpointer cannot be nil"), + + // Validate dispatcher config + c.DispatcherConfig.Validate(), + ) +} + +// DefaultConfig returns a new Config with default values +func DefaultConfig() Config { + // defaults for heartbeatInterval and nodeInfoUpdateInterval are provided by BacalhauConfig, + // and equal to 15 seconds and 1 minute respectively + return Config{ + HeartbeatMissFactor: 5, // allow up to 5 missed heartbeats before marking a node as disconnected + RequestTimeout: 10 * time.Second, + ReconnectInterval: 10 * time.Second, + CheckpointInterval: 30 * time.Second, + ReconnectBackoff: backoff.NewExponential(10*time.Second, 2*time.Minute), + MessageSerializer: envelope.NewSerializer(), + MessageRegistry: nclprotocol.MustCreateMessageRegistry(), + DispatcherConfig: dispatcher.DefaultConfig(), + Clock: clock.New(), + } +} + +func (c *Config) setDefaults() { + defaults := DefaultConfig() + if c.HeartbeatMissFactor == 0 { + c.HeartbeatMissFactor = defaults.HeartbeatMissFactor + } + if c.RequestTimeout == 0 { + c.RequestTimeout = defaults.RequestTimeout + } + if c.ReconnectInterval == 0 { + c.ReconnectInterval = defaults.ReconnectInterval + } + if c.CheckpointInterval == 0 { + c.CheckpointInterval = defaults.CheckpointInterval + } + if c.MessageSerializer == nil { + c.MessageSerializer = defaults.MessageSerializer + } + if c.MessageRegistry == nil { + c.MessageRegistry = defaults.MessageRegistry + } + if c.ReconnectBackoff == nil { + c.ReconnectBackoff = defaults.ReconnectBackoff + } + if c.DispatcherConfig == (dispatcher.Config{}) { + c.DispatcherConfig = defaults.DispatcherConfig + } + if c.Clock == nil { + c.Clock = defaults.Clock + } +} diff --git a/pkg/transport/nclprotocol/compute/controlplane.go b/pkg/transport/nclprotocol/compute/controlplane.go new file mode 100644 index 0000000000..18e9ab1722 --- /dev/null +++ b/pkg/transport/nclprotocol/compute/controlplane.go @@ -0,0 +1,240 @@ +package compute + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/rs/zerolog/log" + + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/models" + "github.com/bacalhau-project/bacalhau/pkg/models/messages" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" +) + +// ControlPlane manages the periodic control operations between a compute node and +// the orchestrator. It is responsible for: +// - Sending periodic heartbeats to indicate node health +// - Updating node information when changes occur +// - Maintaining checkpoints of message processing progress +type ControlPlane struct { + cfg Config // Global configuration for the control plane + + // Core components + requester ncl.Publisher // Used to send messages to orchestrator + healthTracker *HealthTracker // Tracks node health status + incomingSeqTracker *nclprotocol.SequenceTracker // Tracks processed message sequences + checkpointName string // Identifier for checkpoint storage + + // State tracking + latestNodeInfo models.NodeInfo // Cache of most recent node information + lastCheckpoint uint64 // Last checkpointed sequence number + + // Lifecycle management + stopCh chan struct{} // Signals background goroutines to stop + wg sync.WaitGroup // Tracks active background goroutines + mu sync.RWMutex // Protects state changes + running bool // Indicates if control plane is active +} + +// ControlPlaneParams encapsulates all dependencies needed to create a new ControlPlane +type ControlPlaneParams struct { + Config Config + Requester ncl.Publisher // For sending control messages + HealthTracker *HealthTracker // For health monitoring + IncomingSeqTracker *nclprotocol.SequenceTracker // For sequence tracking + CheckpointName string // For checkpoint identification +} + +// NewControlPlane creates a new ControlPlane instance with the provided parameters. +// It initializes the control plane but does not start any background operations. +func NewControlPlane(params ControlPlaneParams) (*ControlPlane, error) { + return &ControlPlane{ + cfg: params.Config, + requester: params.Requester, + healthTracker: params.HealthTracker, + incomingSeqTracker: params.IncomingSeqTracker, + checkpointName: params.CheckpointName, + lastCheckpoint: params.IncomingSeqTracker.GetLastSeqNum(), + stopCh: make(chan struct{}), + }, nil +} + +// Start begins the control plane operations. It launches a background goroutine +// that manages periodic tasks: +// - Heartbeat sending +// - Node info updates +// - Progress checkpointing +func (cp *ControlPlane) Start(ctx context.Context) error { + cp.mu.Lock() + defer cp.mu.Unlock() + + if cp.running { + return fmt.Errorf("control plane already running") + } + + cp.wg.Add(1) + go cp.run(ctx) + + cp.running = true + return nil +} + +// run is the main control loop that manages periodic operations. +// It uses separate timers for each operation type to ensure consistent intervals. +func (cp *ControlPlane) run(ctx context.Context) { + defer cp.wg.Done() + + // Initialize timers for periodic operations + heartbeat := time.NewTimer(cp.cfg.HeartbeatInterval) + nodeInfo := time.NewTimer(cp.cfg.NodeInfoUpdateInterval) + checkpoint := time.NewTimer(cp.cfg.CheckpointInterval) + + defer func() { + heartbeat.Stop() + nodeInfo.Stop() + checkpoint.Stop() + }() + + for { + select { + case <-ctx.Done(): + return + case <-cp.stopCh: + return + + case <-heartbeat.C: + if err := cp.heartbeat(ctx); err != nil { + log.Error().Err(err).Msg("Failed to send heartbeat") + } + heartbeat.Reset(cp.cfg.HeartbeatInterval) + + case <-nodeInfo.C: + if err := cp.updateNodeInfo(ctx); err != nil { + log.Error().Err(err).Msg("Failed to update node info") + } + nodeInfo.Reset(cp.cfg.NodeInfoUpdateInterval) + + case <-checkpoint.C: + if err := cp.checkpointProgress(ctx); err != nil { + log.Error().Err(err).Msg("Failed to checkpoint progress") + } + checkpoint.Reset(cp.cfg.CheckpointInterval) + } + } +} + +// heartbeat sends a heartbeat message to the orchestrator to indicate the node is alive +// and healthy. It includes: +// - Current available compute capacity +// - Queue usage information +// - Latest processed message sequence number +// Updates health tracking on successful heartbeat. +func (cp *ControlPlane) heartbeat(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, cp.cfg.RequestTimeout) + defer cancel() + + // Get latest node info for capacity reporting + nodeInfo := cp.cfg.NodeInfoProvider.GetNodeInfo(ctx) + cp.latestNodeInfo = nodeInfo + + msg := envelope.NewMessage(messages.HeartbeatRequest{ + NodeID: cp.latestNodeInfo.NodeID, + AvailableCapacity: nodeInfo.ComputeNodeInfo.AvailableCapacity, + QueueUsedCapacity: nodeInfo.ComputeNodeInfo.QueueUsedCapacity, + LastOrchestratorSeqNum: cp.incomingSeqTracker.GetLastSeqNum(), + }).WithMetadataValue(envelope.KeyMessageType, messages.HeartbeatRequestMessageType) + + response, err := cp.requester.Request(ctx, ncl.NewPublishRequest(msg)) + if err != nil { + return fmt.Errorf("heartbeat request failed: %w", err) + } + + payload, ok := response.GetPayload(messages.HeartbeatResponse{}) + if !ok { + return fmt.Errorf("invalid heartbeat response payload. expected messages.HeartbeatResponse, got %T", payload) + } + + cp.healthTracker.HeartbeatSuccess() + return nil +} + +// updateNodeInfo checks for changes in node information and sends updates to the +// orchestrator when changes are detected. This includes changes to: +// - Node capacity +// - Supported features +// - Configuration +// - Labels +// Updates health tracking on successful updates. +func (cp *ControlPlane) updateNodeInfo(ctx context.Context) error { + // Only send updates when node info has changed + prevNodeInfo := cp.latestNodeInfo + cp.latestNodeInfo = cp.cfg.NodeInfoProvider.GetNodeInfo(ctx) + if !models.HasNodeInfoChanged(prevNodeInfo, cp.latestNodeInfo) { + return nil + } + + log.Debug().Msg("Node info changed, sending update") + + ctx, cancel := context.WithTimeout(ctx, cp.cfg.RequestTimeout) + defer cancel() + + msg := envelope.NewMessage(messages.UpdateNodeInfoRequest{ + NodeInfo: cp.latestNodeInfo, + }).WithMetadataValue(envelope.KeyMessageType, messages.NodeInfoUpdateRequestMessageType) + + _, err := cp.requester.Request(ctx, ncl.NewPublishRequest(msg)) + if err != nil { + return fmt.Errorf("node info update request failed: %w", err) + } + + cp.healthTracker.UpdateSuccess() + return nil +} + +// checkpointProgress saves the latest processed message sequence number if it has +// changed since the last checkpoint. This allows for resuming message processing +// from the last known point after node restarts. +func (cp *ControlPlane) checkpointProgress(ctx context.Context) error { + newCheckpoint := cp.incomingSeqTracker.GetLastSeqNum() + if newCheckpoint == cp.lastCheckpoint { + return nil + } + if err := cp.cfg.Checkpointer.Checkpoint(ctx, cp.checkpointName, newCheckpoint); err != nil { + log.Error().Err(err).Msg("failed to checkpoint incoming sequence number") + } else { + cp.lastCheckpoint = newCheckpoint + } + return nil +} + +// Stop gracefully shuts down the control plane and waits for background operations +// to complete or until the context is cancelled. +func (cp *ControlPlane) Stop(ctx context.Context) error { + cp.mu.Lock() + if !cp.running { + cp.mu.Unlock() + return nil + } + + cp.running = false + close(cp.stopCh) + cp.mu.Unlock() + + // Wait for graceful shutdown + done := make(chan struct{}) + go func() { + cp.wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/pkg/transport/nclprotocol/compute/dataplane.go b/pkg/transport/nclprotocol/compute/dataplane.go new file mode 100644 index 0000000000..2c5ade09f8 --- /dev/null +++ b/pkg/transport/nclprotocol/compute/dataplane.go @@ -0,0 +1,183 @@ +package compute + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/nats-io/nats.go" + "github.com/rs/zerolog/log" + + "github.com/bacalhau-project/bacalhau/pkg/compute" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" + "github.com/bacalhau-project/bacalhau/pkg/nats/proxy" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/dispatcher" +) + +// watcherID is the unique identifier for the data plane event watcher +const watcherID = "compute-ncl-dispatcher" + +// DataPlane manages the data transfer operations between a compute node and the orchestrator. +// It is responsible for: +// - Setting up and managing the log streaming server +// - Reliable message publishing through ordered publisher +// - Event watching and dispatching +// - Maintaining message sequence ordering +type DataPlane struct { + config Config // Global configuration + + // Core messaging components + Client *nats.Conn // NATS connection for messaging + publisher ncl.OrderedPublisher // Handles ordered message publishing + dispatcher *dispatcher.Dispatcher // Manages event watching and dispatch + + // Sequence tracking + lastReceivedSeqNum uint64 // Last sequence number received from orchestrator + + // State management + mu sync.RWMutex // Protects state changes + running bool // Indicates if data plane is active +} + +// DataPlaneParams encapsulates the parameters needed to create a new DataPlane +type DataPlaneParams struct { + Config Config + Client *nats.Conn // NATS client connection + LastReceivedSeqNum uint64 // Initial sequence number for message ordering +} + +// NewDataPlane creates a new DataPlane instance with the provided parameters. +// It initializes the data plane but does not start any operations - Start() must be called. +func NewDataPlane(params DataPlaneParams) (*DataPlane, error) { + dp := &DataPlane{ + config: params.Config, + Client: params.Client, + lastReceivedSeqNum: params.LastReceivedSeqNum, + } + return dp, nil +} + +// Start initializes and begins data plane operations. This includes: +// 1. Setting up the log stream server for job output streaming +// 2. Creating an ordered publisher for reliable message delivery +// 3. Setting up event watching and dispatching +// 4. Starting the dispatcher +// +// Note that message subscriber and handler are not started here, as they must be started +// during the handshake and before the data plane is started to avoid message loss. +// +// If any component fails to initialize, cleanup is performed before returning error. +func (dp *DataPlane) Start(ctx context.Context) error { + dp.mu.Lock() + defer dp.mu.Unlock() + + if dp.running { + return fmt.Errorf("data plane already running") + } + + var err error + defer func() { + if err != nil { + if cleanupErr := dp.cleanup(ctx); cleanupErr != nil { + log.Warn().Err(cleanupErr).Msg("failed to cleanup after start error") + } + } + }() + + // Set up log streaming for job output + _, err = proxy.NewLogStreamHandler(ctx, proxy.LogStreamHandlerParams{ + Name: dp.config.NodeID, + Conn: dp.Client, + LogstreamServer: dp.config.LogStreamServer, + }) + + // Initialize ordered publisher for reliable message delivery + dp.publisher, err = ncl.NewOrderedPublisher(dp.Client, ncl.OrderedPublisherConfig{ + Name: dp.config.NodeID, + MessageRegistry: dp.config.MessageRegistry, + MessageSerializer: dp.config.MessageSerializer, + Destination: nclprotocol.NatsSubjectComputeOutMsgs(dp.config.NodeID), + }) + if err != nil { + return fmt.Errorf("failed to create publisher: %w", err) + } + + // Create event watcher starting from last known sequence + var dispatcherWatcher watcher.Watcher + dispatcherWatcher, err = watcher.New(ctx, watcherID, dp.config.EventStore, + watcher.WithRetryStrategy(watcher.RetryStrategyBlock), + watcher.WithInitialEventIterator(watcher.AfterSequenceNumberIterator(dp.lastReceivedSeqNum)), + watcher.WithFilter(watcher.EventFilter{ + ObjectTypes: []string{compute.EventObjectExecutionUpsert}, + }), + ) + if err != nil { + return fmt.Errorf("failed to create dispatcher watcher: %w", err) + } + + // Initialize dispatcher to handle event watching and publishing + dp.dispatcher, err = dispatcher.New( + dp.publisher, + dispatcherWatcher, + dp.config.DataPlaneMessageCreator, + dp.config.DispatcherConfig, + ) + if err != nil { + return fmt.Errorf("failed to create dispatcher: %w", err) + } + + // Start the dispatcher + if err = dp.dispatcher.Start(ctx); err != nil { + return fmt.Errorf("failed to start dispatcher: %w", err) + } + + dp.running = true + return nil +} + +// Stop gracefully shuts down all data plane operations. +// It ensures proper cleanup of resources by: +// 1. Stopping the dispatcher +// 2. Closing the publisher +// Any errors during cleanup are collected and returned. +func (dp *DataPlane) Stop(ctx context.Context) error { + dp.mu.Lock() + defer dp.mu.Unlock() + + if !dp.running { + return nil + } + + dp.running = false + return dp.cleanup(ctx) +} + +// cleanup handles the orderly shutdown of data plane components. +// It ensures resources are released in the correct order and collects any errors. +func (dp *DataPlane) cleanup(ctx context.Context) error { + var errs error + + // Stop dispatcher first to prevent new messages + if dp.dispatcher != nil { + if err := dp.dispatcher.Stop(ctx); err != nil { + errs = errors.Join(errs, err) + } + dp.dispatcher = nil + } + + // Then close the publisher + if dp.publisher != nil { + if err := dp.publisher.Close(ctx); err != nil { + errs = errors.Join(errs, err) + } + dp.publisher = nil + } + + if errs != nil { + return fmt.Errorf("failed to cleanup data plane: %w", errs) + } + return nil +} diff --git a/pkg/transport/nclprotocol/compute/errors.go b/pkg/transport/nclprotocol/compute/errors.go new file mode 100644 index 0000000000..8228b45be5 --- /dev/null +++ b/pkg/transport/nclprotocol/compute/errors.go @@ -0,0 +1,3 @@ +package compute + +const errComponent = "ComputeConnection" diff --git a/pkg/transport/nclprotocol/compute/health_tracker.go b/pkg/transport/nclprotocol/compute/health_tracker.go new file mode 100644 index 0000000000..05c8b99d88 --- /dev/null +++ b/pkg/transport/nclprotocol/compute/health_tracker.go @@ -0,0 +1,78 @@ +package compute + +import ( + "sync" + + "github.com/benbjohnson/clock" + + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" +) + +// HealthTracker monitors connection health and maintains status metrics. +// Thread-safe and uses an injectable clock for testing. +type HealthTracker struct { + health nclprotocol.ConnectionHealth + mu sync.RWMutex + clock clock.Clock +} + +// NewHealthTracker creates a new health tracker with the given clock +func NewHealthTracker(clock clock.Clock) *HealthTracker { + return &HealthTracker{ + health: nclprotocol.ConnectionHealth{ + StartTime: clock.Now(), + }, + clock: clock, + } +} + +// MarkConnected updates status when connection is established +func (ht *HealthTracker) MarkConnected() { + ht.mu.Lock() + defer ht.mu.Unlock() + + ht.health.CurrentState = nclprotocol.Connected + ht.health.ConnectedSince = ht.clock.Now() + ht.health.LastSuccessfulHeartbeat = ht.clock.Now() + ht.health.ConsecutiveFailures = 0 + ht.health.LastError = nil + ht.health.CurrentState = nclprotocol.Connected +} + +// MarkDisconnected updates status when connection is lost +func (ht *HealthTracker) MarkDisconnected(err error) { + ht.mu.Lock() + defer ht.mu.Unlock() + + ht.health.CurrentState = nclprotocol.Disconnected + ht.health.LastError = err + ht.health.ConsecutiveFailures++ +} + +// HeartbeatSuccess records successful heartbeat +func (ht *HealthTracker) HeartbeatSuccess() { + ht.mu.Lock() + defer ht.mu.Unlock() + ht.health.LastSuccessfulHeartbeat = ht.clock.Now() +} + +// UpdateSuccess records successful node info update +func (ht *HealthTracker) UpdateSuccess() { + ht.mu.Lock() + defer ht.mu.Unlock() + ht.health.LastSuccessfulUpdate = ht.clock.Now() +} + +// GetState returns current connection state +func (ht *HealthTracker) GetState() nclprotocol.ConnectionState { + ht.mu.RLock() + defer ht.mu.RUnlock() + return ht.health.CurrentState +} + +// GetHealth returns a copy of current health status +func (ht *HealthTracker) GetHealth() nclprotocol.ConnectionHealth { + ht.mu.RLock() + defer ht.mu.RUnlock() + return ht.health +} diff --git a/pkg/transport/nclprotocol/compute/manager.go b/pkg/transport/nclprotocol/compute/manager.go new file mode 100644 index 0000000000..90183d312e --- /dev/null +++ b/pkg/transport/nclprotocol/compute/manager.go @@ -0,0 +1,498 @@ +package compute + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/nats-io/nats.go" + "github.com/rs/zerolog/log" + + "github.com/bacalhau-project/bacalhau/pkg/bacerrors" + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/models/messages" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" +) + +const stateChangesBuffer = 32 + +// ConnectionManager handles the lifecycle of a compute node's connection to the orchestrator. +// It manages the complete connection lifecycle including: +// - Initial connection and handshake +// - Connection health monitoring +// - Automated reconnection with backoff +// - Control and data plane management +// - Connection state transitions +type ConnectionManager struct { + // Configuration for the connection manager + config Config + + // Active NATS connection + natsConn *nats.Conn + + // Core messaging components + subscriber ncl.Subscriber // Handles incoming data plane messages + controlPlane *ControlPlane // Manages periodic operations when connected + dataPlane *DataPlane // Handles outgoing message dispatch + + // Checkpointing configuration + incomingCheckpointName string // Name used for checkpoint storage + incomingSeqTracker *nclprotocol.SequenceTracker // Tracks processed message sequences + + // Health monitoring + healthTracker *HealthTracker // Tracks connection health and state + + // Lifecycle management + running bool // Whether the manager is currently running + stopCh chan struct{} // Signals shutdown to background goroutines + wg sync.WaitGroup // Tracks active background goroutines + + // Event handling + stateHandlers []nclprotocol.ConnectionStateHandler // Callbacks for state transitions + stateHandlersMu sync.RWMutex + stateChanges chan stateChange // Channel for ordered state change notifications + mu sync.RWMutex // Protects shared state access +} + +type stateChange struct { + state nclprotocol.ConnectionState + err error +} + +// NewConnectionManager creates a new connection manager with the given configuration. +// It initializes the manager but does not start any connections - Start() must be called. +func NewConnectionManager(cfg Config) (*ConnectionManager, error) { + cfg.setDefaults() + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + cm := &ConnectionManager{ + config: cfg, + healthTracker: NewHealthTracker(cfg.Clock), + incomingCheckpointName: fmt.Sprintf("incoming-%s", cfg.NodeID), + stopCh: make(chan struct{}), + stateChanges: make(chan stateChange, stateChangesBuffer), // buffered to avoid blocking + } + + return cm, nil +} + +// Start begins the connection management process. It launches background goroutines for: +// - Connection maintenance +// - Heartbeat sending +// - Node info updates +func (cm *ConnectionManager) Start(ctx context.Context) error { + cm.mu.Lock() + defer cm.mu.Unlock() + + if cm.running { + return bacerrors.New("connection manager already running"). + WithCode(bacerrors.BadRequestError). + WithComponent(errComponent) + } + + log.Info(). + Str("node_id", cm.config.NodeID). + Time("start_time", cm.healthTracker.GetHealth().StartTime). + Msg("Starting connection manager") + + // initialize sequence tracker + checkpoint, err := cm.config.Checkpointer.GetCheckpoint(ctx, cm.incomingCheckpointName) + if err != nil { + return fmt.Errorf("failed to get last checkpoint: %w", err) + } + cm.incomingSeqTracker = nclprotocol.NewSequenceTracker().WithLastSeqNum(checkpoint) + cm.config.NodeInfoProvider.GetNodeInfo(ctx) + + // create new channels in case the connection manager is restarted + cm.stopCh = make(chan struct{}) + cm.stateChanges = make(chan stateChange, stateChangesBuffer) + + // Start connection management in background + cm.wg.Add(1) + go cm.maintainConnection(context.TODO()) + + // Start state change notification handler + cm.wg.Add(1) + go cm.handleStateChanges() + + cm.running = true + return nil +} + +// Close gracefully shuts down the connection manager and all its components. +// It waits for background goroutines to complete or until the context is cancelled. +func (cm *ConnectionManager) Close(ctx context.Context) error { + cm.mu.Lock() + if !cm.running { + cm.mu.Unlock() + return nil + } + cm.running = false + close(cm.stopCh) + close(cm.stateChanges) + cm.mu.Unlock() + + // Wait for graceful shutdown + done := make(chan struct{}) + go func() { + cm.wg.Wait() + close(done) + }() + + select { + case <-done: + cm.cleanup(ctx) + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// cleanup performs orderly cleanup of connection manager components: +// 1. Stops the data plane +// 2. Cleans up control plane +// 3. Closes NATS connection +func (cm *ConnectionManager) cleanup(ctx context.Context) { + // Clean up data plane subscriber + if cm.subscriber != nil { + if err := cm.subscriber.Close(ctx); err != nil { + log.Error().Err(err).Msg("Failed to close subscriber") + } + cm.subscriber = nil + } + + // Clean up data plane + if cm.dataPlane != nil { + if err := cm.dataPlane.Stop(ctx); err != nil { + log.Error().Err(err).Msg("Failed to stop data plane") + } + cm.dataPlane = nil + } + + // Clean up control plane + if cm.controlPlane != nil { + if err := cm.controlPlane.Stop(ctx); err != nil { + log.Error().Err(err).Msg("Failed to stop control plane") + } + cm.controlPlane = nil + } + + // Clean up NATS connection last + if cm.natsConn != nil { + cm.natsConn.Close() + cm.natsConn = nil + } +} + +// connect attempts to establish a connection to the orchestrator. It follows these steps: +// 1. Creates NATS connection and transport components +// 2. Performs initial handshake with orchestrator +// 3. Sets up and starts control and data planes +func (cm *ConnectionManager) connect(ctx context.Context) error { + cm.mu.Lock() + defer cm.mu.Unlock() + + if cm.getState() == nclprotocol.Connected { + return nil + } + + log.Info().Str("node_id", cm.config.NodeID).Msg("Attempting to establish connection") + cm.transitionState(nclprotocol.Connecting, nil) + + var err error + defer func() { + if err != nil { + cm.cleanup(ctx) + cm.transitionState(nclprotocol.Disconnected, err) + } + }() + + if err = cm.setupTransport(ctx); err != nil { + return fmt.Errorf("failed to setup transport: %w", err) + } + + if err = cm.setupSubscriber(ctx); err != nil { + return fmt.Errorf("failed to setup subscriber: %w", err) + } + + requester, err := cm.setupRequester(ctx) + if err != nil { + return fmt.Errorf("failed to setup requester: %w", err) + } + + handshakeResponse, err := cm.performHandshake(ctx, requester) + if err != nil { + return fmt.Errorf("handshake failed: %w", err) + } + + if err = cm.setupControlPlane(ctx, requester); err != nil { + return fmt.Errorf("failed to setup control plane: %w", err) + } + + if err = cm.setupDataPlane(ctx, handshakeResponse); err != nil { + return fmt.Errorf("failed to setup data plane: %w", err) + } + + cm.transitionState(nclprotocol.Connected, nil) + return nil +} + +// setupTransport creates the NATS connection +func (cm *ConnectionManager) setupTransport(ctx context.Context) error { + var err error + cm.natsConn, err = cm.config.ClientFactory.CreateClient(ctx) + if err != nil { + return fmt.Errorf("failed to connect to NATS: %w", err) + } + return nil +} + +// setupRequester creates the control plane publisher +func (cm *ConnectionManager) setupRequester(ctx context.Context) (ncl.Publisher, error) { + return ncl.NewPublisher(cm.natsConn, ncl.PublisherConfig{ + Name: cm.config.NodeID, + Destination: nclprotocol.NatsSubjectComputeOutCtrl(cm.config.NodeID), + MessageSerializer: cm.config.MessageSerializer, + MessageRegistry: cm.config.MessageRegistry, + }) +} + +// setupSubscriber creates and starts the data plane message subscriber +func (cm *ConnectionManager) setupSubscriber(ctx context.Context) error { + var err error + cm.subscriber, err = ncl.NewSubscriber(cm.natsConn, ncl.SubscriberConfig{ + Name: cm.config.NodeID, + MessageRegistry: cm.config.MessageRegistry, + MessageSerializer: cm.config.MessageSerializer, + MessageHandler: cm.config.DataPlaneMessageHandler, + ProcessingNotifier: cm.incomingSeqTracker, + }) + if err != nil { + return fmt.Errorf("failed to create subscriber: %w", err) + } + + if err = cm.subscriber.Subscribe(ctx, nclprotocol.NatsSubjectComputeInMsgs(cm.config.NodeID)); err != nil { + return fmt.Errorf("failed to subscribe: %w", err) + } + return nil +} + +// performHandshake executes the initial handshake with the orchestrator +// sending node information and start time +func (cm *ConnectionManager) performHandshake( + ctx context.Context, requester ncl.Publisher) (messages.HandshakeResponse, error) { + ctx, cancel := context.WithTimeout(ctx, cm.config.RequestTimeout) + defer cancel() + + handshake := messages.HandshakeRequest{ + NodeInfo: cm.config.NodeInfoProvider.GetNodeInfo(ctx), + StartTime: cm.GetHealth().StartTime, + LastOrchestratorSeqNum: cm.incomingSeqTracker.GetLastSeqNum(), + } + + // Send handshake + msg := envelope.NewMessage(handshake). + WithMetadataValue(envelope.KeyMessageType, messages.HandshakeRequestMessageType) + + response, err := requester.Request(ctx, ncl.NewPublishRequest(msg)) + if err != nil { + return messages.HandshakeResponse{}, fmt.Errorf("handshake request failed: %w", err) + } + + payload, ok := response.GetPayload(messages.HandshakeResponse{}) + if !ok { + return messages.HandshakeResponse{}, fmt.Errorf( + "invalid handshake response payload. expected messages.HandshakeResponse, got %T", payload) + } + + return payload.(messages.HandshakeResponse), nil +} + +// setupControlPlane creates and starts the control plane +func (cm *ConnectionManager) setupControlPlane(ctx context.Context, requester ncl.Publisher) error { + var err error + cm.controlPlane, err = NewControlPlane(ControlPlaneParams{ + Config: cm.config, + Requester: requester, + HealthTracker: cm.healthTracker, + IncomingSeqTracker: cm.incomingSeqTracker, + CheckpointName: cm.incomingCheckpointName, + }) + if err != nil { + return fmt.Errorf("failed to create control plane: %w", err) + } + + if err = cm.controlPlane.Start(ctx); err != nil { + return fmt.Errorf("failed to start control plane: %w", err) + } + + return nil +} + +// setupDataPlane creates and starts the data plane +func (cm *ConnectionManager) setupDataPlane(ctx context.Context, handshake messages.HandshakeResponse) error { + var err error + cm.dataPlane, err = NewDataPlane(DataPlaneParams{ + Config: cm.config, + Client: cm.natsConn, + LastReceivedSeqNum: handshake.LastComputeSeqNum, + }) + if err != nil { + return fmt.Errorf("failed to create data plane: %w", err) + } + + if err = cm.dataPlane.Start(ctx); err != nil { + return fmt.Errorf("failed to start data plane: %w", err) + } + + return nil +} + +// maintainConnection runs a periodic loop that manages the connection lifecycle. +// It handles initial connection, health monitoring, and reconnection with backoff. +func (cm *ConnectionManager) maintainConnection(ctx context.Context) { + defer cm.wg.Done() + + // Create timer that fires immediately for first connection + timer := time.NewTimer(0) + defer timer.Stop() + + for { + select { + case <-cm.stopCh: + return + + case <-timer.C: + switch cm.getState() { + case nclprotocol.Disconnected: + if err := cm.connect(ctx); err != nil { + failures := cm.GetHealth().ConsecutiveFailures + backoffDuration := cm.config.ReconnectBackoff.BackoffDuration(failures) + + log.Error(). + Err(err). + Int("consecutiveFailures", failures). + Str("backoffDuration", backoffDuration.String()). + Msg("Connection attempt failed") + + cm.config.ReconnectBackoff.Backoff(ctx, failures) + } + + case nclprotocol.Connected: + cm.checkConnectionHealth() + + default: + // Ignore other states, such as connecting + } + + // Reset timer for next interval + timer.Reset(cm.config.ReconnectInterval) + } + } +} + +// checkConnectionHealth verifies the connection is healthy by checking: +// - Recent heartbeat activity +// - NATS connection status +func (cm *ConnectionManager) checkConnectionHealth() { + cm.mu.Lock() + defer cm.mu.Unlock() + + if cm.getState() != nclprotocol.Connected { + return + } + + // Consider connection unhealthy if: + // 1. No heartbeat succeeded within HeartbeatMissFactor intervals + // 2. NATS connection is closed/draining + now := cm.config.Clock.Now() + heartbeatDeadline := now.Add(-time.Duration(cm.config.HeartbeatMissFactor) * cm.config.HeartbeatInterval) + + var reason string + var unhealthy bool + if cm.GetHealth().LastSuccessfulHeartbeat.Before(heartbeatDeadline) { + reason = fmt.Sprintf("no heartbeat for %d intervals", cm.config.HeartbeatMissFactor) + unhealthy = true + } else if cm.natsConn.IsClosed() { + reason = "NATS connection closed" + unhealthy = true + } + + if unhealthy { + log.Warn(). + Time("lastHeartbeat", cm.GetHealth().LastSuccessfulHeartbeat). + Time("deadline", heartbeatDeadline). + Int("heartbeatMissFactor", cm.config.HeartbeatMissFactor). + Str("reason", reason). + Msg("Connection unhealthy, initiating reconnect") + cm.transitionState(nclprotocol.Disconnected, fmt.Errorf("connection unhealthy: %s", reason)) + } +} + +// transitionState handles state transitions between Connected/Disconnected/Connecting. +// It updates health metrics and notifies registered state change handlers. +func (cm *ConnectionManager) transitionState(newState nclprotocol.ConnectionState, err error) { + oldState := cm.getState() + if oldState == newState { + return + } + + // Update state tracking + if newState == nclprotocol.Connected { + cm.healthTracker.MarkConnected() + } else if newState == nclprotocol.Disconnected { + cm.healthTracker.MarkDisconnected(err) + } + + // Queue state change notification + select { + case cm.stateChanges <- stateChange{state: newState, err: err}: + log.Debug(). + Str("oldState", oldState.String()). + Str("newState", newState.String()). + Err(err). + Msg("Connection state changed") + default: + log.Error().Msg("State change notification channel full") + } +} + +func (cm *ConnectionManager) handleStateChanges() { + defer cm.wg.Done() + + for change := range cm.stateChanges { + cm.stateHandlersMu.RLock() + handlers := make([]nclprotocol.ConnectionStateHandler, len(cm.stateHandlers)) + copy(handlers, cm.stateHandlers) + cm.stateHandlersMu.RUnlock() + + for _, handler := range handlers { + handler(change.state) + } + } +} + +// OnStateChange registers a new handler to be called when the connection +// state changes. Handlers are called synchronously when state transitions occur. +func (cm *ConnectionManager) OnStateChange(handler nclprotocol.ConnectionStateHandler) { + cm.stateHandlersMu.Lock() + defer cm.stateHandlersMu.Unlock() + cm.stateHandlers = append(cm.stateHandlers, handler) +} + +// GetHealth returns the current health status of the connection including: +// - Timestamps of last successful operations +// - Current connection state +// - Error counts and details +func (cm *ConnectionManager) GetHealth() nclprotocol.ConnectionHealth { + return cm.healthTracker.GetHealth() +} + +// getState returns the current connection state. +func (cm *ConnectionManager) getState() nclprotocol.ConnectionState { + return cm.GetHealth().CurrentState +} diff --git a/pkg/transport/dispatcher/config.go b/pkg/transport/nclprotocol/dispatcher/config.go similarity index 97% rename from pkg/transport/dispatcher/config.go rename to pkg/transport/nclprotocol/dispatcher/config.go index ba97783218..f26e8fb1bb 100644 --- a/pkg/transport/dispatcher/config.go +++ b/pkg/transport/nclprotocol/dispatcher/config.go @@ -9,7 +9,7 @@ import ( const ( // Default checkpoint settings - defaultCheckpointInterval = 5 * time.Second + defaultCheckpointInterval = 30 * time.Second defaultCheckpointTimeout = 5 * time.Second // Default stall detection settings @@ -30,6 +30,7 @@ const ( type Config struct { // CheckpointInterval determines how often the dispatcher saves its progress. // Lower values provide better durability at the cost of more IO operations. + // Negative values disable checkpointing. // Default: 5 seconds CheckpointInterval time.Duration @@ -92,7 +93,6 @@ func DefaultConfig() Config { func (c *Config) Validate() error { return errors.Join( // Intervals must be positive - validate.IsGreaterThanZero(c.CheckpointInterval, "CheckpointInterval must be positive"), validate.IsGreaterThanZero(c.ProcessInterval, "ProcessInterval must be positive"), validate.IsGreaterThanZero(c.StallCheckInterval, "StallCheckInterval must be positive"), diff --git a/pkg/transport/dispatcher/config_test.go b/pkg/transport/nclprotocol/dispatcher/config_test.go similarity index 100% rename from pkg/transport/dispatcher/config_test.go rename to pkg/transport/nclprotocol/dispatcher/config_test.go diff --git a/pkg/transport/dispatcher/constants.go b/pkg/transport/nclprotocol/dispatcher/constants.go similarity index 100% rename from pkg/transport/dispatcher/constants.go rename to pkg/transport/nclprotocol/dispatcher/constants.go diff --git a/pkg/transport/dispatcher/dispatcher.go b/pkg/transport/nclprotocol/dispatcher/dispatcher.go similarity index 95% rename from pkg/transport/dispatcher/dispatcher.go rename to pkg/transport/nclprotocol/dispatcher/dispatcher.go index beb290a33c..7483e1d960 100644 --- a/pkg/transport/dispatcher/dispatcher.go +++ b/pkg/transport/nclprotocol/dispatcher/dispatcher.go @@ -14,7 +14,7 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" "github.com/bacalhau-project/bacalhau/pkg/lib/validate" "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" - "github.com/bacalhau-project/bacalhau/pkg/transport" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" ) // Dispatcher handles reliable delivery of events from a watcher to NATS. @@ -42,7 +42,7 @@ type Dispatcher struct { // Returns an error if any dependencies are nil or if config validation fails. func New(publisher ncl.OrderedPublisher, watcher watcher.Watcher, - messageCreator transport.MessageCreator, config Config) (*Dispatcher, error) { + messageCreator nclprotocol.MessageCreator, config Config) (*Dispatcher, error) { err := errors.Join( validate.NotNil(publisher, "publisher cannot be nil"), validate.NotNil(watcher, "watcher cannot be nil"), @@ -91,12 +91,17 @@ func (d *Dispatcher) Start(ctx context.Context) error { d.running = true d.mu.Unlock() - d.routinesWg.Add(3) // For the three goroutines - // Start background processing + d.routinesWg.Add(1) go d.processPublishResults(ctx) + + d.routinesWg.Add(1) go d.checkStalledMessages(ctx) - go d.checkpointLoop(ctx) + + if d.config.CheckpointInterval > 0 { + d.routinesWg.Add(1) + go d.checkpointLoop(ctx) + } // Start the watcher return d.watcher.Start(ctx) diff --git a/pkg/transport/dispatcher/dispatcher_e2e_test.go b/pkg/transport/nclprotocol/dispatcher/dispatcher_e2e_test.go similarity index 99% rename from pkg/transport/dispatcher/dispatcher_e2e_test.go rename to pkg/transport/nclprotocol/dispatcher/dispatcher_e2e_test.go index 9e0a633b69..fc5d8b37b9 100644 --- a/pkg/transport/dispatcher/dispatcher_e2e_test.go +++ b/pkg/transport/nclprotocol/dispatcher/dispatcher_e2e_test.go @@ -20,7 +20,7 @@ import ( watchertest "github.com/bacalhau-project/bacalhau/pkg/lib/watcher/test" "github.com/bacalhau-project/bacalhau/pkg/logger" testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils" - "github.com/bacalhau-project/bacalhau/pkg/transport/dispatcher" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/dispatcher" ) type DispatcherE2ETestSuite struct { diff --git a/pkg/transport/dispatcher/dispatcher_test.go b/pkg/transport/nclprotocol/dispatcher/dispatcher_test.go similarity index 96% rename from pkg/transport/dispatcher/dispatcher_test.go rename to pkg/transport/nclprotocol/dispatcher/dispatcher_test.go index a9da40996c..ce5496873a 100644 --- a/pkg/transport/dispatcher/dispatcher_test.go +++ b/pkg/transport/nclprotocol/dispatcher/dispatcher_test.go @@ -13,7 +13,7 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" - "github.com/bacalhau-project/bacalhau/pkg/transport" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" ) type DispatcherTestSuite struct { @@ -22,7 +22,7 @@ type DispatcherTestSuite struct { ctx context.Context publisher *ncl.MockOrderedPublisher watcher *watcher.MockWatcher - creator *transport.MockMessageCreator + creator *nclprotocol.MockMessageCreator config Config handler watcher.EventHandler } @@ -32,7 +32,7 @@ func (suite *DispatcherTestSuite) SetupTest() { suite.ctx = context.Background() suite.publisher = ncl.NewMockOrderedPublisher(suite.ctrl) suite.watcher = watcher.NewMockWatcher(suite.ctrl) - suite.creator = transport.NewMockMessageCreator(suite.ctrl) + suite.creator = nclprotocol.NewMockMessageCreator(suite.ctrl) suite.config = DefaultConfig() } diff --git a/pkg/transport/dispatcher/errors.go b/pkg/transport/nclprotocol/dispatcher/errors.go similarity index 100% rename from pkg/transport/dispatcher/errors.go rename to pkg/transport/nclprotocol/dispatcher/errors.go diff --git a/pkg/transport/dispatcher/handler.go b/pkg/transport/nclprotocol/dispatcher/handler.go similarity index 90% rename from pkg/transport/dispatcher/handler.go rename to pkg/transport/nclprotocol/dispatcher/handler.go index ee93fd5d28..91cc18986a 100644 --- a/pkg/transport/dispatcher/handler.go +++ b/pkg/transport/nclprotocol/dispatcher/handler.go @@ -8,16 +8,16 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" - "github.com/bacalhau-project/bacalhau/pkg/transport" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" ) type messageHandler struct { - creator transport.MessageCreator + creator nclprotocol.MessageCreator publisher ncl.OrderedPublisher state *dispatcherState } -func newMessageHandler(creator transport.MessageCreator, publisher ncl.OrderedPublisher, state *dispatcherState) *messageHandler { +func newMessageHandler(creator nclprotocol.MessageCreator, publisher ncl.OrderedPublisher, state *dispatcherState) *messageHandler { return &messageHandler{ creator: creator, publisher: publisher, diff --git a/pkg/transport/dispatcher/handler_test.go b/pkg/transport/nclprotocol/dispatcher/handler_test.go similarity index 95% rename from pkg/transport/dispatcher/handler_test.go rename to pkg/transport/nclprotocol/dispatcher/handler_test.go index da8de27069..c6bed41185 100644 --- a/pkg/transport/dispatcher/handler_test.go +++ b/pkg/transport/nclprotocol/dispatcher/handler_test.go @@ -13,14 +13,14 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" - "github.com/bacalhau-project/bacalhau/pkg/transport" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" ) type HandlerTestSuite struct { suite.Suite ctrl *gomock.Controller ctx context.Context - creator *transport.MockMessageCreator + creator *nclprotocol.MockMessageCreator publisher *ncl.MockOrderedPublisher state *dispatcherState handler *messageHandler @@ -29,7 +29,7 @@ type HandlerTestSuite struct { func (suite *HandlerTestSuite) SetupTest() { suite.ctrl = gomock.NewController(suite.T()) suite.ctx = context.Background() - suite.creator = transport.NewMockMessageCreator(suite.ctrl) + suite.creator = nclprotocol.NewMockMessageCreator(suite.ctrl) suite.publisher = ncl.NewMockOrderedPublisher(suite.ctrl) suite.state = newDispatcherState() suite.handler = newMessageHandler(suite.creator, suite.publisher, suite.state) diff --git a/pkg/transport/dispatcher/recovery.go b/pkg/transport/nclprotocol/dispatcher/recovery.go similarity index 100% rename from pkg/transport/dispatcher/recovery.go rename to pkg/transport/nclprotocol/dispatcher/recovery.go diff --git a/pkg/transport/dispatcher/recovery_test.go b/pkg/transport/nclprotocol/dispatcher/recovery_test.go similarity index 100% rename from pkg/transport/dispatcher/recovery_test.go rename to pkg/transport/nclprotocol/dispatcher/recovery_test.go diff --git a/pkg/transport/dispatcher/state.go b/pkg/transport/nclprotocol/dispatcher/state.go similarity index 98% rename from pkg/transport/dispatcher/state.go rename to pkg/transport/nclprotocol/dispatcher/state.go index f8cbdc53d3..2365a06231 100644 --- a/pkg/transport/dispatcher/state.go +++ b/pkg/transport/nclprotocol/dispatcher/state.go @@ -55,7 +55,7 @@ func (s *dispatcherState) getCheckpointSeqNum() uint64 { checkpointTarget = s.lastAckedSeqNum } - log.Debug().Uint64("lastCheckpoint", s.lastCheckpoint). + log.Trace().Uint64("lastCheckpoint", s.lastCheckpoint). Uint64("checkpointTarget", checkpointTarget). Uint64("lastObserved", s.lastObservedSeq). Uint64("lastAcked", s.lastAckedSeqNum). diff --git a/pkg/transport/dispatcher/state_test.go b/pkg/transport/nclprotocol/dispatcher/state_test.go similarity index 100% rename from pkg/transport/dispatcher/state_test.go rename to pkg/transport/nclprotocol/dispatcher/state_test.go diff --git a/pkg/transport/dispatcher/utils.go b/pkg/transport/nclprotocol/dispatcher/utils.go similarity index 100% rename from pkg/transport/dispatcher/utils.go rename to pkg/transport/nclprotocol/dispatcher/utils.go diff --git a/pkg/transport/nclprotocol/mocks.go b/pkg/transport/nclprotocol/mocks.go new file mode 100644 index 0000000000..dbc761c680 --- /dev/null +++ b/pkg/transport/nclprotocol/mocks.go @@ -0,0 +1,141 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: types.go + +// Package nclprotocol is a generated GoMock package. +package nclprotocol + +import ( + context "context" + reflect "reflect" + + envelope "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + watcher "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" + gomock "go.uber.org/mock/gomock" +) + +// MockCheckpointer is a mock of Checkpointer interface. +type MockCheckpointer struct { + ctrl *gomock.Controller + recorder *MockCheckpointerMockRecorder +} + +// MockCheckpointerMockRecorder is the mock recorder for MockCheckpointer. +type MockCheckpointerMockRecorder struct { + mock *MockCheckpointer +} + +// NewMockCheckpointer creates a new mock instance. +func NewMockCheckpointer(ctrl *gomock.Controller) *MockCheckpointer { + mock := &MockCheckpointer{ctrl: ctrl} + mock.recorder = &MockCheckpointerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCheckpointer) EXPECT() *MockCheckpointerMockRecorder { + return m.recorder +} + +// Checkpoint mocks base method. +func (m *MockCheckpointer) Checkpoint(ctx context.Context, name string, sequenceNumber uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Checkpoint", ctx, name, sequenceNumber) + ret0, _ := ret[0].(error) + return ret0 +} + +// Checkpoint indicates an expected call of Checkpoint. +func (mr *MockCheckpointerMockRecorder) Checkpoint(ctx, name, sequenceNumber interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Checkpoint", reflect.TypeOf((*MockCheckpointer)(nil).Checkpoint), ctx, name, sequenceNumber) +} + +// GetCheckpoint mocks base method. +func (m *MockCheckpointer) GetCheckpoint(ctx context.Context, name string) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCheckpoint", ctx, name) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCheckpoint indicates an expected call of GetCheckpoint. +func (mr *MockCheckpointerMockRecorder) GetCheckpoint(ctx, name interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCheckpoint", reflect.TypeOf((*MockCheckpointer)(nil).GetCheckpoint), ctx, name) +} + +// MockMessageCreator is a mock of MessageCreator interface. +type MockMessageCreator struct { + ctrl *gomock.Controller + recorder *MockMessageCreatorMockRecorder +} + +// MockMessageCreatorMockRecorder is the mock recorder for MockMessageCreator. +type MockMessageCreatorMockRecorder struct { + mock *MockMessageCreator +} + +// NewMockMessageCreator creates a new mock instance. +func NewMockMessageCreator(ctrl *gomock.Controller) *MockMessageCreator { + mock := &MockMessageCreator{ctrl: ctrl} + mock.recorder = &MockMessageCreatorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMessageCreator) EXPECT() *MockMessageCreatorMockRecorder { + return m.recorder +} + +// CreateMessage mocks base method. +func (m *MockMessageCreator) CreateMessage(event watcher.Event) (*envelope.Message, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateMessage", event) + ret0, _ := ret[0].(*envelope.Message) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateMessage indicates an expected call of CreateMessage. +func (mr *MockMessageCreatorMockRecorder) CreateMessage(event interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateMessage", reflect.TypeOf((*MockMessageCreator)(nil).CreateMessage), event) +} + +// MockMessageCreatorFactory is a mock of MessageCreatorFactory interface. +type MockMessageCreatorFactory struct { + ctrl *gomock.Controller + recorder *MockMessageCreatorFactoryMockRecorder +} + +// MockMessageCreatorFactoryMockRecorder is the mock recorder for MockMessageCreatorFactory. +type MockMessageCreatorFactoryMockRecorder struct { + mock *MockMessageCreatorFactory +} + +// NewMockMessageCreatorFactory creates a new mock instance. +func NewMockMessageCreatorFactory(ctrl *gomock.Controller) *MockMessageCreatorFactory { + mock := &MockMessageCreatorFactory{ctrl: ctrl} + mock.recorder = &MockMessageCreatorFactoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMessageCreatorFactory) EXPECT() *MockMessageCreatorFactoryMockRecorder { + return m.recorder +} + +// CreateMessageCreator mocks base method. +func (m *MockMessageCreatorFactory) CreateMessageCreator(ctx context.Context, nodeID string) MessageCreator { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateMessageCreator", ctx, nodeID) + ret0, _ := ret[0].(MessageCreator) + return ret0 +} + +// CreateMessageCreator indicates an expected call of CreateMessageCreator. +func (mr *MockMessageCreatorFactoryMockRecorder) CreateMessageCreator(ctx, nodeID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateMessageCreator", reflect.TypeOf((*MockMessageCreatorFactory)(nil).CreateMessageCreator), ctx, nodeID) +} diff --git a/pkg/transport/nclprotocol/orchestrator/config.go b/pkg/transport/nclprotocol/orchestrator/config.go new file mode 100644 index 0000000000..337009f477 --- /dev/null +++ b/pkg/transport/nclprotocol/orchestrator/config.go @@ -0,0 +1,113 @@ +package orchestrator + +import ( + "errors" + "time" + + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/lib/validate" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" + natsutil "github.com/bacalhau-project/bacalhau/pkg/nats" + "github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/dispatcher" +) + +// Config defines the configuration for the orchestrator's transport layer. +// It contains settings for both the control plane (node management, heartbeats) +// and data plane (message handling, event dispatching) components. +type Config struct { + // NodeID uniquely identifies this orchestrator instance + NodeID string + + // ClientFactory creates NATS clients for transport connections + ClientFactory natsutil.ClientFactory + + // NodeManager handles compute node lifecycle and state management + NodeManager nodes.Manager + + // Message serialization and type registration + MessageRegistry *envelope.Registry // Registry of message types for serialization + MessageSerializer envelope.MessageSerializer // Handles message envelope serialization + + // Control plane timeouts and intervals + HeartbeatTimeout time.Duration // Maximum time to wait for node heartbeat before considering it disconnected + NodeCleanupInterval time.Duration // How often to check for and cleanup disconnected nodes + RequestHandlerTimeout time.Duration // Timeout for handling individual control plane requests + + // Data plane configuration + DataPlaneMessageHandler ncl.MessageHandler // Handles incoming messages from compute nodes + DataPlaneMessageCreatorFactory nclprotocol.MessageCreatorFactory // Creates message creators for outgoing messages + EventStore watcher.EventStore // Store for watching and dispatching events + DispatcherConfig dispatcher.Config // Configuration for the event dispatcher +} + +// Validate checks if the configuration is valid by verifying: +// - Required fields are set +// - Timeouts and intervals are positive +// - Component configurations are valid +func (c *Config) Validate() error { + return errors.Join( + validate.NotBlank(c.NodeID, "node ID cannot be blank"), + validate.NotNil(c.ClientFactory, "client factory cannot be nil"), + validate.NotNil(c.NodeManager, "node manager cannot be nil"), + validate.NotNil(c.MessageRegistry, "message registry cannot be nil"), + validate.NotNil(c.MessageSerializer, "message serializer cannot be nil"), + validate.IsGreaterThanZero(c.HeartbeatTimeout, "heartbeat timeout must be positive"), + validate.IsGreaterThanZero(c.NodeCleanupInterval, "node cleanup interval must be positive"), + validate.IsGreaterThanZero(c.RequestHandlerTimeout, "request handler timeout must be positive"), + validate.NotNil(c.DataPlaneMessageHandler, "data plane message handler cannot be nil"), + validate.NotNil(c.DataPlaneMessageCreatorFactory, "data plane message creator factory cannot be nil"), + validate.NotNil(c.EventStore, "event store cannot be nil"), + + // Validate nested dispatcher config + c.DispatcherConfig.Validate(), + ) +} + +func DefaultConfig() Config { + return Config{ + // Default timeouts and intervals + HeartbeatTimeout: 2 * time.Minute, // Time before node considered disconnected + NodeCleanupInterval: 30 * time.Second, // Check for disconnected nodes every 30s + RequestHandlerTimeout: 2 * time.Second, // Individual request timeout + + // Default message handling + MessageSerializer: envelope.NewSerializer(), + MessageRegistry: nclprotocol.MustCreateMessageRegistry(), + + // Default dispatcher configuration + DispatcherConfig: dispatcher.DefaultConfig(), + } +} + +// setDefaults applies default values to any unset fields in the config. +// It does not override values that are already set. +func (c *Config) setDefaults() { + defaults := DefaultConfig() + + // Apply default timeouts if not set + if c.HeartbeatTimeout == 0 { + c.HeartbeatTimeout = defaults.HeartbeatTimeout + } + if c.NodeCleanupInterval == 0 { + c.NodeCleanupInterval = defaults.NodeCleanupInterval + } + if c.RequestHandlerTimeout == 0 { + c.RequestHandlerTimeout = defaults.RequestHandlerTimeout + } + + // Apply default message handling if not set + if c.MessageSerializer == nil { + c.MessageSerializer = defaults.MessageSerializer + } + if c.MessageRegistry == nil { + c.MessageRegistry = defaults.MessageRegistry + } + + // Apply default dispatcher config if not set + if c.DispatcherConfig == (dispatcher.Config{}) { + c.DispatcherConfig = defaults.DispatcherConfig + } +} diff --git a/pkg/transport/nclprotocol/orchestrator/dataplane.go b/pkg/transport/nclprotocol/orchestrator/dataplane.go new file mode 100644 index 0000000000..a7c4454b15 --- /dev/null +++ b/pkg/transport/nclprotocol/orchestrator/dataplane.go @@ -0,0 +1,263 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/nats-io/nats.go" + "github.com/rs/zerolog/log" + + "github.com/bacalhau-project/bacalhau/pkg/jobstore" + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/lib/validate" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/dispatcher" +) + +// DataPlane manages the message flow between orchestrator and a single compute node. +// It handles: +// - Reliable message delivery through ordered publisher +// - Sequence tracking for both incoming and outgoing messages +// - Event watching and dispatching +// Each DataPlane instance corresponds to one compute node connection. +type DataPlane struct { + config DataPlaneConfig + + // Core messaging components + subscriber ncl.Subscriber // Handles incoming messages from compute node + publisher ncl.OrderedPublisher // Sends messages to compute node + dispatcher *dispatcher.Dispatcher // Manages event watching and dispatch + + // Sequence Trackers + incomingSequenceTracker *nclprotocol.SequenceTracker + + // State management + mu sync.RWMutex // Protects state changes + running bool // Indicates if data plane is active +} + +// DataPlaneConfig defines the configuration for a DataPlane instance. +// Each config corresponds to a single compute node connection. +type DataPlaneConfig struct { + NodeID string // ID of the compute node this data plane serves + Client *nats.Conn // NATS connection + + // Message handling + MessageHandler ncl.MessageHandler + MessageCreatorFactory nclprotocol.MessageCreatorFactory + MessageRegistry *envelope.Registry + MessageSerializer envelope.MessageSerializer + + // Event tracking + EventStore watcher.EventStore + StartSeqNum uint64 // Initial sequence for event watching + + // Dispatcher settings + DispatcherConfig dispatcher.Config +} + +func (c *DataPlaneConfig) Validate() error { + return errors.Join( + validate.NotBlank(c.NodeID, "nodeID required"), + validate.NotNil(c.Client, "client required"), + validate.NotNil(c.EventStore, "event store required"), + validate.NotNil(c.MessageHandler, "message handler required"), + validate.NotNil(c.MessageCreatorFactory, "message creator factory required"), + validate.NotNil(c.MessageRegistry, "message registry required"), + validate.NotNil(c.MessageSerializer, "message serializer required"), + ) +} + +// NewDataPlane creates a new DataPlane instance for a compute node. +func NewDataPlane(config DataPlaneConfig) (*DataPlane, error) { + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + return &DataPlane{ + config: config, + incomingSequenceTracker: nclprotocol.NewSequenceTracker(), + }, nil +} + +// Start initializes and begins data plane operations. This includes: +// 1. Creating subscriber for compute node messages +// 2. Creating ordered publisher for reliable delivery +// 3. Setting up event watching and dispatching +// 4. Starting all components in correct order +func (dp *DataPlane) Start(ctx context.Context) error { + dp.mu.Lock() + defer dp.mu.Unlock() + + if dp.running { + return fmt.Errorf("data plane already running") + } + + var err error + defer func() { + if err != nil { + if cleanupErr := dp.cleanup(ctx); cleanupErr != nil { + log.Warn().Err(cleanupErr).Msg("failed to cleanup after start error") + } + } + }() + + // Define NATS subjects for this compute node + inSubject := nclprotocol.NatsSubjectOrchestratorInMsgs(dp.config.NodeID) + outSubject := nclprotocol.NatsSubjectOrchestratorOutMsgs(dp.config.NodeID) + + // Set up subscriber for incoming messages + if err = dp.setupSubscriber(ctx, inSubject); err != nil { + return fmt.Errorf("failed to setup subscriber: %w", err) + } + + // Set up publisher for outgoing messages + if err = dp.setupPublisher(outSubject); err != nil { + return fmt.Errorf("failed to setup publisher: %w", err) + } + + // Set up dispatcher for event watching + if err = dp.setupDispatcher(ctx); err != nil { + return fmt.Errorf("failed to setup dispatcher: %w", err) + } + + log.Debug(). + Str("nodeID", dp.config.NodeID). + Str("incomingSubject", inSubject). + Str("outgoingSubject", outSubject). + Uint64("startSeqNum", dp.config.StartSeqNum). + Msg("Data plane started") + + dp.running = true + return nil +} + +func (dp *DataPlane) setupSubscriber(ctx context.Context, subject string) error { + var err error + dp.subscriber, err = ncl.NewSubscriber(dp.config.Client, ncl.SubscriberConfig{ + Name: fmt.Sprintf("orchestrator-%s", dp.config.NodeID), + MessageRegistry: dp.config.MessageRegistry, + MessageSerializer: dp.config.MessageSerializer, + MessageHandler: dp.config.MessageHandler, + ProcessingNotifier: dp.incomingSequenceTracker, + }) + if err != nil { + return fmt.Errorf("create subscriber: %w", err) + } + + return dp.subscriber.Subscribe(ctx, subject) +} + +func (dp *DataPlane) setupPublisher(subject string) error { + var err error + dp.publisher, err = ncl.NewOrderedPublisher(dp.config.Client, ncl.OrderedPublisherConfig{ + Name: fmt.Sprintf("orchestrator-%s", dp.config.NodeID), + MessageRegistry: dp.config.MessageRegistry, + MessageSerializer: dp.config.MessageSerializer, + Destination: subject, + }) + if err != nil { + return fmt.Errorf("create publisher: %w", err) + } + return nil +} + +func (dp *DataPlane) setupDispatcher(ctx context.Context) error { + // Create watcher starting from specified sequence + dispatcherWatcher, err := watcher.New(ctx, + fmt.Sprintf("orchestrator-dispatcher-%s", dp.config.NodeID), + dp.config.EventStore, + watcher.WithRetryStrategy(watcher.RetryStrategyBlock), + watcher.WithInitialEventIterator(watcher.AfterSequenceNumberIterator(dp.config.StartSeqNum)), + watcher.WithFilter(watcher.EventFilter{ + ObjectTypes: []string{jobstore.EventObjectExecutionUpsert}, + }), + ) + if err != nil { + return fmt.Errorf("create watcher: %w", err) + } + + // Create message creator for this compute node + messageCreator := dp.config.MessageCreatorFactory.CreateMessageCreator( + ctx, dp.config.NodeID) + + // Disable checkpointing in dispatcher since we handle it elsewhere + config := dp.config.DispatcherConfig + config.CheckpointInterval = -1 + + // Create and start dispatcher + dp.dispatcher, err = dispatcher.New( + dp.publisher, + dispatcherWatcher, + messageCreator, + config, + ) + if err != nil { + return fmt.Errorf("create dispatcher: %w", err) + } + + if err = dp.dispatcher.Start(ctx); err != nil { + return fmt.Errorf("start dispatcher: %w", err) + } + + return nil +} + +// Stop gracefully shuts down all data plane operations. +// It ensures proper cleanup of resources by stopping components +// in correct order: dispatcher -> subscriber -> publisher +func (dp *DataPlane) Stop(ctx context.Context) error { + dp.mu.Lock() + defer dp.mu.Unlock() + + if !dp.running { + return nil + } + + dp.running = false + return dp.cleanup(ctx) +} + +// cleanup handles orderly shutdown of data plane components +func (dp *DataPlane) cleanup(ctx context.Context) error { + var errs error + + // Stop dispatcher first to prevent new messages + if dp.dispatcher != nil { + if err := dp.dispatcher.Stop(ctx); err != nil { + errs = errors.Join(errs, fmt.Errorf("stop dispatcher: %w", err)) + } + dp.dispatcher = nil + } + + // Then clean up subscriber + if dp.subscriber != nil { + if err := dp.subscriber.Close(ctx); err != nil { + errs = errors.Join(errs, fmt.Errorf("close subscriber: %w", err)) + } + dp.subscriber = nil + } + + // Finally clean up publisher + if dp.publisher != nil { + if err := dp.publisher.Close(ctx); err != nil { + errs = errors.Join(errs, fmt.Errorf("close publisher: %w", err)) + } + dp.publisher = nil + } + + if errs != nil { + return fmt.Errorf("cleanup failed: %w", errs) + } + return nil +} + +// GetLastProcessedSequence returns the last sequence number processed +// from incoming messages from this compute node +func (dp *DataPlane) GetLastProcessedSequence() uint64 { + return dp.incomingSequenceTracker.GetLastSeqNum() +} diff --git a/pkg/transport/nclprotocol/orchestrator/manager.go b/pkg/transport/nclprotocol/orchestrator/manager.go new file mode 100644 index 0000000000..3b3aec9247 --- /dev/null +++ b/pkg/transport/nclprotocol/orchestrator/manager.go @@ -0,0 +1,308 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/nats-io/nats.go" + "github.com/rs/zerolog/log" + + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/models" + "github.com/bacalhau-project/bacalhau/pkg/models/messages" + "github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" +) + +// ComputeManager handles the lifecycle and state management of all compute nodes +// connected to this orchestrator. It is responsible for: +// - Processing compute node handshakes and connections +// - Managing individual node data planes +// - Coordinating message flow between orchestrator and compute nodes +// - Tracking node health and connection state +type ComputeManager struct { + config Config + + // Core components + natsConn *nats.Conn // NATS connection + responder ncl.Responder // Handles control plane requests + dataPlanes sync.Map // map[string]*DataPlane + + // Node management + nodeManager nodes.Manager // Tracks node state and health + + // Lifecycle management + stopCh chan struct{} // Signals background goroutines to stop + wg sync.WaitGroup // Tracks active background goroutines +} + +// NewComputeManager creates a new compute manager with the given configuration. +// The manager must be started with Start() before it begins processing connections. +func NewComputeManager(cfg Config) (*ComputeManager, error) { + cfg.setDefaults() + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + return &ComputeManager{ + config: cfg, + nodeManager: cfg.NodeManager, + stopCh: make(chan struct{}), + }, nil +} + +// Start initializes the manager and begins processing compute node connections. +// This includes: +// 1. Creating NATS connection +// 2. Setting up control plane responder +// 3. Registering message handlers +// 4. Setting up node state change handling +func (cm *ComputeManager) Start(ctx context.Context) error { + var err error + + // Create NATS connection + if err = cm.setupTransport(ctx); err != nil { + return err + } + + // Set up control plane responder + if err = cm.setupControlPlane(ctx); err != nil { + return err + } + + // Register for node state changes + cm.nodeManager.OnConnectionStateChange(cm.handleConnectionStateChange) + + return nil +} + +func (cm *ComputeManager) setupTransport(ctx context.Context) error { + var err error + cm.natsConn, err = cm.config.ClientFactory.CreateClient(ctx) + if err != nil { + return fmt.Errorf("connect to NATS: %w", err) + } + return nil +} + +func (cm *ComputeManager) setupControlPlane(ctx context.Context) error { + var err error + + // Create responder for control messages + cm.responder, err = ncl.NewResponder(cm.natsConn, ncl.ResponderConfig{ + Name: "orchestrator-control", + MessageRegistry: cm.config.MessageRegistry, + MessageSerializer: cm.config.MessageSerializer, + Subject: nclprotocol.NatsSubjectOrchestratorInCtrl(), + }) + if err != nil { + return fmt.Errorf("create control responder: %w", err) + } + + // Register control message handlers + return errors.Join( + cm.responder.Listen(ctx, messages.HandshakeRequestMessageType, + ncl.RequestHandlerFunc(cm.handleHandshakeRequest)), + cm.responder.Listen(ctx, messages.HeartbeatRequestMessageType, + ncl.RequestHandlerFunc(cm.handleHeartbeatRequest)), + cm.responder.Listen(ctx, messages.NodeInfoUpdateRequestMessageType, + ncl.RequestHandlerFunc(cm.handleNodeInfoUpdateRequest)), + ) +} + +// Stop gracefully shuts down the manager and all compute node connections. +// It ensures proper cleanup by: +// 1. Stopping the control plane responder +// 2. Stopping all data planes +// 3. Waiting for background goroutines to complete +func (cm *ComputeManager) Stop(ctx context.Context) error { + close(cm.stopCh) + + var errs error + + // Stop responder first to prevent new connections + if cm.responder != nil { + if err := cm.responder.Close(ctx); err != nil { + errs = errors.Join(errs, fmt.Errorf("close responder: %w", err)) + } + cm.responder = nil + } + + // Stop all data planes + cm.dataPlanes.Range(func(key, value interface{}) bool { + nodeID := key.(string) + if dataPlane, ok := value.(*DataPlane); ok { + if err := dataPlane.Stop(ctx); err != nil { + errs = errors.Join(errs, + fmt.Errorf("stop data plane for node %s: %w", nodeID, err)) + } + } + return true + }) + + // Clean up NATS connection + if cm.natsConn != nil { + cm.natsConn.Close() + cm.natsConn = nil + } + + // Wait for goroutines with timeout + done := make(chan struct{}) + go func() { + cm.wg.Wait() + close(done) + }() + + select { + case <-done: + if errs != nil { + return fmt.Errorf("shutdown errors: %w", errs) + } + return nil + case <-ctx.Done(): + return fmt.Errorf("shutdown timeout: %w", ctx.Err()) + } +} + +// handleHandshakeRequest processes incoming handshake requests from compute nodes. +// For each new node, it: +// 1. Validates the request through node manager +// 2. Creates a new data plane if accepted +// 3. Returns handshake response with connection details +func (cm *ComputeManager) handleHandshakeRequest(ctx context.Context, msg *envelope.Message) (*envelope.Message, error) { + request := msg.Payload.(*messages.HandshakeRequest) + + // Process handshake through node manager + response, err := cm.nodeManager.Handshake(ctx, *request) + if err != nil { + return nil, err + } + + if !response.Accepted { + return envelope.NewMessage(response), nil + } + + // Create data plane for accepted node + if err = cm.setupDataPlane(ctx, request.NodeInfo, request.LastOrchestratorSeqNum); err != nil { + return nil, fmt.Errorf("setup data plane failed: %w", err) + } + + return envelope.NewMessage(response), nil +} + +// setupDataPlane creates and starts a new data plane for a compute node. +// If a data plane already exists for the node, it is gracefully stopped +// and replaced with the new one. +func (cm *ComputeManager) setupDataPlane( + ctx context.Context, + nodeInfo models.NodeInfo, + lastReceivedSeqNum uint64, +) error { + // Create new data plane configuration + dataPlane, err := NewDataPlane(DataPlaneConfig{ + NodeID: nodeInfo.ID(), + Client: cm.natsConn, + MessageRegistry: cm.config.MessageRegistry, + MessageSerializer: cm.config.MessageSerializer, + MessageHandler: cm.config.DataPlaneMessageHandler, + MessageCreatorFactory: cm.config.DataPlaneMessageCreatorFactory, + EventStore: cm.config.EventStore, + StartSeqNum: lastReceivedSeqNum, + DispatcherConfig: cm.config.DispatcherConfig, + }) + if err != nil { + return err + } + + // Atomically replace old with new, stopping old if it exists + if existing, loaded := cm.dataPlanes.Swap(nodeInfo.ID(), dataPlane); loaded { + if dp, ok := existing.(*DataPlane); ok { + if err = dp.Stop(context.TODO()); err != nil { + log.Error(). + Err(err). + Str("nodeID", nodeInfo.ID()). + Msg("Failed to stop existing data plane") + } + } + } + + // Start new data plane + if err = dataPlane.Start(ctx); err != nil { + cm.dataPlanes.Delete(nodeInfo.ID()) + return fmt.Errorf("start data plane: %w", err) + } + + return nil +} + +// handleHeartbeatRequest processes heartbeat messages from compute nodes. +// It verifies the node has an active data plane and updates health tracking. +func (cm *ComputeManager) handleHeartbeatRequest(ctx context.Context, msg *envelope.Message) (*envelope.Message, error) { + request := msg.Payload.(*messages.HeartbeatRequest) + + // Verify data plane exists + dataPlane, exists := cm.getDataPlane(request.NodeID) + if !exists { + return nil, fmt.Errorf("no active data plane for node %s - handshake required", request.NodeID) + } + + // Process through node manager with sequence info + response, err := cm.nodeManager.Heartbeat(ctx, nodes.ExtendedHeartbeatRequest{ + HeartbeatRequest: *request, + LastComputeSeqNum: dataPlane.GetLastProcessedSequence(), + }) + if err != nil { + return nil, err + } + + return envelope.NewMessage(response), nil +} + +// handleNodeInfoUpdateRequest processes node info updates from compute nodes. +// It verifies the node has an active data plane before accepting updates. +func (cm *ComputeManager) handleNodeInfoUpdateRequest(ctx context.Context, msg *envelope.Message) (*envelope.Message, error) { + request := msg.Payload.(*messages.UpdateNodeInfoRequest) + + // Verify data plane exists + if _, ok := cm.dataPlanes.Load(request.NodeInfo.ID()); !ok { + // Return error asking node to reconnect since it has no active data plane + return nil, fmt.Errorf("no active data plane - handshake required") + } + + // Process through node manager + response, err := cm.nodeManager.UpdateNodeInfo(ctx, *request) + if err != nil { + return nil, err + } + + return envelope.NewMessage(response), nil +} + +// handleConnectionStateChange responds to node connection state changes +func (cm *ComputeManager) handleConnectionStateChange(event nodes.NodeConnectionEvent) { + // If node disconnected, stop and remove data plane + if event.Current == models.NodeStates.DISCONNECTED { + if dataPlane, ok := cm.dataPlanes.LoadAndDelete(event.NodeID); ok { + if dp, ok := dataPlane.(*DataPlane); ok { + if err := dp.Stop(context.Background()); err != nil { + log.Error().Err(err). + Str("nodeID", event.NodeID). + Msg("Failed to stop data plane for disconnected node") + } + } + } + } +} + +// getDataPlane safely retrieves the data plane for a node if it exists +func (cm *ComputeManager) getDataPlane(nodeID string) (*DataPlane, bool) { + if value, ok := cm.dataPlanes.Load(nodeID); ok { + if dataPlane, ok := value.(*DataPlane); ok { + return dataPlane, true + } + } + return nil, false +} diff --git a/pkg/transport/nclprotocol/registry.go b/pkg/transport/nclprotocol/registry.go new file mode 100644 index 0000000000..00594e3430 --- /dev/null +++ b/pkg/transport/nclprotocol/registry.go @@ -0,0 +1,42 @@ +package nclprotocol + +import ( + "errors" + + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/models/messages" +) + +// CreateMessageRegistry creates a new payload registry. +func CreateMessageRegistry() (*envelope.Registry, error) { + reg := envelope.NewRegistry() + err := errors.Join( + reg.Register(messages.AskForBidMessageType, messages.AskForBidRequest{}), + reg.Register(messages.BidAcceptedMessageType, messages.BidAcceptedRequest{}), + reg.Register(messages.BidRejectedMessageType, messages.BidRejectedRequest{}), + reg.Register(messages.CancelExecutionMessageType, messages.CancelExecutionRequest{}), + reg.Register(messages.BidResultMessageType, messages.BidResult{}), + reg.Register(messages.RunResultMessageType, messages.RunResult{}), + reg.Register(messages.ComputeErrorMessageType, messages.ComputeError{}), + + // Control plane messages + reg.Register(messages.HandshakeRequestMessageType, messages.HandshakeRequest{}), + reg.Register(messages.HeartbeatRequestMessageType, messages.HeartbeatRequest{}), + reg.Register(messages.NodeInfoUpdateRequestMessageType, messages.UpdateNodeInfoRequest{}), + + // Control plane responses + reg.Register(messages.HandshakeResponseType, messages.HandshakeResponse{}), + reg.Register(messages.HeartbeatResponseType, messages.HeartbeatResponse{}), + reg.Register(messages.NodeInfoUpdateResponseType, messages.UpdateNodeInfoResponse{}), + ) + return reg, err +} + +// MustCreateMessageRegistry creates a new payload registry. +func MustCreateMessageRegistry() *envelope.Registry { + reg, err := CreateMessageRegistry() + if err != nil { + panic(err) + } + return reg +} diff --git a/pkg/transport/nclprotocol/subjects.go b/pkg/transport/nclprotocol/subjects.go new file mode 100644 index 0000000000..642d6da4bb --- /dev/null +++ b/pkg/transport/nclprotocol/subjects.go @@ -0,0 +1,29 @@ +package nclprotocol + +import ( + "fmt" +) + +func NatsSubjectOrchestratorInCtrl() string { + return "bacalhau.global.compute.*.out.ctrl" +} + +func NatsSubjectOrchestratorInMsgs(computeNodeID string) string { + return fmt.Sprintf("bacalhau.global.compute.%s.out.msgs", computeNodeID) +} + +func NatsSubjectOrchestratorOutMsgs(computeNodeID string) string { + return fmt.Sprintf("bacalhau.global.compute.%s.in.msgs", computeNodeID) +} + +func NatsSubjectComputeInMsgs(computeNodeID string) string { + return fmt.Sprintf("bacalhau.global.compute.%s.in.msgs", computeNodeID) +} + +func NatsSubjectComputeOutCtrl(computeNodeID string) string { + return fmt.Sprintf("bacalhau.global.compute.%s.out.ctrl", computeNodeID) +} + +func NatsSubjectComputeOutMsgs(computeNodeID string) string { + return fmt.Sprintf("bacalhau.global.compute.%s.out.msgs", computeNodeID) +} diff --git a/pkg/transport/nclprotocol/tracker.go b/pkg/transport/nclprotocol/tracker.go new file mode 100644 index 0000000000..0ac493c3e6 --- /dev/null +++ b/pkg/transport/nclprotocol/tracker.go @@ -0,0 +1,52 @@ +package nclprotocol + +import ( + "context" + "sync/atomic" + + "github.com/rs/zerolog/log" + + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" +) + +// SequenceTracker tracks the last successfully processed message sequence number. +// Used by connection managers to checkpoint progress and resume message processing +// after restarts. Thread-safe through atomic operations. +type SequenceTracker struct { + lastSeqNum atomic.Uint64 +} + +// NewSequenceTracker creates a new sequence tracker starting at sequence 0 +func NewSequenceTracker() *SequenceTracker { + return &SequenceTracker{} +} + +// WithLastSeqNum sets the initial sequence number for resuming processing +func (s *SequenceTracker) WithLastSeqNum(seqNum uint64) *SequenceTracker { + s.lastSeqNum.Store(seqNum) + return s +} + +// UpdateLastSeqNum updates the latest processed sequence number atomically +func (s *SequenceTracker) UpdateLastSeqNum(seqNum uint64) { + s.lastSeqNum.Store(seqNum) +} + +// GetLastSeqNum returns the last processed sequence number atomically +func (s *SequenceTracker) GetLastSeqNum() uint64 { + return s.lastSeqNum.Load() +} + +// OnProcessed implements ncl.ProcessingNotifier to track message sequence numbers. +// Called after each successful message processing operation. +func (s *SequenceTracker) OnProcessed(ctx context.Context, message *envelope.Message) { + if message.Metadata.Has(KeySeqNum) { + s.UpdateLastSeqNum(message.Metadata.GetUint64(KeySeqNum)) + } else { + log.Trace().Msgf("No sequence number found in message metadata %v", message.Metadata) + } +} + +// Ensure SequenceTracker implements ProcessingNotifier +var _ ncl.ProcessingNotifier = &SequenceTracker{} diff --git a/pkg/transport/nclprotocol/types.go b/pkg/transport/nclprotocol/types.go new file mode 100644 index 0000000000..5b2fdace58 --- /dev/null +++ b/pkg/transport/nclprotocol/types.go @@ -0,0 +1,76 @@ +//go:generate mockgen --source types.go --destination mocks.go --package nclprotocol + +package nclprotocol + +import ( + "context" + "fmt" + "time" + + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" +) + +// ConnectionState represents the current state of a connection +type ConnectionState int + +const ( + Disconnected ConnectionState = iota + Connecting + Connected +) + +// String returns the string representation of the connection state +func (c ConnectionState) String() string { + switch c { + case Disconnected: + return "Disconnected" + case Connecting: + return "Connecting" + case Connected: + return "Connected" + default: + return "Unknown" + } +} + +type Checkpointer interface { + Checkpoint(ctx context.Context, name string, sequenceNumber uint64) error + GetCheckpoint(ctx context.Context, name string) (uint64, error) +} + +// ConnectionStateHandler is called when connection state changes +type ConnectionStateHandler func(ConnectionState) + +type ConnectionHealth struct { + StartTime time.Time + LastSuccessfulHeartbeat time.Time + LastSuccessfulUpdate time.Time + CurrentState ConnectionState + ConsecutiveFailures int + LastError error + ConnectedSince time.Time +} + +const ( + KeySeqNum = "Bacalhau-SeqNum" +) + +// MessageCreator defines how events from the watcher are converted into +// messages for publishing. This is the primary extension point for customizing +// transport behavior. +type MessageCreator interface { + // CreateMessage converts a watcher event into a message envelope. + // Returns nil if no message should be published for this event. + // Any error will halt event processing. + CreateMessage(event watcher.Event) (*envelope.Message, error) +} + +type MessageCreatorFactory interface { + CreateMessageCreator(ctx context.Context, nodeID string) MessageCreator +} + +// GenerateMsgID Message ID generation helper +func GenerateMsgID(event watcher.Event) string { + return fmt.Sprintf("seq-%d", event.SeqNum) +} diff --git a/pkg/transport/types.go b/pkg/transport/types.go deleted file mode 100644 index 18c545195d..0000000000 --- a/pkg/transport/types.go +++ /dev/null @@ -1,28 +0,0 @@ -//go:generate mockgen --source types.go --destination mocks.go --package transport -package transport - -import ( - "fmt" - - "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" - "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" -) - -const ( - KeySeqNum = "Bacalhau-SeqNum" -) - -// MessageCreator defines how events from the watcher are converted into -// messages for publishing. This is the primary extension point for customizing -// transport behavior. -type MessageCreator interface { - // CreateMessage converts a watcher event into a message envelope. - // Returns nil if no message should be published for this event. - // Any error will halt event processing. - CreateMessage(event watcher.Event) (*envelope.Message, error) -} - -// GenerateMsgID Message ID generation helper -func GenerateMsgID(event watcher.Event) string { - return fmt.Sprintf("seq-%d", event.SeqNum) -} From a4880c7b61d3a5a51208f0233f8a366d352a26e0 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Tue, 10 Dec 2024 08:42:15 +0200 Subject: [PATCH 02/16] Update pkg/transport/nclprotocol/compute/health_tracker.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- pkg/transport/nclprotocol/compute/health_tracker.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/transport/nclprotocol/compute/health_tracker.go b/pkg/transport/nclprotocol/compute/health_tracker.go index 05c8b99d88..980749b92f 100644 --- a/pkg/transport/nclprotocol/compute/health_tracker.go +++ b/pkg/transport/nclprotocol/compute/health_tracker.go @@ -36,7 +36,6 @@ func (ht *HealthTracker) MarkConnected() { ht.health.LastSuccessfulHeartbeat = ht.clock.Now() ht.health.ConsecutiveFailures = 0 ht.health.LastError = nil - ht.health.CurrentState = nclprotocol.Connected } // MarkDisconnected updates status when connection is lost From 89ce65ce86a56b2c044503a6e89891371f795f0f Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Tue, 10 Dec 2024 08:42:23 +0200 Subject: [PATCH 03/16] Update pkg/transport/bprotocol/compute/transport.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- pkg/transport/bprotocol/compute/transport.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/transport/bprotocol/compute/transport.go b/pkg/transport/bprotocol/compute/transport.go index c2f4366396..0414aa0559 100644 --- a/pkg/transport/bprotocol/compute/transport.go +++ b/pkg/transport/bprotocol/compute/transport.go @@ -145,7 +145,8 @@ func (cm *ConnectionManager) Start(ctx context.Context) error { HeartbeatConfig: cm.config.HeartbeatConfig, }) if err = managementClient.RegisterNode(ctx); err != nil { - if errors.As(err, &bprotocol.ErrUpgradeAvailable) { + var upgradeErr *bprotocol.ErrUpgradeAvailable + if errors.As(err, &upgradeErr) { log.Info().Msg("Disabling bprotocol management client due to upgrade available") cm.Stop(ctx) return nil From ff764c203bd97f33ddb3e0f8bcd30410c9305435 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Tue, 10 Dec 2024 08:43:32 +0200 Subject: [PATCH 04/16] Update pkg/transport/nclprotocol/compute/dataplane.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- pkg/transport/nclprotocol/compute/dataplane.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/transport/nclprotocol/compute/dataplane.go b/pkg/transport/nclprotocol/compute/dataplane.go index 2c5ade09f8..78517206c1 100644 --- a/pkg/transport/nclprotocol/compute/dataplane.go +++ b/pkg/transport/nclprotocol/compute/dataplane.go @@ -93,7 +93,9 @@ func (dp *DataPlane) Start(ctx context.Context) error { Conn: dp.Client, LogstreamServer: dp.config.LogStreamServer, }) - + if err != nil { + return fmt.Errorf("failed to set up log stream handler: %w", err) + } // Initialize ordered publisher for reliable message delivery dp.publisher, err = ncl.NewOrderedPublisher(dp.Client, ncl.OrderedPublisherConfig{ Name: dp.config.NodeID, From 40931b63204c0465982eae23e9f6de7f8bd1f514 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Tue, 10 Dec 2024 09:09:05 +0200 Subject: [PATCH 05/16] coderabbit fixes --- pkg/transport/bprotocol/compute/transport.go | 3 +- pkg/transport/nclprotocol/compute/manager.go | 42 +++++++++++++++---- .../nclprotocol/orchestrator/dataplane.go | 2 +- .../nclprotocol/orchestrator/manager.go | 9 +++- 4 files changed, 43 insertions(+), 13 deletions(-) diff --git a/pkg/transport/bprotocol/compute/transport.go b/pkg/transport/bprotocol/compute/transport.go index 0414aa0559..fc775ae9d0 100644 --- a/pkg/transport/bprotocol/compute/transport.go +++ b/pkg/transport/bprotocol/compute/transport.go @@ -145,8 +145,7 @@ func (cm *ConnectionManager) Start(ctx context.Context) error { HeartbeatConfig: cm.config.HeartbeatConfig, }) if err = managementClient.RegisterNode(ctx); err != nil { - var upgradeErr *bprotocol.ErrUpgradeAvailable - if errors.As(err, &upgradeErr) { + if errors.Is(err, bprotocol.ErrUpgradeAvailable) { log.Info().Msg("Disabling bprotocol management client due to upgrade available") cm.Stop(ctx) return nil diff --git a/pkg/transport/nclprotocol/compute/manager.go b/pkg/transport/nclprotocol/compute/manager.go index 90183d312e..07d81c6db4 100644 --- a/pkg/transport/nclprotocol/compute/manager.go +++ b/pkg/transport/nclprotocol/compute/manager.go @@ -133,13 +133,13 @@ func (cm *ConnectionManager) Close(ctx context.Context) error { } cm.running = false close(cm.stopCh) - close(cm.stateChanges) cm.mu.Unlock() // Wait for graceful shutdown done := make(chan struct{}) go func() { cm.wg.Wait() + close(cm.stateChanges) close(done) }() @@ -464,18 +464,42 @@ func (cm *ConnectionManager) transitionState(newState nclprotocol.ConnectionStat func (cm *ConnectionManager) handleStateChanges() { defer cm.wg.Done() - for change := range cm.stateChanges { - cm.stateHandlersMu.RLock() - handlers := make([]nclprotocol.ConnectionStateHandler, len(cm.stateHandlers)) - copy(handlers, cm.stateHandlers) - cm.stateHandlersMu.RUnlock() - - for _, handler := range handlers { - handler(change.state) + for { + select { + case <-cm.stopCh: + // Process any remaining state changes before exiting + for { + select { + case change, ok := <-cm.stateChanges: + if !ok { + return + } + cm.processStateChange(change) + default: + return + } + } + case change, ok := <-cm.stateChanges: + if !ok { + return + } + cm.processStateChange(change) } } } +// processStateChange handles a single state change notification +func (cm *ConnectionManager) processStateChange(change stateChange) { + cm.stateHandlersMu.RLock() + handlers := make([]nclprotocol.ConnectionStateHandler, len(cm.stateHandlers)) + copy(handlers, cm.stateHandlers) + cm.stateHandlersMu.RUnlock() + + for _, handler := range handlers { + handler(change.state) + } +} + // OnStateChange registers a new handler to be called when the connection // state changes. Handlers are called synchronously when state transitions occur. func (cm *ConnectionManager) OnStateChange(handler nclprotocol.ConnectionStateHandler) { diff --git a/pkg/transport/nclprotocol/orchestrator/dataplane.go b/pkg/transport/nclprotocol/orchestrator/dataplane.go index a7c4454b15..b12c120a76 100644 --- a/pkg/transport/nclprotocol/orchestrator/dataplane.go +++ b/pkg/transport/nclprotocol/orchestrator/dataplane.go @@ -200,7 +200,7 @@ func (dp *DataPlane) setupDispatcher(ctx context.Context) error { return fmt.Errorf("create dispatcher: %w", err) } - if err = dp.dispatcher.Start(ctx); err != nil { + if err = dp.dispatcher.Start(context.TODO()); err != nil { return fmt.Errorf("start dispatcher: %w", err) } diff --git a/pkg/transport/nclprotocol/orchestrator/manager.go b/pkg/transport/nclprotocol/orchestrator/manager.go index 3b3aec9247..74a1e8e44b 100644 --- a/pkg/transport/nclprotocol/orchestrator/manager.go +++ b/pkg/transport/nclprotocol/orchestrator/manager.go @@ -220,7 +220,7 @@ func (cm *ComputeManager) setupDataPlane( // Atomically replace old with new, stopping old if it exists if existing, loaded := cm.dataPlanes.Swap(nodeInfo.ID(), dataPlane); loaded { if dp, ok := existing.(*DataPlane); ok { - if err = dp.Stop(context.TODO()); err != nil { + if err = dp.Stop(ctx); err != nil { log.Error(). Err(err). Str("nodeID", nodeInfo.ID()). @@ -283,6 +283,13 @@ func (cm *ComputeManager) handleNodeInfoUpdateRequest(ctx context.Context, msg * // handleConnectionStateChange responds to node connection state changes func (cm *ComputeManager) handleConnectionStateChange(event nodes.NodeConnectionEvent) { + // Check if we're shutting down + select { + case <-cm.stopCh: + return + default: + } + // If node disconnected, stop and remove data plane if event.Current == models.NodeStates.DISCONNECTED { if dataPlane, ok := cm.dataPlanes.LoadAndDelete(event.NodeID); ok { From 4a523a8f0d39214ceea19c3fd20efeba3edd2801 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Tue, 10 Dec 2024 10:11:25 +0200 Subject: [PATCH 06/16] make ncl preferred + fix tests --- pkg/lib/validate/general.go | 8 +- pkg/lib/validate/general_test.go | 82 ++++++++++++++ pkg/models/protocol.go | 25 +---- pkg/node/compute.go | 1 - .../watchers/ncl_message_creator.go | 26 ++++- .../watchers/ncl_message_creator_test.go | 103 +++++++++++++++++- .../watchers/protocol_router_test.go | 22 ---- pkg/transport/bprotocol/compute/transport.go | 3 +- .../nclprotocol/dispatcher/config_test.go | 5 - .../nclprotocol/orchestrator/dataplane.go | 5 +- pkg/transport/nclprotocol/types.go | 2 +- 11 files changed, 219 insertions(+), 63 deletions(-) diff --git a/pkg/lib/validate/general.go b/pkg/lib/validate/general.go index 43f0fa6608..557e96bd3a 100644 --- a/pkg/lib/validate/general.go +++ b/pkg/lib/validate/general.go @@ -11,8 +11,12 @@ func NotNil(value any, msg string, args ...any) error { // Use reflection to handle cases where value is a nil pointer wrapped in an interface val := reflect.ValueOf(value) - if val.Kind() == reflect.Ptr && val.IsNil() { - return createError(msg, args...) + switch val.Kind() { + case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice, reflect.Func: + if val.IsNil() { + return createError(msg, args...) + } + default: } return nil } diff --git a/pkg/lib/validate/general_test.go b/pkg/lib/validate/general_test.go index 04b14facfc..09b33d131c 100644 --- a/pkg/lib/validate/general_test.go +++ b/pkg/lib/validate/general_test.go @@ -4,6 +4,14 @@ package validate import "testing" +type doer struct{} + +func (d doer) Do() {} + +type Doer interface { + Do() +} + // TestIsNotNil tests the NotNil function for various scenarios. func TestIsNotNil(t *testing.T) { t.Run("NilValue", func(t *testing.T) { @@ -35,4 +43,78 @@ func TestIsNotNil(t *testing.T) { t.Errorf("NotNil failed: unexpected error for non-nil pointer") } }) + + t.Run("NilFunc", func(t *testing.T) { + var nilFunc func() + err := NotNil(nilFunc, "value should not be nil") + if err == nil { + t.Errorf("NotNil failed: expected error for nil func") + } + }) + + t.Run("NonNilFunc", func(t *testing.T) { + nonNilFunc := func() {} + err := NotNil(nonNilFunc, "value should not be nil") + if err != nil { + t.Errorf("NotNil failed: unexpected error for non-nil func") + } + }) + + t.Run("NilSlice", func(t *testing.T) { + var nilSlice []int + err := NotNil(nilSlice, "value should not be nil") + if err == nil { + t.Errorf("NotNil failed: expected error for nil slice") + } + }) + + t.Run("NonNilSlice", func(t *testing.T) { + nonNilSlice := make([]int, 0) + err := NotNil(nonNilSlice, "value should not be nil") + if err != nil { + t.Errorf("NotNil failed: unexpected error for non-nil slice") + } + }) + + t.Run("NilMap", func(t *testing.T) { + var nilMap map[string]int + err := NotNil(nilMap, "value should not be nil") + if err == nil { + t.Errorf("NotNil failed: expected error for nil map") + } + }) + + t.Run("NonNilMap", func(t *testing.T) { + nonNilMap := make(map[string]int) + err := NotNil(nonNilMap, "value should not be nil") + if err != nil { + t.Errorf("NotNil failed: unexpected error for non-nil map") + } + }) + + t.Run("NilInterface", func(t *testing.T) { + var nilInterface Doer + err := NotNil(nilInterface, "value should not be nil") + if err == nil { + t.Errorf("NotNil failed: expected error for nil interface") + } + }) + + t.Run("NonNilInterface", func(t *testing.T) { + var nonNilInterface Doer = doer{} + err := NotNil(nonNilInterface, "value should not be nil") + if err != nil { + t.Errorf("NotNil failed: unexpected error for non-nil interface") + } + }) + + t.Run("FormattedMessage", func(t *testing.T) { + err := NotNil(nil, "value %s should not be nil", "test") + if err == nil { + t.Errorf("NotNil failed: expected error for nil value with formatted message") + } + if err.Error() != "value test should not be nil" { + t.Errorf("NotNil failed: unexpected error message, got %q", err.Error()) + } + }) } diff --git a/pkg/models/protocol.go b/pkg/models/protocol.go index 5ee69d05f7..cebe1919f9 100644 --- a/pkg/models/protocol.go +++ b/pkg/models/protocol.go @@ -1,9 +1,5 @@ package models -import ( - "os" -) - type Protocol string const ( @@ -14,22 +10,13 @@ const ( // ProtocolBProtocolV2 is nats based request/response protocol. // Currently the default protocol while NCL is under development. ProtocolBProtocolV2 Protocol = "bprotocol/v2" - - // EnvPreferNCL is the environment variable to prefer NCL protocol usage. - // This can be used to test NCL protocol while it's still in development. - EnvPreferNCL = "BACALHAU_PREFER_NCL_PROTOCOL" ) var ( // preferredProtocols is the order of protocols based on preference. - // NOTE: While NCL protocol (ProtocolNCLV1) is under active development, - // we maintain ProtocolBProtocolV2 as the default choice for stability. - // NCL can be enabled via BACALHAU_PREFER_NCL_PROTOCOL env var for testing - // and development purposes. Once NCL reaches stable status, it will become - // the default protocol. preferredProtocols = []Protocol{ - ProtocolBProtocolV2, ProtocolNCLV1, + ProtocolBProtocolV2, } ) @@ -41,16 +28,6 @@ func (p Protocol) String() string { // GetPreferredProtocol accepts a slice of available protocols and returns the // preferred protocol based on the order of preference along with any error func GetPreferredProtocol(availableProtocols []Protocol) Protocol { - // Check if NCL is preferred via environment variable - if os.Getenv(EnvPreferNCL) == "true" { - // If NCL is available when preferred, use it - for _, p := range availableProtocols { - if p == ProtocolNCLV1 { - return ProtocolNCLV1 - } - } - } - for _, preferred := range preferredProtocols { for _, available := range availableProtocols { if preferred == available { diff --git a/pkg/node/compute.go b/pkg/node/compute.go index 4e8b4b7ac5..753e39f821 100644 --- a/pkg/node/compute.go +++ b/pkg/node/compute.go @@ -194,7 +194,6 @@ func NewComputeNode( } if err = legacyConnectionManager.Start(ctx); err != nil { log.Warn().Err(err).Msg("failed to start legacy connection manager. continuing without it") - err = nil } // connection manager diff --git a/pkg/orchestrator/watchers/ncl_message_creator.go b/pkg/orchestrator/watchers/ncl_message_creator.go index f7628a161e..ce6e4d59bf 100644 --- a/pkg/orchestrator/watchers/ncl_message_creator.go +++ b/pkg/orchestrator/watchers/ncl_message_creator.go @@ -2,12 +2,14 @@ package watchers import ( "context" + "errors" "github.com/rs/zerolog/log" "github.com/bacalhau-project/bacalhau/pkg/bacerrors" "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/lib/validate" "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" "github.com/bacalhau-project/bacalhau/pkg/models" "github.com/bacalhau-project/bacalhau/pkg/models/messages" @@ -32,7 +34,8 @@ func NewNCLMessageCreatorFactory(params NCLMessageCreatorFactoryParams) *NCLMess } } -func (f *NCLMessageCreatorFactory) CreateMessageCreator(ctx context.Context, nodeID string) nclprotocol.MessageCreator { +func (f *NCLMessageCreatorFactory) CreateMessageCreator( + ctx context.Context, nodeID string) (nclprotocol.MessageCreator, error) { return NewNCLMessageCreator(NCLMessageCreatorParams{ NodeID: nodeID, ProtocolRouter: f.protocolRouter, @@ -53,12 +56,28 @@ type NCLMessageCreatorParams struct { } // NewNCLMessageCreator creates a new NCL protocol dispatcher -func NewNCLMessageCreator(params NCLMessageCreatorParams) *NCLMessageCreator { +func NewNCLMessageCreator(params NCLMessageCreatorParams) (*NCLMessageCreator, error) { + err := errors.Join( + validate.NotBlank(params.NodeID, "nodeID cannot be blank"), + validate.NotNil(params.ProtocolRouter, "protocol router cannot be nil"), + validate.NotNil(params.SubjectFn, "subject function cannot be nil"), + ) + if params.SubjectFn != nil { + // verify the subject function is provided and that it returns a non-empty string + // by just validating against the current NodeID + err = errors.Join(err, + validate.NotBlank(params.SubjectFn(params.NodeID), "subject function returned empty")) + } + + if err != nil { + return nil, bacerrors.Wrap(err, "failed to create NCLMessageCreator"). + WithComponent(nclDispatcherErrComponent) + } return &NCLMessageCreator{ nodeID: params.NodeID, protocolRouter: params.ProtocolRouter, subjectFn: params.SubjectFn, - } + }, nil } func (d *NCLMessageCreator) CreateMessage(event watcher.Event) (*envelope.Message, error) { @@ -165,3 +184,4 @@ func (d *NCLMessageCreator) createCancelMessage(upsert models.ExecutionUpsert) * // compile-time check that NCLMessageCreator implements dispatcher.MessageCreator var _ nclprotocol.MessageCreator = &NCLMessageCreator{} +var _ nclprotocol.MessageCreatorFactory = &NCLMessageCreatorFactory{} diff --git a/pkg/orchestrator/watchers/ncl_message_creator_test.go b/pkg/orchestrator/watchers/ncl_message_creator_test.go index 9a8a5992ec..cc91931a8a 100644 --- a/pkg/orchestrator/watchers/ncl_message_creator_test.go +++ b/pkg/orchestrator/watchers/ncl_message_creator_test.go @@ -3,6 +3,7 @@ package watchers import ( + "context" "testing" "github.com/stretchr/testify/suite" @@ -44,16 +45,93 @@ func (s *NCLMessageCreatorTestSuite) SetupTest() { return "test." + nodeID } - s.creator = NewNCLMessageCreator(NCLMessageCreatorParams{ + s.creator, err = NewNCLMessageCreator(NCLMessageCreatorParams{ + NodeID: "test-node", ProtocolRouter: s.protocolRouter, SubjectFn: s.subjectFn, }) + s.Require().NoError(err) } func (s *NCLMessageCreatorTestSuite) TearDownTest() { s.ctrl.Finish() } +func (s *NCLMessageCreatorTestSuite) TestNewNCLMessageCreator() { + tests := []struct { + name string + params NCLMessageCreatorParams + expectError string + }{ + { + name: "valid params", + params: NCLMessageCreatorParams{ + NodeID: "test-node", + ProtocolRouter: s.protocolRouter, + SubjectFn: s.subjectFn, + }, + }, + { + name: "missing nodeID", + params: NCLMessageCreatorParams{ + ProtocolRouter: s.protocolRouter, + SubjectFn: s.subjectFn, + }, + expectError: "nodeID cannot be blank", + }, + { + name: "missing protocol router", + params: NCLMessageCreatorParams{ + NodeID: "test-node", + SubjectFn: s.subjectFn, + }, + expectError: "protocol router cannot be nil", + }, + { + name: "missing subject function", + params: NCLMessageCreatorParams{ + NodeID: "test-node", + ProtocolRouter: s.protocolRouter, + }, + expectError: "subject function cannot be nil", + }, + { + name: "blank subject function", + params: NCLMessageCreatorParams{ + NodeID: "test-node", + ProtocolRouter: s.protocolRouter, + SubjectFn: func(nodeID string) string { return "" }, + }, + expectError: "subject function returned empty", + }, + } + + for _, tc := range tests { + s.Run(tc.name, func() { + creator, err := NewNCLMessageCreator(tc.params) + if tc.expectError != "" { + s.Error(err) + s.ErrorContains(err, tc.expectError) + s.Nil(creator) + } else { + s.NoError(err) + s.NotNil(creator) + } + }) + } +} + +func (s *NCLMessageCreatorTestSuite) TestMessageCreatorFactory() { + factory := NewNCLMessageCreatorFactory(NCLMessageCreatorFactoryParams{ + ProtocolRouter: s.protocolRouter, + SubjectFn: s.subjectFn, + }) + + creator, err := factory.CreateMessageCreator(context.Background(), "test-node") + s.Require().NoError(err) + s.NotNil(creator) +} + func (s *NCLMessageCreatorTestSuite) TestCreateMessage_InvalidObject() { msg, err := s.creator.CreateMessage(watcher.Event{ Object: "not an execution upsert", @@ -63,9 +141,11 @@ func (s *NCLMessageCreatorTestSuite) TestCreateMessage_InvalidObject() { } func (s *NCLMessageCreatorTestSuite) TestCreateMessage_NoStateChange() { + execution := mock.Execution() + execution.NodeID = "test-node" upsert := models.ExecutionUpsert{ - Previous: mock.Execution(), - Current: mock.Execution(), + Previous: execution, + Current: execution, } msg, err := s.creator.CreateMessage(createExecutionEvent(upsert)) @@ -73,11 +153,24 @@ func (s *NCLMessageCreatorTestSuite) TestCreateMessage_NoStateChange() { s.Nil(msg) } +func (s *NCLMessageCreatorTestSuite) TestCreateMessage_WrongNode() { + upsert := setupNewExecution( + models.ExecutionDesiredStatePending, + models.ExecutionStateNew, + ) + upsert.Current.NodeID = "different-node" + + msg, err := s.creator.CreateMessage(createExecutionEvent(upsert)) + s.NoError(err) + s.Nil(msg) +} + func (s *NCLMessageCreatorTestSuite) TestCreateMessage_UnsupportedProtocol() { upsert := setupNewExecution( models.ExecutionDesiredStatePending, models.ExecutionStateNew, ) + upsert.Current.NodeID = "test-node" // Mock node only supporting BProtocol s.nodeStore.EXPECT().Get(gomock.Any(), upsert.Current.NodeID).Return( @@ -113,6 +206,7 @@ func (s *NCLMessageCreatorTestSuite) TestCreateMessage_AskForBid() { tc.desiredState, models.ExecutionStateNew, ) + upsert.Current.NodeID = "test-node" // Mock node supporting NCL s.nodeStore.EXPECT().Get(gomock.Any(), upsert.Current.NodeID).Return( @@ -145,6 +239,7 @@ func (s *NCLMessageCreatorTestSuite) TestCreateMessage_BidAccepted() { models.ExecutionDesiredStateRunning, models.ExecutionStateAskForBidAccepted, ) + upsert.Current.NodeID = "test-node" // Mock node supporting NCL s.nodeStore.EXPECT().Get(gomock.Any(), upsert.Current.NodeID).Return( @@ -176,6 +271,7 @@ func (s *NCLMessageCreatorTestSuite) TestCreateMessage_BidRejected() { models.ExecutionDesiredStateStopped, models.ExecutionStateAskForBidAccepted, ) + upsert.Current.NodeID = "test-node" // Mock node supporting NCL s.nodeStore.EXPECT().Get(gomock.Any(), upsert.Current.NodeID).Return( @@ -206,6 +302,7 @@ func (s *NCLMessageCreatorTestSuite) TestCreateMessage_Cancel() { models.ExecutionDesiredStateStopped, models.ExecutionStateRunning, ) + upsert.Current.NodeID = "test-node" // Mock node supporting NCL s.nodeStore.EXPECT().Get(gomock.Any(), upsert.Current.NodeID).Return( diff --git a/pkg/orchestrator/watchers/protocol_router_test.go b/pkg/orchestrator/watchers/protocol_router_test.go index 04b7ed1c85..6a29b81d11 100644 --- a/pkg/orchestrator/watchers/protocol_router_test.go +++ b/pkg/orchestrator/watchers/protocol_router_test.go @@ -101,29 +101,7 @@ func (s *ProtocolRouterTestSuite) TestPreferredProtocol_NodeStoreError() { s.Empty(protocol) } -func (s *ProtocolRouterTestSuite) TestPreferredProtocol_PreferBProtocol() { - s.T().Setenv(models.EnvPreferNCL, "false") - execution := mock.Execution() - - // Node supports both protocols - nodeState := models.NodeState{ - Info: models.NodeInfo{ - SupportedProtocols: []models.Protocol{ - models.ProtocolNCLV1, - models.ProtocolBProtocolV2, - }, - }, - } - - s.nodeStore.EXPECT().Get(s.ctx, execution.NodeID).Return(nodeState, nil) - - protocol, err := s.router.PreferredProtocol(s.ctx, execution) - s.NoError(err) - s.Equal(models.ProtocolBProtocolV2, protocol) -} - func (s *ProtocolRouterTestSuite) TestPreferredProtocol_PreferNCL() { - s.T().Setenv(models.EnvPreferNCL, "true") execution := mock.Execution() // Node supports both protocols diff --git a/pkg/transport/bprotocol/compute/transport.go b/pkg/transport/bprotocol/compute/transport.go index fc775ae9d0..f92d40898d 100644 --- a/pkg/transport/bprotocol/compute/transport.go +++ b/pkg/transport/bprotocol/compute/transport.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "strings" "github.com/nats-io/nats.go" "github.com/rs/zerolog/log" @@ -145,7 +146,7 @@ func (cm *ConnectionManager) Start(ctx context.Context) error { HeartbeatConfig: cm.config.HeartbeatConfig, }) if err = managementClient.RegisterNode(ctx); err != nil { - if errors.Is(err, bprotocol.ErrUpgradeAvailable) { + if strings.Contains(err.Error(), bprotocol.ErrUpgradeAvailable.Error()) { log.Info().Msg("Disabling bprotocol management client due to upgrade available") cm.Stop(ctx) return nil diff --git a/pkg/transport/nclprotocol/dispatcher/config_test.go b/pkg/transport/nclprotocol/dispatcher/config_test.go index cfe3ec8dbf..f48343daea 100644 --- a/pkg/transport/nclprotocol/dispatcher/config_test.go +++ b/pkg/transport/nclprotocol/dispatcher/config_test.go @@ -52,11 +52,6 @@ func (suite *ConfigTestSuite) TestConfigValidation() { mutate: func(c *Config) { *c = Config{} }, expectError: "must be positive", }, - { - name: "zero checkpoint interval", - mutate: func(c *Config) { c.CheckpointInterval = 0 }, - expectError: "CheckpointInterval must be positive", - }, { name: "zero stall timeout", mutate: func(c *Config) { c.StallTimeout = 0 }, diff --git a/pkg/transport/nclprotocol/orchestrator/dataplane.go b/pkg/transport/nclprotocol/orchestrator/dataplane.go index b12c120a76..07d35340ad 100644 --- a/pkg/transport/nclprotocol/orchestrator/dataplane.go +++ b/pkg/transport/nclprotocol/orchestrator/dataplane.go @@ -182,8 +182,11 @@ func (dp *DataPlane) setupDispatcher(ctx context.Context) error { } // Create message creator for this compute node - messageCreator := dp.config.MessageCreatorFactory.CreateMessageCreator( + messageCreator, err := dp.config.MessageCreatorFactory.CreateMessageCreator( ctx, dp.config.NodeID) + if err != nil { + return fmt.Errorf("create message creator: %w", err) + } // Disable checkpointing in dispatcher since we handle it elsewhere config := dp.config.DispatcherConfig diff --git a/pkg/transport/nclprotocol/types.go b/pkg/transport/nclprotocol/types.go index 5b2fdace58..358dc5093e 100644 --- a/pkg/transport/nclprotocol/types.go +++ b/pkg/transport/nclprotocol/types.go @@ -67,7 +67,7 @@ type MessageCreator interface { } type MessageCreatorFactory interface { - CreateMessageCreator(ctx context.Context, nodeID string) MessageCreator + CreateMessageCreator(ctx context.Context, nodeID string) (MessageCreator, error) } // GenerateMsgID Message ID generation helper From 59dda1ae75954fbf13b9894050f74c1da9bdc6ba Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Tue, 10 Dec 2024 10:22:07 +0200 Subject: [PATCH 07/16] deprecate ResourceUpdateInterval --- pkg/config/defaults.go | 5 ++--- pkg/config/migrate.go | 5 ++--- pkg/config/types/compute.go | 3 +-- pkg/test/compute/resourcelimits_test.go | 2 +- pkg/test/devstack/oversubscription_test.go | 2 +- pkg/transport/bprotocol/compute/management_client.go | 2 +- 6 files changed, 8 insertions(+), 11 deletions(-) diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 07dcc6bc1c..d2dffc2554 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -47,9 +47,8 @@ var Default = types.Bacalhau{ Enabled: false, Orchestrators: []string{"nats://127.0.0.1:4222"}, Heartbeat: types.Heartbeat{ - InfoUpdateInterval: types.Minute, - ResourceUpdateInterval: 30 * types.Second, - Interval: 15 * types.Second, + InfoUpdateInterval: types.Minute, + Interval: 15 * types.Second, }, AllocatedCapacity: types.ResourceScaler{ CPU: "70%", diff --git a/pkg/config/migrate.go b/pkg/config/migrate.go index 5a0b5d26eb..217b4329fb 100644 --- a/pkg/config/migrate.go +++ b/pkg/config/migrate.go @@ -38,9 +38,8 @@ func MigrateV1(in v1types.BacalhauConfig) (types.Bacalhau, error) { }), Orchestrators: in.Node.Network.Orchestrators, Heartbeat: types.Heartbeat{ - Interval: types.Duration(in.Node.Compute.ControlPlaneSettings.HeartbeatFrequency), - ResourceUpdateInterval: types.Duration(in.Node.Compute.ControlPlaneSettings.ResourceUpdateFrequency), - InfoUpdateInterval: types.Duration(in.Node.Compute.ControlPlaneSettings.InfoUpdateFrequency), + Interval: types.Duration(in.Node.Compute.ControlPlaneSettings.HeartbeatFrequency), + InfoUpdateInterval: types.Duration(in.Node.Compute.ControlPlaneSettings.InfoUpdateFrequency), }, AllowListedLocalPaths: in.Node.AllowListedLocalPaths, Auth: types.ComputeAuth{Token: in.Node.Network.AuthSecret}, diff --git a/pkg/config/types/compute.go b/pkg/config/types/compute.go index 074a956818..a7e04d31c4 100644 --- a/pkg/config/types/compute.go +++ b/pkg/config/types/compute.go @@ -31,8 +31,7 @@ type ComputeTLS struct { type Heartbeat struct { // InfoUpdateInterval specifies the time between updates of non-resource information to the orchestrator. InfoUpdateInterval Duration `yaml:"InfoUpdateInterval,omitempty" json:"InfoUpdateInterval,omitempty"` - // ResourceUpdateInterval specifies the time between updates of resource information to the orchestrator. - // Deprecated: only used by legacy transport, will be removed in the future. + // Deprecated: use Interval instead ResourceUpdateInterval Duration `yaml:"ResourceUpdateInterval,omitempty" json:"ResourceUpdateInterval,omitempty"` // Interval specifies the time between heartbeat signals sent to the orchestrator. Interval Duration `yaml:"Interval,omitempty" json:"Interval,omitempty"` diff --git a/pkg/test/compute/resourcelimits_test.go b/pkg/test/compute/resourcelimits_test.go index f4635b271d..9ff3bbe455 100644 --- a/pkg/test/compute/resourcelimits_test.go +++ b/pkg/test/compute/resourcelimits_test.go @@ -305,7 +305,7 @@ func (suite *ComputeNodeResourceLimitsSuite) TestParallelGPU() { GPU: "1", }), Heartbeat: types.Heartbeat{ - ResourceUpdateInterval: types.Duration(50 * time.Millisecond), + Interval: types.Duration(50 * time.Millisecond), }, }, }) diff --git a/pkg/test/devstack/oversubscription_test.go b/pkg/test/devstack/oversubscription_test.go index 5a3af2889c..1416b701b5 100644 --- a/pkg/test/devstack/oversubscription_test.go +++ b/pkg/test/devstack/oversubscription_test.go @@ -67,7 +67,7 @@ func (s *OverSubscriptionTestSuite) setupStack(overSubscriptionFactor float64) { Compute: types.Compute{ AllocatedCapacity: types.ResourceScalerFromModelsResourceConfig(s.jobResources), Heartbeat: types.Heartbeat{ - ResourceUpdateInterval: types.Duration(s.resourceUpdateFrequency), + Interval: types.Duration(s.resourceUpdateFrequency), }, }, JobDefaults: types.JobDefaults{ diff --git a/pkg/transport/bprotocol/compute/management_client.go b/pkg/transport/bprotocol/compute/management_client.go index 916cdca5cc..80da6af607 100644 --- a/pkg/transport/bprotocol/compute/management_client.go +++ b/pkg/transport/bprotocol/compute/management_client.go @@ -113,7 +113,7 @@ func (m *ManagementClient) heartbeat(ctx context.Context, seq uint64) { func (m *ManagementClient) Start(ctx context.Context) { infoTicker := time.NewTicker(m.heartbeatConfig.InfoUpdateInterval.AsTimeDuration()) - resourceTicker := time.NewTicker(m.heartbeatConfig.ResourceUpdateInterval.AsTimeDuration()) + resourceTicker := time.NewTicker(m.heartbeatConfig.Interval.AsTimeDuration()) // The heartbeat ticker will be used to send heartbeats to the requester node and // should be configured to be about half of the heartbeat frequency of the requester node From 90e0fddccd9f5ec478e80c7cf218292cb170cfe6 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Tue, 10 Dec 2024 10:24:31 +0200 Subject: [PATCH 08/16] Update pkg/transport/nclprotocol/compute/manager.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- pkg/transport/nclprotocol/compute/manager.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/transport/nclprotocol/compute/manager.go b/pkg/transport/nclprotocol/compute/manager.go index 07d81c6db4..0710310723 100644 --- a/pkg/transport/nclprotocol/compute/manager.go +++ b/pkg/transport/nclprotocol/compute/manager.go @@ -414,7 +414,8 @@ func (cm *ConnectionManager) checkConnectionHealth() { var reason string var unhealthy bool - if cm.GetHealth().LastSuccessfulHeartbeat.Before(heartbeatDeadline) { + health := cm.GetHealth() + if health.LastSuccessfulHeartbeat.Before(heartbeatDeadline) { reason = fmt.Sprintf("no heartbeat for %d intervals", cm.config.HeartbeatMissFactor) unhealthy = true } else if cm.natsConn.IsClosed() { From 4ee32882e18b592125c7a43da53f222dcd3275b1 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Tue, 10 Dec 2024 10:26:33 +0200 Subject: [PATCH 09/16] reuse health --- pkg/transport/nclprotocol/compute/manager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/transport/nclprotocol/compute/manager.go b/pkg/transport/nclprotocol/compute/manager.go index 0710310723..ef4362583f 100644 --- a/pkg/transport/nclprotocol/compute/manager.go +++ b/pkg/transport/nclprotocol/compute/manager.go @@ -425,7 +425,7 @@ func (cm *ConnectionManager) checkConnectionHealth() { if unhealthy { log.Warn(). - Time("lastHeartbeat", cm.GetHealth().LastSuccessfulHeartbeat). + Time("lastHeartbeat", health.LastSuccessfulHeartbeat). Time("deadline", heartbeatDeadline). Int("heartbeatMissFactor", cm.config.HeartbeatMissFactor). Str("reason", reason). From 2bb1bd43aac23cd883e4c2f0ac5f66baa1c84d82 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Tue, 10 Dec 2024 10:29:29 +0200 Subject: [PATCH 10/16] fix lint, spellcheck and swagger --- .cspell/custom-dictionary.txt | 1 + pkg/swagger/docs.go | 2 +- pkg/swagger/swagger.json | 2 +- pkg/transport/nclprotocol/compute/config.go | 2 ++ webui/lib/api/generated/types.gen.ts | 2 +- webui/lib/api/schema/swagger.json | 2 +- 6 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.cspell/custom-dictionary.txt b/.cspell/custom-dictionary.txt index 3355fb742b..d55c3f3e8e 100644 --- a/.cspell/custom-dictionary.txt +++ b/.cspell/custom-dictionary.txt @@ -20,6 +20,7 @@ boltdb booga boxo bprotocol +nclprotocol BRSNW BUCKETNAME buildx diff --git a/pkg/swagger/docs.go b/pkg/swagger/docs.go index 72b3a6d297..6c35df11a4 100644 --- a/pkg/swagger/docs.go +++ b/pkg/swagger/docs.go @@ -2560,7 +2560,7 @@ const docTemplate = `{ "type": "integer" }, "ResourceUpdateInterval": { - "description": "ResourceUpdateInterval specifies the time between updates of resource information to the orchestrator.", + "description": "Deprecated: use Interval instead", "type": "integer" } } diff --git a/pkg/swagger/swagger.json b/pkg/swagger/swagger.json index a773ef7b8f..52e2856112 100644 --- a/pkg/swagger/swagger.json +++ b/pkg/swagger/swagger.json @@ -2556,7 +2556,7 @@ "type": "integer" }, "ResourceUpdateInterval": { - "description": "ResourceUpdateInterval specifies the time between updates of resource information to the orchestrator.", + "description": "Deprecated: use Interval instead", "type": "integer" } } diff --git a/pkg/transport/nclprotocol/compute/config.go b/pkg/transport/nclprotocol/compute/config.go index a9b4f9ccbf..9ed9ceb296 100644 --- a/pkg/transport/nclprotocol/compute/config.go +++ b/pkg/transport/nclprotocol/compute/config.go @@ -78,6 +78,8 @@ func (c *Config) Validate() error { } // DefaultConfig returns a new Config with default values +// +//nolint:mnd func DefaultConfig() Config { // defaults for heartbeatInterval and nodeInfoUpdateInterval are provided by BacalhauConfig, // and equal to 15 seconds and 1 minute respectively diff --git a/webui/lib/api/generated/types.gen.ts b/webui/lib/api/generated/types.gen.ts index 61cdb3a60c..6178d9f252 100644 --- a/webui/lib/api/generated/types.gen.ts +++ b/webui/lib/api/generated/types.gen.ts @@ -987,7 +987,7 @@ export type types_Heartbeat = { */ Interval?: number; /** - * ResourceUpdateInterval specifies the time between updates of resource information to the orchestrator. + * Deprecated: use Interval instead */ ResourceUpdateInterval?: number; }; diff --git a/webui/lib/api/schema/swagger.json b/webui/lib/api/schema/swagger.json index a773ef7b8f..52e2856112 100644 --- a/webui/lib/api/schema/swagger.json +++ b/webui/lib/api/schema/swagger.json @@ -2556,7 +2556,7 @@ "type": "integer" }, "ResourceUpdateInterval": { - "description": "ResourceUpdateInterval specifies the time between updates of resource information to the orchestrator.", + "description": "Deprecated: use Interval instead", "type": "integer" } } From 8c56ee71715eef6a3d1048ea549c775be129aef5 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Wed, 11 Dec 2024 10:03:37 +0200 Subject: [PATCH 11/16] add tests --- .cspell/custom-dictionary.txt | 4 +- go.mod | 1 + pkg/models/buildversion.go | 9 + pkg/models/node_info.go | 61 +++- pkg/models/node_info_test.go | 322 +++++++++++++++++ pkg/models/resource.go | 18 + pkg/node/compute.go | 4 +- pkg/test/utils/nats.go | 64 ++++ pkg/test/utils/utils.go | 47 --- pkg/test/utils/watcher.go | 78 ++++ pkg/transport/nclprotocol/compute/config.go | 2 +- .../nclprotocol/compute/config_test.go | 149 ++++++++ .../nclprotocol/compute/controlplane.go | 21 +- .../nclprotocol/compute/controlplane_test.go | 277 ++++++++++++++ .../nclprotocol/compute/dataplane.go | 35 +- .../nclprotocol/compute/dataplane_test.go | 242 +++++++++++++ .../nclprotocol/compute/health_tracker.go | 8 + .../compute/health_tracker_test.go | 123 +++++++ pkg/transport/nclprotocol/compute/manager.go | 80 +++-- .../nclprotocol/compute/manager_test.go | 269 ++++++++++++++ .../dispatcher/dispatcher_e2e_test.go | 18 +- pkg/transport/nclprotocol/mocks.go | 5 +- .../orchestrator/dataplane_test.go | 294 +++++++++++++++ .../nclprotocol/test/control_plane.go | 340 ++++++++++++++++++ .../nclprotocol/test/message_creation.go | 95 +++++ .../nclprotocol/test/message_handling.go | 63 ++++ pkg/transport/nclprotocol/test/nodes.go | 40 +++ pkg/transport/nclprotocol/test/utils.go | 17 + 28 files changed, 2551 insertions(+), 135 deletions(-) create mode 100644 pkg/models/node_info_test.go create mode 100644 pkg/test/utils/nats.go create mode 100644 pkg/test/utils/watcher.go create mode 100644 pkg/transport/nclprotocol/compute/config_test.go create mode 100644 pkg/transport/nclprotocol/compute/controlplane_test.go create mode 100644 pkg/transport/nclprotocol/compute/dataplane_test.go create mode 100644 pkg/transport/nclprotocol/compute/health_tracker_test.go create mode 100644 pkg/transport/nclprotocol/compute/manager_test.go create mode 100644 pkg/transport/nclprotocol/orchestrator/dataplane_test.go create mode 100644 pkg/transport/nclprotocol/test/control_plane.go create mode 100644 pkg/transport/nclprotocol/test/message_creation.go create mode 100644 pkg/transport/nclprotocol/test/message_handling.go create mode 100644 pkg/transport/nclprotocol/test/nodes.go create mode 100644 pkg/transport/nclprotocol/test/utils.go diff --git a/.cspell/custom-dictionary.txt b/.cspell/custom-dictionary.txt index d55c3f3e8e..7f86071cc8 100644 --- a/.cspell/custom-dictionary.txt +++ b/.cspell/custom-dictionary.txt @@ -442,4 +442,6 @@ tlsca Lenf traefik bprotocolcompute -bprotocolorchestrator \ No newline at end of file +bprotocolorchestrator +nclprotocolcompute +ncltest \ No newline at end of file diff --git a/go.mod b/go.mod index a3a767ccaa..a39dfb3320 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/ghodss/yaml v1.0.0 github.com/go-playground/validator/v10 v10.16.0 github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.1 github.com/hashicorp/go-retryablehttp v0.7.7 diff --git a/pkg/models/buildversion.go b/pkg/models/buildversion.go index b3a59fbcc8..6f72ece968 100644 --- a/pkg/models/buildversion.go +++ b/pkg/models/buildversion.go @@ -14,3 +14,12 @@ type BuildVersionInfo struct { GOOS string `json:"GOOS" example:"linux"` GOARCH string `json:"GOARCH" example:"amd64"` } + +func (b *BuildVersionInfo) Copy() *BuildVersionInfo { + if b == nil { + return nil + } + newB := new(BuildVersionInfo) + *newB = *b + return newB +} diff --git a/pkg/models/node_info.go b/pkg/models/node_info.go index c8a4d34924..5db51d451f 100644 --- a/pkg/models/node_info.go +++ b/pkg/models/node_info.go @@ -4,8 +4,11 @@ package models import ( "context" "fmt" + "slices" "strings" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/exp/maps" ) @@ -107,10 +110,41 @@ func (n NodeInfo) IsComputeNode() bool { return n.NodeType == NodeTypeCompute } -// HasNodeInfoChanged returns true if the node info has changed compared to the previous call -// TODO: implement this function -func HasNodeInfoChanged(prev, current NodeInfo) bool { - return false +// Copy returns a deep copy of the NodeInfo +func (n *NodeInfo) Copy() *NodeInfo { + if n == nil { + return nil + } + cpy := new(NodeInfo) + *cpy = *n + + // Deep copy maps + cpy.Labels = maps.Clone(n.Labels) + cpy.SupportedProtocols = slices.Clone(n.SupportedProtocols) + cpy.ComputeNodeInfo = *n.ComputeNodeInfo.Copy() + cpy.BacalhauVersion = *n.BacalhauVersion.Copy() + return cpy +} + +// HasStaticConfigChanged returns true if the static/configuration aspects of this node +// have changed compared to other. It ignores dynamic operational fields like queue capacity +// and execution counts that change frequently during normal operation. +func (n NodeInfo) HasStaticConfigChanged(other NodeInfo) bool { + // Define which fields to ignore in the comparison + opts := []cmp.Option{ + cmpopts.IgnoreFields(ComputeNodeInfo{}, + "QueueUsedCapacity", + "AvailableCapacity", + "RunningExecutions", + "EnqueuedExecutions", + ), + // Ignore ordering in slices + cmpopts.SortSlices(func(a, b string) bool { return a < b }), + cmpopts.SortSlices(func(a, b Protocol) bool { return string(a) < string(b) }), + cmpopts.SortSlices(func(a, b GPU) bool { return a.Less(b) }), // Sort GPUs by all fields for stable comparison + } + + return !cmp.Equal(n, other, opts...) } // ComputeNodeInfo contains metadata about the current state and abilities of a compute node. Compute Nodes share @@ -126,3 +160,22 @@ type ComputeNodeInfo struct { RunningExecutions int `json:"RunningExecutions"` EnqueuedExecutions int `json:"EnqueuedExecutions"` } + +// Copy provides a copy of the allocation and deep copies the job +func (c *ComputeNodeInfo) Copy() *ComputeNodeInfo { + if c == nil { + return nil + } + cpy := new(ComputeNodeInfo) + *cpy = *c + + // Deep copy slices + cpy.ExecutionEngines = slices.Clone(c.ExecutionEngines) + cpy.Publishers = slices.Clone(c.Publishers) + cpy.StorageSources = slices.Clone(c.StorageSources) + cpy.MaxCapacity = *c.MaxCapacity.Copy() + cpy.QueueUsedCapacity = *c.QueueUsedCapacity.Copy() + cpy.AvailableCapacity = *c.AvailableCapacity.Copy() + cpy.MaxJobRequirements = *c.MaxJobRequirements.Copy() + return cpy +} diff --git a/pkg/models/node_info_test.go b/pkg/models/node_info_test.go new file mode 100644 index 0000000000..46dc330372 --- /dev/null +++ b/pkg/models/node_info_test.go @@ -0,0 +1,322 @@ +//go:build unit || !integration + +package models + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type NodeInfoTestSuite struct { + suite.Suite +} + +func TestNodeInfoTestSuite(t *testing.T) { + suite.Run(t, new(NodeInfoTestSuite)) +} + +func (s *NodeInfoTestSuite) TestHasNodeInfoChanged() { + baseNodeInfo := &NodeInfo{ + NodeID: "node-1", + NodeType: NodeTypeCompute, + Labels: map[string]string{ + "zone": "us-east-1", + "env": "prod", + }, + SupportedProtocols: []Protocol{ProtocolNCLV1, ProtocolBProtocolV2}, + BacalhauVersion: BuildVersionInfo{Major: "1", Minor: "0"}, + ComputeNodeInfo: ComputeNodeInfo{ + ExecutionEngines: []string{"docker", "wasm"}, + Publishers: []string{"ipfs"}, + StorageSources: []string{"s3", "ipfs"}, + MaxCapacity: Resources{ + CPU: 4, + Memory: 8192, + Disk: 100, + GPU: 1, + GPUs: []GPU{ + { + Index: 0, + Name: "Tesla T4", + Vendor: GPUVendorNvidia, + Memory: 16384, + PCIAddress: "0000:00:1e.0", + }, + }, + }, + MaxJobRequirements: Resources{ + CPU: 2, + Memory: 4096, + Disk: 50, + GPU: 0, + }, + // Dynamic fields that should be ignored + QueueUsedCapacity: Resources{ + CPU: 1, + Memory: 2048, + Disk: 25, + }, + AvailableCapacity: Resources{ + CPU: 3, + Memory: 6144, + Disk: 75, + }, + RunningExecutions: 2, + EnqueuedExecutions: 1, + }, + } + + testCases := []struct { + name string + changeFunction func(info *NodeInfo) *NodeInfo + expectChanged bool + }{ + { + name: "identical nodes", + changeFunction: func(info *NodeInfo) *NodeInfo { return info.Copy() }, + expectChanged: false, + }, + { + name: "different node ID", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.NodeID = "node-2" + return info + }, + expectChanged: true, + }, + { + name: "different node type", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.NodeType = NodeTypeRequester + return info + }, + expectChanged: true, + }, + { + name: "different labels - new label", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.Labels["new"] = "value" + return info + }, + expectChanged: true, + }, + { + name: "different labels - changed value", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.Labels["zone"] = "us-west-1" + return info + }, + expectChanged: true, + }, + { + name: "different labels - removed label", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + delete(info.Labels, "zone") + return info + }, + expectChanged: true, + }, + { + name: "different protocols - added", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.SupportedProtocols = append(info.SupportedProtocols, Protocol("NewProtocol")) + return info + }, + expectChanged: true, + }, + { + name: "different version", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.BacalhauVersion.Minor = "1" + return info + }, + expectChanged: true, + }, + { + name: "different execution engines", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.ExecutionEngines = append(info.ComputeNodeInfo.ExecutionEngines, "kubernetes") + return info + }, + expectChanged: true, + }, + { + name: "different publishers", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.Publishers = append(info.ComputeNodeInfo.Publishers, "s3") + return info + }, + expectChanged: true, + }, + { + name: "different storage sources", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.StorageSources = []string{"s3"} + return info + }, + expectChanged: true, + }, + { + name: "different max capacity", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.MaxCapacity.CPU = 8 + return info + }, + expectChanged: true, + }, + { + name: "different max job requirements", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.MaxJobRequirements.Memory = 8192 + return info + }, + expectChanged: true, + }, + { + name: "changed queue capacity only", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.QueueUsedCapacity.CPU = 2 + return info + }, + expectChanged: false, + }, + { + name: "changed available capacity only", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.AvailableCapacity.Memory = 4096 + return info + }, + expectChanged: false, + }, + { + name: "changed running executions only", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.RunningExecutions = 5 + return info + }, + expectChanged: false, + }, + { + name: "changed enqueued executions only", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.EnqueuedExecutions = 3 + return info + }, + expectChanged: false, + }, + { + name: "multiple dynamic field changes only", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.RunningExecutions = 5 + info.ComputeNodeInfo.EnqueuedExecutions = 3 + info.ComputeNodeInfo.QueueUsedCapacity.CPU = 2 + info.ComputeNodeInfo.AvailableCapacity.Memory = 4096 + return info + }, + expectChanged: false, + }, + { + name: "same labels different order", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + // Recreate labels in different order + info.Labels = map[string]string{ + "env": "prod", + "zone": "us-east-1", + } + return info + }, + expectChanged: false, + }, + { + name: "same engines different order", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.ExecutionEngines = []string{"wasm", "docker"} + return info + }, + expectChanged: false, + }, + { + name: "same storage sources different order", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.StorageSources = []string{"ipfs", "s3"} + return info + }, + expectChanged: false, + }, + { + name: "same protocols different order", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.SupportedProtocols = []Protocol{ProtocolBProtocolV2, ProtocolNCLV1} + return info + }, + expectChanged: false, + }, + { + name: "different max capacity GPUs", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.MaxCapacity.GPUs = []GPU{ + {Index: 0, Name: "RTX 3080", Vendor: GPUVendorNvidia, Memory: 10240, PCIAddress: "0000:00:1e.0"}, // Different GPU spec + } + return info + }, + expectChanged: true, + }, + { + name: "same GPUs different order", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + // Set the exact same GPUs but in different order + info.ComputeNodeInfo.MaxCapacity.GPUs = []GPU{ + {Index: 0, Name: "Tesla T4", Vendor: GPUVendorNvidia, Memory: 16384, PCIAddress: "0000:00:1e.0"}, + } + return info + }, + expectChanged: false, + }, + { + name: "changed GPU count only", + changeFunction: func(info *NodeInfo) *NodeInfo { + info = info.Copy() + info.ComputeNodeInfo.MaxCapacity.GPU = 2 + return info + }, + expectChanged: true, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + current := tc.changeFunction(baseNodeInfo) + changed := baseNodeInfo.HasStaticConfigChanged(*current) + + if tc.expectChanged { + s.True(changed, "Expected node info to have changed") + } else { + s.False(changed, "Expected node info to remain unchanged") + } + }) + } +} diff --git a/pkg/models/resource.go b/pkg/models/resource.go index a4e5c313cd..23a0d04e7b 100644 --- a/pkg/models/resource.go +++ b/pkg/models/resource.go @@ -126,6 +126,24 @@ type GPU struct { PCIAddress string } +// Less compares this GPU with another for sorting/ordering purposes +// The comparison order is: Index, Name, Vendor, Memory, PCIAddress +func (g GPU) Less(other GPU) bool { + if g.Index != other.Index { + return g.Index < other.Index + } + if g.Name != other.Name { + return g.Name < other.Name + } + if g.Vendor != other.Vendor { + return g.Vendor < other.Vendor + } + if g.Memory != other.Memory { + return g.Memory < other.Memory + } + return g.PCIAddress < other.PCIAddress +} + type Resources struct { // CPU units CPU float64 `json:"CPU,omitempty"` diff --git a/pkg/node/compute.go b/pkg/node/compute.go index 753e39f821..e37472e661 100644 --- a/pkg/node/compute.go +++ b/pkg/node/compute.go @@ -28,7 +28,7 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/publisher" "github.com/bacalhau-project/bacalhau/pkg/storage" bprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/bprotocol/compute" - transportcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute" + nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute" "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/dispatcher" ) @@ -197,7 +197,7 @@ func NewComputeNode( } // connection manager - connectionManager, err := transportcompute.NewConnectionManager(transportcompute.Config{ + connectionManager, err := nclprotocolcompute.NewConnectionManager(nclprotocolcompute.Config{ NodeID: cfg.NodeID, ClientFactory: clientFactory, NodeInfoProvider: nodeInfoProvider, diff --git a/pkg/test/utils/nats.go b/pkg/test/utils/nats.go new file mode 100644 index 0000000000..47a7f23205 --- /dev/null +++ b/pkg/test/utils/nats.go @@ -0,0 +1,64 @@ +package testutils + +import ( + "net/url" + "strconv" + "testing" + "time" + + natsserver "github.com/nats-io/nats-server/v2/server" + natstest "github.com/nats-io/nats-server/v2/test" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/require" + + "github.com/bacalhau-project/bacalhau/pkg/lib/network" +) + +// startNatsOnPort will start a NATS server on a specific port and return a server and client instances +func startNatsOnPort(t *testing.T, port int) *natsserver.Server { + t.Helper() + opts := &natstest.DefaultTestOptions + opts.Port = port + + natsServer := natstest.RunServer(opts) + return natsServer +} + +func StartNatsServer(t *testing.T) *natsserver.Server { + t.Helper() + port, err := network.GetFreePort() + require.NoError(t, err) + + return startNatsOnPort(t, port) +} + +func CreateNatsClient(t *testing.T, url string) *nats.Conn { + nc, err := nats.Connect(url, + nats.ReconnectBufSize(-1), // disable reconnect buffer so client fails fast if disconnected + nats.ReconnectWait(200*time.Millisecond), //nolint:mnd // reduce reconnect wait to fail fast in tests + nats.FlusherTimeout(100*time.Millisecond), //nolint:mnd // reduce flusher timeout to speed up tests + ) + require.NoError(t, err) + return nc +} + +// StartNats will start a NATS server on a random port and return a server and client instances +func StartNats(t *testing.T) (*natsserver.Server, *nats.Conn) { + natsServer := StartNatsServer(t) + return natsServer, CreateNatsClient(t, natsServer.ClientURL()) +} + +// RestartNatsServer will restart the NATS server and return a new server and client using the same port +func RestartNatsServer(t *testing.T, natsServer *natsserver.Server) (*natsserver.Server, *nats.Conn) { + t.Helper() + natsServer.Shutdown() + + u, err := url.Parse(natsServer.ClientURL()) + require.NoError(t, err, "Failed to parse NATS server URL %s", natsServer.ClientURL()) + + port, err := strconv.Atoi(u.Port()) + require.NoError(t, err, "Failed to convert port %s to int", u.Port()) + + natsServer = startNatsOnPort(t, port) + return natsServer, CreateNatsClient(t, natsServer.ClientURL()) +} diff --git a/pkg/test/utils/utils.go b/pkg/test/utils/utils.go index b9e2d7b691..6b365592ac 100644 --- a/pkg/test/utils/utils.go +++ b/pkg/test/utils/utils.go @@ -2,20 +2,12 @@ package testutils import ( "context" - "net/url" "regexp" - "strconv" "testing" - "time" - natsserver "github.com/nats-io/nats-server/v2/server" - natstest "github.com/nats-io/nats-server/v2/test" - - "github.com/nats-io/nats.go" "github.com/stretchr/testify/require" "github.com/bacalhau-project/bacalhau/pkg/config/types" - "github.com/bacalhau-project/bacalhau/pkg/lib/network" "github.com/bacalhau-project/bacalhau/pkg/models" "github.com/bacalhau-project/bacalhau/pkg/publicapi/apimodels" clientv2 "github.com/bacalhau-project/bacalhau/pkg/publicapi/client/v2" @@ -54,42 +46,3 @@ func MustHaveIPFS(t testing.TB, cfg types.Bacalhau) { func IsIPFSEnabled(ipfsConnect string) bool { return ipfsConnect != "" } - -// startNatsOnPort will start a NATS server on a specific port and return a server and client instances -func startNatsOnPort(t *testing.T, port int) (*natsserver.Server, *nats.Conn) { - t.Helper() - opts := &natstest.DefaultTestOptions - opts.Port = port - - natsServer := natstest.RunServer(opts) - nc, err := nats.Connect(natsServer.ClientURL(), - nats.ReconnectBufSize(-1), // disable reconnect buffer so client fails fast if disconnected - nats.ReconnectWait(200*time.Millisecond), //nolint:mnd // reduce reconnect wait to fail fast in tests - nats.FlusherTimeout(100*time.Millisecond), //nolint:mnd // reduce flusher timeout to speed up tests - ) - require.NoError(t, err) - return natsServer, nc -} - -// StartNats will start a NATS server on a random port and return a server and client instances -func StartNats(t *testing.T) (*natsserver.Server, *nats.Conn) { - t.Helper() - port, err := network.GetFreePort() - require.NoError(t, err) - - return startNatsOnPort(t, port) -} - -// RestartNatsServer will restart the NATS server and return a new server and client using the same port -func RestartNatsServer(t *testing.T, natsServer *natsserver.Server) (*natsserver.Server, *nats.Conn) { - t.Helper() - natsServer.Shutdown() - - u, err := url.Parse(natsServer.ClientURL()) - require.NoError(t, err, "Failed to parse NATS server URL %s", natsServer.ClientURL()) - - port, err := strconv.Atoi(u.Port()) - require.NoError(t, err, "Failed to convert port %s to int", u.Port()) - - return startNatsOnPort(t, port) -} diff --git a/pkg/test/utils/watcher.go b/pkg/test/utils/watcher.go new file mode 100644 index 0000000000..f4dccc3716 --- /dev/null +++ b/pkg/test/utils/watcher.go @@ -0,0 +1,78 @@ +package testutils + +import ( + "errors" + "reflect" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/bacalhau-project/bacalhau/pkg/compute" + "github.com/bacalhau-project/bacalhau/pkg/jobstore" + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" + boltdb_watcher "github.com/bacalhau-project/bacalhau/pkg/lib/watcher/boltdb" + watchertest "github.com/bacalhau-project/bacalhau/pkg/lib/watcher/test" + "github.com/bacalhau-project/bacalhau/pkg/models" +) + +const ( + TypeString = "string" +) + +func CreateComputeEventStore(t *testing.T) watcher.EventStore { + eventObjectSerializer := watcher.NewJSONSerializer() + err := errors.Join( + eventObjectSerializer.RegisterType(compute.EventObjectExecutionUpsert, reflect.TypeOf(models.ExecutionUpsert{})), + eventObjectSerializer.RegisterType(compute.EventObjectExecutionEvent, reflect.TypeOf(models.Event{})), + ) + require.NoError(t, err) + + database := watchertest.CreateBoltDB(t) + + eventStore, err := boltdb_watcher.NewEventStore(database, + boltdb_watcher.WithEventsBucket("events"), + boltdb_watcher.WithCheckpointBucket("checkpoints"), + boltdb_watcher.WithEventSerializer(eventObjectSerializer), + ) + require.NoError(t, err) + return eventStore +} + +func CreateJobEventStore(t *testing.T) watcher.EventStore { + eventObjectSerializer := watcher.NewJSONSerializer() + err := errors.Join( + eventObjectSerializer.RegisterType(jobstore.EventObjectExecutionUpsert, reflect.TypeOf(models.ExecutionUpsert{})), + eventObjectSerializer.RegisterType(jobstore.EventObjectEvaluation, reflect.TypeOf(models.Event{})), + ) + require.NoError(t, err) + + database := watchertest.CreateBoltDB(t) + + eventStore, err := boltdb_watcher.NewEventStore(database, + boltdb_watcher.WithEventsBucket("events"), + boltdb_watcher.WithCheckpointBucket("checkpoints"), + boltdb_watcher.WithEventSerializer(eventObjectSerializer), + ) + require.NoError(t, err) + return eventStore +} + +func CreateStringEventStore(t *testing.T) (watcher.EventStore, *envelope.Registry) { + eventObjectSerializer := watcher.NewJSONSerializer() + require.NoError(t, eventObjectSerializer.RegisterType(TypeString, reflect.TypeOf(""))) + + database := watchertest.CreateBoltDB(t) + + eventStore, err := boltdb_watcher.NewEventStore(database, + boltdb_watcher.WithEventsBucket("events"), + boltdb_watcher.WithCheckpointBucket("checkpoints"), + boltdb_watcher.WithEventSerializer(eventObjectSerializer), + ) + require.NoError(t, err) + + registry := envelope.NewRegistry() + require.NoError(t, registry.Register(TypeString, "")) + + return eventStore, registry +} diff --git a/pkg/transport/nclprotocol/compute/config.go b/pkg/transport/nclprotocol/compute/config.go index 9ed9ceb296..59e8f1f943 100644 --- a/pkg/transport/nclprotocol/compute/config.go +++ b/pkg/transport/nclprotocol/compute/config.go @@ -96,7 +96,7 @@ func DefaultConfig() Config { } } -func (c *Config) setDefaults() { +func (c *Config) SetDefaults() { defaults := DefaultConfig() if c.HeartbeatMissFactor == 0 { c.HeartbeatMissFactor = defaults.HeartbeatMissFactor diff --git a/pkg/transport/nclprotocol/compute/config_test.go b/pkg/transport/nclprotocol/compute/config_test.go new file mode 100644 index 0000000000..629fcae0c7 --- /dev/null +++ b/pkg/transport/nclprotocol/compute/config_test.go @@ -0,0 +1,149 @@ +//go:build unit || !integration + +package compute_test + +import ( + "context" + "testing" + "time" + + "github.com/benbjohnson/clock" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + + "github.com/bacalhau-project/bacalhau/pkg/lib/backoff" + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + natsutil "github.com/bacalhau-project/bacalhau/pkg/nats" + testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" + nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/dispatcher" + ncltest "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/test" +) + +type ConfigTestSuite struct { + suite.Suite + ctrl *gomock.Controller + nodeInfoProvider *ncltest.MockNodeInfoProvider + messageHandler *ncl.MockMessageHandler + checkpointer *nclprotocol.MockCheckpointer +} + +func TestConfigTestSuite(t *testing.T) { + suite.Run(t, new(ConfigTestSuite)) +} + +func (s *ConfigTestSuite) SetupTest() { + s.ctrl = gomock.NewController(s.T()) + s.nodeInfoProvider = ncltest.NewMockNodeInfoProvider() + s.messageHandler = ncl.NewMockMessageHandler(s.ctrl) + s.checkpointer = nclprotocol.NewMockCheckpointer(s.ctrl) +} + +func (s *ConfigTestSuite) TearDownTest() { + s.ctrl.Finish() +} + +func (s *ConfigTestSuite) getValidConfig() nclprotocolcompute.Config { + return nclprotocolcompute.Config{ + NodeID: "test-node", + ClientFactory: natsutil.ClientFactoryFunc(func(ctx context.Context) (*nats.Conn, error) { + return nil, nil + }), + NodeInfoProvider: s.nodeInfoProvider, + MessageSerializer: envelope.NewSerializer(), + MessageRegistry: nclprotocol.MustCreateMessageRegistry(), + HeartbeatInterval: time.Second, + HeartbeatMissFactor: 5, + NodeInfoUpdateInterval: time.Second, + RequestTimeout: time.Second, + ReconnectInterval: time.Second, + ReconnectBackoff: backoff.NewExponential(time.Second, 2*time.Second), + DataPlaneMessageHandler: s.messageHandler, + DataPlaneMessageCreator: &ncltest.MockMessageCreator{}, + EventStore: testutils.CreateComputeEventStore(s.T()), + LogStreamServer: &ncltest.MockLogStreamServer{}, + Checkpointer: s.checkpointer, + CheckpointInterval: time.Second, + Clock: clock.New(), + DispatcherConfig: dispatcher.DefaultConfig(), + } +} + +func (s *ConfigTestSuite) TestValidation() { + testCases := []struct { + name string + mutate func(*nclprotocolcompute.Config) + expectError string + }{ + { + name: "valid config", + mutate: func(c *nclprotocolcompute.Config) {}, + expectError: "", + }, + { + name: "missing node ID", + mutate: func(c *nclprotocolcompute.Config) { c.NodeID = "" }, + expectError: "nodeID cannot be blank", + }, + { + name: "missing required dependencies", + mutate: func(c *nclprotocolcompute.Config) { c.ClientFactory = nil; c.NodeInfoProvider = nil }, + expectError: "cannot be nil", + }, + { + name: "invalid intervals", + mutate: func(c *nclprotocolcompute.Config) { + c.HeartbeatInterval = 0 + c.NodeInfoUpdateInterval = 0 + }, + expectError: "must be positive", + }, + { + name: "invalid dispatcher config", + mutate: func(c *nclprotocolcompute.Config) { + c.DispatcherConfig.StallTimeout = time.Second + c.DispatcherConfig.StallCheckInterval = 2 * time.Second + }, + expectError: "StallCheckInterval must be less than StallTimeout", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + cfg := s.getValidConfig() + tc.mutate(&cfg) + err := cfg.Validate() + + if tc.expectError == "" { + s.NoError(err) + } else { + s.Error(err) + s.Contains(err.Error(), tc.expectError) + } + }) + } +} + +func (s *ConfigTestSuite) TestSetDefaults() { + emptyConfig := nclprotocolcompute.Config{} + emptyConfig.SetDefaults() + defaults := nclprotocolcompute.DefaultConfig() + + s.Equal(defaults.HeartbeatMissFactor, emptyConfig.HeartbeatMissFactor) + s.Equal(defaults.RequestTimeout, emptyConfig.RequestTimeout) + s.NotNil(emptyConfig.MessageSerializer) + s.NotNil(emptyConfig.ReconnectBackoff) + s.NotEqual(dispatcher.Config{}, emptyConfig.DispatcherConfig) + + // Existing values should not be overwritten + customConfig := nclprotocolcompute.Config{ + HeartbeatMissFactor: 10, + RequestTimeout: 20 * time.Second, + } + customConfig.SetDefaults() + s.Equal(10, customConfig.HeartbeatMissFactor) + s.Equal(20*time.Second, customConfig.RequestTimeout) +} diff --git a/pkg/transport/nclprotocol/compute/controlplane.go b/pkg/transport/nclprotocol/compute/controlplane.go index 18e9ab1722..2634d85d6f 100644 --- a/pkg/transport/nclprotocol/compute/controlplane.go +++ b/pkg/transport/nclprotocol/compute/controlplane.go @@ -76,6 +76,8 @@ func (cp *ControlPlane) Start(ctx context.Context) error { return fmt.Errorf("control plane already running") } + cp.latestNodeInfo = cp.cfg.NodeInfoProvider.GetNodeInfo(ctx) + cp.wg.Add(1) go cp.run(ctx) @@ -89,15 +91,9 @@ func (cp *ControlPlane) run(ctx context.Context) { defer cp.wg.Done() // Initialize timers for periodic operations - heartbeat := time.NewTimer(cp.cfg.HeartbeatInterval) - nodeInfo := time.NewTimer(cp.cfg.NodeInfoUpdateInterval) - checkpoint := time.NewTimer(cp.cfg.CheckpointInterval) - - defer func() { - heartbeat.Stop() - nodeInfo.Stop() - checkpoint.Stop() - }() + heartbeat := time.NewTicker(cp.cfg.HeartbeatInterval) + nodeInfo := time.NewTicker(cp.cfg.NodeInfoUpdateInterval) + checkpoint := time.NewTicker(cp.cfg.CheckpointInterval) for { select { @@ -110,19 +106,14 @@ func (cp *ControlPlane) run(ctx context.Context) { if err := cp.heartbeat(ctx); err != nil { log.Error().Err(err).Msg("Failed to send heartbeat") } - heartbeat.Reset(cp.cfg.HeartbeatInterval) - case <-nodeInfo.C: if err := cp.updateNodeInfo(ctx); err != nil { log.Error().Err(err).Msg("Failed to update node info") } - nodeInfo.Reset(cp.cfg.NodeInfoUpdateInterval) - case <-checkpoint.C: if err := cp.checkpointProgress(ctx); err != nil { log.Error().Err(err).Msg("Failed to checkpoint progress") } - checkpoint.Reset(cp.cfg.CheckpointInterval) } } } @@ -173,7 +164,7 @@ func (cp *ControlPlane) updateNodeInfo(ctx context.Context) error { // Only send updates when node info has changed prevNodeInfo := cp.latestNodeInfo cp.latestNodeInfo = cp.cfg.NodeInfoProvider.GetNodeInfo(ctx) - if !models.HasNodeInfoChanged(prevNodeInfo, cp.latestNodeInfo) { + if !prevNodeInfo.HasStaticConfigChanged(cp.latestNodeInfo) { return nil } diff --git a/pkg/transport/nclprotocol/compute/controlplane_test.go b/pkg/transport/nclprotocol/compute/controlplane_test.go new file mode 100644 index 0000000000..13b86ba1ab --- /dev/null +++ b/pkg/transport/nclprotocol/compute/controlplane_test.go @@ -0,0 +1,277 @@ +//go:build unit || !integration + +package compute_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/benbjohnson/clock" + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/models/messages" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" + nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute" + ncltest "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/test" +) + +type ControlPlaneTestSuite struct { + suite.Suite + ctrl *gomock.Controller + ctx context.Context + cancel context.CancelFunc + clock clock.Clock + requester *ncl.MockPublisher + nodeInfoProvider *ncltest.MockNodeInfoProvider + healthTracker *nclprotocolcompute.HealthTracker + checkpointer *ncltest.MockCheckpointer + seqTracker *nclprotocol.SequenceTracker + config nclprotocolcompute.Config +} + +func TestControlPlaneTestSuite(t *testing.T) { + suite.Run(t, new(ControlPlaneTestSuite)) +} + +func (s *ControlPlaneTestSuite) SetupTest() { + s.ctrl = gomock.NewController(s.T()) + s.ctx, s.cancel = context.WithCancel(context.Background()) + s.clock = clock.New() // tickers didn't work properly with mock clock + + // Create mocks + s.requester = ncl.NewMockPublisher(s.ctrl) + s.nodeInfoProvider = ncltest.NewMockNodeInfoProvider() + s.checkpointer = ncltest.NewMockCheckpointer() + + // Create real components + s.healthTracker = nclprotocolcompute.NewHealthTracker(s.clock) + s.seqTracker = nclprotocol.NewSequenceTracker() + + // Setup basic config with short intervals for testing + s.config = nclprotocolcompute.Config{ + NodeID: "test-node", + NodeInfoProvider: s.nodeInfoProvider, + Checkpointer: s.checkpointer, + HeartbeatInterval: 50 * time.Millisecond, + NodeInfoUpdateInterval: 100 * time.Millisecond, + CheckpointInterval: 150 * time.Millisecond, + RequestTimeout: 50 * time.Millisecond, + Clock: s.clock, + } +} + +func (s *ControlPlaneTestSuite) createControlPlane( + heartbeatInterval time.Duration, + nodeInfoInterval time.Duration, + checkpointInterval time.Duration, +) *nclprotocolcompute.ControlPlane { + config := nclprotocolcompute.Config{ + NodeID: "test-node", + NodeInfoProvider: s.nodeInfoProvider, + Checkpointer: s.checkpointer, + HeartbeatInterval: heartbeatInterval, + NodeInfoUpdateInterval: nodeInfoInterval, + CheckpointInterval: checkpointInterval, + RequestTimeout: 50 * time.Millisecond, + Clock: s.clock, + } + + cp, err := nclprotocolcompute.NewControlPlane(nclprotocolcompute.ControlPlaneParams{ + Config: config, + Requester: s.requester, + HealthTracker: s.healthTracker, + IncomingSeqTracker: s.seqTracker, + CheckpointName: "test-checkpoint", + }) + s.Require().NoError(err) + return cp +} + +func (s *ControlPlaneTestSuite) TearDownTest() { + s.cancel() + s.ctrl.Finish() +} + +func (s *ControlPlaneTestSuite) TestLifecycle() { + controlPlane := s.createControlPlane( + 50*time.Millisecond, + 100*time.Millisecond, + 150*time.Millisecond) + defer s.Require().NoError(controlPlane.Stop(s.ctx)) + + testCases := []struct { + name string + operation func() error + expectError bool + errorMsg string + }{ + { + name: "first start succeeds", + operation: func() error { return controlPlane.Start(s.ctx) }, + expectError: false, + }, + { + name: "second start fails", + operation: func() error { return controlPlane.Start(s.ctx) }, + expectError: true, + errorMsg: "already running", + }, + { + name: "first stop succeeds", + operation: func() error { return controlPlane.Stop(s.ctx) }, + expectError: false, + }, + { + name: "second stop is noop", + operation: func() error { return controlPlane.Stop(s.ctx) }, + expectError: false, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + err := tc.operation() + if tc.expectError { + s.Require().Error(err) + s.Require().Contains(err.Error(), tc.errorMsg) + } else { + s.Require().NoError(err) + } + }) + } +} + +func (s *ControlPlaneTestSuite) TestHeartbeat() { + // Create control plane with only heartbeat enabled + controlPlane := s.createControlPlane( + 50*time.Millisecond, // heartbeat + 1*time.Hour, // node info - disabled + 1*time.Hour, // checkpoint - disabled + ) + defer s.Require().NoError(controlPlane.Stop(s.ctx)) + + nodeInfo := s.nodeInfoProvider.GetNodeInfo(s.ctx) + heartbeatMsg := envelope.NewMessage(messages.HeartbeatRequest{ + NodeID: nodeInfo.NodeID, + AvailableCapacity: nodeInfo.ComputeNodeInfo.AvailableCapacity, + QueueUsedCapacity: nodeInfo.ComputeNodeInfo.QueueUsedCapacity, + LastOrchestratorSeqNum: s.seqTracker.GetLastSeqNum(), + }).WithMetadataValue(envelope.KeyMessageType, messages.HeartbeatRequestMessageType) + + s.requester.EXPECT(). + Request(gomock.Any(), ncl.NewPublishRequest(heartbeatMsg)). + Return(envelope.NewMessage(messages.HeartbeatResponse{}), nil). + Times(1) + + s.Require().Zero(s.healthTracker.GetHealth().LastSuccessfulHeartbeat) + + s.Require().NoError(controlPlane.Start(s.ctx)) + time.Sleep(50 * time.Millisecond) + + s.Require().Eventually(func() bool { + health := s.healthTracker.GetHealth() + return !health.LastSuccessfulHeartbeat.IsZero() + }, 100*time.Millisecond, 10*time.Millisecond, "Heartbeat did not succeed") +} + +func (s *ControlPlaneTestSuite) TestNodeInfoUpdate() { + // Create control plane with only checkpointing enabled + controlPlane := s.createControlPlane( + 1*time.Hour, // heartbeat - disabled + 50*time.Millisecond, // node info + 1*time.Hour, // checkpoint - disabled + ) + defer s.Require().NoError(controlPlane.Stop(s.ctx)) + + // Start control plane + s.Require().NoError(controlPlane.Start(s.ctx)) + + // update node info after start + oldInfo := s.nodeInfoProvider.GetNodeInfo(s.ctx) + newInfo := *oldInfo.Copy() + newInfo.Labels["new"] = "value" + s.nodeInfoProvider.SetNodeInfo(newInfo) + + // expect a node info update + updateMsg := envelope.NewMessage(messages.UpdateNodeInfoRequest{ + NodeInfo: newInfo, + }).WithMetadataValue(envelope.KeyMessageType, messages.NodeInfoUpdateRequestMessageType) + + s.requester.EXPECT(). + Request(gomock.Any(), ncl.NewPublishRequest(updateMsg)). + Return(envelope.NewMessage(messages.UpdateNodeInfoResponse{}), nil). + Times(1) + + // Advance clock to trigger update + time.Sleep(s.config.NodeInfoUpdateInterval) + time.Sleep(50 * time.Millisecond) // Allow goroutine to process + + // Verify health tracker state + health := s.healthTracker.GetHealth() + s.Require().NotZero(health.LastSuccessfulUpdate) + + // Verify no more updates are sent + time.Sleep(s.config.NodeInfoUpdateInterval) + time.Sleep(50 * time.Millisecond) // Allow goroutine to process +} + +func (s *ControlPlaneTestSuite) TestCheckpointing() { + // Create control plane with only checkpointing enabled + controlPlane := s.createControlPlane( + 1*time.Hour, // heartbeat - disabled + 1*time.Hour, // node info - disabled + 50*time.Millisecond, // checkpoint + ) + defer s.Require().NoError(controlPlane.Stop(s.ctx)) + + // Set sequence number to checkpoint + s.seqTracker.UpdateLastSeqNum(42) + + // Track checkpoint calls + var checkpointCalled bool + s.checkpointer.OnCheckpointSet(func(name string, value uint64) { + s.Equal("test-checkpoint", name) + s.Equal(uint64(42), value) + checkpointCalled = true + }) + + s.Require().NoError(controlPlane.Start(s.ctx)) + // Wait for checkpoint to be called + s.Eventually(func() bool { + return checkpointCalled + }, 100*time.Millisecond, 10*time.Millisecond) + + // Verify checkpoint was stored + value, err := s.checkpointer.GetStoredCheckpoint("test-checkpoint") + s.Require().NoError(err) + s.Equal(uint64(42), value) +} + +func (s *ControlPlaneTestSuite) TestErrorHandling() { + // Create control plane with only heartbeat enabled + controlPlane := s.createControlPlane( + 50*time.Millisecond, // heartbeat + 1*time.Hour, // node info - disabled + 1*time.Hour, // checkpoint - disabled + ) + defer s.Require().NoError(controlPlane.Stop(s.ctx)) + + // Setup error response + s.requester.EXPECT(). + Request(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("network error")). + Times(1) + + // Start control plane + s.Require().NoError(controlPlane.Start(s.ctx)) + time.Sleep(70 * time.Millisecond) + + // Verify health tracker reflects failure + health := s.healthTracker.GetHealth() + s.Require().Zero(health.LastSuccessfulHeartbeat) +} diff --git a/pkg/transport/nclprotocol/compute/dataplane.go b/pkg/transport/nclprotocol/compute/dataplane.go index 78517206c1..9980f939f2 100644 --- a/pkg/transport/nclprotocol/compute/dataplane.go +++ b/pkg/transport/nclprotocol/compute/dataplane.go @@ -31,8 +31,8 @@ type DataPlane struct { // Core messaging components Client *nats.Conn // NATS connection for messaging - publisher ncl.OrderedPublisher // Handles ordered message publishing - dispatcher *dispatcher.Dispatcher // Manages event watching and dispatch + Publisher ncl.OrderedPublisher // Handles ordered message publishing + Dispatcher *dispatcher.Dispatcher // Manages event watching and dispatch // Sequence tracking lastReceivedSeqNum uint64 // Last sequence number received from orchestrator @@ -78,6 +78,10 @@ func (dp *DataPlane) Start(ctx context.Context) error { return fmt.Errorf("data plane already running") } + if ctx.Err() != nil { + return ctx.Err() + } + var err error defer func() { if err != nil { @@ -97,7 +101,7 @@ func (dp *DataPlane) Start(ctx context.Context) error { return fmt.Errorf("failed to set up log stream handler: %w", err) } // Initialize ordered publisher for reliable message delivery - dp.publisher, err = ncl.NewOrderedPublisher(dp.Client, ncl.OrderedPublisherConfig{ + dp.Publisher, err = ncl.NewOrderedPublisher(dp.Client, ncl.OrderedPublisherConfig{ Name: dp.config.NodeID, MessageRegistry: dp.config.MessageRegistry, MessageSerializer: dp.config.MessageSerializer, @@ -121,8 +125,8 @@ func (dp *DataPlane) Start(ctx context.Context) error { } // Initialize dispatcher to handle event watching and publishing - dp.dispatcher, err = dispatcher.New( - dp.publisher, + dp.Dispatcher, err = dispatcher.New( + dp.Publisher, dispatcherWatcher, dp.config.DataPlaneMessageCreator, dp.config.DispatcherConfig, @@ -132,7 +136,7 @@ func (dp *DataPlane) Start(ctx context.Context) error { } // Start the dispatcher - if err = dp.dispatcher.Start(ctx); err != nil { + if err = dp.Dispatcher.Start(ctx); err != nil { return fmt.Errorf("failed to start dispatcher: %w", err) } @@ -157,25 +161,32 @@ func (dp *DataPlane) Stop(ctx context.Context) error { return dp.cleanup(ctx) } +// IsRunning returns true if the data plane is currently running. +func (dp *DataPlane) IsRunning() bool { + dp.mu.RLock() + defer dp.mu.RUnlock() + return dp.running +} + // cleanup handles the orderly shutdown of data plane components. // It ensures resources are released in the correct order and collects any errors. func (dp *DataPlane) cleanup(ctx context.Context) error { var errs error // Stop dispatcher first to prevent new messages - if dp.dispatcher != nil { - if err := dp.dispatcher.Stop(ctx); err != nil { + if dp.Dispatcher != nil { + if err := dp.Dispatcher.Stop(ctx); err != nil { errs = errors.Join(errs, err) } - dp.dispatcher = nil + dp.Dispatcher = nil } // Then close the publisher - if dp.publisher != nil { - if err := dp.publisher.Close(ctx); err != nil { + if dp.Publisher != nil { + if err := dp.Publisher.Close(ctx); err != nil { errs = errors.Join(errs, err) } - dp.publisher = nil + dp.Publisher = nil } if errs != nil { diff --git a/pkg/transport/nclprotocol/compute/dataplane_test.go b/pkg/transport/nclprotocol/compute/dataplane_test.go new file mode 100644 index 0000000000..6470244a3b --- /dev/null +++ b/pkg/transport/nclprotocol/compute/dataplane_test.go @@ -0,0 +1,242 @@ +//go:build unit || !integration + +package compute_test + +import ( + "context" + "testing" + "time" + + natsserver "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/suite" + + "github.com/bacalhau-project/bacalhau/pkg/compute" + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" + "github.com/bacalhau-project/bacalhau/pkg/models" + "github.com/bacalhau-project/bacalhau/pkg/models/messages" + testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" + nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute" + ncltest "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/test" +) + +type DataPlaneTestSuite struct { + suite.Suite + ctx context.Context + cancel context.CancelFunc + natsConn *nats.Conn + natsServer *natsserver.Server + dataPlane *nclprotocolcompute.DataPlane + config nclprotocolcompute.Config + logServer *ncltest.MockLogStreamServer + msgChan chan *envelope.Message + sub ncl.Subscriber +} + +func TestDataPlaneTestSuite(t *testing.T) { + suite.Run(t, new(DataPlaneTestSuite)) +} + +func (s *DataPlaneTestSuite) SetupTest() { + s.ctx, s.cancel = context.WithCancel(context.Background()) + + // Start NATS server and get client connection + s.natsServer, s.natsConn = testutils.StartNats(s.T()) + + // Create basic config + s.config = nclprotocolcompute.Config{ + NodeID: "test-node", + LogStreamServer: &ncltest.MockLogStreamServer{}, + DataPlaneMessageCreator: &ncltest.MockMessageCreator{}, + EventStore: testutils.CreateComputeEventStore(s.T()), + } + s.config.SetDefaults() + + s.setupSubscriber() + + // Create data plane + dp, err := nclprotocolcompute.NewDataPlane(nclprotocolcompute.DataPlaneParams{ + Config: s.config, + Client: s.natsConn, + LastReceivedSeqNum: 0, + }) + s.Require().NoError(err) + s.dataPlane = dp +} + +func (s *DataPlaneTestSuite) setupSubscriber() { + s.msgChan = make(chan *envelope.Message, 10) // Buffer multiple messages + sub, err := ncl.NewSubscriber(s.natsConn, ncl.SubscriberConfig{ + Name: "test-subscriber", + MessageRegistry: s.config.MessageRegistry, + MessageHandler: ncl.MessageHandlerFunc(func(ctx context.Context, msg *envelope.Message) error { + s.msgChan <- msg + return nil + }), + }) + s.Require().NoError(err) + + err = sub.Subscribe(s.ctx, nclprotocol.NatsSubjectComputeOutMsgs(s.config.NodeID)) + s.Require().NoError(err) + s.sub = sub +} + +func (s *DataPlaneTestSuite) TearDownTest() { + if s.sub != nil { + s.sub.Close(context.Background()) + } + if s.cancel != nil { + s.cancel() + } + if s.dataPlane != nil { + s.dataPlane.Stop(context.Background()) + } + if s.natsConn != nil { + s.natsConn.Close() + } + if s.natsServer != nil { + s.natsServer.Shutdown() + } + close(s.msgChan) +} + +func (s *DataPlaneTestSuite) TestLifecycle() { + testCases := []struct { + name string + operation func() error + verifyState func() bool + expectError bool + errorMsg string + }{ + { + name: "first start succeeds", + operation: func() error { return s.dataPlane.Start(s.ctx) }, + verifyState: func() bool { return s.dataPlane.IsRunning() }, + expectError: false, + }, + { + name: "second start fails", + operation: func() error { return s.dataPlane.Start(s.ctx) }, + verifyState: func() bool { return s.dataPlane.IsRunning() }, + expectError: true, + errorMsg: "already running", + }, + { + name: "first stop succeeds", + operation: func() error { return s.dataPlane.Stop(s.ctx) }, + verifyState: func() bool { return !s.dataPlane.IsRunning() }, + expectError: false, + }, + { + name: "second stop is noop", + operation: func() error { return s.dataPlane.Stop(s.ctx) }, + verifyState: func() bool { return !s.dataPlane.IsRunning() }, + expectError: false, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + err := tc.operation() + if tc.expectError { + s.Require().Error(err) + s.Require().Contains(err.Error(), tc.errorMsg) + } else { + s.Require().NoError(err) + } + s.Require().True(tc.verifyState()) + }) + } +} + +func (s *DataPlaneTestSuite) TestStartupFailureCleanup() { + testCases := []struct { + name string + preStart func() + verifyFail func(error) + expectError string + }{ + { + name: "NATS connection failure", + preStart: func() { s.natsConn.Close() }, + expectError: "connection closed", + }, + { + name: "context cancellation", + preStart: func() { + s.cancel() + time.Sleep(10 * time.Millisecond) // Allow cancellation to propagate + }, + expectError: "context canceled", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + tc.preStart() + err := s.dataPlane.Start(s.ctx) + s.Require().Error(err) + s.Require().Contains(err.Error(), tc.expectError) + s.Require().False(s.dataPlane.IsRunning()) + s.Require().Nil(s.dataPlane.Publisher) + s.Require().Nil(s.dataPlane.Dispatcher) + }) + } +} + +func (s *DataPlaneTestSuite) TestMessageHandling() { + s.Require().NoError(s.dataPlane.Start(s.ctx)) + + testCases := []struct { + name string + event models.ExecutionUpsert + expectMessage bool + expectedMsgTyp string + }{ + { + name: "valid execution upsert", + event: models.ExecutionUpsert{ + Current: &models.Execution{ + ID: "test-job-1", + NodeID: "test-node", + }, + }, + expectMessage: true, + expectedMsgTyp: messages.BidResultMessageType, + }, + { + name: "another execution upsert", + event: models.ExecutionUpsert{ + Current: &models.Execution{ + ID: "test-job-2", + NodeID: "test-node", + }, + }, + expectMessage: true, + expectedMsgTyp: messages.BidResultMessageType, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + err := s.config.EventStore.StoreEvent(s.ctx, watcher.StoreEventRequest{ + Operation: watcher.OperationCreate, + ObjectType: compute.EventObjectExecutionUpsert, + Object: tc.event, + }) + s.Require().NoError(err) + + if tc.expectMessage { + select { + case msg := <-s.msgChan: + s.Require().Equal(tc.expectedMsgTyp, msg.Metadata.Get(envelope.KeyMessageType)) + case <-time.After(time.Second): + s.Require().Fail("Timeout waiting for message") + } + } + }) + } +} diff --git a/pkg/transport/nclprotocol/compute/health_tracker.go b/pkg/transport/nclprotocol/compute/health_tracker.go index 980749b92f..873bb8bedf 100644 --- a/pkg/transport/nclprotocol/compute/health_tracker.go +++ b/pkg/transport/nclprotocol/compute/health_tracker.go @@ -48,6 +48,14 @@ func (ht *HealthTracker) MarkDisconnected(err error) { ht.health.ConsecutiveFailures++ } +// MarkConnecting update status when connection is in progress +func (ht *HealthTracker) MarkConnecting() { + ht.mu.Lock() + defer ht.mu.Unlock() + + ht.health.CurrentState = nclprotocol.Connecting +} + // HeartbeatSuccess records successful heartbeat func (ht *HealthTracker) HeartbeatSuccess() { ht.mu.Lock() diff --git a/pkg/transport/nclprotocol/compute/health_tracker_test.go b/pkg/transport/nclprotocol/compute/health_tracker_test.go new file mode 100644 index 0000000000..9fe680d0ad --- /dev/null +++ b/pkg/transport/nclprotocol/compute/health_tracker_test.go @@ -0,0 +1,123 @@ +//go:build unit || !integration + +package compute + +import ( + "fmt" + "testing" + "time" + + "github.com/benbjohnson/clock" + "github.com/stretchr/testify/suite" + + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" +) + +type HealthTrackerTestSuite struct { + suite.Suite + clock *clock.Mock + tracker *HealthTracker +} + +func TestHealthTrackerTestSuite(t *testing.T) { + suite.Run(t, new(HealthTrackerTestSuite)) +} + +func (s *HealthTrackerTestSuite) SetupTest() { + s.clock = clock.NewMock() + s.tracker = NewHealthTracker(s.clock) +} + +func (s *HealthTrackerTestSuite) TestInitialState() { + startTime := s.clock.Now() + health := s.tracker.GetHealth() + + s.Require().Equal(nclprotocol.Disconnected, health.CurrentState) + s.Require().Equal(startTime, health.StartTime) + s.Require().True(health.LastSuccessfulHeartbeat.IsZero()) + s.Require().True(health.LastSuccessfulUpdate.IsZero()) + s.Require().Equal(0, health.ConsecutiveFailures) + s.Require().Nil(health.LastError) + s.Require().True(health.ConnectedSince.IsZero()) +} + +func (s *HealthTrackerTestSuite) TestMarkConnected() { + // Advance clock to have distinct timestamps + s.clock.Add(time.Second) + connectedTime := s.clock.Now() + + s.tracker.MarkConnected() + health := s.tracker.GetHealth() + + s.Require().Equal(nclprotocol.Connected, health.CurrentState) + s.Require().Equal(connectedTime, health.ConnectedSince) + s.Require().Equal(connectedTime, health.LastSuccessfulHeartbeat) + s.Require().Equal(0, health.ConsecutiveFailures) + s.Require().Nil(health.LastError) +} + +func (s *HealthTrackerTestSuite) TestMarkDisconnected() { + // Set up initial connected state + s.tracker.MarkConnected() + + // Simulate disconnection + expectedErr := fmt.Errorf("connection lost") + s.tracker.MarkDisconnected(expectedErr) + health := s.tracker.GetHealth() + + s.Require().Equal(nclprotocol.Disconnected, health.CurrentState) + s.Require().Equal(expectedErr, health.LastError) + s.Require().Equal(1, health.ConsecutiveFailures) + + // Multiple disconnections should increment failure count + s.tracker.MarkDisconnected(expectedErr) + health = s.tracker.GetHealth() + s.Require().Equal(2, health.ConsecutiveFailures) +} + +func (s *HealthTrackerTestSuite) TestSuccessfulOperations() { + // Initial timestamps + s.clock.Add(time.Second) + s.tracker.MarkConnected() + + // Test heartbeat success + s.clock.Add(time.Second) + heartbeatTime := s.clock.Now() + s.tracker.HeartbeatSuccess() + + // Test update success + s.clock.Add(time.Second) + updateTime := s.clock.Now() + s.tracker.UpdateSuccess() + + // Verify timestamps + health := s.tracker.GetHealth() + s.Require().Equal(heartbeatTime, health.LastSuccessfulHeartbeat) + s.Require().Equal(updateTime, health.LastSuccessfulUpdate) +} + +func (s *HealthTrackerTestSuite) TestConnectionStateTransitions() { + // Test full connection lifecycle + states := []struct { + operation func() + expected nclprotocol.ConnectionState + }{ + { + operation: func() { s.tracker.MarkConnected() }, + expected: nclprotocol.Connected, + }, + { + operation: func() { s.tracker.MarkDisconnected(fmt.Errorf("error")) }, + expected: nclprotocol.Disconnected, + }, + { + operation: func() { s.tracker.MarkConnected() }, + expected: nclprotocol.Connected, + }, + } + + for _, tc := range states { + tc.operation() + s.Require().Equal(tc.expected, s.tracker.GetState()) + } +} diff --git a/pkg/transport/nclprotocol/compute/manager.go b/pkg/transport/nclprotocol/compute/manager.go index ef4362583f..dec42147ed 100644 --- a/pkg/transport/nclprotocol/compute/manager.go +++ b/pkg/transport/nclprotocol/compute/manager.go @@ -64,7 +64,7 @@ type stateChange struct { // NewConnectionManager creates a new connection manager with the given configuration. // It initializes the manager but does not start any connections - Start() must be called. func NewConnectionManager(cfg Config) (*ConnectionManager, error) { - cfg.setDefaults() + cfg.SetDefaults() if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("invalid config: %w", err) } @@ -105,7 +105,6 @@ func (cm *ConnectionManager) Start(ctx context.Context) error { return fmt.Errorf("failed to get last checkpoint: %w", err) } cm.incomingSeqTracker = nclprotocol.NewSequenceTracker().WithLastSeqNum(checkpoint) - cm.config.NodeInfoProvider.GetNodeInfo(ctx) // create new channels in case the connection manager is restarted cm.stopCh = make(chan struct{}) @@ -113,7 +112,7 @@ func (cm *ConnectionManager) Start(ctx context.Context) error { // Start connection management in background cm.wg.Add(1) - go cm.maintainConnection(context.TODO()) + go cm.maintainConnectionLoop(context.TODO()) // Start state change notification handler cm.wg.Add(1) @@ -309,7 +308,12 @@ func (cm *ConnectionManager) performHandshake( "invalid handshake response payload. expected messages.HandshakeResponse, got %T", payload) } - return payload.(messages.HandshakeResponse), nil + handshakeResponse := payload.(messages.HandshakeResponse) + if !handshakeResponse.Accepted { + return messages.HandshakeResponse{}, fmt.Errorf( + "handshake rejected by orchestrator due to %s", handshakeResponse.Reason) + } + return handshakeResponse, nil } // setupControlPlane creates and starts the control plane @@ -352,46 +356,49 @@ func (cm *ConnectionManager) setupDataPlane(ctx context.Context, handshake messa return nil } -// maintainConnection runs a periodic loop that manages the connection lifecycle. +// maintainConnectionLoop runs a periodic loop that manages the connection lifecycle. // It handles initial connection, health monitoring, and reconnection with backoff. -func (cm *ConnectionManager) maintainConnection(ctx context.Context) { +func (cm *ConnectionManager) maintainConnectionLoop(ctx context.Context) { defer cm.wg.Done() - // Create timer that fires immediately for first connection - timer := time.NewTimer(0) - defer timer.Stop() + // Initial connection attempt + cm.maintainConnection(ctx) + + // Start periodic connection maintenance + ticker := cm.config.Clock.Ticker(cm.config.ReconnectInterval) + defer ticker.Stop() for { select { case <-cm.stopCh: return + case <-ticker.C: + cm.maintainConnection(ctx) + } + } +} - case <-timer.C: - switch cm.getState() { - case nclprotocol.Disconnected: - if err := cm.connect(ctx); err != nil { - failures := cm.GetHealth().ConsecutiveFailures - backoffDuration := cm.config.ReconnectBackoff.BackoffDuration(failures) - - log.Error(). - Err(err). - Int("consecutiveFailures", failures). - Str("backoffDuration", backoffDuration.String()). - Msg("Connection attempt failed") - - cm.config.ReconnectBackoff.Backoff(ctx, failures) - } - - case nclprotocol.Connected: - cm.checkConnectionHealth() +func (cm *ConnectionManager) maintainConnection(ctx context.Context) { + switch cm.getState() { + case nclprotocol.Disconnected: + if err := cm.connect(ctx); err != nil { + failures := cm.GetHealth().ConsecutiveFailures + backoffDuration := cm.config.ReconnectBackoff.BackoffDuration(failures) + + log.Error(). + Err(err). + Int("consecutiveFailures", failures). + Str("backoffDuration", backoffDuration.String()). + Msg("Connection attempt failed") + + cm.config.ReconnectBackoff.Backoff(ctx, failures) + } - default: - // Ignore other states, such as connecting - } + case nclprotocol.Connected: + cm.checkConnectionHealth() - // Reset timer for next interval - timer.Reset(cm.config.ReconnectInterval) - } + default: + // Ignore other states, such as connecting } } @@ -443,9 +450,12 @@ func (cm *ConnectionManager) transitionState(newState nclprotocol.ConnectionStat } // Update state tracking - if newState == nclprotocol.Connected { + switch newState { + case nclprotocol.Connecting: + cm.healthTracker.MarkConnecting() + case nclprotocol.Connected: cm.healthTracker.MarkConnected() - } else if newState == nclprotocol.Disconnected { + case nclprotocol.Disconnected: cm.healthTracker.MarkDisconnected(err) } diff --git a/pkg/transport/nclprotocol/compute/manager_test.go b/pkg/transport/nclprotocol/compute/manager_test.go new file mode 100644 index 0000000000..f3c19f9ae8 --- /dev/null +++ b/pkg/transport/nclprotocol/compute/manager_test.go @@ -0,0 +1,269 @@ +//go:build unit || !integration + +package compute_test + +import ( + "context" + "fmt" + "reflect" + "testing" + "time" + + "github.com/benbjohnson/clock" + "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/suite" + + "github.com/bacalhau-project/bacalhau/pkg/lib/backoff" + "github.com/bacalhau-project/bacalhau/pkg/models" + "github.com/bacalhau-project/bacalhau/pkg/models/messages" + natsutil "github.com/bacalhau-project/bacalhau/pkg/nats" + testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" + nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/dispatcher" + ncltest "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/test" +) + +type ConnectionManagerTestSuite struct { + suite.Suite + ctx context.Context + cancel context.CancelFunc + clock clock.Clock + clientFactory natsutil.ClientFactory + nodeInfoProvider *ncltest.MockNodeInfoProvider + messageHandler *ncltest.MockMessageHandler + checkpointer *ncltest.MockCheckpointer + manager *nclprotocolcompute.ConnectionManager + mockResponder *ncltest.MockResponder + config nclprotocolcompute.Config + natsServer *server.Server + natsConn *nats.Conn +} + +func (s *ConnectionManagerTestSuite) SetupTest() { + s.ctx, s.cancel = context.WithCancel(context.Background()) + s.clock = clock.New() // tickers didn't work properly with mock clock + + // Setup NATS server and client + s.natsServer, s.natsConn = testutils.StartNats(s.T()) + + // Fresh client with each call + s.clientFactory = natsutil.ClientFactoryFunc(func(ctx context.Context) (*nats.Conn, error) { + return testutils.CreateNatsClient(s.T(), s.natsServer.ClientURL()), nil + }) + + // Create mocks + s.nodeInfoProvider = ncltest.NewMockNodeInfoProvider() + s.messageHandler = ncltest.NewMockMessageHandler() + s.checkpointer = ncltest.NewMockCheckpointer() + + // Setup base configuration + s.config = nclprotocolcompute.Config{ + NodeID: "test-node", + NodeInfoProvider: s.nodeInfoProvider, + ClientFactory: s.clientFactory, + Checkpointer: s.checkpointer, + EventStore: testutils.CreateComputeEventStore(s.T()), + LogStreamServer: &ncltest.MockLogStreamServer{}, + + DataPlaneMessageHandler: s.messageHandler, + DataPlaneMessageCreator: &ncltest.MockMessageCreator{}, + + Clock: s.clock, + HeartbeatInterval: 100 * time.Millisecond, + HeartbeatMissFactor: 3, + NodeInfoUpdateInterval: 100 * time.Millisecond, + CheckpointInterval: 1 * time.Second, + ReconnectInterval: 100 * time.Millisecond, + RequestTimeout: 50 * time.Millisecond, + ReconnectBackoff: backoff.NewExponential(50*time.Millisecond, 100*time.Millisecond), + DispatcherConfig: dispatcher.DefaultConfig(), + } + + // Setup mock responder + mockResponder, err := ncltest.NewMockResponder(s.natsConn, nil) + s.Require().NoError(err) + s.mockResponder = mockResponder + + // Create manager + manager, err := nclprotocolcompute.NewConnectionManager(s.config) + s.Require().NoError(err) + s.manager = manager +} + +// TearDownTest +func (s *ConnectionManagerTestSuite) TearDownTest() { + if s.manager != nil { + s.Require().NoError(s.manager.Close(context.Background())) + } + if s.mockResponder != nil { + s.Require().NoError(s.mockResponder.Close(context.Background())) + } + if s.natsConn != nil { + s.natsConn.Close() + } + if s.natsServer != nil { + s.natsServer.Shutdown() + } + + s.cancel() +} + +func (s *ConnectionManagerTestSuite) TestSuccessfulConnection() { + // Setup initial checkpoint + lastOrchestratorSeqNum := uint64(124) + s.checkpointer.SetCheckpoint("incoming-test-node", lastOrchestratorSeqNum) + + err := s.manager.Start(s.ctx) + s.Require().NoError(err) + + s.Require().Eventually(func() bool { + return len(s.mockResponder.GetHandshakes()) > 0 + }, time.Second, 10*time.Millisecond, "handshake not received") + + // Verify handshake request + handshakes := s.mockResponder.GetHandshakes() + s.Require().Len(handshakes, 1) + s.Require().Equal(s.config.NodeID, handshakes[0].NodeInfo.ID()) + s.Require().Equal(lastOrchestratorSeqNum, handshakes[0].LastOrchestratorSeqNum) + + // Verify connection established + s.Require().Eventually(func() bool { + health := s.manager.GetHealth() + return health.CurrentState == nclprotocol.Connected + }, time.Second, 10*time.Millisecond, "manager did not connect") + + // verify no heartbeats yet + s.Require().Empty(s.mockResponder.GetHeartbeats()) + + // trigger heartbeat + previousTick := s.manager.GetHealth().LastSuccessfulHeartbeat + time.Sleep(s.config.HeartbeatInterval) + + // wait for some heartbeats + s.Require().Eventually(func() bool { + return len(s.mockResponder.GetHeartbeats()) > 0 + }, time.Second, 10*time.Millisecond, "manager did not send heartbeats") + + // Verify heartbeat content + nodeInfo := s.nodeInfoProvider.GetNodeInfo(s.ctx) + heartbeats := s.mockResponder.GetHeartbeats() + s.Require().Len(heartbeats, 1) + s.Require().Equal(messages.HeartbeatRequest{ + NodeID: nodeInfo.NodeID, + AvailableCapacity: nodeInfo.ComputeNodeInfo.AvailableCapacity, + QueueUsedCapacity: nodeInfo.ComputeNodeInfo.QueueUsedCapacity, + LastOrchestratorSeqNum: lastOrchestratorSeqNum, + }, heartbeats[0]) + + // verify state + s.Require().Greater(s.manager.GetHealth().LastSuccessfulHeartbeat, previousTick) + + // update node info and heartbeat again + nodeInfo.ComputeNodeInfo.AvailableCapacity = models.Resources{CPU: 100, Memory: 1000, GPU: 3} + nodeInfo.ComputeNodeInfo.QueueUsedCapacity = models.Resources{CPU: 10, Memory: 100, GPU: 1} + s.nodeInfoProvider.SetNodeInfo(nodeInfo) + + // trigger heartbeat + time.Sleep(s.config.HeartbeatInterval) + s.Require().Eventually(func() bool { + lastHeartbeat := s.mockResponder.GetHeartbeats()[len(s.mockResponder.GetHeartbeats())-1] + return reflect.DeepEqual(lastHeartbeat, messages.HeartbeatRequest{ + NodeID: nodeInfo.NodeID, + AvailableCapacity: nodeInfo.ComputeNodeInfo.AvailableCapacity, + QueueUsedCapacity: nodeInfo.ComputeNodeInfo.QueueUsedCapacity, + LastOrchestratorSeqNum: lastOrchestratorSeqNum, + }) + }, time.Second, 10*time.Millisecond, "manager did not send heartbeats") + +} + +func (s *ConnectionManagerTestSuite) TestRejectedHandshake() { + // Configure responder to reject handshake + s.mockResponder.Behaviour().HandshakeResponse.Response = messages.HandshakeResponse{ + Accepted: false, + Reason: "node not allowed", + } + + err := s.manager.Start(s.ctx) + s.Require().NoError(err) + + // Verify disconnected state + s.Require().Eventually(func() bool { + health := s.manager.GetHealth() + return health.CurrentState == nclprotocol.Disconnected && + health.LastError != nil && + health.ConsecutiveFailures > 0 + }, time.Second, 10*time.Millisecond) + + // Allow handshake and verify reconnection + s.mockResponder.Behaviour().HandshakeResponse.Response = messages.HandshakeResponse{ + Accepted: true, + } + + // Retry handshake + time.Sleep(s.config.ReconnectInterval) + s.Require().Eventually(func() bool { + health := s.manager.GetHealth() + return health.CurrentState == nclprotocol.Connected + }, time.Second, 10*time.Millisecond, "manager should be connected") +} + +func (s *ConnectionManagerTestSuite) TestHeartbeatFailure() { + err := s.manager.Start(s.ctx) + s.Require().NoError(err) + + // Wait for initial connection + s.Require().Eventually(func() bool { + health := s.manager.GetHealth() + return health.CurrentState == nclprotocol.Connected + }, time.Second, 10*time.Millisecond) + + // Configure heartbeat failure + s.mockResponder.Behaviour().HeartbeatResponse.Error = fmt.Errorf("heartbeat failed") + + // Wait for disconnect after missed heartbeats + time.Sleep(s.config.HeartbeatInterval * time.Duration(s.config.HeartbeatMissFactor+1)) + + // Should disconnect after missing heartbeats + s.Require().Eventually(func() bool { + health := s.manager.GetHealth() + return health.CurrentState == nclprotocol.Disconnected && + health.LastError != nil + }, time.Second, 10*time.Millisecond) +} + +func (s *ConnectionManagerTestSuite) TestNodeInfoUpdates() { + // Configure heartbeat callback to trigger node info updates + s.mockResponder.Behaviour().OnHeartbeat = func(req messages.HeartbeatRequest) { + newInfo := s.nodeInfoProvider.GetNodeInfo(s.ctx) + newInfo.Labels = map[string]string{"heartbeat": time.Now().String()} + s.nodeInfoProvider.SetNodeInfo(newInfo) + } + + err := s.manager.Start(s.ctx) + s.Require().NoError(err) + + // Wait for connection + s.Require().Eventually(func() bool { + health := s.manager.GetHealth() + return health.CurrentState == nclprotocol.Connected + }, time.Second, 10*time.Millisecond) + + // Verify node info updates received + s.Require().Eventually(func() bool { + return len(s.mockResponder.GetNodeInfos()) > 0 + }, time.Second, 10*time.Millisecond) + + nodeInfos := s.mockResponder.GetNodeInfos() + s.Require().Len(nodeInfos, 1) + s.Require().Equal( + s.nodeInfoProvider.GetNodeInfo(s.ctx).ID(), + nodeInfos[0].NodeInfo.ID(), + ) +} + +func TestConnectionManagerTestSuite(t *testing.T) { + suite.Run(t, new(ConnectionManagerTestSuite)) +} diff --git a/pkg/transport/nclprotocol/dispatcher/dispatcher_e2e_test.go b/pkg/transport/nclprotocol/dispatcher/dispatcher_e2e_test.go index fc5d8b37b9..ea2014ef36 100644 --- a/pkg/transport/nclprotocol/dispatcher/dispatcher_e2e_test.go +++ b/pkg/transport/nclprotocol/dispatcher/dispatcher_e2e_test.go @@ -5,7 +5,6 @@ package dispatcher_test import ( "context" "fmt" - "reflect" "testing" "time" @@ -16,8 +15,6 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" - "github.com/bacalhau-project/bacalhau/pkg/lib/watcher/boltdb" - watchertest "github.com/bacalhau-project/bacalhau/pkg/lib/watcher/test" "github.com/bacalhau-project/bacalhau/pkg/logger" testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils" "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/dispatcher" @@ -44,18 +41,7 @@ func (s *DispatcherE2ETestSuite) SetupTest() { s.natsServer, s.nc = testutils.StartNats(s.T()) // Create boltdb store and watcher - eventObjectSerializer := watcher.NewJSONSerializer() - s.Require().NoError(eventObjectSerializer.RegisterType("test", reflect.TypeOf(""))) - store, err := boltdb.NewEventStore( - watchertest.CreateBoltDB(s.T()), - boltdb.WithEventSerializer(eventObjectSerializer), - ) - s.Require().NoError(err) - s.store = store - - // Create registry - s.registry = envelope.NewRegistry() - s.Require().NoError(s.registry.Register("test", "string")) + s.store, s.registry = testutils.CreateStringEventStore(s.T()) // Create subscriber s.received = make([]*envelope.Message, 0) @@ -245,7 +231,7 @@ func (s *DispatcherE2ETestSuite) TestCheckpointingAndRestart() { func (s *DispatcherE2ETestSuite) storeEvent(index int) { err := s.store.StoreEvent(s.ctx, watcher.StoreEventRequest{ Operation: watcher.OperationCreate, - ObjectType: "test", + ObjectType: "string", Object: fmt.Sprintf("event-%d", index), }) s.Require().NoError(err) diff --git a/pkg/transport/nclprotocol/mocks.go b/pkg/transport/nclprotocol/mocks.go index dbc761c680..69093ed310 100644 --- a/pkg/transport/nclprotocol/mocks.go +++ b/pkg/transport/nclprotocol/mocks.go @@ -127,11 +127,12 @@ func (m *MockMessageCreatorFactory) EXPECT() *MockMessageCreatorFactoryMockRecor } // CreateMessageCreator mocks base method. -func (m *MockMessageCreatorFactory) CreateMessageCreator(ctx context.Context, nodeID string) MessageCreator { +func (m *MockMessageCreatorFactory) CreateMessageCreator(ctx context.Context, nodeID string) (MessageCreator, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateMessageCreator", ctx, nodeID) ret0, _ := ret[0].(MessageCreator) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // CreateMessageCreator indicates an expected call of CreateMessageCreator. diff --git a/pkg/transport/nclprotocol/orchestrator/dataplane_test.go b/pkg/transport/nclprotocol/orchestrator/dataplane_test.go new file mode 100644 index 0000000000..dc3951e297 --- /dev/null +++ b/pkg/transport/nclprotocol/orchestrator/dataplane_test.go @@ -0,0 +1,294 @@ +//go:build unit || !integration + +package orchestrator_test + +import ( + "context" + "fmt" + "testing" + "time" + + natsserver "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/suite" + + "github.com/bacalhau-project/bacalhau/pkg/jobstore" + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" + "github.com/bacalhau-project/bacalhau/pkg/models" + "github.com/bacalhau-project/bacalhau/pkg/models/messages" + testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/orchestrator" + ncltest "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/test" +) + +// TestMessage represents a test message to be sent/received +type TestMessage struct { + Name string + Message interface{} + Type string + Sequence uint64 + ExpectProcessed bool +} + +type DataPlaneTestSuite struct { + suite.Suite + ctx context.Context + cancel context.CancelFunc + natsConn *nats.Conn + natsServer *natsserver.Server + dataPlane *orchestrator.DataPlane + config orchestrator.DataPlaneConfig + msgHandler *ncltest.MockMessageHandler + msgCreatorFactory *ncltest.MockMessageCreatorFactory + msgCreator *ncltest.MockMessageCreator + + // Test message passing + publisher ncl.Publisher // For sending test messages + consumer ncl.Subscriber // For receiving published messages +} + +func (s *DataPlaneTestSuite) SetupTest() { + s.ctx, s.cancel = context.WithCancel(context.Background()) + + // Start NATS server and get client connection + s.natsServer, s.natsConn = testutils.StartNats(s.T()) + + // Create mocks + s.msgHandler = ncltest.NewMockMessageHandler() + s.msgCreatorFactory = ncltest.NewMockMessageCreatorFactory("test-node") + s.msgCreator = s.msgCreatorFactory.GetCreator() + + // Create basic config + s.config = orchestrator.DataPlaneConfig{ + NodeID: "test-node", + Client: s.natsConn, + MessageHandler: s.msgHandler, + MessageCreatorFactory: s.msgCreatorFactory, + MessageRegistry: nclprotocol.MustCreateMessageRegistry(), + MessageSerializer: envelope.NewSerializer(), + EventStore: testutils.CreateJobEventStore(s.T()), + } + + // Setup test message passing + s.setupMessagePassing() + + // Create data plane + dp, err := orchestrator.NewDataPlane(s.config) + s.Require().NoError(err) + s.dataPlane = dp +} + +func (s *DataPlaneTestSuite) setupMessagePassing() { + var err error + + // Create publisher for sending test messages + s.publisher, err = ncl.NewPublisher(s.natsConn, ncl.PublisherConfig{ + Name: "test-publisher", + MessageRegistry: s.config.MessageRegistry, + Destination: nclprotocol.NatsSubjectOrchestratorInMsgs(s.config.NodeID), + }) + s.Require().NoError(err) + + // Create subscriber for consuming outgoing messages + s.consumer, err = ncl.NewSubscriber(s.natsConn, ncl.SubscriberConfig{ + Name: "test-consumer", + MessageRegistry: s.config.MessageRegistry, + MessageSerializer: s.config.MessageSerializer, + MessageHandler: s.msgHandler, + }) + s.Require().NoError(err) + + err = s.consumer.Subscribe(s.ctx, nclprotocol.NatsSubjectOrchestratorOutMsgs(s.config.NodeID)) + s.Require().NoError(err) +} + +func (s *DataPlaneTestSuite) TearDownTest() { + if s.consumer != nil { + s.consumer.Close(context.Background()) + } + if s.cancel != nil { + s.cancel() + } + if s.dataPlane != nil { + s.dataPlane.Stop(context.Background()) + } + if s.natsConn != nil { + s.natsConn.Close() + } + if s.natsServer != nil { + s.natsServer.Shutdown() + } +} + +func TestDataPlaneTestSuite(t *testing.T) { + suite.Run(t, new(DataPlaneTestSuite)) +} + +func (s *DataPlaneTestSuite) TestLifecycle() { + testCases := []struct { + name string + operation func() error + expectError bool + errorMsg string + }{ + { + name: "first start succeeds", + operation: func() error { return s.dataPlane.Start(s.ctx) }, + expectError: false, + }, + { + name: "second start fails", + operation: func() error { return s.dataPlane.Start(s.ctx) }, + expectError: true, + errorMsg: "already running", + }, + { + name: "first stop succeeds", + operation: func() error { return s.dataPlane.Stop(s.ctx) }, + expectError: false, + }, + { + name: "second stop is noop", + operation: func() error { return s.dataPlane.Stop(s.ctx) }, + expectError: false, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + err := tc.operation() + if tc.expectError { + s.Error(err) + s.Contains(err.Error(), tc.errorMsg) + } else { + s.NoError(err) + } + }) + } +} + +func (s *DataPlaneTestSuite) TestIncomingMessageProcessing() { + s.Require().NoError(s.dataPlane.Start(s.ctx)) + + testMessages := []TestMessage{ + { + Name: "bid result message", + Message: messages.BidResult{ + BaseResponse: messages.BaseResponse{ExecutionID: "test-1"}, + }, + Type: messages.BidResultMessageType, + Sequence: 1, + ExpectProcessed: true, + }, + { + Name: "run result message", + Message: messages.RunResult{ + BaseResponse: messages.BaseResponse{ExecutionID: "test-2"}, + }, + Type: messages.RunResultMessageType, + Sequence: 2, + ExpectProcessed: true, + }, + } + + // Send test messages + for _, tm := range testMessages { + s.Run(tm.Name, func() { + s.sendTestMessage(tm) + + if tm.ExpectProcessed { + s.verifyMessageProcessed(tm) + } + }) + } + + // Verify sequence tracking + s.Equal(uint64(2), s.dataPlane.GetLastProcessedSequence()) +} + +func (s *DataPlaneTestSuite) TestEventToMessageDispatch() { + s.Require().NoError(s.dataPlane.Start(s.ctx)) + + execution := &models.Execution{ + ID: "test-1", + NodeID: s.config.NodeID, + } + + // Configure message to be created + expectedMsg := envelope.NewMessage(messages.BidResult{ + BaseResponse: messages.BaseResponse{ExecutionID: execution.ID}, + }).WithMetadataValue(envelope.KeyMessageType, messages.BidResultMessageType) + s.msgCreator.SetNextMessage(expectedMsg) + + // Store execution event + err := s.config.EventStore.StoreEvent(s.ctx, watcher.StoreEventRequest{ + Operation: watcher.OperationCreate, + ObjectType: jobstore.EventObjectExecutionUpsert, + Object: models.ExecutionUpsert{ + Current: execution, + }, + }) + s.Require().NoError(err) + + // Wait for message to be published + s.Eventually(func() bool { + msgs := s.msgHandler.GetMessages() + return len(msgs) > 0 && msgs[0].Metadata.Get(envelope.KeyMessageType) == messages.BidResultMessageType + }, time.Second, 10*time.Millisecond) +} + +func (s *DataPlaneTestSuite) TestSequenceTracking() { + s.Require().NoError(s.dataPlane.Start(s.ctx)) + + // Send messages with sequential sequence numbers + numMessages := 5 + for i := 1; i <= numMessages; i++ { + msg := envelope.NewMessage(messages.BidResult{ + BaseResponse: messages.BaseResponse{ExecutionID: fmt.Sprintf("test-%d", i)}, + }). + WithMetadataValue(envelope.KeyMessageType, messages.BidResultMessageType). + WithMetadataValue(nclprotocol.KeySeqNum, fmt.Sprint(i)) + + err := s.publisher.Publish(s.ctx, ncl.NewPublishRequest(msg)) + s.Require().NoError(err) + } + + // Verify final sequence number + s.Eventually(func() bool { + return s.dataPlane.GetLastProcessedSequence() == uint64(numMessages) + }, time.Second, 10*time.Millisecond) + + // Verify messages were processed in order + msgs := s.msgHandler.GetMessages() + s.Len(msgs, numMessages) + for i, msg := range msgs { + s.Equal(fmt.Sprint(i+1), msg.Metadata.Get(nclprotocol.KeySeqNum)) + } +} + +// Helper methods + +func (s *DataPlaneTestSuite) sendTestMessage(tm TestMessage) { + msg := envelope.NewMessage(tm.Message). + WithMetadataValue(envelope.KeyMessageType, tm.Type). + WithMetadataValue(nclprotocol.KeySeqNum, fmt.Sprint(tm.Sequence)) + + err := s.publisher.Publish(s.ctx, ncl.NewPublishRequest(msg)) + s.Require().NoError(err) +} + +func (s *DataPlaneTestSuite) verifyMessageProcessed(tm TestMessage) { + s.Eventually(func() bool { + msgs := s.msgHandler.GetMessages() + for _, msg := range msgs { + if msg.Metadata.Get(envelope.KeyMessageType) == tm.Type && + msg.Metadata.Get(nclprotocol.KeySeqNum) == fmt.Sprint(tm.Sequence) { + return true + } + } + return false + }, time.Second, 10*time.Millisecond) +} diff --git a/pkg/transport/nclprotocol/test/control_plane.go b/pkg/transport/nclprotocol/test/control_plane.go new file mode 100644 index 0000000000..c9474171cb --- /dev/null +++ b/pkg/transport/nclprotocol/test/control_plane.go @@ -0,0 +1,340 @@ +package test + +import ( + "context" + "fmt" + "sync" + + "github.com/nats-io/nats.go" + + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/models/messages" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" +) + +// MockResponderBehavior configures how the mock responder handles different request types. +// It allows customizing responses and errors for testing different scenarios. +type MockResponderBehavior struct { + // HandshakeResponse controls behavior for handshake requests + HandshakeResponse struct { + Error error // Error to return, if any + Response messages.HandshakeResponse // Response to return if no error + } + + // HeartbeatResponse controls behavior for heartbeat requests + HeartbeatResponse struct { + Error error // Error to return, if any + Response messages.HeartbeatResponse // Response to return if no error + } + + // NodeInfoResponse controls behavior for node info update requests + NodeInfoResponse struct { + Error error // Error to return, if any + Response messages.UpdateNodeInfoResponse // Response to return if no error + } + + // Callbacks for request inspection + OnHandshake func(messages.HandshakeRequest) // Called when handshake received + OnHeartbeat func(messages.HeartbeatRequest) // Called when heartbeat received + OnNodeInfo func(messages.UpdateNodeInfoRequest) // Called when node info update received +} + +// MockResponder provides a configurable mock implementation of the control plane responder. +// It tracks requests received and provides configurable responses for testing. +type MockResponder struct { + behavior *MockResponderBehavior + responder ncl.Responder + mu sync.RWMutex + + // Request history + handshakes []messages.HandshakeRequest + heartbeats []messages.HeartbeatRequest + nodeInfos []messages.UpdateNodeInfoRequest +} + +// NewMockResponder creates a new mock responder with the given behavior. +// If behavior is nil, default success responses are used. +func NewMockResponder(conn *nats.Conn, behavior *MockResponderBehavior) (*MockResponder, error) { + if behavior == nil { + behavior = &MockResponderBehavior{ + HandshakeResponse: struct { + Error error + Response messages.HandshakeResponse + }{ + Response: messages.HandshakeResponse{Accepted: true}, + }, + HeartbeatResponse: struct { + Error error + Response messages.HeartbeatResponse + }{ + Response: messages.HeartbeatResponse{}, + }, + NodeInfoResponse: struct { + Error error + Response messages.UpdateNodeInfoResponse + }{ + Response: messages.UpdateNodeInfoResponse{Accepted: true}, + }, + } + } + + responder, err := ncl.NewResponder(conn, ncl.ResponderConfig{ + Name: "mock-responder", + MessageRegistry: nclprotocol.MustCreateMessageRegistry(), + MessageSerializer: envelope.NewSerializer(), + Subject: nclprotocol.NatsSubjectOrchestratorInCtrl(), + }) + if err != nil { + return nil, fmt.Errorf("create responder: %w", err) + } + + mr := &MockResponder{ + behavior: behavior, + responder: responder, + } + + if err := mr.setupHandlers(context.Background()); err != nil { + responder.Close(context.Background()) + return nil, err + } + + return mr, nil +} + +func (m *MockResponder) setupHandlers(ctx context.Context) error { + // Handshake handler + if err := m.responder.Listen(ctx, messages.HandshakeRequestMessageType, + ncl.RequestHandlerFunc(func(ctx context.Context, msg *envelope.Message) (*envelope.Message, error) { + req := *msg.Payload.(*messages.HandshakeRequest) + m.recordHandshake(req) + + if m.behavior.HandshakeResponse.Error != nil { + return nil, m.behavior.HandshakeResponse.Error + } + return envelope.NewMessage(m.behavior.HandshakeResponse.Response). + WithMetadataValue(envelope.KeyMessageType, messages.HandshakeResponseType), nil + })); err != nil { + return err + } + + // Heartbeat handler + if err := m.responder.Listen(ctx, messages.HeartbeatRequestMessageType, + ncl.RequestHandlerFunc(func(ctx context.Context, msg *envelope.Message) (*envelope.Message, error) { + req := *msg.Payload.(*messages.HeartbeatRequest) + m.recordHeartbeat(req) + + if m.behavior.HeartbeatResponse.Error != nil { + return nil, m.behavior.HeartbeatResponse.Error + } + return envelope.NewMessage(m.behavior.HeartbeatResponse.Response). + WithMetadataValue(envelope.KeyMessageType, messages.HeartbeatResponseType), nil + })); err != nil { + return err + } + + // Node info handler + if err := m.responder.Listen(ctx, messages.NodeInfoUpdateRequestMessageType, + ncl.RequestHandlerFunc(func(ctx context.Context, msg *envelope.Message) (*envelope.Message, error) { + req := *msg.Payload.(*messages.UpdateNodeInfoRequest) + m.recordNodeInfo(req) + + if m.behavior.NodeInfoResponse.Error != nil { + return nil, m.behavior.NodeInfoResponse.Error + } + return envelope.NewMessage(m.behavior.NodeInfoResponse.Response). + WithMetadataValue(envelope.KeyMessageType, messages.NodeInfoUpdateResponseType), nil + })); err != nil { + return err + } + + return nil +} + +// Record methods for inspection +func (m *MockResponder) recordHandshake(req messages.HandshakeRequest) { + m.mu.Lock() + defer m.mu.Unlock() + m.handshakes = append(m.handshakes, req) + if m.behavior.OnHandshake != nil { + m.behavior.OnHandshake(req) + } +} + +func (m *MockResponder) recordHeartbeat(req messages.HeartbeatRequest) { + m.mu.Lock() + defer m.mu.Unlock() + m.heartbeats = append(m.heartbeats, req) + if m.behavior.OnHeartbeat != nil { + m.behavior.OnHeartbeat(req) + } +} + +func (m *MockResponder) recordNodeInfo(req messages.UpdateNodeInfoRequest) { + m.mu.Lock() + defer m.mu.Unlock() + m.nodeInfos = append(m.nodeInfos, req) + if m.behavior.OnNodeInfo != nil { + m.behavior.OnNodeInfo(req) + } +} + +// GetHandshakes returns a copy of all handshake requests received +func (m *MockResponder) GetHandshakes() []messages.HandshakeRequest { + m.mu.RLock() + defer m.mu.RUnlock() + result := make([]messages.HandshakeRequest, len(m.handshakes)) + copy(result, m.handshakes) + return result +} + +// GetHeartbeats returns a copy of all heartbeat requests received +func (m *MockResponder) GetHeartbeats() []messages.HeartbeatRequest { + m.mu.RLock() + defer m.mu.RUnlock() + result := make([]messages.HeartbeatRequest, len(m.heartbeats)) + copy(result, m.heartbeats) + return result +} + +// GetNodeInfos returns a copy of all node info update requests received +func (m *MockResponder) GetNodeInfos() []messages.UpdateNodeInfoRequest { + m.mu.RLock() + defer m.mu.RUnlock() + result := make([]messages.UpdateNodeInfoRequest, len(m.nodeInfos)) + copy(result, m.nodeInfos) + return result +} + +// Behaviour returns the behavior configuration +func (m *MockResponder) Behaviour() *MockResponderBehavior { + return m.behavior +} + +// Close shuts down the responder +func (m *MockResponder) Close(ctx context.Context) error { + return m.responder.Close(ctx) +} + +// MockCheckpointer provides a thread-safe mock implementation of Checkpointer for testing. +// It tracks checkpoints and allows configuring errors and validation behavior. +type MockCheckpointer struct { + mu sync.RWMutex + checkpoints map[string]uint64 // Stored checkpoint values by name + setErrors map[string]error // Errors to return for Checkpoint calls by name + getErrors map[string]error // Errors to return for GetCheckpoint calls by name + onSet func(string, uint64) // Optional callback when checkpoint is set + onGet func(string) // Optional callback when checkpoint is retrieved +} + +// NewMockCheckpointer creates a new mock checkpointer instance +func NewMockCheckpointer() *MockCheckpointer { + return &MockCheckpointer{ + checkpoints: make(map[string]uint64), + setErrors: make(map[string]error), + getErrors: make(map[string]error), + } +} + +// Checkpoint implements the Checkpointer interface +func (m *MockCheckpointer) Checkpoint(ctx context.Context, name string, sequenceNumber uint64) error { + m.mu.Lock() + defer m.mu.Unlock() + + // Check for configured error + if err := m.setErrors[name]; err != nil { + return err + } + + // Store checkpoint + m.checkpoints[name] = sequenceNumber + + // Call optional callback + if m.onSet != nil { + m.onSet(name, sequenceNumber) + } + + return nil +} + +// GetCheckpoint implements the Checkpointer interface +func (m *MockCheckpointer) GetCheckpoint(ctx context.Context, name string) (uint64, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + // Check for configured error + if err := m.getErrors[name]; err != nil { + return 0, err + } + + // Call optional callback + if m.onGet != nil { + m.onGet(name) + } + + // Return stored value or 0 if not found + return m.checkpoints[name], nil +} + +// Helper methods for test configuration + +// SetCheckpoint directly sets a checkpoint value +func (m *MockCheckpointer) SetCheckpoint(name string, value uint64) { + m.mu.Lock() + defer m.mu.Unlock() + m.checkpoints[name] = value +} + +// SetCheckpointError configures an error to be returned by Checkpoint +func (m *MockCheckpointer) SetCheckpointError(name string, err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.setErrors[name] = err +} + +// SetGetCheckpointError configures an error to be returned by GetCheckpoint +func (m *MockCheckpointer) SetGetCheckpointError(name string, err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.getErrors[name] = err +} + +// OnCheckpointSet sets a callback to be called when Checkpoint is called +func (m *MockCheckpointer) OnCheckpointSet(callback func(name string, value uint64)) { + m.mu.Lock() + defer m.mu.Unlock() + m.onSet = callback +} + +// OnCheckpointGet sets a callback to be called when GetCheckpoint is called +func (m *MockCheckpointer) OnCheckpointGet(callback func(name string)) { + m.mu.Lock() + defer m.mu.Unlock() + m.onGet = callback +} + +// GetStoredCheckpoint returns the currently stored checkpoint value +func (m *MockCheckpointer) GetStoredCheckpoint(name string) (uint64, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + value, exists := m.checkpoints[name] + if !exists { + return 0, fmt.Errorf("no checkpoint found for %s", name) + } + return value, nil +} + +// Reset clears all stored checkpoints and configured errors +func (m *MockCheckpointer) Reset() { + m.mu.Lock() + defer m.mu.Unlock() + m.checkpoints = make(map[string]uint64) + m.setErrors = make(map[string]error) + m.getErrors = make(map[string]error) + m.onSet = nil + m.onGet = nil +} + +// compile-time check for interface implementation +var _ nclprotocol.Checkpointer = &MockCheckpointer{} diff --git a/pkg/transport/nclprotocol/test/message_creation.go b/pkg/transport/nclprotocol/test/message_creation.go new file mode 100644 index 0000000000..cc2a736370 --- /dev/null +++ b/pkg/transport/nclprotocol/test/message_creation.go @@ -0,0 +1,95 @@ +package test + +import ( + "context" + "fmt" + "sync" + + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" + "github.com/bacalhau-project/bacalhau/pkg/models/messages" + "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" +) + +// MockMessageCreator provides a configurable implementation of MessageCreator for testing. +// It allows setting predefined messages or errors to be returned. +type MockMessageCreator struct { + // Error if set, CreateMessage will return this error + Error error + + // Message if set, CreateMessage will return this message + // If nil and Error is nil, a default BidResult message is returned + Message *envelope.Message +} + +// CreateMessage implements nclprotocol.MessageCreator +func (c *MockMessageCreator) CreateMessage(event watcher.Event) (*envelope.Message, error) { + if c.Error != nil { + return nil, c.Error + } + if c.Message != nil { + return c.Message, nil + } + // Return default message if no specific behavior configured + return envelope.NewMessage(messages.BidResult{}). + WithMetadataValue(envelope.KeyMessageType, messages.BidResultMessageType), nil +} + +// SetNextMessage configures the next message to be returned by CreateMessage +func (c *MockMessageCreator) SetNextMessage(msg *envelope.Message) { + c.Message = msg +} + +// MockMessageCreatorFactory manages MockMessageCreator instances for testing. +// It provides a thread-safe way to create and configure message creators per node. +type MockMessageCreatorFactory struct { + nodeID string // Expected node ID for validation + mockCreator *MockMessageCreator // Shared mock creator instance + createError error // Error to return from CreateMessageCreator + mu sync.RWMutex // Protects concurrent access +} + +// NewMockMessageCreatorFactory creates a new factory that validates against the given nodeID +func NewMockMessageCreatorFactory(nodeID string) *MockMessageCreatorFactory { + return &MockMessageCreatorFactory{ + nodeID: nodeID, + mockCreator: &MockMessageCreator{}, + } +} + +// CreateMessageCreator implements nclprotocol.MessageCreatorFactory. +// Returns the mock creator if nodeID matches, error otherwise. +func (f *MockMessageCreatorFactory) CreateMessageCreator(ctx context.Context, nodeID string) (nclprotocol.MessageCreator, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.createError != nil { + return nil, f.createError + } + + if nodeID != f.nodeID { + return nil, fmt.Errorf("unknown node ID: %s", nodeID) + } + + return f.mockCreator, nil +} + +// GetCreator provides access to the underlying mock creator for configuration +func (f *MockMessageCreatorFactory) GetCreator() *MockMessageCreator { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mockCreator +} + +// SetCreateError configures an error to be returned by CreateMessageCreator +func (f *MockMessageCreatorFactory) SetCreateError(err error) { + f.mu.Lock() + defer f.mu.Unlock() + f.createError = err +} + +// Ensure interface implementations +var ( + _ nclprotocol.MessageCreator = &MockMessageCreator{} + _ nclprotocol.MessageCreatorFactory = &MockMessageCreatorFactory{} +) diff --git a/pkg/transport/nclprotocol/test/message_handling.go b/pkg/transport/nclprotocol/test/message_handling.go new file mode 100644 index 0000000000..1307508c12 --- /dev/null +++ b/pkg/transport/nclprotocol/test/message_handling.go @@ -0,0 +1,63 @@ +package test + +import ( + "context" + "sync" + + "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" + "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" +) + +// MockMessageHandler provides a configurable mock message handler for testing. +// It records messages received and allows configuring errors and processing behavior. +type MockMessageHandler struct { + mu sync.RWMutex + messages []envelope.Message + shouldProcess bool + error error +} + +// NewMockMessageHandler creates a new mock message handler +func NewMockMessageHandler() *MockMessageHandler { + return &MockMessageHandler{ + messages: make([]envelope.Message, 0), + shouldProcess: true, + } +} + +// HandleMessage implements ncl.MessageHandler +func (m *MockMessageHandler) HandleMessage(ctx context.Context, msg *envelope.Message) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.error != nil { + return m.error + } + m.messages = append(m.messages, *msg) + return nil +} + +// ShouldProcess implements ncl.MessageHandler +func (m *MockMessageHandler) ShouldProcess(ctx context.Context, msg *envelope.Message) bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.shouldProcess +} + +// GetMessages returns a copy of all messages received +func (m *MockMessageHandler) GetMessages() []envelope.Message { + m.mu.RLock() + defer m.mu.RUnlock() + result := make([]envelope.Message, len(m.messages)) + copy(result, m.messages) + return result +} + +// SetError configures an error to be returned by HandleMessage +func (m *MockMessageHandler) SetError(err error) { + m.mu.Lock() + m.error = err + m.mu.Unlock() +} + +// compile-time check for interface implementation +var _ ncl.MessageHandler = &MockMessageHandler{} diff --git a/pkg/transport/nclprotocol/test/nodes.go b/pkg/transport/nclprotocol/test/nodes.go new file mode 100644 index 0000000000..4830ac7187 --- /dev/null +++ b/pkg/transport/nclprotocol/test/nodes.go @@ -0,0 +1,40 @@ +package test + +import ( + "context" + "sync" + + "github.com/bacalhau-project/bacalhau/pkg/models" +) + +type MockNodeInfoProvider struct { + nodeInfo models.NodeInfo + mu sync.RWMutex +} + +// NewMockNodeInfoProvider creates a new mock node info provider +func NewMockNodeInfoProvider() *MockNodeInfoProvider { + return &MockNodeInfoProvider{ + nodeInfo: models.NodeInfo{ + NodeID: "test-node", + NodeType: models.NodeTypeCompute, + Labels: map[string]string{}, + ComputeNodeInfo: models.ComputeNodeInfo{ + AvailableCapacity: models.Resources{CPU: 4}, + QueueUsedCapacity: models.Resources{CPU: 1}, + }, + }, + } +} + +func (m *MockNodeInfoProvider) GetNodeInfo(ctx context.Context) models.NodeInfo { + m.mu.RLock() + defer m.mu.RUnlock() + return m.nodeInfo +} + +func (m *MockNodeInfoProvider) SetNodeInfo(nodeInfo models.NodeInfo) { + m.mu.Lock() + defer m.mu.Unlock() + m.nodeInfo = nodeInfo +} diff --git a/pkg/transport/nclprotocol/test/utils.go b/pkg/transport/nclprotocol/test/utils.go new file mode 100644 index 0000000000..2ac3cf210c --- /dev/null +++ b/pkg/transport/nclprotocol/test/utils.go @@ -0,0 +1,17 @@ +package test + +import ( + "context" + + "github.com/bacalhau-project/bacalhau/pkg/lib/concurrency" + "github.com/bacalhau-project/bacalhau/pkg/models" + "github.com/bacalhau-project/bacalhau/pkg/models/messages" +) + +// MockLogStreamServer implements a minimal logstream.Server for testing +type MockLogStreamServer struct{} + +func (m *MockLogStreamServer) GetLogStream(ctx context.Context, request messages.ExecutionLogsRequest) ( + <-chan *concurrency.AsyncResult[models.ExecutionLog], error) { + return nil, nil +} From 66f38939510a9890ad6b753dc878e71dce77f99a Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Wed, 11 Dec 2024 10:28:19 +0200 Subject: [PATCH 12/16] Update pkg/transport/nclprotocol/compute/dataplane_test.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- pkg/transport/nclprotocol/compute/dataplane_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pkg/transport/nclprotocol/compute/dataplane_test.go b/pkg/transport/nclprotocol/compute/dataplane_test.go index 6470244a3b..463e4d5594 100644 --- a/pkg/transport/nclprotocol/compute/dataplane_test.go +++ b/pkg/transport/nclprotocol/compute/dataplane_test.go @@ -168,7 +168,12 @@ func (s *DataPlaneTestSuite) TestStartupFailureCleanup() { name: "context cancellation", preStart: func() { s.cancel() - time.Sleep(10 * time.Millisecond) // Allow cancellation to propagate + select { + case <-s.ctx.Done(): + // Context cancellation has propagated + case <-time.After(100 * time.Millisecond): + s.Require().Fail("Timeout waiting for context cancellation") + } }, expectError: "context canceled", }, From 003b729b5ddda87af0ec22b4bcd1286e44df448d Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Wed, 11 Dec 2024 10:30:13 +0200 Subject: [PATCH 13/16] Update pkg/transport/nclprotocol/test/utils.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- pkg/transport/nclprotocol/test/utils.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/transport/nclprotocol/test/utils.go b/pkg/transport/nclprotocol/test/utils.go index 2ac3cf210c..52b7a0b2d3 100644 --- a/pkg/transport/nclprotocol/test/utils.go +++ b/pkg/transport/nclprotocol/test/utils.go @@ -13,5 +13,7 @@ type MockLogStreamServer struct{} func (m *MockLogStreamServer) GetLogStream(ctx context.Context, request messages.ExecutionLogsRequest) ( <-chan *concurrency.AsyncResult[models.ExecutionLog], error) { - return nil, nil + ch := make(chan *concurrency.AsyncResult[models.ExecutionLog]) + close(ch) + return ch, nil } From 52327f1ee29a990b2ced60366d85871447f0acef Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Wed, 11 Dec 2024 10:32:03 +0200 Subject: [PATCH 14/16] Update pkg/transport/nclprotocol/compute/dataplane.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- pkg/transport/nclprotocol/compute/dataplane.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pkg/transport/nclprotocol/compute/dataplane.go b/pkg/transport/nclprotocol/compute/dataplane.go index 9980f939f2..cfe45b4273 100644 --- a/pkg/transport/nclprotocol/compute/dataplane.go +++ b/pkg/transport/nclprotocol/compute/dataplane.go @@ -52,6 +52,12 @@ type DataPlaneParams struct { // NewDataPlane creates a new DataPlane instance with the provided parameters. // It initializes the data plane but does not start any operations - Start() must be called. func NewDataPlane(params DataPlaneParams) (*DataPlane, error) { + if params.Client == nil { + return nil, fmt.Errorf("NATS client is required") + } + if params.Config.NodeID == "" { + return nil, fmt.Errorf("node ID is required") + } dp := &DataPlane{ config: params.Config, Client: params.Client, From 3b84def7c4a9bc620970858b39a96ac43db294f8 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Wed, 11 Dec 2024 10:32:38 +0200 Subject: [PATCH 15/16] coderabbit fixes --- pkg/models/node_info.go | 12 +++--- pkg/models/utils.go | 9 +++++ pkg/test/utils/watcher.go | 40 ++++++++----------- .../nclprotocol/compute/manager_test.go | 2 +- .../nclprotocol/test/control_plane.go | 6 +-- pkg/transport/nclprotocol/test/nodes.go | 2 +- 6 files changed, 36 insertions(+), 35 deletions(-) diff --git a/pkg/models/node_info.go b/pkg/models/node_info.go index 5db51d451f..c889618216 100644 --- a/pkg/models/node_info.go +++ b/pkg/models/node_info.go @@ -121,8 +121,8 @@ func (n *NodeInfo) Copy() *NodeInfo { // Deep copy maps cpy.Labels = maps.Clone(n.Labels) cpy.SupportedProtocols = slices.Clone(n.SupportedProtocols) - cpy.ComputeNodeInfo = *n.ComputeNodeInfo.Copy() - cpy.BacalhauVersion = *n.BacalhauVersion.Copy() + cpy.ComputeNodeInfo = copyOrZero(n.ComputeNodeInfo.Copy()) + cpy.BacalhauVersion = copyOrZero(n.BacalhauVersion.Copy()) return cpy } @@ -173,9 +173,9 @@ func (c *ComputeNodeInfo) Copy() *ComputeNodeInfo { cpy.ExecutionEngines = slices.Clone(c.ExecutionEngines) cpy.Publishers = slices.Clone(c.Publishers) cpy.StorageSources = slices.Clone(c.StorageSources) - cpy.MaxCapacity = *c.MaxCapacity.Copy() - cpy.QueueUsedCapacity = *c.QueueUsedCapacity.Copy() - cpy.AvailableCapacity = *c.AvailableCapacity.Copy() - cpy.MaxJobRequirements = *c.MaxJobRequirements.Copy() + cpy.MaxCapacity = copyOrZero(c.MaxCapacity.Copy()) + cpy.QueueUsedCapacity = copyOrZero(c.QueueUsedCapacity.Copy()) + cpy.AvailableCapacity = copyOrZero(c.AvailableCapacity.Copy()) + cpy.MaxJobRequirements = copyOrZero(c.MaxJobRequirements.Copy()) return cpy } diff --git a/pkg/models/utils.go b/pkg/models/utils.go index 33d8702f36..d86243e1fa 100644 --- a/pkg/models/utils.go +++ b/pkg/models/utils.go @@ -22,3 +22,12 @@ func ValidateSlice[T Validatable](slice []T) error { } return nil } + +// Helper function for copying or getting zero value +func copyOrZero[T any](v *T) T { + var zero T // Create zero value + if v == nil { + return zero + } + return *v +} diff --git a/pkg/test/utils/watcher.go b/pkg/test/utils/watcher.go index f4dccc3716..bc87c02b62 100644 --- a/pkg/test/utils/watcher.go +++ b/pkg/test/utils/watcher.go @@ -28,14 +28,7 @@ func CreateComputeEventStore(t *testing.T) watcher.EventStore { ) require.NoError(t, err) - database := watchertest.CreateBoltDB(t) - - eventStore, err := boltdb_watcher.NewEventStore(database, - boltdb_watcher.WithEventsBucket("events"), - boltdb_watcher.WithCheckpointBucket("checkpoints"), - boltdb_watcher.WithEventSerializer(eventObjectSerializer), - ) - require.NoError(t, err) + eventStore := createEventStore(t, eventObjectSerializer) return eventStore } @@ -43,36 +36,35 @@ func CreateJobEventStore(t *testing.T) watcher.EventStore { eventObjectSerializer := watcher.NewJSONSerializer() err := errors.Join( eventObjectSerializer.RegisterType(jobstore.EventObjectExecutionUpsert, reflect.TypeOf(models.ExecutionUpsert{})), - eventObjectSerializer.RegisterType(jobstore.EventObjectEvaluation, reflect.TypeOf(models.Event{})), + eventObjectSerializer.RegisterType(jobstore.EventObjectEvaluation, reflect.TypeOf(models.Evaluation{})), ) require.NoError(t, err) - database := watchertest.CreateBoltDB(t) - - eventStore, err := boltdb_watcher.NewEventStore(database, - boltdb_watcher.WithEventsBucket("events"), - boltdb_watcher.WithCheckpointBucket("checkpoints"), - boltdb_watcher.WithEventSerializer(eventObjectSerializer), - ) - require.NoError(t, err) + eventStore := createEventStore(t, eventObjectSerializer) return eventStore } +// CreateStringEventStore creates a new event store for string events using BoltDB +// and returns both the event store and an envelope registry. +// The returned EventStore must be closed by the caller when no longer needed. func CreateStringEventStore(t *testing.T) (watcher.EventStore, *envelope.Registry) { eventObjectSerializer := watcher.NewJSONSerializer() require.NoError(t, eventObjectSerializer.RegisterType(TypeString, reflect.TypeOf(""))) - database := watchertest.CreateBoltDB(t) + eventStore := createEventStore(t, eventObjectSerializer) + registry := envelope.NewRegistry() + require.NoError(t, registry.Register(TypeString, "")) + return eventStore, registry +} + +func createEventStore(t *testing.T, serializer *watcher.JSONSerializer) watcher.EventStore { + database := watchertest.CreateBoltDB(t) eventStore, err := boltdb_watcher.NewEventStore(database, boltdb_watcher.WithEventsBucket("events"), boltdb_watcher.WithCheckpointBucket("checkpoints"), - boltdb_watcher.WithEventSerializer(eventObjectSerializer), + boltdb_watcher.WithEventSerializer(serializer), ) require.NoError(t, err) - - registry := envelope.NewRegistry() - require.NoError(t, registry.Register(TypeString, "")) - - return eventStore, registry + return eventStore } diff --git a/pkg/transport/nclprotocol/compute/manager_test.go b/pkg/transport/nclprotocol/compute/manager_test.go index f3c19f9ae8..c4f40fa4b2 100644 --- a/pkg/transport/nclprotocol/compute/manager_test.go +++ b/pkg/transport/nclprotocol/compute/manager_test.go @@ -82,7 +82,7 @@ func (s *ConnectionManagerTestSuite) SetupTest() { } // Setup mock responder - mockResponder, err := ncltest.NewMockResponder(s.natsConn, nil) + mockResponder, err := ncltest.NewMockResponder(s.ctx, s.natsConn, nil) s.Require().NoError(err) s.mockResponder = mockResponder diff --git a/pkg/transport/nclprotocol/test/control_plane.go b/pkg/transport/nclprotocol/test/control_plane.go index c9474171cb..00b2dcac9a 100644 --- a/pkg/transport/nclprotocol/test/control_plane.go +++ b/pkg/transport/nclprotocol/test/control_plane.go @@ -55,7 +55,7 @@ type MockResponder struct { // NewMockResponder creates a new mock responder with the given behavior. // If behavior is nil, default success responses are used. -func NewMockResponder(conn *nats.Conn, behavior *MockResponderBehavior) (*MockResponder, error) { +func NewMockResponder(ctx context.Context, conn *nats.Conn, behavior *MockResponderBehavior) (*MockResponder, error) { if behavior == nil { behavior = &MockResponderBehavior{ HandshakeResponse: struct { @@ -94,8 +94,8 @@ func NewMockResponder(conn *nats.Conn, behavior *MockResponderBehavior) (*MockRe responder: responder, } - if err := mr.setupHandlers(context.Background()); err != nil { - responder.Close(context.Background()) + if err := mr.setupHandlers(ctx); err != nil { + responder.Close(ctx) return nil, err } diff --git a/pkg/transport/nclprotocol/test/nodes.go b/pkg/transport/nclprotocol/test/nodes.go index 4830ac7187..a6ea73bcfe 100644 --- a/pkg/transport/nclprotocol/test/nodes.go +++ b/pkg/transport/nclprotocol/test/nodes.go @@ -30,7 +30,7 @@ func NewMockNodeInfoProvider() *MockNodeInfoProvider { func (m *MockNodeInfoProvider) GetNodeInfo(ctx context.Context) models.NodeInfo { m.mu.RLock() defer m.mu.RUnlock() - return m.nodeInfo + return *m.nodeInfo.Copy() } func (m *MockNodeInfoProvider) SetNodeInfo(nodeInfo models.NodeInfo) { From 97500e2889383f438920f5f1ab025692217e2c07 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Wed, 11 Dec 2024 11:58:47 +0200 Subject: [PATCH 16/16] readme --- pkg/transport/nclprotocol/README.md | 383 ++++++++++++++++++++++++++++ 1 file changed, 383 insertions(+) create mode 100644 pkg/transport/nclprotocol/README.md diff --git a/pkg/transport/nclprotocol/README.md b/pkg/transport/nclprotocol/README.md new file mode 100644 index 0000000000..148f105162 --- /dev/null +++ b/pkg/transport/nclprotocol/README.md @@ -0,0 +1,383 @@ +# NCL Protocol Documentation + +The NCL (NATS Client Library) Protocol manages reliable bidirectional communication between compute nodes and orchestrators in the Bacalhau network. It provides ordered async message delivery, connection health monitoring, and automatic recovery from failures. + +## Table of Contents +1. [Definitions & Key Concepts](#definitions--key-concepts) +2. [Architecture Overview](#architecture-overview) +3. [Message Sequencing](#message-sequencing) +4. [Connection Lifecycle](#connection-lifecycle) +5. [Message Contracts](#message-contracts) +6. [Communication Flows](#communication-flows) +7. [Component Dependencies](#component-dependencies) +8. [Configuration](#configuration) +9. [Glossary](#glossary) + +## Definitions & Key Concepts + +### Events and Messages +- **Event**: An immutable record of a state change in the local system +- **Message**: A communication packet sent between nodes derived from events +- **Sequence Number**: A monotonically increasing identifier for ordering events and messages + +### Node Information +- **Node ID**: Unique identifier for each compute node +- **Resources**: Computational resources like CPU, Memory, GPU +- **Available Capacity**: Currently free resources on a node +- **Queue Used Capacity**: Resources allocated to queued jobs + +### Connection States +- **Disconnected**: No active connection, no message processing +- **Connecting**: Attempting to establish connection +- **Connected**: Active message processing and health monitoring + +Transitions between states occur based on: +- Successful/failed handshakes +- Missing heartbeats +- Network failures +- Explicit disconnection + +## Architecture Overview + +The protocol consists of two main planes: + +### Control Plane +- Handles connection establishment and health monitoring +- Manages periodic heartbeats and node info updates +- Maintains connection state and health metrics +- Handles checkpointing for recovery + +### Data Plane +- Provides reliable, ordered message delivery +- Manages event watching and dispatching +- Tracks message sequences for both sides +- Handles recovery from network failures + +### NATS Subject Structure +``` +bacalhau.global.compute..in.msgs - Messages to compute node +bacalhau.global.compute..out.msgs - Messages from compute node +bacalhau.global.compute..out.ctrl - Control messages from compute +bacalhau.global.compute.*.out.ctrl - Global control channel +``` + +## Message Sequencing + +### Overview + +The NCL protocol integrates with a local event watcher system to decouple event processing from message delivery. Each node maintains its own ordered ledger of events that the protocol watches and selectively publishes. This decoupling provides several benefits: + +- Clean separation between business logic and message transport +- Reliable local event ordering +- Simple checkpointing and recovery +- Built-in replay capabilities + +### Event Flow Architecture + +``` +Local Event Store NCL Protocol Remote Node +┌──────────────┐ ┌─────────────────────┐ ┌──────────────┐ +│ │ │ 1. Watch Events │ │ │ +│ Ordered │◄───┤ 2. Filter Relevant │ │ │ +│ Event │ │ 3. Create Messages │───►│ Receive │ +│ Ledger │ │ 4. Track Sequences │ │ Process │ +│ │ │ 5. Checkpoint │ │ │ +└──────────────┘ └─────────────────────┘ └──────────────┘ +``` + +### Key Components + +1. **Event Store** + - Maintains ordered sequence of all local events + - Each event has unique monotonic sequence number + - Supports seeking and replay from any position + +2. **Event Watcher** + - Watches event store for new entries + - Filters events relevant for transport + - Supports resuming from checkpoint + +3. **Message Dispatcher** + - Creates messages from events + - Manages reliable delivery + - Tracks publish acknowledgments + + +## Connection Lifecycle + +### Initial Connection + +1. **Handshake** + - Compute node initiates connection by sending HandshakeRequest + - Includes node info, start time, and last processed sequence number + - Orchestrator validates request and accepts/rejects connection + - On acceptance, orchestrator creates dedicated data plane for node + +2. **Data Plane Setup** + - Both sides establish message subscriptions + - Create ordered publishers for reliable delivery + - Initialize event watchers and dispatchers + - Set up sequence tracking + +### Ongoing Communication + +1. **Health Monitoring** + - Compute nodes send periodic heartbeats + - Include current capacity and last processed sequence + - Orchestrator tracks node health and connection state + - Missing heartbeats trigger disconnection + +2. **Node Info Updates** + - Compute nodes send updates when configuration changes + - Includes updated capacity, features, labels + - Orchestrator maintains current node state + +3. **Message Flow** + - Data flows through separate control/data subjects + - Messages include sequence numbers for ordering + - Both sides track processed sequences + - Failed deliveries trigger automatic recovery + +## Message Contracts + +### Handshake Messages + +```typescript +// Request sent by compute node to initiate connection +HandshakeRequest { + NodeInfo: models.NodeInfo + StartTime: Time + LastOrchestratorSeqNum: uint64 +} + +// Response from orchestrator +HandshakeResponse { + Accepted: boolean + Reason: string // Only set if not accepted + LastComputeSeqNum: uint64 +} +``` + +### Heartbeat Messages + +```typescript +// Periodic heartbeat from compute node +HeartbeatRequest { + NodeID: string + AvailableCapacity: Resources + QueueUsedCapacity: Resources + LastOrchestratorSeqNum: uint64 +} + +// Acknowledgment from orchestrator +HeartbeatResponse { + LastComputeSeqNum: uint64 +} +``` + +### Node Info Update Messages + +```typescript +// Node info update notification +UpdateNodeInfoRequest { + NodeInfo: NodeInfo // Same structure as in HandshakeRequest +} + +UpdateNodeInfoResponse { + Accepted: boolean + Reason: string // Only set if not accepted +} +``` + +## Communication Flows + +### Initial Connection and Handshake +The following sequence shows the initial connection establishment between compute node and orchestrator: +```mermaid +sequenceDiagram + participant C as Compute Node + participant O as Orchestrator + + Note over C,O: Connection Establishment + C->>O: HandshakeRequest(NodeInfo, StartTime, LastSeqNum) + + Note over O: Validate Node + alt Valid Node + O->>O: Create Data Plane + O->>O: Setup Message Handlers + O-->>C: HandshakeResponse(Accepted=true, LastSeqNum) + + Note over C: Setup Data Plane + C->>C: Start Control Plane + C->>C: Initialize Data Plane + + Note over C,O: Begin Regular Communication + C->>O: Initial Heartbeat + O-->>C: HeartbeatResponse + else Invalid Node + O-->>C: HandshakeResponse(Accepted=false, Reason) + Note over C: Retry with backoff + end +``` + +### Regular Operation Flow + +The following sequence shows the ongoing communication pattern between compute node and orchestrator, including periodic health checks and configuration updates: +```mermaid +sequenceDiagram + participant C as Compute Node + participant O as Orchestrator + + rect rgb(200, 230, 200) + Note over C,O: Periodic Health Monitoring + loop Every HeartbeatInterval + C->>O: HeartbeatRequest(NodeID, Capacity, LastSeqNum) + O-->>C: HeartbeatResponse() + end + end + + rect rgb(230, 200, 200) + Note over C,O: Node Info Updates + C->>C: Detect Config Change + C->>O: UpdateNodeInfoRequest(NewNodeInfo) + O-->>C: UpdateNodeInfoResponse(Accepted) + end + + rect rgb(200, 200, 230) + Note over C,O: Data Plane Messages + O->>C: Execution Messages (with SeqNum) + C->>O: Result Messages (with SeqNum) + Note over C,O: Both track sequence numbers + end +``` + +During regular operation: +- Heartbeats occur every HeartbeatInterval (default 15s) +- Configuration changes trigger immediate updates +- Data plane messages flow continuously in both directions +- Both sides maintain sequence tracking and acknowledgments + +### Failure Recover Flow +The protocol provides comprehensive failure recovery through several mechanisms: +```mermaid +sequenceDiagram + participant C as Compute Node + participant O as Orchestrator + + rect rgb(240, 200, 200) + Note over C,O: Network Failure + C->>O: HeartbeatRequest + x--xO: Connection Lost + + Note over C: Detect Missing Response + C->>C: Mark Disconnected + C->>C: Stop Data Plane + + Note over O: Detect Missing Heartbeats + O->>O: Mark Node Disconnected + O->>O: Cleanup Node Resources + end + + rect rgb(200, 240, 200) + Note over C,O: Recovery + loop Until Connected + Note over C: Exponential Backoff + C->>O: HandshakeRequest(LastSeqNum) + O-->>C: HandshakeResponse(Accepted) + end + + Note over C,O: Resume from Last Checkpoint + Note over C: Restart Data Plane + Note over O: Recreate Node Resources + end +``` + +#### Failure Detection +- Missing heartbeats beyond threshold +- NATS connection failures +- Message publish failures + +#### Recovery Process +1. Both sides independently detect failure +2. Clean up existing resources +3. Compute node initiates reconnection +4. Resume from last checkpoint: + - Load last checkpoint sequence + - Resume event watching + - Rebuild publish state + - Resend pending messages +5. Continue normal operation + +This process ensures: +- No events are lost +- Messages remain ordered +- Efficient recovery +- At-least-once delivery + +## Component Dependencies + +### Compute Node Components: + +``` +ConnectionManager +├── ControlPlane +│ ├── NodeInfoProvider +│ │ └── Monitors node state changes +│ ├── MessageHandler +│ │ └── Processes control messages +│ └── Checkpointer +│ └── Saves progress state +└── DataPlane + ├── LogStreamServer + │ └── Handles job output streaming + ├── MessageHandler + │ └── Processes execution messages + ├── MessageCreator + │ └── Formats outgoing messages + └── EventStore + └── Tracks execution events +``` + +### Orchestrator Components: + +``` +ComputeManager +├── NodeManager +│ ├── Tracks node states +│ └── Manages node lifecycle +├── MessageHandler +│ └── Processes node messages +├── MessageCreatorFactory +│ └── Creates per-node message handlers +└── DataPlane (per node) + ├── Subscriber + │ └── Handles incoming messages + ├── Publisher + │ └── Sends ordered messages + └── Dispatcher + └── Watches and sends events +``` + +## Configuration + +### Connection Management +- `HeartbeatInterval`: How often compute nodes send heartbeats (default: 15s) +- `HeartbeatMissFactor`: Number of missed heartbeats before disconnection (default: 5) +- `NodeInfoUpdateInterval`: How often node info updates are checked (default: 60s) +- `RequestTimeout`: Timeout for individual requests (default: 10s) + +### Recovery Settings +- `ReconnectInterval`: Base interval between reconnection attempts (default: 10s) +- `BaseRetryInterval`: Initial retry delay after failure (default: 5s) +- `MaxRetryInterval`: Maximum retry delay (default: 5m) + +### Data Plane Settings +- `CheckpointInterval`: How often sequence progress is saved (default: 30s) + +## Glossary + +- **Checkpoint**: A saved position in the event sequence used for recovery +- **Handshake**: Initial connection protocol between compute node and orchestrator +- **Heartbeat**: Periodic health check message from compute node to orchestrator +- **Node Info**: Current state and capabilities of a compute node +- **Sequence Number**: Monotonically increasing identifier used for message ordering \ No newline at end of file