diff --git a/pkg/models/node_info.go b/pkg/models/node_info.go index 5db51d451f..c889618216 100644 --- a/pkg/models/node_info.go +++ b/pkg/models/node_info.go @@ -121,8 +121,8 @@ func (n *NodeInfo) Copy() *NodeInfo { // Deep copy maps cpy.Labels = maps.Clone(n.Labels) cpy.SupportedProtocols = slices.Clone(n.SupportedProtocols) - cpy.ComputeNodeInfo = *n.ComputeNodeInfo.Copy() - cpy.BacalhauVersion = *n.BacalhauVersion.Copy() + cpy.ComputeNodeInfo = copyOrZero(n.ComputeNodeInfo.Copy()) + cpy.BacalhauVersion = copyOrZero(n.BacalhauVersion.Copy()) return cpy } @@ -173,9 +173,9 @@ func (c *ComputeNodeInfo) Copy() *ComputeNodeInfo { cpy.ExecutionEngines = slices.Clone(c.ExecutionEngines) cpy.Publishers = slices.Clone(c.Publishers) cpy.StorageSources = slices.Clone(c.StorageSources) - cpy.MaxCapacity = *c.MaxCapacity.Copy() - cpy.QueueUsedCapacity = *c.QueueUsedCapacity.Copy() - cpy.AvailableCapacity = *c.AvailableCapacity.Copy() - cpy.MaxJobRequirements = *c.MaxJobRequirements.Copy() + cpy.MaxCapacity = copyOrZero(c.MaxCapacity.Copy()) + cpy.QueueUsedCapacity = copyOrZero(c.QueueUsedCapacity.Copy()) + cpy.AvailableCapacity = copyOrZero(c.AvailableCapacity.Copy()) + cpy.MaxJobRequirements = copyOrZero(c.MaxJobRequirements.Copy()) return cpy } diff --git a/pkg/models/utils.go b/pkg/models/utils.go index 33d8702f36..d86243e1fa 100644 --- a/pkg/models/utils.go +++ b/pkg/models/utils.go @@ -22,3 +22,12 @@ func ValidateSlice[T Validatable](slice []T) error { } return nil } + +// Helper function for copying or getting zero value +func copyOrZero[T any](v *T) T { + var zero T // Create zero value + if v == nil { + return zero + } + return *v +} diff --git a/pkg/test/utils/watcher.go b/pkg/test/utils/watcher.go index f4dccc3716..bc87c02b62 100644 --- a/pkg/test/utils/watcher.go +++ b/pkg/test/utils/watcher.go @@ -28,14 +28,7 @@ func CreateComputeEventStore(t *testing.T) watcher.EventStore { ) require.NoError(t, err) - database := watchertest.CreateBoltDB(t) - - eventStore, err := boltdb_watcher.NewEventStore(database, - boltdb_watcher.WithEventsBucket("events"), - boltdb_watcher.WithCheckpointBucket("checkpoints"), - boltdb_watcher.WithEventSerializer(eventObjectSerializer), - ) - require.NoError(t, err) + eventStore := createEventStore(t, eventObjectSerializer) return eventStore } @@ -43,36 +36,35 @@ func CreateJobEventStore(t *testing.T) watcher.EventStore { eventObjectSerializer := watcher.NewJSONSerializer() err := errors.Join( eventObjectSerializer.RegisterType(jobstore.EventObjectExecutionUpsert, reflect.TypeOf(models.ExecutionUpsert{})), - eventObjectSerializer.RegisterType(jobstore.EventObjectEvaluation, reflect.TypeOf(models.Event{})), + eventObjectSerializer.RegisterType(jobstore.EventObjectEvaluation, reflect.TypeOf(models.Evaluation{})), ) require.NoError(t, err) - database := watchertest.CreateBoltDB(t) - - eventStore, err := boltdb_watcher.NewEventStore(database, - boltdb_watcher.WithEventsBucket("events"), - boltdb_watcher.WithCheckpointBucket("checkpoints"), - boltdb_watcher.WithEventSerializer(eventObjectSerializer), - ) - require.NoError(t, err) + eventStore := createEventStore(t, eventObjectSerializer) return eventStore } +// CreateStringEventStore creates a new event store for string events using BoltDB +// and returns both the event store and an envelope registry. +// The returned EventStore must be closed by the caller when no longer needed. func CreateStringEventStore(t *testing.T) (watcher.EventStore, *envelope.Registry) { eventObjectSerializer := watcher.NewJSONSerializer() require.NoError(t, eventObjectSerializer.RegisterType(TypeString, reflect.TypeOf(""))) - database := watchertest.CreateBoltDB(t) + eventStore := createEventStore(t, eventObjectSerializer) + registry := envelope.NewRegistry() + require.NoError(t, registry.Register(TypeString, "")) + return eventStore, registry +} + +func createEventStore(t *testing.T, serializer *watcher.JSONSerializer) watcher.EventStore { + database := watchertest.CreateBoltDB(t) eventStore, err := boltdb_watcher.NewEventStore(database, boltdb_watcher.WithEventsBucket("events"), boltdb_watcher.WithCheckpointBucket("checkpoints"), - boltdb_watcher.WithEventSerializer(eventObjectSerializer), + boltdb_watcher.WithEventSerializer(serializer), ) require.NoError(t, err) - - registry := envelope.NewRegistry() - require.NoError(t, registry.Register(TypeString, "")) - - return eventStore, registry + return eventStore } diff --git a/pkg/transport/nclprotocol/compute/manager_test.go b/pkg/transport/nclprotocol/compute/manager_test.go index f3c19f9ae8..c4f40fa4b2 100644 --- a/pkg/transport/nclprotocol/compute/manager_test.go +++ b/pkg/transport/nclprotocol/compute/manager_test.go @@ -82,7 +82,7 @@ func (s *ConnectionManagerTestSuite) SetupTest() { } // Setup mock responder - mockResponder, err := ncltest.NewMockResponder(s.natsConn, nil) + mockResponder, err := ncltest.NewMockResponder(s.ctx, s.natsConn, nil) s.Require().NoError(err) s.mockResponder = mockResponder diff --git a/pkg/transport/nclprotocol/test/control_plane.go b/pkg/transport/nclprotocol/test/control_plane.go index c9474171cb..00b2dcac9a 100644 --- a/pkg/transport/nclprotocol/test/control_plane.go +++ b/pkg/transport/nclprotocol/test/control_plane.go @@ -55,7 +55,7 @@ type MockResponder struct { // NewMockResponder creates a new mock responder with the given behavior. // If behavior is nil, default success responses are used. -func NewMockResponder(conn *nats.Conn, behavior *MockResponderBehavior) (*MockResponder, error) { +func NewMockResponder(ctx context.Context, conn *nats.Conn, behavior *MockResponderBehavior) (*MockResponder, error) { if behavior == nil { behavior = &MockResponderBehavior{ HandshakeResponse: struct { @@ -94,8 +94,8 @@ func NewMockResponder(conn *nats.Conn, behavior *MockResponderBehavior) (*MockRe responder: responder, } - if err := mr.setupHandlers(context.Background()); err != nil { - responder.Close(context.Background()) + if err := mr.setupHandlers(ctx); err != nil { + responder.Close(ctx) return nil, err } diff --git a/pkg/transport/nclprotocol/test/nodes.go b/pkg/transport/nclprotocol/test/nodes.go index 4830ac7187..a6ea73bcfe 100644 --- a/pkg/transport/nclprotocol/test/nodes.go +++ b/pkg/transport/nclprotocol/test/nodes.go @@ -30,7 +30,7 @@ func NewMockNodeInfoProvider() *MockNodeInfoProvider { func (m *MockNodeInfoProvider) GetNodeInfo(ctx context.Context) models.NodeInfo { m.mu.RLock() defer m.mu.RUnlock() - return m.nodeInfo + return *m.nodeInfo.Copy() } func (m *MockNodeInfoProvider) SetNodeInfo(nodeInfo models.NodeInfo) {