diff --git a/data/server-discovery-and-monitoring/monitoring/replica_set_with_removal.json b/data/server-discovery-and-monitoring/monitoring/replica_set_with_removal.json index a14456cdba..3cad92d6b8 100644 --- a/data/server-discovery-and-monitoring/monitoring/replica_set_with_removal.json +++ b/data/server-discovery-and-monitoring/monitoring/replica_set_with_removal.json @@ -3,30 +3,7 @@ "uri": "mongodb://a,b/", "phases": [ { - "responses": [ - [ - "a:27017", - { - "ok": 1, - "ismaster": true, - "setName": "rs", - "setVersion": 1, - "primary": "a:27017", - "hosts": [ - "a:27017" - ], - "minWireVersion": 0, - "maxWireVersion": 4 - } - ], - [ - "b:27017", - { - "ok": 1, - "ismaster": true - } - ] - ], + "responses": [], "outcome": { "events": [ { @@ -73,7 +50,37 @@ "topologyId": "42", "address": "b:27017" } - }, + } + ] + } + }, + { + "responses": [ + [ + "a:27017", + { + "ok": 1, + "ismaster": true, + "setName": "rs", + "setVersion": 1, + "primary": "a:27017", + "hosts": [ + "a:27017" + ], + "minWireVersion": 0, + "maxWireVersion": 4 + } + ], + [ + "b:27017", + { + "ok": 1, + "ismaster": true + } + ] + ], + "outcome": { + "events": [ { "server_description_changed_event": { "topologyId": "42", diff --git a/data/server-discovery-and-monitoring/monitoring/replica_set_with_removal.yml b/data/server-discovery-and-monitoring/monitoring/replica_set_with_removal.yml index e156ec7100..754d8270e6 100644 --- a/data/server-discovery-and-monitoring/monitoring/replica_set_with_removal.yml +++ b/data/server-discovery-and-monitoring/monitoring/replica_set_with_removal.yml @@ -2,22 +2,7 @@ description: "Monitoring a replica set with non member" uri: "mongodb://a,b/" phases: - - responses: - - - - "a:27017" - - { - ok: 1, - ismaster: true, - setName: "rs", - setVersion: 1.0, - primary: "a:27017", - hosts: [ "a:27017" ], - minWireVersion: 0, - maxWireVersion: 4 - } - - - - "b:27017" - - { ok: 1, ismaster: true } + responses: [] outcome: events: - @@ -52,6 +37,25 @@ phases: server_opening_event: topologyId: "42" address: "b:27017" + - + responses: + - + - "a:27017" + - { + ok: 1, + ismaster: true, + setName: "rs", + setVersion: 1.0, + primary: "a:27017", + hosts: [ "a:27017" ], + minWireVersion: 0, + maxWireVersion: 4 + } + - + - "b:27017" + - { ok: 1, ismaster: true } + outcome: + events: - server_description_changed_event: topologyId: "42" diff --git a/data/server-discovery-and-monitoring/monitoring/standalone.json b/data/server-discovery-and-monitoring/monitoring/standalone.json index 5d40286c97..f3df3ec764 100644 --- a/data/server-discovery-and-monitoring/monitoring/standalone.json +++ b/data/server-discovery-and-monitoring/monitoring/standalone.json @@ -1,6 +1,6 @@ { "description": "Monitoring a standalone connection", - "uri": "mongodb://a:27017", + "uri": "mongodb://a:27017/?directConnection=true", "phases": [ { "responses": [ diff --git a/data/server-discovery-and-monitoring/monitoring/standalone.yml b/data/server-discovery-and-monitoring/monitoring/standalone.yml index aff3f7322c..de79d0ff49 100644 --- a/data/server-discovery-and-monitoring/monitoring/standalone.yml +++ b/data/server-discovery-and-monitoring/monitoring/standalone.yml @@ -1,5 +1,5 @@ description: "Monitoring a standalone connection" -uri: "mongodb://a:27017" +uri: "mongodb://a:27017/?directConnection=true" phases: - responses: diff --git a/data/server-discovery-and-monitoring/monitoring/standalone_suppress_equal_description_changes.json b/data/server-discovery-and-monitoring/monitoring/standalone_suppress_equal_description_changes.json new file mode 100644 index 0000000000..b8122a363b --- /dev/null +++ b/data/server-discovery-and-monitoring/monitoring/standalone_suppress_equal_description_changes.json @@ -0,0 +1,113 @@ +{ + "description": "Monitoring a standalone connection - suppress update events for equal server descriptions", + "uri": "mongodb://a:27017/?directConnection=true", + "phases": [ + { + "responses": [ + [ + "a:27017", + { + "ok": 1, + "ismaster": true, + "minWireVersion": 0, + "maxWireVersion": 4 + } + ], + [ + "a:27017", + { + "ok": 1, + "ismaster": true, + "minWireVersion": 0, + "maxWireVersion": 4 + } + ] + ], + "outcome": { + "events": [ + { + "topology_opening_event": { + "topologyId": "42" + } + }, + { + "topology_description_changed_event": { + "topologyId": "42", + "previousDescription": { + "topologyType": "Unknown", + "servers": [] + }, + "newDescription": { + "topologyType": "Single", + "servers": [ + { + "address": "a:27017", + "arbiters": [], + "hosts": [], + "passives": [], + "type": "Unknown" + } + ] + } + } + }, + { + "server_opening_event": { + "topologyId": "42", + "address": "a:27017" + } + }, + { + "server_description_changed_event": { + "topologyId": "42", + "address": "a:27017", + "previousDescription": { + "address": "a:27017", + "arbiters": [], + "hosts": [], + "passives": [], + "type": "Unknown" + }, + "newDescription": { + "address": "a:27017", + "arbiters": [], + "hosts": [], + "passives": [], + "type": "Standalone" + } + } + }, + { + "topology_description_changed_event": { + "topologyId": "42", + "previousDescription": { + "topologyType": "Single", + "servers": [ + { + "address": "a:27017", + "arbiters": [], + "hosts": [], + "passives": [], + "type": "Unknown" + } + ] + }, + "newDescription": { + "topologyType": "Single", + "servers": [ + { + "address": "a:27017", + "arbiters": [], + "hosts": [], + "passives": [], + "type": "Standalone" + } + ] + } + } + } + ] + } + } + ] +} diff --git a/data/server-discovery-and-monitoring/monitoring/standalone_suppress_equal_description_changes.yml b/data/server-discovery-and-monitoring/monitoring/standalone_suppress_equal_description_changes.yml new file mode 100644 index 0000000000..f9b13bdebd --- /dev/null +++ b/data/server-discovery-and-monitoring/monitoring/standalone_suppress_equal_description_changes.yml @@ -0,0 +1,73 @@ +description: "Monitoring a standalone connection - suppress update events for equal server descriptions" +uri: "mongodb://a:27017/?directConnection=true" +phases: + - + responses: + - + - "a:27017" + - { ok: 1, ismaster: true, minWireVersion: 0, maxWireVersion: 4 } + - + - "a:27017" + - { ok: 1, ismaster: true, minWireVersion: 0, maxWireVersion: 4 } + + outcome: + events: + - + topology_opening_event: + topologyId: "42" + - + topology_description_changed_event: + topologyId: "42" + previousDescription: + topologyType: "Unknown" + servers: [] + newDescription: + topologyType: "Single" + servers: + - + address: "a:27017" + arbiters: [] + hosts: [] + passives: [] + type: "Unknown" + - + server_opening_event: + topologyId: "42" + address: "a:27017" + - + server_description_changed_event: + topologyId: "42" + address: "a:27017" + previousDescription: + address: "a:27017" + arbiters: [] + hosts: [] + passives: [] + type: "Unknown" + newDescription: + address: "a:27017" + arbiters: [] + hosts: [] + passives: [] + type: "Standalone" + - + topology_description_changed_event: + topologyId: "42" + previousDescription: + topologyType: "Single" + servers: + - + address: "a:27017" + arbiters: [] + hosts: [] + passives: [] + type: "Unknown" + newDescription: + topologyType: "Single" + servers: + - + address: "a:27017" + arbiters: [] + hosts: [] + passives: [] + type: "Standalone" diff --git a/event/monitoring.go b/event/monitoring.go index 240f2398e6..eb21933126 100644 --- a/event/monitoring.go +++ b/event/monitoring.go @@ -10,6 +10,9 @@ import ( "context" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/description" ) // CommandStartedEvent represents an event generated when a command is sent to a server. @@ -89,3 +92,80 @@ type PoolEvent struct { type PoolMonitor struct { Event func(*PoolEvent) } + +// ServerDescriptionChangedEvent represents a server description change. +type ServerDescriptionChangedEvent struct { + Address address.Address + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of + PreviousDescription description.Server + NewDescription description.Server +} + +// ServerOpeningEvent is an event generated when the server is initialized. +type ServerOpeningEvent struct { + Address address.Address + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of +} + +// ServerClosedEvent is an event generated when the server is closed. +type ServerClosedEvent struct { + Address address.Address + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of +} + +// TopologyDescriptionChangedEvent represents a topology description change. +type TopologyDescriptionChangedEvent struct { + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of + PreviousDescription description.Topology + NewDescription description.Topology +} + +// TopologyOpeningEvent is an event generated when the topology is initialized. +type TopologyOpeningEvent struct { + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of +} + +// TopologyClosedEvent is an event generated when the topology is closed. +type TopologyClosedEvent struct { + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of +} + +// ServerHeartbeatStartedEvent is an event generated when the heartbeat is started. +type ServerHeartbeatStartedEvent struct { + ConnectionID string // The address this heartbeat was sent to with a unique identifier + Awaited bool // If this heartbeat was awaitable +} + +// ServerHeartbeatSucceededEvent is an event generated when the heartbeat succeeds. +type ServerHeartbeatSucceededEvent struct { + DurationNanos int64 + Reply description.Server + ConnectionID string // The address this heartbeat was sent to with a unique identifier + Awaited bool // If this heartbeat was awaitable +} + +// ServerHeartbeatFailedEvent is an event generated when the heartbeat fails. +type ServerHeartbeatFailedEvent struct { + DurationNanos int64 + Failure error + ConnectionID string // The address this heartbeat was sent to with a unique identifier + Awaited bool // If this heartbeat was awaitable +} + +// ServerMonitor represents a monitor that is triggered for different server events. The client +// will monitor changes on the MongoDB deployment it is connected to, and this monitor reports +// the changes in the client's representation of the deployment. The topology represents the +// overall deployment, and heartbeats are sent to individual servers to check their current status. +type ServerMonitor struct { + ServerDescriptionChanged func(*ServerDescriptionChangedEvent) + ServerOpening func(*ServerOpeningEvent) + ServerClosed func(*ServerClosedEvent) + // TopologyDescriptionChanged is called when the topology is locked, so the callback should + // not attempt any operation that requires server selection on the same client. + TopologyDescriptionChanged func(*TopologyDescriptionChangedEvent) + TopologyOpening func(*TopologyOpeningEvent) + TopologyClosed func(*TopologyClosedEvent) + ServerHeartbeatStarted func(*ServerHeartbeatStartedEvent) + ServerHeartbeatSucceeded func(*ServerHeartbeatSucceededEvent) + ServerHeartbeatFailed func(*ServerHeartbeatFailedEvent) +} diff --git a/mongo/client.go b/mongo/client.go index 6e5eedcee1..8b8691bcbe 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -62,6 +62,7 @@ type Client struct { registry *bsoncodec.Registry marshaller BSONAppender monitor *event.CommandMonitor + serverMonitor *event.ServerMonitor sessionPool *session.Pool // client-side encryption fields @@ -495,6 +496,19 @@ func (c *Client) configure(opts *options.ClientOptions) error { func(*event.CommandMonitor) *event.CommandMonitor { return opts.Monitor }, )) } + // ServerMonitor + if opts.ServerMonitor != nil { + c.serverMonitor = opts.ServerMonitor + serverOpts = append( + serverOpts, + topology.WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return opts.ServerMonitor }), + ) + + topologyOpts = append( + topologyOpts, + topology.WithTopologyServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return opts.ServerMonitor }), + ) + } // ReadConcern c.readConcern = readconcern.New() if opts.ReadConcern != nil { diff --git a/mongo/client_test.go b/mongo/client_test.go index fc894f8a21..cd46a80861 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -238,6 +238,11 @@ func TestClient(t *testing.T) { client := setupClient(options.Client().SetWriteConcern(wc)) assert.Equal(t, wc, client.writeConcern, "mismatch; expected write concern %v, got %v", wc, client.writeConcern) }) + t.Run("server monitor", func(t *testing.T) { + monitor := &event.ServerMonitor{} + client := setupClient(options.Client().SetServerMonitor(monitor)) + assert.Equal(t, monitor, client.serverMonitor, "expected sdam monitor %v, got %v", monitor, client.serverMonitor) + }) t.Run("GetURI", func(t *testing.T) { t.Run("ApplyURI not called", func(t *testing.T) { opts := options.Client().SetHosts([]string{"localhost:27017"}) diff --git a/mongo/description/server.go b/mongo/description/server.go index 7b4181e80c..fec77ddc04 100644 --- a/mongo/description/server.go +++ b/mongo/description/server.go @@ -31,12 +31,14 @@ type SelectedServer struct { type Server struct { Addr address.Address + Arbiters []string AverageRTT time.Duration AverageRTTSet bool Compression []string // compression methods returned by server CanonicalAddr address.Address ElectionID primitive.ObjectID HeartbeatInterval time.Duration + Hosts []string LastError error LastUpdateTime time.Time LastWriteTime time.Time @@ -44,6 +46,8 @@ type Server struct { MaxDocumentSize uint32 MaxMessageSize uint32 Members []address.Address + Passives []string + Primary address.Address ReadOnly bool SessionTimeoutMinutes uint32 SetName string @@ -69,12 +73,11 @@ func NewServer(addr address.Address, response bson.Raw) Server { var isReplicaSet, isMaster, hidden, secondary, arbiterOnly bool var msg string var version VersionRange - var hosts, passives, arbiters []string for _, element := range elements { switch element.Key() { case "arbiters": var err error - arbiters, err = decodeStringSlice(element, "arbiters") + desc.Arbiters, err = decodeStringSlice(element, "arbiters") if err != nil { desc.LastError = err return desc @@ -106,7 +109,7 @@ func NewServer(addr address.Address, response bson.Raw) Server { } case "hosts": var err error - hosts, err = decodeStringSlice(element, "hosts") + desc.Hosts, err = decodeStringSlice(element, "hosts") if err != nil { desc.LastError = err return desc @@ -203,11 +206,18 @@ func NewServer(addr address.Address, response bson.Raw) Server { } case "passives": var err error - passives, err = decodeStringSlice(element, "passives") + desc.Passives, err = decodeStringSlice(element, "passives") if err != nil { desc.LastError = err return desc } + case "primary": + primary, ok := element.Value().StringValueOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'primary' to be a string but it's a BSON %s", element.Value().Type) + return desc + } + desc.Primary = address.Address(primary) case "readOnly": desc.ReadOnly, ok = element.Value().BooleanOK() if !ok { @@ -269,15 +279,15 @@ func NewServer(addr address.Address, response bson.Raw) Server { } } - for _, host := range hosts { + for _, host := range desc.Hosts { desc.Members = append(desc.Members, address.Address(host).Canonicalize()) } - for _, passive := range passives { + for _, passive := range desc.Passives { desc.Members = append(desc.Members, address.Address(passive).Canonicalize()) } - for _, arbiter := range arbiters { + for _, arbiter := range desc.Arbiters { desc.Members = append(desc.Members, address.Address(arbiter).Canonicalize()) } @@ -412,3 +422,81 @@ func decodeStringMap(element bson.RawElement, name string) (map[string]string, e func (s Server) SupportsRetryWrites() bool { return s.SessionTimeoutMinutes != 0 && s.Kind != Standalone } + +// Equal compares two server descriptions and returns true if they are equal +func (s Server) Equal(other Server) bool { + if s.CanonicalAddr.String() != other.CanonicalAddr.String() { + return false + } + + if !sliceStringEqual(s.Arbiters, other.Arbiters) { + return false + } + + if !sliceStringEqual(s.Hosts, other.Hosts) { + return false + } + + if !sliceStringEqual(s.Passives, other.Passives) { + return false + } + + if s.Primary != other.Primary { + return false + } + + if s.SetName != other.SetName { + return false + } + + if s.Kind != other.Kind { + return false + } + + if s.LastError != nil || other.LastError != nil { + if s.LastError == nil || other.LastError == nil { + return false + } + if s.LastError.Error() != other.LastError.Error() { + return false + } + } + + if !s.WireVersion.Equals(other.WireVersion) { + return false + } + + if len(s.Tags) != len(other.Tags) || !s.Tags.ContainsAll(other.Tags) { + return false + } + + if s.SetVersion != other.SetVersion { + return false + } + + if s.ElectionID != other.ElectionID { + return false + } + + if s.SessionTimeoutMinutes != other.SessionTimeoutMinutes { + return false + } + + if s.TopologyVersion != other.TopologyVersion && CompareTopologyVersion(s.TopologyVersion, other.TopologyVersion) != 0 { + return false + } + + return true +} + +func sliceStringEqual(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} diff --git a/mongo/description/server_test.go b/mongo/description/server_test.go new file mode 100644 index 0000000000..1ea3ed3a22 --- /dev/null +++ b/mongo/description/server_test.go @@ -0,0 +1,70 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package description + +import ( + "errors" + "testing" + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/internal/testutil/assert" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/tag" +) + +func TestServer(t *testing.T) { + t.Run("equals", func(t *testing.T) { + defaultServer := Server{} + // Only some of the Server fields affect equality + testCases := []struct { + name string + server Server + equal bool + }{ + {"empty", Server{}, true}, + {"address", Server{Addr: address.Address("foo")}, true}, + {"arbiters", Server{Arbiters: []string{"foo"}}, false}, + {"rtt", Server{AverageRTT: time.Second}, true}, + {"compression", Server{Compression: []string{"foo"}}, true}, + {"canonicalAddr", Server{CanonicalAddr: address.Address("foo")}, false}, + {"electionID", Server{ElectionID: primitive.NewObjectID()}, false}, + {"heartbeatInterval", Server{HeartbeatInterval: time.Second}, true}, + {"hosts", Server{Hosts: []string{"foo"}}, false}, + {"lastError", Server{LastError: errors.New("foo")}, false}, + {"lastUpdateTime", Server{LastUpdateTime: time.Now()}, true}, + {"lastWriteTime", Server{LastWriteTime: time.Now()}, true}, + {"maxBatchCount", Server{MaxBatchCount: 1}, true}, + {"maxDocumentSize", Server{MaxDocumentSize: 1}, true}, + {"maxMessageSize", Server{MaxMessageSize: 1}, true}, + {"members", Server{Members: []address.Address{address.Address("foo")}}, true}, + {"passives", Server{Passives: []string{"foo"}}, false}, + {"primary", Server{Primary: address.Address("foo")}, false}, + {"readOnly", Server{ReadOnly: true}, true}, + {"sessionTimeoutMinutes", Server{SessionTimeoutMinutes: 1}, false}, + {"setName", Server{SetName: "foo"}, false}, + {"setVersion", Server{SetVersion: 1}, false}, + { + "speculativeAuthenticate", + Server{SpeculativeAuthenticate: bson.Raw{'\x08', '\x00', '\x00', '\x00', '\x0A', 'x', '\x00', '\x00'}}, + true, + }, + {"tags", Server{Tags: tag.Set{tag.Tag{"foo", "bar"}}}, false}, + {"topologyVersion", Server{TopologyVersion: &TopologyVersion{primitive.NewObjectID(), 0}}, false}, + {"kind", Server{Kind: Standalone}, false}, + {"wireVersion", Server{WireVersion: &VersionRange{1, 2}}, false}, + {"saslSupportedMechs", Server{SaslSupportedMechs: []string{"foo"}}, true}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := defaultServer.Equal(tc.server) + assert.Equal(t, actual, tc.equal, "expected %v, got %v", tc.equal, actual) + }) + } + }) +} diff --git a/mongo/description/topology.go b/mongo/description/topology.go index f192f4b6fa..7f394e7bf4 100644 --- a/mongo/description/topology.go +++ b/mongo/description/topology.go @@ -10,11 +10,13 @@ import ( "fmt" "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/readpref" ) // Topology represents a description of a mongodb topology type Topology struct { Servers []Server + SetName string Kind TopologyKind SessionTimeoutMinutes uint32 CompatibilityErr error @@ -103,3 +105,113 @@ func (t Topology) String() string { } return fmt.Sprintf("Type: %s, Servers: [%s]", t.Kind, serversStr) } + +// Equal compares two topology descriptions and returns true if they are equal +func (t Topology) Equal(other Topology) bool { + + diff := DiffTopology(t, other) + if len(diff.Added) != 0 || len(diff.Removed) != 0 { + return false + } + + if t.Kind != other.Kind { + return false + } + + topoServers := make(map[string]Server) + for _, s := range t.Servers { + topoServers[s.Addr.String()] = s + } + + otherServers := make(map[string]Server) + for _, s := range other.Servers { + otherServers[s.Addr.String()] = s + } + + if len(topoServers) != len(otherServers) { + return false + } + + for _, server := range topoServers { + otherServer := otherServers[server.Addr.String()] + + if !server.Equal(otherServer) { + return false + } + } + + return true +} + +// HasReadableServer returns true if a topology has a server available for reading +// based on the specified read preference. Single and sharded topologies only require an +// available server, while replica sets require an available server that has a kind +// compatible with the given read preference mode. +func (t Topology) HasReadableServer(mode readpref.Mode) bool { + switch t.Kind { + case Single, Sharded: + return hasAvailableServer(t.Servers, 0) + case ReplicaSetWithPrimary: + return hasAvailableServer(t.Servers, mode) + case ReplicaSetNoPrimary, ReplicaSet: + if mode == readpref.PrimaryMode { + return false + } + // invalid read preference + if !mode.IsValid() { + return false + } + + return hasAvailableServer(t.Servers, mode) + } + return false +} + +// HasWritableServer returns true if a topology has a server available for writing +func (t Topology) HasWritableServer() bool { + return t.HasReadableServer(readpref.PrimaryMode) +} + +// hasAvailableServer returns true if any servers are available based on +// the read preference. +func hasAvailableServer(servers []Server, mode readpref.Mode) bool { + switch mode { + case readpref.PrimaryMode: + for _, s := range servers { + if s.Kind == RSPrimary { + return true + } + } + return false + case readpref.PrimaryPreferredMode, readpref.SecondaryPreferredMode, readpref.NearestMode: + for _, s := range servers { + if s.Kind == RSPrimary || s.Kind == RSSecondary { + return true + } + } + return false + case readpref.SecondaryMode: + for _, s := range servers { + if s.Kind == RSSecondary { + return true + } + } + return false + } + + // read preference is not specified + for _, s := range servers { + switch s.Kind { + case Standalone, + RSMember, + RSPrimary, + RSSecondary, + RSArbiter, + RSGhost, + Mongos: + return true + } + } + + return false +} diff --git a/mongo/description/version_range.go b/mongo/description/version_range.go index 984dff89e9..5d6270c521 100644 --- a/mongo/description/version_range.go +++ b/mongo/description/version_range.go @@ -25,6 +25,17 @@ func (vr VersionRange) Includes(v int32) bool { return v >= vr.Min && v <= vr.Max } +// Equals returns a bool indicating whether the supplied VersionRange is equal. +func (vr *VersionRange) Equals(other *VersionRange) bool { + if vr == nil && other == nil { + return true + } + if vr == nil || other == nil { + return false + } + return vr.Min == other.Min && vr.Max == other.Max +} + // String implements the fmt.Stringer interface. func (vr VersionRange) String() string { return fmt.Sprintf("[%d, %d]", vr.Min, vr.Max) diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index f4a3917f16..1cfd7151c6 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -108,6 +108,7 @@ type ClientOptions struct { MinPoolSize *uint64 PoolMonitor *event.PoolMonitor Monitor *event.CommandMonitor + ServerMonitor *event.ServerMonitor ReadConcern *readconcern.ReadConcern ReadPreference *readpref.ReadPref Registry *bsoncodec.Registry @@ -518,6 +519,12 @@ func (c *ClientOptions) SetMonitor(m *event.CommandMonitor) *ClientOptions { return c } +// SetServerMonitor specifies an SDAM monitor used to monitor SDAM events. +func (c *ClientOptions) SetServerMonitor(m *event.ServerMonitor) *ClientOptions { + c.ServerMonitor = m + return c +} + // SetReadConcern specifies the read concern to use for read operations. A read concern level can also be set through // the "readConcernLevel" URI option (e.g. "readConcernLevel=majority"). The default is nil, meaning the server will use // its configured default. @@ -750,6 +757,9 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.Monitor != nil { c.Monitor = opt.Monitor } + if opt.ServerMonitor != nil { + c.ServerMonitor = opt.ServerMonitor + } if opt.ReadConcern != nil { c.ReadConcern = opt.ReadConcern } diff --git a/mongo/readpref/mode.go b/mongo/readpref/mode.go index deacf9f337..ce036504cb 100644 --- a/mongo/readpref/mode.go +++ b/mongo/readpref/mode.go @@ -72,3 +72,17 @@ func (mode Mode) String() string { return "unknown" } } + +// IsValid checks whether the mode is valid. +func (mode Mode) IsValid() bool { + switch mode { + case PrimaryMode, + PrimaryPreferredMode, + SecondaryMode, + SecondaryPreferredMode, + NearestMode: + return true + default: + return false + } +} diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 94f2034a12..069305fb66 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -459,6 +459,10 @@ func (c *connection) setSocketTimeout(timeout time.Duration) { c.writeTimeout = timeout } +func (c *connection) ID() string { + return c.id +} + // initConnection is an adapter used during connection initialization. It has the minimum // functionality necessary to implement the driver.Connection interface, which is required to pass a // *connection to a Handshaker. diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index e041b8f343..e0886b3bd8 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -24,6 +24,29 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver" ) +type testHandshaker struct { + getDescription func(context.Context, address.Address, driver.Connection) (description.Server, error) + finishHandshake func(context.Context, driver.Connection) error +} + +// GetDescription implements the Handshaker interface. +func (th *testHandshaker) GetDescription(ctx context.Context, addr address.Address, conn driver.Connection) (description.Server, error) { + if th.getDescription != nil { + return th.getDescription(ctx, addr, conn) + } + return description.Server{}, nil +} + +// FinishHandshake implements the Handshaker interface. +func (th *testHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error { + if th.finishHandshake != nil { + return th.finishHandshake(ctx, conn) + } + return nil +} + +var _ driver.Handshaker = &testHandshaker{} + func TestConnection(t *testing.T) { t.Run("connection", func(t *testing.T) { t.Run("newConnection", func(t *testing.T) { diff --git a/x/mongo/driver/topology/fsm.go b/x/mongo/driver/topology/fsm.go index bed7be8df8..94dadc8ca5 100644 --- a/x/mongo/driver/topology/fsm.go +++ b/x/mongo/driver/topology/fsm.go @@ -21,7 +21,6 @@ var minSupportedMongoDBVersion = "2.6" type fsm struct { description.Topology - SetName string maxElectionID primitive.ObjectID maxSetVersion uint32 compatible atomic.Value @@ -48,6 +47,7 @@ func (f *fsm) apply(s description.Server) (description.Topology, description.Ser f.Topology = description.Topology{ Kind: f.Kind, Servers: newServers, + SetName: f.SetName, } // For data bearing servers, set SessionTimeoutMinutes to the lowest among them diff --git a/x/mongo/driver/topology/sdam_spec_test.go b/x/mongo/driver/topology/sdam_spec_test.go index ab14026eca..e288a7eee2 100644 --- a/x/mongo/driver/topology/sdam_spec_test.go +++ b/x/mongo/driver/topology/sdam_spec_test.go @@ -12,12 +12,14 @@ import ( "io/ioutil" "net" "path" + "sync" "sync/atomic" "testing" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/testutil/assert" testhelpers "go.mongodb.org/mongo-driver/internal/testutil/helpers" "go.mongodb.org/mongo-driver/mongo/address" @@ -53,6 +55,7 @@ type IsMaster struct { Msg string `bson:"msg,omitempty"` OK int32 `bson:"ok"` Passives []string `bson:"passives,omitempty"` + Primary string `bson:"primary,omitempty"` ReadOnly bool `bson:"readOnly,omitempty"` SaslSupportedMechs []string `bson:"saslSupportedMechs,omitempty"` Secondary bool `bson:"secondary,omitempty"` @@ -95,6 +98,57 @@ type applicationError struct { Response bsoncore.Document } +type topologyDescription struct { + TopologyType string `bson:"topologyType"` + Servers []serverDescription `bson:"servers"` + SetName string `bson:"setName,omitempty"` +} + +type serverDescription struct { + Address string `bson:"address"` + Arbiters []string `bson:"arbiters"` + Hosts []string `bson:"hosts"` + Passives []string `bson:"passives"` + Primary string `bson:"primary,omitempty"` + SetName string `bson:"setName,omitempty"` + Type string `bson:"type"` +} + +type topologyOpeningEvent struct { + TopologyID string `bson:"topologyId"` +} + +type serverOpeningEvent struct { + Address string `bson:"address"` + TopologyID string `bson:"topologyId"` +} + +type topologyDescriptionChangedEvent struct { + TopologyID string `bson:"topologyId"` + PreviousDescription topologyDescription `bson:"previousDescription"` + NewDescription topologyDescription `bson:"newDescription"` +} + +type serverDescriptionChangedEvent struct { + Address string `bson:"address"` + TopologyID string `bson:"topologyId"` + PreviousDescription serverDescription `bson:"previousDescription"` + NewDescription serverDescription `bson:"newDescription"` +} + +type serverClosedEvent struct { + Address string `bson:"address"` + TopologyID string `bson:"topologyId"` +} + +type monitoringEvent struct { + TopologyOpeningEvent *topologyOpeningEvent `bson:"topology_opening_event,omitempty"` + ServerOpeningEvent *serverOpeningEvent `bson:"server_opening_event,omitempty"` + TopologyDescriptionChangedEvent *topologyDescriptionChangedEvent `bson:"topology_description_changed_event,omitempty"` + ServerDescriptionChangedEvent *serverDescriptionChangedEvent `bson:"server_description_changed_event,omitempty"` + ServerClosedEvent *serverClosedEvent `bson:"server_closed_event,omitempty"` +} + type outcome struct { Servers map[string]server TopologyType string @@ -103,6 +157,7 @@ type outcome struct { MaxSetVersion uint32 MaxElectionID primitive.ObjectID `bson:"maxElectionId"` Compatible *bool + Events []monitoringEvent } type phase struct { @@ -118,8 +173,41 @@ type testCase struct { Phases []phase } +func serverDescriptionChanged(e *event.ServerDescriptionChangedEvent) { + lock.Lock() + publishedEvents = append(publishedEvents, *e) + lock.Unlock() +} + +func serverOpening(e *event.ServerOpeningEvent) { + lock.Lock() + publishedEvents = append(publishedEvents, *e) + lock.Unlock() +} + +func topologyDescriptionChanged(e *event.TopologyDescriptionChangedEvent) { + lock.Lock() + publishedEvents = append(publishedEvents, *e) + lock.Unlock() +} + +func topologyOpening(e *event.TopologyOpeningEvent) { + lock.Lock() + publishedEvents = append(publishedEvents, *e) + lock.Unlock() +} + +func serverClosed(e *event.ServerClosedEvent) { + lock.Lock() + publishedEvents = append(publishedEvents, *e) + lock.Unlock() +} + const testsDir string = "../../../../data/server-discovery-and-monitoring/" +var publishedEvents []interface{} +var lock sync.Mutex + func (r *response) UnmarshalBSON(buf []byte) error { doc := bson.Raw(buf) if err := doc.Index(0).Value().Unmarshal(&r.Host); err != nil { @@ -137,12 +225,21 @@ func setUpTopology(t *testing.T, uri string) *Topology { cs, err := connstring.ParseAndValidate(uri) assert.Nil(t, err, "Parse error: %v", err) + sdam := &event.ServerMonitor{ + ServerDescriptionChanged: serverDescriptionChanged, + ServerOpening: serverOpening, + TopologyDescriptionChanged: topologyDescriptionChanged, + TopologyOpening: topologyOpening, + ServerClosed: serverClosed, + } + // Disable server monitoring because the hosts in the SDAM spec tests don't actually exist, so the server monitor // can race with the test and mark the server Unknown when it fails to connect, which causes tests to fail. serverOpts := []ServerOption{ withMonitoringDisabled(func(bool) bool { return true }), + WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return sdam }), } topo, err := New( WithConnString(func(connstring.ConnString) connstring.ConnString { @@ -151,27 +248,14 @@ func setUpTopology(t *testing.T, uri string) *Topology { WithServerOptions(func(opts ...ServerOption) []ServerOption { return append(opts, serverOpts...) }), + WithTopologyServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { + return sdam + }), ) assert.Nil(t, err, "topology.New error: %v", err) - // add servers to topology without starting heartbeat goroutines - topo.serversLock.Lock() - for _, a := range topo.cfg.seedList { - addr := address.Address(a).Canonicalize() - topo.fsm.Servers = append(topo.fsm.Servers, description.Server{Addr: addr}) - - svr, err := NewServer(addr, primitive.NewObjectID(), topo.cfg.serverOpts...) - assert.Nil(t, err, "NewServer error: %v", err) - atomic.StoreInt32(&svr.connectionstate, connected) - svr.desc.Store(description.NewDefaultServer(svr.address)) - svr.updateTopologyCallback.Store(topo.updateCallback) - - topo.servers[addr] = svr - } - topo.desc.Store(description.Topology{Servers: topo.fsm.Servers}) - topo.serversLock.Unlock() - - atomic.StoreInt32(&topo.connectionstate, connected) + err = topo.Connect() + assert.Nil(t, err, "topology.Connect error: %v", err) return topo } @@ -404,6 +488,115 @@ func applyErrors(t *testing.T, topo *Topology, errors []applicationError) { } } +func compareServerDescriptions(t *testing.T, + expected serverDescription, actual description.Server, idx int) { + t.Helper() + + assert.Equal(t, expected.Address, actual.Addr.String(), + "%v: expected server address %s, got %s", idx, expected.Address, actual.Addr) + + assert.Equal(t, len(expected.Hosts), len(actual.Hosts), + "%v: expected %d hosts, got %d", idx, len(expected.Hosts), len(actual.Hosts)) + for idx, expectedHost := range expected.Hosts { + actualHost := actual.Hosts[idx] + assert.Equal(t, expectedHost, string(actualHost), "%v: expected host %s, got %s", idx, expectedHost, actualHost) + } + + assert.Equal(t, len(expected.Passives), len(actual.Passives), + "%v: expected %d hosts, got %d", idx, len(expected.Passives), len(actual.Passives)) + for idx, expectedPassive := range expected.Passives { + actualPassive := actual.Passives[idx] + assert.Equal(t, expectedPassive, string(actualPassive), "%v: expected passive %s, got %s", idx, expectedPassive, actualPassive) + } + + assert.Equal(t, expected.Primary, string(actual.Primary), + "%v: expected primary %s, got %s", idx, expected.Primary, actual.Primary) + assert.Equal(t, expected.SetName, actual.SetName, + "%v: expected set name %s, got %s", idx, expected.SetName, actual.SetName) + + // PossiblePrimary is only relevant to single-threaded drivers. + if expected.Type == "PossiblePrimary" { + expected.Type = "Unknown" + } + assert.Equal(t, expected.Type, actual.Kind.String(), + "%v: expected server kind %s, got %s", idx, expected.Type, actual.Kind.String()) +} + +func compareTopologyDescriptions(t *testing.T, + expected topologyDescription, actual description.Topology, idx int) { + t.Helper() + + assert.Equal(t, expected.TopologyType, actual.Kind.String(), + "%v: expected topology kind %s, got %s", idx, expected.TopologyType, actual.Kind.String()) + assert.Equal(t, len(expected.Servers), len(actual.Servers), + "%v: expected %d servers, got %d", idx, len(expected.Servers), len(actual.Servers)) + + for idx, es := range expected.Servers { + as := actual.Servers[idx] + compareServerDescriptions(t, es, as, idx) + } + + assert.Equal(t, expected.SetName, actual.SetName, + "%v: expected set name %s, got %s", idx, expected.SetName, actual.SetName) +} + +func compareEvents(t *testing.T, events []monitoringEvent) { + t.Helper() + + lock.Lock() + defer lock.Unlock() + + assert.Equal(t, len(events), len(publishedEvents), + "expected %d published events, got %d\n", + len(events), len(publishedEvents)) + + for idx, me := range events { + if me.TopologyOpeningEvent != nil { + actual, ok := publishedEvents[idx].(event.TopologyOpeningEvent) + assert.True(t, ok, "%v: expected type %T, got %T", idx, event.TopologyOpeningEvent{}, publishedEvents[idx]) + assert.False(t, primitive.ObjectID(actual.TopologyID).IsZero(), "%v: expected topology id", idx) + } + if me.ServerOpeningEvent != nil { + actual, ok := publishedEvents[idx].(event.ServerOpeningEvent) + assert.True(t, ok, "%v: expected type %T, got %T", idx, event.ServerOpeningEvent{}, publishedEvents[idx]) + + evt := me.ServerOpeningEvent + assert.Equal(t, evt.Address, string(actual.Address), + "%v: expected address %s, got %s", idx, evt.Address, actual.Address) + assert.False(t, primitive.ObjectID(actual.TopologyID).IsZero(), "%v: expected topology id", idx) + } + if me.TopologyDescriptionChangedEvent != nil { + actual, ok := publishedEvents[idx].(event.TopologyDescriptionChangedEvent) + assert.True(t, ok, "%v: expected type %T, got %T", idx, event.TopologyDescriptionChangedEvent{}, publishedEvents[idx]) + + evt := me.TopologyDescriptionChangedEvent + compareTopologyDescriptions(t, evt.PreviousDescription, actual.PreviousDescription, idx) + compareTopologyDescriptions(t, evt.NewDescription, actual.NewDescription, idx) + assert.False(t, primitive.ObjectID(actual.TopologyID).IsZero(), "%v: expected topology id", idx) + } + if me.ServerDescriptionChangedEvent != nil { + actual, ok := publishedEvents[idx].(event.ServerDescriptionChangedEvent) + assert.True(t, ok, "%v: expected type %T, got %T", idx, event.ServerDescriptionChangedEvent{}, publishedEvents[idx]) + + evt := me.ServerDescriptionChangedEvent + assert.Equal(t, evt.Address, string(actual.Address), + "%v: expected server address %s, got %s", idx, evt.Address, actual.Address) + compareServerDescriptions(t, evt.PreviousDescription, actual.PreviousDescription, idx) + compareServerDescriptions(t, evt.NewDescription, actual.NewDescription, idx) + assert.False(t, primitive.ObjectID(actual.TopologyID).IsZero(), "%v: expected topology id", idx) + } + if me.ServerClosedEvent != nil { + actual, ok := publishedEvents[idx].(event.ServerClosedEvent) + assert.True(t, ok, "%v: expected type %T, got %T", idx, event.ServerClosedEvent{}, publishedEvents[idx]) + + evt := me.ServerClosedEvent + assert.Equal(t, evt.Address, string(actual.Address), + "%v: expected server address %s, got %s", idx, evt.Address, actual.Address) + assert.False(t, primitive.ObjectID(actual.TopologyID).IsZero(), "%v: expected topology id", idx) + } + } +} + func runTest(t *testing.T, directory string, filename string) { filepath := path.Join(testsDir, directory, filename) content, err := ioutil.ReadFile(filepath) @@ -424,6 +617,13 @@ func runTest(t *testing.T, directory string, filename string) { for _, phase := range test.Phases { applyResponses(t, topo, phase.Responses, sub) applyErrors(t, topo, phase.ApplicationErrors) + + if phase.Outcome.Events != nil { + compareEvents(t, phase.Outcome.Events) + publishedEvents = nil + continue + } + publishedEvents = nil if phase.Outcome.Compatible == nil || *phase.Outcome.Compatible { assert.True(t, topo.fsm.compatible.Load().(bool), "Expected servers to be compatible") assert.Nil(t, topo.fsm.compatibilityErr, "expected fsm.compatiblity to be nil, got %v", @@ -495,7 +695,7 @@ func runTest(t *testing.T, directory string, filename string) { // Test case for all SDAM spec tests. func TestSDAMSpec(t *testing.T) { - for _, subdir := range []string{"single", "rs", "sharded", "errors"} { + for _, subdir := range []string{"single", "rs", "sharded", "errors", "monitoring"} { for _, file := range testhelpers.FindJSONFilesInDir(t, path.Join(testsDir, subdir)) { runTest(t, subdir, file) } diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index c2ef1a73f9..235ce2d3ef 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -184,6 +184,9 @@ func NewServer(addr address.Address, topologyID primitive.ObjectID, opts ...Serv if err != nil { return nil, err } + + s.publishServerOpeningEvent(s.address) + return s, nil } @@ -633,6 +636,7 @@ func (s *Server) createBaseOperation(conn driver.Connection) *operation.IsMaster func (s *Server) check() (description.Server, error) { var descPtr *description.Server var err error + var durationNanos int64 // Create a new connection if this is the first check, the connection was closed after an error during the previous // check, or the previous check was cancelled. @@ -643,6 +647,7 @@ func (s *Server) check() (description.Server, error) { // Use the description from the connection handshake as the value for this check. s.rttMonitor.addSample(s.conn.isMasterRTT) descPtr = &s.conn.desc + durationNanos = s.conn.isMasterRTT.Nanoseconds() } } @@ -653,12 +658,15 @@ func (s *Server) check() (description.Server, error) { heartbeatConn := initConnection{s.conn} baseOperation := s.createBaseOperation(heartbeatConn) previousDescription := s.Description() + streamable := previousDescription.TopologyVersion != nil + s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable) + start := time.Now() switch { case s.conn.getCurrentlyStreaming(): // The connection is already in a streaming state, so we stream the next response. err = baseOperation.StreamResponse(s.heartbeatCtx, heartbeatConn) - case previousDescription.TopologyVersion != nil: + case streamable: // The server supports the streamable protocol. Set the socket timeout to // connectTimeoutMS+heartbeatFrequencyMS and execute an awaitable isMaster request. Set conn.canStream so // the wire message will advertise streaming support to the server. @@ -684,15 +692,19 @@ func (s *Server) check() (description.Server, error) { s.conn.setSocketTimeout(s.cfg.heartbeatTimeout) err = baseOperation.Execute(s.heartbeatCtx) } + durationNanos = time.Since(start).Nanoseconds() + if err == nil { tempDesc := baseOperation.Result(s.address) descPtr = &tempDesc + s.publishServerHeartbeatSucceededEvent(s.conn.ID(), durationNanos, tempDesc, s.conn.getCurrentlyStreaming() || streamable) } else { // Close the connection here rather than below so we ensure we're not closing a connection that wasn't // successfully created. if s.conn != nil { _ = s.conn.close() } + s.publishServerHeartbeatFailedEvent(s.conn.ID(), durationNanos, err, s.conn.getCurrentlyStreaming() || streamable) } } @@ -781,6 +793,64 @@ func (ss *ServerSubscription) Unsubscribe() error { return nil } +// publishes a ServerOpeningEvent to indicate the server is being initialized +func (s *Server) publishServerOpeningEvent(addr address.Address) { + serverOpening := &event.ServerOpeningEvent{ + Address: addr, + TopologyID: s.topologyID, + } + + if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerOpening != nil { + s.cfg.serverMonitor.ServerOpening(serverOpening) + } +} + +// publishes a ServerHeartbeatStartedEvent to indicate an ismaster command has started +func (s *Server) publishServerHeartbeatStartedEvent(connectionID string, await bool) { + serverHeartbeatStarted := &event.ServerHeartbeatStartedEvent{ + ConnectionID: connectionID, + Awaited: await, + } + + if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatStarted != nil { + s.cfg.serverMonitor.ServerHeartbeatStarted(serverHeartbeatStarted) + } +} + +// publishes a ServerHeartbeatSucceededEvent to indicate ismaster has succeeded +func (s *Server) publishServerHeartbeatSucceededEvent(connectionID string, + durationNanos int64, + desc description.Server, + await bool) { + serverHeartbeatSucceeded := &event.ServerHeartbeatSucceededEvent{ + DurationNanos: durationNanos, + Reply: desc, + ConnectionID: connectionID, + Awaited: await, + } + + if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatSucceeded != nil { + s.cfg.serverMonitor.ServerHeartbeatSucceeded(serverHeartbeatSucceeded) + } +} + +// publishes a ServerHeartbeatFailedEvent to indicate ismaster has failed +func (s *Server) publishServerHeartbeatFailedEvent(connectionID string, + durationNanos int64, + err error, + await bool) { + serverHeartbeatFailed := &event.ServerHeartbeatFailedEvent{ + DurationNanos: durationNanos, + Failure: err, + ConnectionID: connectionID, + Awaited: await, + } + + if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatFailed != nil { + s.cfg.serverMonitor.ServerHeartbeatFailed(serverHeartbeatFailed) + } +} + // unwrapConnectionError returns the connection error wrapped by err, or nil if err does not wrap a connection error. func unwrapConnectionError(err error) error { // This is essentially an implementation of errors.As to unwrap this error until we get a ConnectionError and then diff --git a/x/mongo/driver/topology/server_options.go b/x/mongo/driver/topology/server_options.go index 6675e7a7ae..902bfbdd76 100644 --- a/x/mongo/driver/topology/server_options.go +++ b/x/mongo/driver/topology/server_options.go @@ -27,6 +27,7 @@ type serverConfig struct { maxConns uint64 minConns uint64 poolMonitor *event.PoolMonitor + serverMonitor *event.ServerMonitor connectionPoolMaxIdleTime time.Duration registry *bsoncodec.Registry monitoringDisabled bool @@ -138,6 +139,14 @@ func WithConnectionPoolMonitor(fn func(*event.PoolMonitor) *event.PoolMonitor) S } } +// WithServerMonitor configures the monitor for all SDAM events for a server +func WithServerMonitor(fn func(*event.ServerMonitor) *event.ServerMonitor) ServerOption { + return func(cfg *serverConfig) error { + cfg.serverMonitor = fn(cfg.serverMonitor) + return nil + } +} + // WithClock configures the ClusterClock for the server to use. func WithClock(fn func(clock *session.ClusterClock) *session.ClusterClock) ServerOption { return func(cfg *serverConfig) error { diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index 87adc77d8c..a20d95ea90 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -4,10 +4,13 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +// +build go1.13 + package topology import ( "context" + "errors" "net" "runtime" "sync" @@ -18,6 +21,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/testutil/assert" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" @@ -41,6 +45,7 @@ func (cncd *channelNetConnDialer) DialContext(_ context.Context, _, _ string) (n cnc := &drivertest.ChannelNetConn{ Written: make(chan []byte, 1), ReadResp: make(chan []byte, 2), + ReadErr: make(chan error, 1), } if err := cnc.AddResponse(makeIsMasterReply()); err != nil { return nil, err @@ -49,29 +54,6 @@ func (cncd *channelNetConnDialer) DialContext(_ context.Context, _, _ string) (n return cnc, nil } -type testHandshaker struct { - getDescription func(context.Context, address.Address, driver.Connection) (description.Server, error) - finishHandshake func(context.Context, driver.Connection) error -} - -// GetDescription implements the Handshaker interface. -func (th *testHandshaker) GetDescription(ctx context.Context, addr address.Address, conn driver.Connection) (description.Server, error) { - if th.getDescription != nil { - return th.getDescription(ctx, addr, conn) - } - return description.Server{}, nil -} - -// FinishHandshake implements the Handshaker interface. -func (th *testHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error { - if th.finishHandshake != nil { - return th.finishHandshake(ctx, conn) - } - return nil -} - -var _ driver.Handshaker = &testHandshaker{} - func TestServer(t *testing.T) { var serverTestTable = []struct { name string @@ -309,6 +291,97 @@ func TestServer(t *testing.T) { t.Fatal("client metadata not expected in heartbeat but found") } }) + t.Run("heartbeat monitoring", func(t *testing.T) { + var publishedEvents []interface{} + + serverHeartbeatStarted := func(e *event.ServerHeartbeatStartedEvent) { + publishedEvents = append(publishedEvents, *e) + } + + serverHeartbeatSucceeded := func(e *event.ServerHeartbeatSucceededEvent) { + publishedEvents = append(publishedEvents, *e) + } + + serverHeartbeatFailed := func(e *event.ServerHeartbeatFailedEvent) { + publishedEvents = append(publishedEvents, *e) + } + + sdam := &event.ServerMonitor{ + ServerHeartbeatStarted: serverHeartbeatStarted, + ServerHeartbeatSucceeded: serverHeartbeatSucceeded, + ServerHeartbeatFailed: serverHeartbeatFailed, + } + + dialer := &channelNetConnDialer{} + dialerOpt := WithDialer(func(Dialer) Dialer { + return dialer + }) + serverOpts := []ServerOption{ + WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption { + return append(connOpts, dialerOpt) + }), + withMonitoringDisabled(func(bool) bool { return true }), + WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return sdam }), + } + + s, err := NewServer(address.Address("localhost:27017"), primitive.NewObjectID(), serverOpts...) + if err != nil { + t.Fatalf("error from NewServer: %v", err) + } + + // set up heartbeat connection, which doesn't send events + _, err = s.check() + assert.Nil(t, err, "check error: %v", err) + + channelConn := s.conn.nc.(*drivertest.ChannelNetConn) + _ = channelConn.GetWrittenMessage() + + t.Run("success", func(t *testing.T) { + publishedEvents = nil + // do a heartbeat with a non-nil connection + if err = channelConn.AddResponse(makeIsMasterReply()); err != nil { + t.Fatalf("error adding response: %v", err) + } + _, err = s.check() + _ = channelConn.GetWrittenMessage() + assert.Nil(t, err, "check error: %v", err) + + assert.Equal(t, len(publishedEvents), 2, "expected %v events, got %v", 2, len(publishedEvents)) + + started, ok := publishedEvents[0].(event.ServerHeartbeatStartedEvent) + assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatStartedEvent{}, publishedEvents[0]) + assert.Equal(t, started.ConnectionID, s.conn.ID(), "expected connectionID to match") + assert.False(t, started.Awaited, "expected awaited to be false") + + succeeded, ok := publishedEvents[1].(event.ServerHeartbeatSucceededEvent) + assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatSucceededEvent{}, publishedEvents[1]) + assert.Equal(t, succeeded.ConnectionID, s.conn.ID(), "expected connectionID to match") + assert.Equal(t, succeeded.Reply.Addr, s.address, "expected address %v, got %v", s.address, succeeded.Reply.Addr) + assert.False(t, succeeded.Awaited, "expected awaited to be false") + }) + t.Run("failure", func(t *testing.T) { + publishedEvents = nil + // do a heartbeat with a non-nil connection + readErr := errors.New("error") + channelConn.ReadErr <- readErr + _, err = s.check() + _ = channelConn.GetWrittenMessage() + assert.Nil(t, err, "check error: %v", err) + + assert.Equal(t, len(publishedEvents), 2, "expected %v events, got %v", 2, len(publishedEvents)) + + started, ok := publishedEvents[0].(event.ServerHeartbeatStartedEvent) + assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatStartedEvent{}, publishedEvents[0]) + assert.Equal(t, started.ConnectionID, s.conn.ID(), "expected connectionID to match") + assert.False(t, started.Awaited, "expected awaited to be false") + + failed, ok := publishedEvents[1].(event.ServerHeartbeatFailedEvent) + assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatFailedEvent{}, publishedEvents[1]) + assert.Equal(t, failed.ConnectionID, s.conn.ID(), "expected connectionID to match") + assert.False(t, failed.Awaited, "expected awaited to be false") + assert.True(t, errors.Is(failed.Failure, readErr), "expected Failure to be %v, got: %v", readErr, failed.Failure) + }) + }) t.Run("WithServerAppName", func(t *testing.T) { name := "test" diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index b78d65cb1f..f2f87a0983 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -22,6 +22,7 @@ import ( "fmt" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -131,22 +132,12 @@ func New(opts ...Option) (*Topology, error) { return t.apply(context.TODO(), desc) } - // A replica set name sets the initial topology type to ReplicaSetNoPrimary unless a direct connection is also - // specified, in which case the initial type is Single. - if cfg.replicaSetName != "" { - t.fsm.SetName = cfg.replicaSetName - t.fsm.Kind = description.ReplicaSetNoPrimary - } - - // A direct connection unconditionally sets the topology type to Single. - if cfg.mode == SingleMode { - t.fsm.Kind = description.Single - } - if t.cfg.uri != "" { t.pollingRequired = strings.HasPrefix(t.cfg.uri, "mongodb+srv://") } + t.publishTopologyOpeningEvent() + return t, nil } @@ -160,9 +151,34 @@ func (t *Topology) Connect() error { t.desc.Store(description.Topology{}) var err error t.serversLock.Lock() + + // A replica set name sets the initial topology type to ReplicaSetNoPrimary unless a direct connection is also + // specified, in which case the initial type is Single. + if t.cfg.replicaSetName != "" { + t.fsm.SetName = t.cfg.replicaSetName + t.fsm.Kind = description.ReplicaSetNoPrimary + } + + // A direct connection unconditionally sets the topology type to Single. + if t.cfg.mode == SingleMode { + t.fsm.Kind = description.Single + } + + for _, a := range t.cfg.seedList { + addr := address.Address(a).Canonicalize() + t.fsm.Servers = append(t.fsm.Servers, description.NewDefaultServer(addr)) + } + + // store new description + newDesc := description.Topology{ + Kind: t.fsm.Kind, + Servers: t.fsm.Servers, + SessionTimeoutMinutes: t.fsm.SessionTimeoutMinutes, + } + t.desc.Store(newDesc) + t.publishTopologyDescriptionChangedEvent(description.Topology{}, t.fsm.Topology) for _, a := range t.cfg.seedList { addr := address.Address(a).Canonicalize() - t.fsm.Servers = append(t.fsm.Servers, description.Server{Addr: addr}) err = t.addServer(addr) if err != nil { return err @@ -198,6 +214,7 @@ func (t *Topology) Disconnect(ctx context.Context) error { for _, server := range servers { _ = server.Disconnect(ctx) + t.publishServerClosedEvent(server.address) } t.subLock.Lock() @@ -216,6 +233,7 @@ func (t *Topology) Disconnect(ctx context.Context) error { t.desc.Store(description.Topology{}) atomic.StoreInt32(&t.connectionstate, disconnected) + t.publishTopologyClosedEvent() return nil } @@ -545,6 +563,7 @@ func (t *Topology) processSRVResults(parsedHosts []string) bool { if t.serversClosed { return false } + prev := t.fsm.Topology diff := t.fsm.Topology.DiffHostlist(parsedHosts) if len(diff.Added) == 0 && len(diff.Removed) == 0 { @@ -564,6 +583,7 @@ func (t *Topology) processSRVResults(parsedHosts []string) bool { }() delete(t.servers, addr) t.fsm.removeServerByAddr(addr) + t.publishServerClosedEvent(s.address) } for _, a := range diff.Added { addr := address.Address(a).Canonicalize() @@ -578,6 +598,10 @@ func (t *Topology) processSRVResults(parsedHosts []string) bool { } t.desc.Store(newDesc) + if !prev.Equal(newDesc) { + t.publishTopologyDescriptionChangedEvent(prev, newDesc) + } + t.subLock.Lock() for _, ch := range t.subscribers { // We drain the description if there's one in the channel @@ -617,6 +641,10 @@ func (t *Topology) apply(ctx context.Context, desc description.Server) descripti return desc } + if !oldDesc.Equal(desc) { + t.publishServerDescriptionChangedEvent(oldDesc, desc) + } + diff := description.DiffTopology(prev, current) for _, removed := range diff.Removed { @@ -627,6 +655,7 @@ func (t *Topology) apply(ctx context.Context, desc description.Server) descripti _ = s.Disconnect(cancelCtx) }() delete(t.servers, removed.Addr) + t.publishServerClosedEvent(s.address) } } @@ -635,6 +664,9 @@ func (t *Topology) apply(ctx context.Context, desc description.Server) descripti } t.desc.Store(current) + if !prev.Equal(current) { + t.publishTopologyDescriptionChangedEvent(prev, current) + } t.subLock.Lock() for _, ch := range t.subscribers { @@ -677,3 +709,64 @@ func (t *Topology) String() string { } return fmt.Sprintf("Type: %s, Servers: [%s]", desc.Kind, serversStr) } + +// publishes a ServerDescriptionChangedEvent to indicate the server description has changed +func (t *Topology) publishServerDescriptionChangedEvent(prev description.Server, current description.Server) { + serverDescriptionChanged := &event.ServerDescriptionChangedEvent{ + Address: current.Addr, + TopologyID: t.id, + PreviousDescription: prev, + NewDescription: current, + } + + if t.cfg.serverMonitor != nil && t.cfg.serverMonitor.ServerDescriptionChanged != nil { + t.cfg.serverMonitor.ServerDescriptionChanged(serverDescriptionChanged) + } +} + +// publishes a ServerClosedEvent to indicate the server has closed +func (t *Topology) publishServerClosedEvent(addr address.Address) { + serverClosed := &event.ServerClosedEvent{ + Address: addr, + TopologyID: t.id, + } + + if t.cfg.serverMonitor != nil && t.cfg.serverMonitor.ServerClosed != nil { + t.cfg.serverMonitor.ServerClosed(serverClosed) + } +} + +// publishes a TopologyDescriptionChangedEvent to indicate the topology description has changed +func (t *Topology) publishTopologyDescriptionChangedEvent(prev description.Topology, current description.Topology) { + topologyDescriptionChanged := &event.TopologyDescriptionChangedEvent{ + TopologyID: t.id, + PreviousDescription: prev, + NewDescription: current, + } + + if t.cfg.serverMonitor != nil && t.cfg.serverMonitor.TopologyDescriptionChanged != nil { + t.cfg.serverMonitor.TopologyDescriptionChanged(topologyDescriptionChanged) + } +} + +// publishes a TopologyOpeningEvent to indicate the topology is being initialized +func (t *Topology) publishTopologyOpeningEvent() { + topologyOpening := &event.TopologyOpeningEvent{ + TopologyID: t.id, + } + + if t.cfg.serverMonitor != nil && t.cfg.serverMonitor.TopologyOpening != nil { + t.cfg.serverMonitor.TopologyOpening(topologyOpening) + } +} + +// publishes a TopologyClosedEvent to indicate the topology has been closed +func (t *Topology) publishTopologyClosedEvent() { + topologyClosed := &event.TopologyClosedEvent{ + TopologyID: t.id, + } + + if t.cfg.serverMonitor != nil && t.cfg.serverMonitor.TopologyClosed != nil { + t.cfg.serverMonitor.TopologyClosed(topologyClosed) + } +} diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index b58f7f18bc..c40f5510cc 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -17,6 +17,7 @@ import ( "strings" "time" + "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" @@ -34,6 +35,7 @@ type config struct { cs connstring.ConnString // This must not be used for any logic in topology.Topology. uri string serverSelectionTimeout time.Duration + serverMonitor *event.ServerMonitor } func newConfig(opts ...Option) (*config, error) { @@ -274,6 +276,14 @@ func WithServerSelectionTimeout(fn func(time.Duration) time.Duration) Option { } } +// WithTopologyServerMonitor configures the monitor for all SDAM events +func WithTopologyServerMonitor(fn func(*event.ServerMonitor) *event.ServerMonitor) Option { + return func(cfg *config) error { + cfg.serverMonitor = fn(cfg.serverMonitor) + return nil + } +} + // WithURI specifies the URI that was used to create the topology. func WithURI(fn func(string) string) Option { return func(cfg *config) error {