diff --git a/mongo/integration/cmd_monitoring_helpers_test.go b/mongo/integration/cmd_monitoring_helpers_test.go index ef821d9e9a8..1869c02b535 100644 --- a/mongo/integration/cmd_monitoring_helpers_test.go +++ b/mongo/integration/cmd_monitoring_helpers_test.go @@ -270,21 +270,30 @@ func checkExpectations(mt *mtest.T, expectations *[]*expectation, id0, id1 bson. return } - for idx, expectation := range *expectations { - var err error + startedEvents := make([]*cmdStartedEvt, 0, len(*expectations)) + succeededEvents := make([]*cmdSucceededEvt, 0, len(*expectations)) + failedEvents := make([]*cmdFailedEvt, 0, len(*expectations)) + for _, expectation := range *expectations { if expectation.CommandStartedEvent != nil { - err = compareStartedEvent(mt, expectation, id0, id1) + startedEvents = append(startedEvents, expectation.CommandStartedEvent) } if expectation.CommandSucceededEvent != nil { - err = compareSucceededEvent(mt, expectation) + succeededEvents = append(succeededEvents, expectation.CommandSucceededEvent) } if expectation.CommandFailedEvent != nil { - err = compareFailedEvent(mt, expectation) + failedEvents = append(failedEvents, expectation.CommandFailedEvent) } - - assert.Nil(mt, err, "expectation comparison error at index %v: %s", idx, err) } + + var err error + err = compareStartedEvents(mt, startedEvents, id0, id1) + assert.Nil(mt, err, "expectation comparison %s", err) + err = compareSucceededEvents(mt, succeededEvents) + assert.Nil(mt, err, "expectation comparison %s", err) + err = compareFailedEvents(mt, failedEvents) + assert.Nil(mt, err, "expectation comparison %s", err) + } // newMatchError appends `expected` and `actual` BSON data to an error. @@ -298,83 +307,104 @@ func newMatchError(mt *mtest.T, expected bson.Raw, actual bson.Raw, format strin return fmt.Errorf("%s\nExpected %s\nGot: %s", msg, string(expectedJSON), string(actualJSON)) } -func compareStartedEvent(mt *mtest.T, expectation *expectation, id0, id1 bson.Raw) error { +func compareStartedEvents(mt *mtest.T, expectations []*cmdStartedEvt, id0, id1 bson.Raw) error { mt.Helper() - expected := expectation.CommandStartedEvent - - if len(expected.Extra) > 0 { - return fmt.Errorf("unrecognized fields for CommandStartedEvent: %v", expected.Extra) - } - - evt := mt.GetStartedEvent() - if evt == nil { - return errors.New("expected CommandStartedEvent, got nil") - } - - if expected.CommandName != "" && expected.CommandName != evt.CommandName { - return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName) - } - if expected.DatabaseName != "" && expected.DatabaseName != evt.DatabaseName { - return fmt.Errorf("database name mismatch; expected %s, got %s", expected.DatabaseName, evt.DatabaseName) - } - - eElems, err := expected.Command.Elements() - if err != nil { - return fmt.Errorf("error getting expected command elements: %s", err) + expectedCmds := make(map[string]bool) + for _, expected := range expectations { + expectedCmds[expected.CommandName] = true } - for _, elem := range eElems { - key := elem.Key() - val := elem.Value() - - actualVal, err := evt.Command.LookupErr(key) + compare := func(expected *cmdStartedEvt) error { + if len(expected.Extra) > 0 { + return fmt.Errorf("unrecognized fields for CommandStartedEvent: %v", expected.Extra) + } - // Keys that may be nil - if val.Type == bson.TypeNull { - // Expected value is BSON null. Expect the actual field to be omitted. - if errors.Is(err, bsoncore.ErrElementNotFound) { - continue + var evt *event.CommandStartedEvent + // skip events not in expectations + for { + evt = mt.GetStartedEvent() + if evt == nil { + return fmt.Errorf("expected CommandStartedEvent %s, got nil", expected.CommandName) } - if err != nil { - return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got error: %v", key, err) + if expected.CommandName == "" { + break + } else if v, ok := expectedCmds[evt.CommandName]; ok && v { + break } - return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got %q", key, actualVal) } - assert.Nil(mt, err, "expected command to contain key %q", key) - if key == "batchSize" { - // Some command monitoring tests expect that the driver will send a lower batch size if the required batch - // size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server - // versions do not support the limit option, but not for 3.2+. We've already validated that the command - // contains a batchSize field above and we can skip the actual value comparison below. - continue + if expected.CommandName != "" && expected.CommandName != evt.CommandName { + return fmt.Errorf("command name mismatch for started event; expected %s, got %s", expected.CommandName, evt.CommandName) + } + if expected.DatabaseName != "" && expected.DatabaseName != evt.DatabaseName { + return fmt.Errorf("database name mismatch; expected %s, got %s", expected.DatabaseName, evt.DatabaseName) } - switch key { - case "lsid": - sessName := val.StringValue() - var expectedID bson.Raw - actualID := actualVal.Document() + eElems, err := expected.Command.Elements() + if err != nil { + return fmt.Errorf("error getting expected command elements: %s", err) + } - switch sessName { - case "session0": - expectedID = id0 - case "session1": - expectedID = id1 - default: - return newMatchError(mt, expected.Command, evt.Command, "unrecognized session identifier in command document: %s", sessName) + for _, elem := range eElems { + key := elem.Key() + val := elem.Value() + + actualVal, err := evt.Command.LookupErr(key) + + // Keys that may be nil + if val.Type == bson.TypeNull { + // Expected value is BSON null. Expect the actual field to be omitted. + if errors.Is(err, bsoncore.ErrElementNotFound) { + continue + } + if err != nil { + return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got error: %v", key, err) + } + return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got %q", key, actualVal) } + assert.Nil(mt, err, "expected command to contain key %q", key) - if !bytes.Equal(expectedID, actualID) { - return newMatchError(mt, expected.Command, evt.Command, "session ID mismatch for session %s; expected %s, got %s", sessName, expectedID, - actualID) + if key == "batchSize" { + // Some command monitoring tests expect that the driver will send a lower batch size if the required batch + // size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server + // versions do not support the limit option, but not for 3.2+. We've already validated that the command + // contains a batchSize field above and we can skip the actual value comparison below. + continue } - default: - if err := compareValues(mt, key, val, actualVal); err != nil { - return newMatchError(mt, expected.Command, evt.Command, "%s", err) + + switch key { + case "lsid": + sessName := val.StringValue() + var expectedID bson.Raw + actualID := actualVal.Document() + + switch sessName { + case "session0": + expectedID = id0 + case "session1": + expectedID = id1 + default: + return newMatchError(mt, expected.Command, evt.Command, "unrecognized session identifier in command document: %s", sessName) + } + + if !bytes.Equal(expectedID, actualID) { + return newMatchError(mt, expected.Command, evt.Command, "session ID mismatch for session %s; expected %s, got %s", sessName, expectedID, + actualID) + } + default: + if err := compareValues(mt, key, val, actualVal); err != nil { + return newMatchError(mt, expected.Command, evt.Command, "%s", err) + } } } + return nil + } + for idx, expected := range expectations { + err := compare(expected) + if err != nil { + return fmt.Errorf("error at index %d: %s", idx, err) + } } return nil } @@ -416,60 +446,106 @@ func compareWriteErrors(mt *mtest.T, expected, actual bson.Raw) error { return nil } -func compareSucceededEvent(mt *mtest.T, expectation *expectation) error { +func compareSucceededEvents(mt *mtest.T, expectations []*cmdSucceededEvt) error { mt.Helper() - expected := expectation.CommandSucceededEvent - if len(expected.Extra) > 0 { - return fmt.Errorf("unrecognized fields for CommandSucceededEvent: %v", expected.Extra) - } - evt := mt.GetSucceededEvent() - if evt == nil { - return errors.New("expected CommandSucceededEvent, got nil") + expectedCmds := make(map[string]bool) + for _, expected := range expectations { + expectedCmds[expected.CommandName] = true } - if expected.CommandName != "" && expected.CommandName != evt.CommandName { - return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName) - } + compare := func(expected *cmdSucceededEvt) error { + if len(expected.Extra) > 0 { + return fmt.Errorf("unrecognized fields for CommandSucceededEvent: %v", expected.Extra) + } - eElems, err := expected.Reply.Elements() - if err != nil { - return fmt.Errorf("error getting expected reply elements: %s", err) - } + var evt *event.CommandSucceededEvent + // skip events not in expectations + for { + evt = mt.GetSucceededEvent() + if evt == nil { + return fmt.Errorf("expected CommandSucceededEvent %s, got nil", expected.CommandName) + } + if expected.CommandName == "" { + break + } else if v, ok := expectedCmds[evt.CommandName]; ok && v { + break + } + } - for _, elem := range eElems { - key := elem.Key() - val := elem.Value() - actualVal := evt.Reply.Lookup(key) + if expected.CommandName != "" && expected.CommandName != evt.CommandName { + return fmt.Errorf("command name mismatch for succeeded event; expected %s, got %s", expected.CommandName, evt.CommandName) + } - switch key { - case "writeErrors": - if err = compareWriteErrors(mt, val.Array(), actualVal.Array()); err != nil { - return newMatchError(mt, expected.Reply, evt.Reply, "%s", err) - } - default: - if err := compareValues(mt, key, val, actualVal); err != nil { - return newMatchError(mt, expected.Reply, evt.Reply, "%s", err) + eElems, err := expected.Reply.Elements() + if err != nil { + return fmt.Errorf("error getting expected reply elements: %s", err) + } + + for _, elem := range eElems { + key := elem.Key() + val := elem.Value() + actualVal := evt.Reply.Lookup(key) + + switch key { + case "writeErrors": + if err = compareWriteErrors(mt, val.Array(), actualVal.Array()); err != nil { + return newMatchError(mt, expected.Reply, evt.Reply, "%s", err) + } + default: + if err := compareValues(mt, key, val, actualVal); err != nil { + return newMatchError(mt, expected.Reply, evt.Reply, "%s", err) + } } } + return nil + } + for idx, expected := range expectations { + err := compare(expected) + if err != nil { + return fmt.Errorf("error at index %d: %s", idx, err) + } } return nil } -func compareFailedEvent(mt *mtest.T, expectation *expectation) error { +func compareFailedEvents(mt *mtest.T, expectations []*cmdFailedEvt) error { mt.Helper() - expected := expectation.CommandFailedEvent - if len(expected.Extra) > 0 { - return fmt.Errorf("unrecognized fields for CommandFailedEvent: %v", expected.Extra) - } - evt := mt.GetFailedEvent() - if evt == nil { - return errors.New("expected CommandFailedEvent, got nil") + expectedCmds := make(map[string]bool) + for _, expected := range expectations { + expectedCmds[expected.CommandName] = true } - if expected.CommandName != "" && expected.CommandName != evt.CommandName { - return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName) + compare := func(expected *cmdFailedEvt) error { + if len(expected.Extra) > 0 { + return fmt.Errorf("unrecognized fields for CommandFailedEvent: %v", expected.Extra) + } + + var evt *event.CommandFailedEvent + // skip events not in expectations + for { + evt = mt.GetFailedEvent() + if evt == nil { + return fmt.Errorf("expected CommandFailedEvent %s, got nil", expected.CommandName) + } + if expected.CommandName == "" { + break + } else if v, ok := expectedCmds[evt.CommandName]; ok && v { + break + } + } + + if expected.CommandName != "" && expected.CommandName != evt.CommandName { + return fmt.Errorf("command name mismatch for failed event; expected %s, got %s", expected.CommandName, evt.CommandName) + } + return nil + } + for idx, expected := range expectations { + err := compare(expected) + if err != nil { + return fmt.Errorf("error at index %d: %s", idx, err) + } } return nil } diff --git a/mongo/integration/mtest/mongotest.go b/mongo/integration/mtest/mongotest.go index 7fd40890d39..70c0f0333db 100644 --- a/mongo/integration/mtest/mongotest.go +++ b/mongo/integration/mtest/mongotest.go @@ -616,9 +616,11 @@ func (t *T) ClearFailPoints() { if err != nil { t.Fatalf("error clearing fail point %s: %v", fp.name, err) } - if fp.client != t.Client { - _ = fp.client.Disconnect(context.Background()) - t.fpClients[fp.client] = false + t.fpClients[fp.client] = false + } + for client, active := range t.fpClients { + if !active && client != t.Client { + _ = client.Disconnect(context.Background()) } } t.failPoints = t.failPoints[:0] diff --git a/mongo/integration/unified_spec_test.go b/mongo/integration/unified_spec_test.go index df2d705132b..857abfbee0e 100644 --- a/mongo/integration/unified_spec_test.go +++ b/mongo/integration/unified_spec_test.go @@ -141,21 +141,27 @@ type operation struct { } type expectation struct { - CommandStartedEvent *struct { - CommandName string `bson:"command_name"` - DatabaseName string `bson:"database_name"` - Command bson.Raw `bson:"command"` - Extra map[string]interface{} `bson:",inline"` - } `bson:"command_started_event"` - CommandSucceededEvent *struct { - CommandName string `bson:"command_name"` - Reply bson.Raw `bson:"reply"` - Extra map[string]interface{} `bson:",inline"` - } `bson:"command_succeeded_event"` - CommandFailedEvent *struct { - CommandName string `bson:"command_name"` - Extra map[string]interface{} `bson:",inline"` - } `bson:"command_failed_event"` + CommandStartedEvent *cmdStartedEvt `bson:"command_started_event"` + CommandSucceededEvent *cmdSucceededEvt `bson:"command_succeeded_event"` + CommandFailedEvent *cmdFailedEvt `bson:"command_failed_event"` +} + +type cmdStartedEvt struct { + CommandName string `bson:"command_name"` + DatabaseName string `bson:"database_name"` + Command bson.Raw `bson:"command"` + Extra map[string]interface{} `bson:",inline"` +} + +type cmdSucceededEvt struct { + CommandName string `bson:"command_name"` + Reply bson.Raw `bson:"reply"` + Extra map[string]interface{} `bson:",inline"` +} + +type cmdFailedEvt struct { + CommandName string `bson:"command_name"` + Extra map[string]interface{} `bson:",inline"` } type outcome struct {