Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Aug 9, 2024
1 parent 928555a commit 4dcc585
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 122 deletions.
284 changes: 180 additions & 104 deletions mongo/integration/cmd_monitoring_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,21 +270,30 @@ func checkExpectations(mt *mtest.T, expectations *[]*expectation, id0, id1 bson.
return
}

for idx, expectation := range *expectations {
var err error
startedEvents := make([]*cmdStartedEvt, 0, len(*expectations))
succeededEvents := make([]*cmdSucceededEvt, 0, len(*expectations))
failedEvents := make([]*cmdFailedEvt, 0, len(*expectations))

for _, expectation := range *expectations {
if expectation.CommandStartedEvent != nil {
err = compareStartedEvent(mt, expectation, id0, id1)
startedEvents = append(startedEvents, expectation.CommandStartedEvent)
}
if expectation.CommandSucceededEvent != nil {
err = compareSucceededEvent(mt, expectation)
succeededEvents = append(succeededEvents, expectation.CommandSucceededEvent)
}
if expectation.CommandFailedEvent != nil {
err = compareFailedEvent(mt, expectation)
failedEvents = append(failedEvents, expectation.CommandFailedEvent)
}

assert.Nil(mt, err, "expectation comparison error at index %v: %s", idx, err)
}

var err error
err = compareStartedEvents(mt, startedEvents, id0, id1)
assert.Nil(mt, err, "expectation comparison %s", err)
err = compareSucceededEvents(mt, succeededEvents)
assert.Nil(mt, err, "expectation comparison %s", err)
err = compareFailedEvents(mt, failedEvents)
assert.Nil(mt, err, "expectation comparison %s", err)

}

// newMatchError appends `expected` and `actual` BSON data to an error.
Expand All @@ -298,83 +307,104 @@ func newMatchError(mt *mtest.T, expected bson.Raw, actual bson.Raw, format strin
return fmt.Errorf("%s\nExpected %s\nGot: %s", msg, string(expectedJSON), string(actualJSON))
}

func compareStartedEvent(mt *mtest.T, expectation *expectation, id0, id1 bson.Raw) error {
func compareStartedEvents(mt *mtest.T, expectations []*cmdStartedEvt, id0, id1 bson.Raw) error {
mt.Helper()

expected := expectation.CommandStartedEvent

if len(expected.Extra) > 0 {
return fmt.Errorf("unrecognized fields for CommandStartedEvent: %v", expected.Extra)
}

evt := mt.GetStartedEvent()
if evt == nil {
return errors.New("expected CommandStartedEvent, got nil")
}

if expected.CommandName != "" && expected.CommandName != evt.CommandName {
return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName)
}
if expected.DatabaseName != "" && expected.DatabaseName != evt.DatabaseName {
return fmt.Errorf("database name mismatch; expected %s, got %s", expected.DatabaseName, evt.DatabaseName)
}

eElems, err := expected.Command.Elements()
if err != nil {
return fmt.Errorf("error getting expected command elements: %s", err)
expectedCmds := make(map[string]bool)
for _, expected := range expectations {
expectedCmds[expected.CommandName] = true
}

for _, elem := range eElems {
key := elem.Key()
val := elem.Value()

actualVal, err := evt.Command.LookupErr(key)
compare := func(expected *cmdStartedEvt) error {
if len(expected.Extra) > 0 {
return fmt.Errorf("unrecognized fields for CommandStartedEvent: %v", expected.Extra)
}

// Keys that may be nil
if val.Type == bson.TypeNull {
// Expected value is BSON null. Expect the actual field to be omitted.
if errors.Is(err, bsoncore.ErrElementNotFound) {
continue
var evt *event.CommandStartedEvent
// skip events not in expectations
for {
evt = mt.GetStartedEvent()
if evt == nil {
return fmt.Errorf("expected CommandStartedEvent %s, got nil", expected.CommandName)
}
if err != nil {
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got error: %v", key, err)
if expected.CommandName == "" {
break
} else if v, ok := expectedCmds[evt.CommandName]; ok && v {
break
}
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got %q", key, actualVal)
}
assert.Nil(mt, err, "expected command to contain key %q", key)

if key == "batchSize" {
// Some command monitoring tests expect that the driver will send a lower batch size if the required batch
// size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
// versions do not support the limit option, but not for 3.2+. We've already validated that the command
// contains a batchSize field above and we can skip the actual value comparison below.
continue
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
return fmt.Errorf("command name mismatch for started event; expected %s, got %s", expected.CommandName, evt.CommandName)
}
if expected.DatabaseName != "" && expected.DatabaseName != evt.DatabaseName {
return fmt.Errorf("database name mismatch; expected %s, got %s", expected.DatabaseName, evt.DatabaseName)
}

switch key {
case "lsid":
sessName := val.StringValue()
var expectedID bson.Raw
actualID := actualVal.Document()
eElems, err := expected.Command.Elements()
if err != nil {
return fmt.Errorf("error getting expected command elements: %s", err)
}

switch sessName {
case "session0":
expectedID = id0
case "session1":
expectedID = id1
default:
return newMatchError(mt, expected.Command, evt.Command, "unrecognized session identifier in command document: %s", sessName)
for _, elem := range eElems {
key := elem.Key()
val := elem.Value()

actualVal, err := evt.Command.LookupErr(key)

// Keys that may be nil
if val.Type == bson.TypeNull {
// Expected value is BSON null. Expect the actual field to be omitted.
if errors.Is(err, bsoncore.ErrElementNotFound) {
continue
}
if err != nil {
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got error: %v", key, err)
}
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got %q", key, actualVal)
}
assert.Nil(mt, err, "expected command to contain key %q", key)

if !bytes.Equal(expectedID, actualID) {
return newMatchError(mt, expected.Command, evt.Command, "session ID mismatch for session %s; expected %s, got %s", sessName, expectedID,
actualID)
if key == "batchSize" {
// Some command monitoring tests expect that the driver will send a lower batch size if the required batch
// size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
// versions do not support the limit option, but not for 3.2+. We've already validated that the command
// contains a batchSize field above and we can skip the actual value comparison below.
continue
}
default:
if err := compareValues(mt, key, val, actualVal); err != nil {
return newMatchError(mt, expected.Command, evt.Command, "%s", err)

switch key {
case "lsid":
sessName := val.StringValue()
var expectedID bson.Raw
actualID := actualVal.Document()

switch sessName {
case "session0":
expectedID = id0
case "session1":
expectedID = id1
default:
return newMatchError(mt, expected.Command, evt.Command, "unrecognized session identifier in command document: %s", sessName)
}

if !bytes.Equal(expectedID, actualID) {
return newMatchError(mt, expected.Command, evt.Command, "session ID mismatch for session %s; expected %s, got %s", sessName, expectedID,
actualID)
}
default:
if err := compareValues(mt, key, val, actualVal); err != nil {
return newMatchError(mt, expected.Command, evt.Command, "%s", err)
}
}
}
return nil
}
for idx, expected := range expectations {
err := compare(expected)
if err != nil {
return fmt.Errorf("error at index %d: %s", idx, err)
}
}
return nil
}
Expand Down Expand Up @@ -416,60 +446,106 @@ func compareWriteErrors(mt *mtest.T, expected, actual bson.Raw) error {
return nil
}

func compareSucceededEvent(mt *mtest.T, expectation *expectation) error {
func compareSucceededEvents(mt *mtest.T, expectations []*cmdSucceededEvt) error {
mt.Helper()

expected := expectation.CommandSucceededEvent
if len(expected.Extra) > 0 {
return fmt.Errorf("unrecognized fields for CommandSucceededEvent: %v", expected.Extra)
}
evt := mt.GetSucceededEvent()
if evt == nil {
return errors.New("expected CommandSucceededEvent, got nil")
expectedCmds := make(map[string]bool)
for _, expected := range expectations {
expectedCmds[expected.CommandName] = true
}

if expected.CommandName != "" && expected.CommandName != evt.CommandName {
return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName)
}
compare := func(expected *cmdSucceededEvt) error {
if len(expected.Extra) > 0 {
return fmt.Errorf("unrecognized fields for CommandSucceededEvent: %v", expected.Extra)
}

eElems, err := expected.Reply.Elements()
if err != nil {
return fmt.Errorf("error getting expected reply elements: %s", err)
}
var evt *event.CommandSucceededEvent
// skip events not in expectations
for {
evt = mt.GetSucceededEvent()
if evt == nil {
return fmt.Errorf("expected CommandSucceededEvent %s, got nil", expected.CommandName)
}
if expected.CommandName == "" {
break
} else if v, ok := expectedCmds[evt.CommandName]; ok && v {
break
}
}

for _, elem := range eElems {
key := elem.Key()
val := elem.Value()
actualVal := evt.Reply.Lookup(key)
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
return fmt.Errorf("command name mismatch for succeeded event; expected %s, got %s", expected.CommandName, evt.CommandName)
}

switch key {
case "writeErrors":
if err = compareWriteErrors(mt, val.Array(), actualVal.Array()); err != nil {
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
}
default:
if err := compareValues(mt, key, val, actualVal); err != nil {
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
eElems, err := expected.Reply.Elements()
if err != nil {
return fmt.Errorf("error getting expected reply elements: %s", err)
}

for _, elem := range eElems {
key := elem.Key()
val := elem.Value()
actualVal := evt.Reply.Lookup(key)

switch key {
case "writeErrors":
if err = compareWriteErrors(mt, val.Array(), actualVal.Array()); err != nil {
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
}
default:
if err := compareValues(mt, key, val, actualVal); err != nil {
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
}
}
}
return nil
}
for idx, expected := range expectations {
err := compare(expected)
if err != nil {
return fmt.Errorf("error at index %d: %s", idx, err)
}
}
return nil
}

func compareFailedEvent(mt *mtest.T, expectation *expectation) error {
func compareFailedEvents(mt *mtest.T, expectations []*cmdFailedEvt) error {
mt.Helper()

expected := expectation.CommandFailedEvent
if len(expected.Extra) > 0 {
return fmt.Errorf("unrecognized fields for CommandFailedEvent: %v", expected.Extra)
}
evt := mt.GetFailedEvent()
if evt == nil {
return errors.New("expected CommandFailedEvent, got nil")
expectedCmds := make(map[string]bool)
for _, expected := range expectations {
expectedCmds[expected.CommandName] = true
}

if expected.CommandName != "" && expected.CommandName != evt.CommandName {
return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName)
compare := func(expected *cmdFailedEvt) error {
if len(expected.Extra) > 0 {
return fmt.Errorf("unrecognized fields for CommandFailedEvent: %v", expected.Extra)
}

var evt *event.CommandFailedEvent
// skip events not in expectations
for {
evt = mt.GetFailedEvent()
if evt == nil {
return fmt.Errorf("expected CommandFailedEvent %s, got nil", expected.CommandName)
}
if expected.CommandName == "" {
break
} else if v, ok := expectedCmds[evt.CommandName]; ok && v {
break
}
}

if expected.CommandName != "" && expected.CommandName != evt.CommandName {
return fmt.Errorf("command name mismatch for failed event; expected %s, got %s", expected.CommandName, evt.CommandName)
}
return nil
}
for idx, expected := range expectations {
err := compare(expected)
if err != nil {
return fmt.Errorf("error at index %d: %s", idx, err)
}
}
return nil
}
8 changes: 5 additions & 3 deletions mongo/integration/mtest/mongotest.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,9 +616,11 @@ func (t *T) ClearFailPoints() {
if err != nil {
t.Fatalf("error clearing fail point %s: %v", fp.name, err)
}
if fp.client != t.Client {
_ = fp.client.Disconnect(context.Background())
t.fpClients[fp.client] = false
t.fpClients[fp.client] = false
}
for client, active := range t.fpClients {
if !active && client != t.Client {
_ = client.Disconnect(context.Background())
}
}
t.failPoints = t.failPoints[:0]
Expand Down
Loading

0 comments on commit 4dcc585

Please sign in to comment.