diff --git a/client_test.go b/client_test.go index 0c44e588..bd0f5487 100644 --- a/client_test.go +++ b/client_test.go @@ -4,21 +4,19 @@ import ( "context" "testing" - "github.com/pascaldekloe/goe/verify" - "github.com/gopcua/opcua/id" "github.com/gopcua/opcua/ua" + "github.com/stretchr/testify/require" ) func TestClient_Send_DoesNotPanicWhenDisconnected(t *testing.T) { c, err := NewClient("opc.tcp://example.com:4840") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") + err = c.Send(context.Background(), &ua.ReadRequest{}, func(i ua.Response) error { return nil }) - verify.Values(t, "", err, ua.StatusBadServerNotConnected) + require.Equal(t, ua.StatusBadServerNotConnected, err) } func TestCloneReadRequest(t *testing.T) { @@ -100,7 +98,7 @@ func TestCloneReadRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := cloneReadRequest(tt.req) - verify.Values(t, "", got, tt.want) + require.Equal(t, tt.want, got) }) } } @@ -181,7 +179,7 @@ func TestCloneBrowseRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := cloneBrowseRequest(tt.req) - verify.Values(t, "", got, tt.want) + require.Equal(t, tt.want, got) }) } } diff --git a/config_test.go b/config_test.go index a427efb2..d9064790 100644 --- a/config_test.go +++ b/config_test.go @@ -12,12 +12,11 @@ import ( "testing" "time" - "github.com/pascaldekloe/goe/verify" - "github.com/gopcua/opcua/ua" "github.com/gopcua/opcua/uacp" "github.com/gopcua/opcua/uapolicy" "github.com/gopcua/opcua/uasc" + "github.com/stretchr/testify/require" ) // test certificate generated with @@ -129,11 +128,7 @@ func TestOptions(t *testing.T) { randomRequestID = func() uint32 { return 125 } defer func() { randomRequestID = nil }() - d, err := os.MkdirTemp("", "gopcua") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(d) + d := t.TempDir() var ( certDERFile = filepath.Join(d, "cert.der") @@ -152,18 +147,17 @@ func TestOptions(t *testing.T) { } } - if err := os.WriteFile(certDERFile, certDER, 0644); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(certPEMFile, certPEM, 0644); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(keyDERFile, keyDER, 0644); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(keyPEMFile, keyPEM, 0644); err != nil { - t.Fatal(err) - } + err := os.WriteFile(certDERFile, certDER, 0644) + require.NoError(t, err, "WriteFile(certDERFile) failed") + + err = os.WriteFile(certPEMFile, certPEM, 0644) + require.NoError(t, err, "WriteFile(certPEMFile) failed") + + err = os.WriteFile(keyDERFile, keyDER, 0644) + require.NoError(t, err, "WriteFile(keyDERFile) failed") + + err = os.WriteFile(keyPEMFile, keyPEM, 0644) + require.NoError(t, err, "WriteFile(keyPEMFile) failed") defer os.Remove(keyPEMFile) tests := []struct { @@ -814,15 +808,10 @@ func TestOptions(t *testing.T) { cfg, err := ApplyConfig(tt.opt) if got, want := errstr(err), errstr(tt.err); got != "" || want != "" { - if got != want { - t.Fatalf("got error %q want %q", got, want) - } + require.Equal(t, want, got, "got error %q want %q", got, want) return } - if !verify.Values(t, "", cfg, tt.cfg) { - t.Logf("got %#v", cfg) - t.Logf("want %#v", tt.cfg) - } + require.Equal(t, tt.cfg, cfg) }) } } diff --git a/errors/errors_test.go b/errors/errors_test.go index 926798b2..d5a4f439 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -1,20 +1,23 @@ package errors -import "testing" +import ( + "errors" + "testing" -func TestErrors(t *testing.T) { - err := Errorf("hello %s", "world") - if err.Error() != "opcua: hello world" { - t.Fatalf("got %s, wanted %s", err.Error(), "opcua: hello world") - } - - err = New("hello") - if err.Error() != "opcua: hello" { - t.Fatalf("got %s, wanted %s", err.Error(), "opcua: hello") - } + "github.com/stretchr/testify/require" +) - err = New("hello %s") - if err.Error() != "opcua: hello %s" { - t.Fatalf("got %s, wanted %s", err.Error(), "opcua: %s") - } +func TestErrors(t *testing.T) { + t.Run("expand", func(t *testing.T) { + err := Errorf("hello %s", "world") + require.Error(t, err, errors.New("opcua: hello world")) + }) + t.Run("simple", func(t *testing.T) { + err := New("hello") + require.Error(t, err, errors.New("opcua: hello")) + }) + t.Run("parameter", func(t *testing.T) { + err := New("hello %s") + require.Error(t, err, errors.New("opcua: hello %s")) + }) } diff --git a/go.mod b/go.mod index 24e3ec82..b6f9f8a3 100644 --- a/go.mod +++ b/go.mod @@ -3,17 +3,16 @@ module github.com/gopcua/opcua go 1.22.0 require ( - github.com/google/uuid v1.3.0 - github.com/pascaldekloe/goe v0.1.1 + github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.10.0 golang.org/x/exp v0.0.0-20241204233417-43b7b7cde48d - golang.org/x/term v0.8.0 + golang.org/x/term v0.27.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/sys v0.8.0 // indirect + golang.org/x/sys v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f1d7d4bb..3acfdf5b 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/pascaldekloe/goe v0.1.1 h1:Ah6WQ56rZONR3RW3qWa2NCZ6JAVvSpUcoLBaOmYFt9Q= -github.com/pascaldekloe/goe v0.1.1/go.mod h1:KSyfaxQOh0HZPjDP1FL/kFtbqYqrALJTaMafFUIccqU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/exp v0.0.0-20241204233417-43b7b7cde48d h1:0olWaB5pg3+oychR51GUVCEsGkeCU/2JxjBgIo4f3M0= golang.org/x/exp v0.0.0-20241204233417-43b7b7cde48d/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/stats/stats_test.go b/stats/stats_test.go index 651cdbc9..27992de9 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/gopcua/opcua/ua" - "github.com/pascaldekloe/goe/verify" + "github.com/stretchr/testify/require" ) func newExpVarInt(i int64) *expvar.Int { @@ -20,13 +20,13 @@ func TestConvienienceFuncs(t *testing.T) { Reset() Client().Add("a", 1) - verify.Values(t, "", Client().Get("a"), newExpVarInt(1)) + require.Equal(t, newExpVarInt(1), Client().Get("a")) Error().Add("b", 2) - verify.Values(t, "", Error().Get("b"), newExpVarInt(2)) + require.Equal(t, newExpVarInt(2), Error().Get("b")) Subscription().Add("c", 3) - verify.Values(t, "", Subscription().Get("c"), newExpVarInt(3)) + require.Equal(t, newExpVarInt(3), Subscription().Get("c")) } func TestRecordError(t *testing.T) { @@ -46,8 +46,7 @@ func TestRecordError(t *testing.T) { t.Run(tt.key, func(t *testing.T) { s := NewStats() s.RecordError(tt.err) - got, want := s.Error.Get(tt.key), newExpVarInt(1) - verify.Values(t, "", got, want) + require.Equal(t, newExpVarInt(1), s.Error.Get(tt.key)) }) } } diff --git a/tests/go/namespace_test.go b/tests/go/namespace_test.go index ff91b974..be13571f 100644 --- a/tests/go/namespace_test.go +++ b/tests/go/namespace_test.go @@ -8,10 +8,9 @@ import ( "testing" "time" - "github.com/pascaldekloe/goe/verify" - "github.com/gopcua/opcua" "github.com/gopcua/opcua/ua" + "github.com/stretchr/testify/require" ) func TestNamespace(t *testing.T) { @@ -21,41 +20,32 @@ func TestNamespace(t *testing.T) { defer srv.Close() c, err := opcua.NewClient("opc.tcp://localhost:4840", opcua.SecurityMode(ua.MessageSecurityModeNone)) - if err != nil { - t.Fatal(err) - } - if err := c.Connect(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") + + err = c.Connect(ctx) + require.NoError(t, err, "Connect failed") defer c.Close(ctx) time.Sleep(2 * time.Second) t.Run("NamespaceArray", func(t *testing.T) { got, err := c.NamespaceArray(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NamespaceArray failed") + want := []string{ "http://opcfoundation.org/UA/", "NodeNamespace", "http://gopcua.com/", } - verify.Values(t, "", got, want) + require.Equal(t, want, got, "NamespaceArray not equal") }) t.Run("FindNamespace", func(t *testing.T) { ns, err := c.FindNamespace(ctx, "http://gopcua.com/") - if err != nil { - t.Fatal(err) - } - if got, want := ns, uint16(2); got != want { - t.Fatalf("got namespace id %d want %d", got, want) - } + require.NoError(t, err, "FindNamespace failed") + require.Equal(t, uint16(2), ns, "namespace id not equal") }) t.Run("UpdateNamespaces", func(t *testing.T) { err := c.UpdateNamespaces(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "UpdateNamespaces failed") }) } diff --git a/tests/go/read_test.go b/tests/go/read_test.go index 9cad5166..3df1043d 100644 --- a/tests/go/read_test.go +++ b/tests/go/read_test.go @@ -8,10 +8,9 @@ import ( "testing" "time" - "github.com/pascaldekloe/goe/verify" - "github.com/gopcua/opcua" "github.com/gopcua/opcua/ua" + "github.com/stretchr/testify/require" ) // TestRead performs an integration test to read values @@ -38,12 +37,10 @@ func TestRead(t *testing.T) { time.Sleep(2 * time.Second) c, err := opcua.NewClient("opc.tcp://localhost:4840", opcua.SecurityMode(ua.MessageSecurityModeNone)) - if err != nil { - t.Fatal(err) - } - if err := c.Connect(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") + + err = c.Connect(ctx) + require.NoError(t, err, "Connect failed") defer c.Close(ctx) for _, tt := range tests { @@ -68,15 +65,9 @@ func testRead(t *testing.T, ctx context.Context, c *opcua.Client, v interface{}, }, TimestampsToReturn: ua.TimestampsToReturnBoth, }) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - if resp.Results[0].Status != ua.StatusOK { - t.Fatalf("Status not OK: %v", resp.Results[0].Status) - } - if got, want := resp.Results[0].Value.Value(), v; !verify.Values(t, "", got, want) { - t.Fail() - } + require.NoError(t, err, "Read failed") + require.Equal(t, ua.StatusOK, resp.Results[0].Status, "Status not OK") + require.Equal(t, v, resp.Results[0].Value.Value(), "Results[0].Value not equal") } func testRegisteredRead(t *testing.T, ctx context.Context, c *opcua.Client, v interface{}, id *ua.NodeID) { @@ -85,9 +76,7 @@ func testRegisteredRead(t *testing.T, ctx context.Context, c *opcua.Client, v in resp, err := c.RegisterNodes(ctx, &ua.RegisterNodesRequest{ NodesToRegister: []*ua.NodeID{id}, }) - if err != nil { - t.Fatalf("RegisterNodes failed: %s", err) - } + require.NoError(t, err, "RegisterNodes failed") testRead(t, ctx, c, v, resp.RegisteredNodeIDs[0]) testRead(t, ctx, c, v, resp.RegisteredNodeIDs[0]) @@ -98,7 +87,5 @@ func testRegisteredRead(t *testing.T, ctx context.Context, c *opcua.Client, v in _, err = c.UnregisterNodes(ctx, &ua.UnregisterNodesRequest{ NodesToUnregister: []*ua.NodeID{id}, }) - if err != nil { - t.Fatalf("UnregisterNodes failed: %s", err) - } + require.NoError(t, err, "UnregisterNodes failed") } diff --git a/tests/go/stats_test.go b/tests/go/stats_test.go index 4e9df96f..fc4db025 100644 --- a/tests/go/stats_test.go +++ b/tests/go/stats_test.go @@ -5,11 +5,10 @@ import ( "expvar" "testing" - "github.com/pascaldekloe/goe/verify" - "github.com/gopcua/opcua" "github.com/gopcua/opcua/stats" "github.com/gopcua/opcua/ua" + "github.com/stretchr/testify/require" ) func newExpVarInt(i int64) *expvar.Int { @@ -27,12 +26,10 @@ func TestStats(t *testing.T) { defer srv.Close() c, err := opcua.NewClient("opc.tcp://localhost:4840", opcua.SecurityMode(ua.MessageSecurityModeNone)) - if err != nil { - t.Fatal(err) - } - if err := c.Connect(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") + + err = c.Connect(ctx) + require.NoError(t, err, "Connect failed") c.Close(ctx) want := map[string]*expvar.Int{ @@ -53,20 +50,14 @@ func TestStats(t *testing.T) { got := map[string]expvar.Var{} stats.Client().Do(func(kv expvar.KeyValue) { got[kv.Key] = kv.Value }) for k := range got { - if _, ok := want[k]; !ok { - t.Fatalf("got unexpected key %q", k) - } + require.Contains(t, want, k, "got unexpected key %q", k) } for k := range want { - if _, ok := got[k]; !ok { - t.Fatalf("missing expected key %q", k) - } + require.Contains(t, got, k, "missing expected key %q", k) } for k, ev := range want { v := stats.Client().Get(k) - if !verify.Values(t, "", v, ev) { - t.Errorf("got %s for %q, want %s", v.String(), k, ev.String()) - } + require.Equal(t, ev, v) } } diff --git a/tests/go/write_test.go b/tests/go/write_test.go index 475bbbae..5d7c84dc 100644 --- a/tests/go/write_test.go +++ b/tests/go/write_test.go @@ -10,6 +10,7 @@ import ( "github.com/gopcua/opcua" "github.com/gopcua/opcua/ua" + "github.com/stretchr/testify/require" ) // TestWrite performs an integration test to first write @@ -36,12 +37,10 @@ func TestWrite(t *testing.T) { time.Sleep(2 * time.Second) c, err := opcua.NewClient("opc.tcp://localhost:4840", opcua.SecurityMode(ua.MessageSecurityModeNone)) - if err != nil { - t.Fatal(err) - } - if err := c.Connect(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") + + err = c.Connect(ctx) + require.NoError(t, err, "Connect failed") defer c.Close(ctx) for _, tt := range tests { @@ -73,10 +72,6 @@ func testWrite(t *testing.T, ctx context.Context, c *opcua.Client, status ua.Sta t.Helper() resp, err := c.Write(ctx, req) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - if got, want := resp.Results[0], status; got != want { - t.Fatalf("got status %v want %v", got, want) - } + require.NoError(t, err, "Write failed") + require.Equal(t, status, resp.Results[0], "status not equal") } diff --git a/tests/python/method_test.go b/tests/python/method_test.go index 4739af5f..e8de5352 100644 --- a/tests/python/method_test.go +++ b/tests/python/method_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package uatest @@ -7,10 +6,9 @@ import ( "context" "testing" - "github.com/pascaldekloe/goe/verify" - "github.com/gopcua/opcua" "github.com/gopcua/opcua/ua" + "github.com/stretchr/testify/require" ) type Complex struct { @@ -62,26 +60,18 @@ func TestCallMethod(t *testing.T) { defer srv.Close() c, err := opcua.NewClient(srv.Endpoint, srv.Opts...) - if err != nil { - t.Fatal(err) - } - if err := c.Connect(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") + + err = c.Connect(ctx) + require.NoError(t, err, "Connect failed") defer c.Close(ctx) for _, tt := range tests { t.Run(tt.req.ObjectID.String(), func(t *testing.T) { resp, err := c.Call(ctx, tt.req) - if err != nil { - t.Fatal(err) - } - if got, want := resp.StatusCode, ua.StatusOK; got != want { - t.Fatalf("got status %v want %v", got, want) - } - if got, want := resp.OutputArguments, tt.out; !verify.Values(t, "", got, want) { - t.Fail() - } + require.NoError(t, err, "Call failed") + require.Equal(t, ua.StatusOK, resp.StatusCode, "StatusCode not equal") + require.Equal(t, tt.out, resp.OutputArguments, "OuptutArgs not equal") }) } } diff --git a/tests/python/namespace_test.go b/tests/python/namespace_test.go index bf59d1f3..55434879 100644 --- a/tests/python/namespace_test.go +++ b/tests/python/namespace_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package uatest @@ -7,9 +6,8 @@ import ( "context" "testing" - "github.com/pascaldekloe/goe/verify" - "github.com/gopcua/opcua" + "github.com/stretchr/testify/require" ) func TestNamespace(t *testing.T) { @@ -19,39 +17,30 @@ func TestNamespace(t *testing.T) { defer srv.Close() c, err := opcua.NewClient(srv.Endpoint, srv.Opts...) - if err != nil { - t.Fatal(err) - } - if err := c.Connect(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") + + err = c.Connect(ctx) + require.NoError(t, err, "Connect failed") defer c.Close(ctx) t.Run("NamespaceArray", func(t *testing.T) { got, err := c.NamespaceArray(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NamespaceArray failed") + want := []string{ "http://opcfoundation.org/UA/", "urn:freeopcua:python:server", "http://gopcua.com/", } - verify.Values(t, "", got, want) + require.Equal(t, want, got) }) t.Run("FindNamespace", func(t *testing.T) { ns, err := c.FindNamespace(ctx, "http://gopcua.com/") - if err != nil { - t.Fatal(err) - } - if got, want := ns, uint16(2); got != want { - t.Fatalf("got namespace id %d want %d", got, want) - } + require.NoError(t, err, "FindNamespace failed") + require.Equal(t, uint16(2), ns, "namespace id not equal") }) t.Run("UpdateNamespaces", func(t *testing.T) { err := c.UpdateNamespaces(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "UpdateNamespaces failed") }) } diff --git a/tests/python/py_server.go b/tests/python/py_server.go index 1fd50ff4..88add57b 100644 --- a/tests/python/py_server.go +++ b/tests/python/py_server.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package uatest diff --git a/tests/python/read_test.go b/tests/python/read_test.go index 1f8e5207..44848c9d 100644 --- a/tests/python/read_test.go +++ b/tests/python/read_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package uatest @@ -7,10 +6,9 @@ import ( "context" "testing" - "github.com/pascaldekloe/goe/verify" - "github.com/gopcua/opcua" "github.com/gopcua/opcua/ua" + "github.com/stretchr/testify/require" ) // TestRead performs an integration test to read values @@ -34,12 +32,10 @@ func TestRead(t *testing.T) { defer srv.Close() c, err := opcua.NewClient(srv.Endpoint, srv.Opts...) - if err != nil { - t.Fatal(err) - } - if err := c.Connect(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") + + err = c.Connect(ctx) + require.NoError(t, err, "Connect failed") defer c.Close(ctx) for _, tt := range tests { @@ -63,15 +59,9 @@ func testRead(t *testing.T, ctx context.Context, c *opcua.Client, v interface{}, }, TimestampsToReturn: ua.TimestampsToReturnBoth, }) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - if resp.Results[0].Status != ua.StatusOK { - t.Fatalf("Status not OK: %v", resp.Results[0].Status) - } - if got, want := resp.Results[0].Value.Value(), v; !verify.Values(t, "", got, want) { - t.Fail() - } + require.NoError(t, err, "Read failed") + require.Equal(t, ua.StatusOK, resp.Results[0].Status, "Status not OK") + require.Equal(t, v, resp.Results[0].Value.Value(), "Results[0].Value not equal") } func testRegisteredRead(t *testing.T, ctx context.Context, c *opcua.Client, v interface{}, id *ua.NodeID) { @@ -80,9 +70,7 @@ func testRegisteredRead(t *testing.T, ctx context.Context, c *opcua.Client, v in resp, err := c.RegisterNodes(ctx, &ua.RegisterNodesRequest{ NodesToRegister: []*ua.NodeID{id}, }) - if err != nil { - t.Fatalf("RegisterNodes failed: %s", err) - } + require.NoError(t, err, "RegisterNodes failed") testRead(t, ctx, c, v, resp.RegisteredNodeIDs[0]) testRead(t, ctx, c, v, resp.RegisteredNodeIDs[0]) @@ -93,7 +81,5 @@ func testRegisteredRead(t *testing.T, ctx context.Context, c *opcua.Client, v in _, err = c.UnregisterNodes(ctx, &ua.UnregisterNodesRequest{ NodesToUnregister: []*ua.NodeID{id}, }) - if err != nil { - t.Fatalf("UnregisterNodes failed: %s", err) - } + require.NoError(t, err, "UnregisterNodes failed") } diff --git a/tests/python/read_unknow_node_id_test.go b/tests/python/read_unknow_node_id_test.go index c3335652..db8dd815 100644 --- a/tests/python/read_unknow_node_id_test.go +++ b/tests/python/read_unknow_node_id_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package uatest @@ -10,6 +9,7 @@ import ( "github.com/gopcua/opcua" "github.com/gopcua/opcua/id" "github.com/gopcua/opcua/ua" + "github.com/stretchr/testify/require" ) // TestRead performs an integration test to read values @@ -21,12 +21,10 @@ func TestReadUnknowNodeID(t *testing.T) { defer srv.Close() c, err := opcua.NewClient(srv.Endpoint, srv.Opts...) - if err != nil { - t.Fatal(err) - } - if err := c.Connect(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") + + err = c.Connect(ctx) + require.NoError(t, err, "Connect failed") defer c.Close(ctx) // read node with unknown extension object @@ -37,13 +35,9 @@ func TestReadUnknowNodeID(t *testing.T) { {NodeID: nodeWithUnknownType}, }, }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Read failed") - if got, want := resp.Results[0].Status, ua.StatusBadDataTypeIDUnknown; got != want { - t.Errorf("got status %v want %v for a node with an unknown type", got, want) - } + require.Equal(t, ua.StatusBadDataTypeIDUnknown, resp.Results[0].Status, "got different status for a node with an unknown type") // check that the connection is still usable by reading another node. _, err = c.Read(ctx, &ua.ReadRequest{ @@ -53,7 +47,5 @@ func TestReadUnknowNodeID(t *testing.T) { }, }, }) - if err != nil { - t.Error(err) - } + require.NoError(t, err, "Read failed") } diff --git a/tests/python/reconnection_test.go b/tests/python/reconnection_test.go index ff7899ad..97cf955f 100644 --- a/tests/python/reconnection_test.go +++ b/tests/python/reconnection_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package uatest @@ -11,6 +10,7 @@ import ( "github.com/gopcua/opcua" "github.com/gopcua/opcua/monitor" "github.com/gopcua/opcua/ua" + "github.com/stretchr/testify/require" ) const ( @@ -28,18 +28,14 @@ func TestAutoReconnection(t *testing.T) { defer srv.Close() c, err := opcua.NewClient(srv.Endpoint, srv.Opts...) - if err != nil { - t.Fatal(err) - } - if err := c.Connect(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") + + err = c.Connect(ctx) + require.NoError(t, err, "Connect failed") defer c.Close(ctx) m, err := monitor.NewNodeMonitor(c) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewNodeMonitor failed") ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -92,17 +88,14 @@ func TestAutoReconnection(t *testing.T) { ch, currentTimeNodeID, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "ChanSubscribe failed") defer sub.Unsubscribe(ctx) for _, tt := range tests { - ok := t.Run(tt.name, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { - if msg := <-ch; msg.Error != nil { - t.Fatalf("No error expected for first value: %s", msg.Error) - } + msg := <-ch + require.NoError(t, msg.Error, "No error expected for first value") downC := make(chan struct{}, 1) dTimeout := time.NewTimer(disconnectTimeout) @@ -132,7 +125,7 @@ func TestAutoReconnection(t *testing.T) { select { case <-dTimeout.C: cancel() - t.Fatal("Timeout reached, the connection did not go down as expected") + require.Fail(t, "Timeout reached, the connection did not go down as expected") case <-downC: } @@ -144,16 +137,10 @@ func TestAutoReconnection(t *testing.T) { rTimeout := time.NewTimer(reconnectionTimeout) select { case <-rTimeout.C: - t.Fatal("Timeout reached, reconnection failed") + require.Fail(t, "Timeout reached, reconnection failed") case msg := <-ch: - if err := msg.Error; err != nil { - t.Fatal(err) - } + require.NoError(t, msg.Error) } }) - - if !ok { - t.FailNow() - } } } diff --git a/tests/python/stats_test.go b/tests/python/stats_test.go index 1a99bdf3..59ffd2b5 100644 --- a/tests/python/stats_test.go +++ b/tests/python/stats_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package uatest @@ -8,10 +7,9 @@ import ( "expvar" "testing" - "github.com/pascaldekloe/goe/verify" - "github.com/gopcua/opcua" "github.com/gopcua/opcua/stats" + "github.com/stretchr/testify/require" ) func newExpVarInt(i int64) *expvar.Int { @@ -29,13 +27,11 @@ func TestStats(t *testing.T) { defer srv.Close() c, err := opcua.NewClient(srv.Endpoint, srv.Opts...) - if err != nil { - t.Fatal(err) - } - if err := c.Connect(ctx); err != nil { - t.Fatal(err) - } - c.Close(ctx) + require.NoError(t, err, "NewClient failed") + + err = c.Connect(ctx) + require.NoError(t, err, "Connect failed") + c.Close(ctx) // close immediately want := map[string]*expvar.Int{ "Dial": newExpVarInt(1), @@ -55,20 +51,14 @@ func TestStats(t *testing.T) { got := map[string]expvar.Var{} stats.Client().Do(func(kv expvar.KeyValue) { got[kv.Key] = kv.Value }) for k := range got { - if _, ok := want[k]; !ok { - t.Fatalf("got unexpected key %q", k) - } + require.Contains(t, want, k, "got unexpected key %q", k) } for k := range want { - if _, ok := got[k]; !ok { - t.Fatalf("missing expected key %q", k) - } + require.Contains(t, got, k, "missing expected key %q", k) } for k, ev := range want { v := stats.Client().Get(k) - if !verify.Values(t, "", v, ev) { - t.Errorf("got %s for %q, want %s", v.String(), k, ev.String()) - } + require.Equal(t, ev, v, "got %s for %q, want %s", v.String(), k, ev.String()) } } diff --git a/tests/python/timeout_test.go b/tests/python/timeout_test.go index 8a83c96d..1467a55e 100644 --- a/tests/python/timeout_test.go +++ b/tests/python/timeout_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package uatest @@ -11,6 +10,7 @@ import ( "github.com/gopcua/opcua" "github.com/gopcua/opcua/errors" + "github.com/stretchr/testify/require" ) const ( @@ -22,18 +22,13 @@ const ( func TestClientTimeoutViaOptions(t *testing.T) { c, err := opcua.NewClient(tcpNoRstTestServer, opcua.DialTimeout(forceTimeoutDuration)) - if err != nil { - t.Fatal(err) - } - + require.NoError(t, err, "NewClient failed") connectAndValidate(t, c, context.Background(), forceTimeoutDuration) } func TestClientTimeoutViaContext(t *testing.T) { c, err := opcua.NewClient(tcpNoRstTestServer) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") ctx, cancel := context.WithTimeout(context.Background(), forceTimeoutDuration) defer cancel() @@ -45,26 +40,24 @@ func connectAndValidate(t *testing.T, c *opcua.Client, ctx context.Context, d ti start := time.Now() err := c.Connect(ctx) - if err == nil { - t.Fatal("err should not be nil") - } + require.Error(t, err, "Connect should fail") elapsed := time.Since(start) var oe *net.OpError switch { case errors.As(err, &oe) && !oe.Timeout(): - t.Fatalf("got %#v, wanted net.timeoutError", oe.Unwrap()) + require.Fail(t, "got %#v, wanted net.timeoutError", oe.Unwrap()) case errors.As(err, &oe): // ignore default: - t.Fatalf("got %T, wanted %T", err, &net.OpError{}) + require.Fail(t, "got %T, wanted %T", err, &net.OpError{}) } pct := 0.05 if !within(elapsed, d, pct) { - t.Fatalf("took %s, expected %s +/- %v%%", elapsed, d, pct*100) + require.Fail(t, "took %s, expected %s +/- %v%%", elapsed, d, pct*100) } } diff --git a/tests/python/write_test.go b/tests/python/write_test.go index 3b72edf1..b5b7e599 100644 --- a/tests/python/write_test.go +++ b/tests/python/write_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package uatest @@ -9,6 +8,7 @@ import ( "github.com/gopcua/opcua" "github.com/gopcua/opcua/ua" + "github.com/stretchr/testify/require" ) // TestWrite performs an integration test to first write @@ -33,12 +33,10 @@ func TestWrite(t *testing.T) { defer srv.Close() c, err := opcua.NewClient(srv.Endpoint, srv.Opts...) - if err != nil { - t.Fatal(err) - } - if err := c.Connect(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, err, "NewClient failed") + + err = c.Connect(ctx) + require.NoError(t, err, "Connect failed") defer c.Close(ctx) for _, tt := range tests { @@ -70,10 +68,6 @@ func testWrite(t *testing.T, ctx context.Context, c *opcua.Client, status ua.Sta t.Helper() resp, err := c.Write(ctx, req) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - if got, want := resp.Results[0], status; got != want { - t.Fatalf("got status %v want %v", got, want) - } + require.NoError(t, err, "Write failed") + require.Equal(t, status, resp.Results[0], "Write result not equal") } diff --git a/ua/codec_test.go b/ua/codec_test.go index 14e20aae..4dfc7cc2 100644 --- a/ua/codec_test.go +++ b/ua/codec_test.go @@ -38,26 +38,23 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { case reflect.Slice: v = reflect.New(typ) // typ: []x, v: *[]x default: - t.Fatalf("%T is not a pointer or a slice", c.Struct) + require.Fail(t, "%T is not a pointer or a slice", c.Struct) } - if _, err := Decode(c.Bytes, v.Interface()); err != nil { - t.Fatal(err) - } + _, err := Decode(c.Bytes, v.Interface()) + require.NoError(t, err, "Decode failed") // if v is a *[]x we need to dereference it before comparing it. if typ.Kind() == reflect.Slice { v = v.Elem() } - require.Equal(t, c.Struct, v.Interface()) + require.Equal(t, c.Struct, v.Interface(), "Decoded payload not equal") }) t.Run("encode", func(t *testing.T) { b, err := Encode(c.Struct) - if err != nil { - t.Fatal(err) - } - require.Equal(t, c.Bytes, b) + require.NoError(t, err, "Encode failed") + require.Equal(t, c.Bytes, b, "Encoded payload not equal") }) }) } diff --git a/ua/decode_test.go b/ua/decode_test.go index acdc04f0..74e4160b 100644 --- a/ua/decode_test.go +++ b/ua/decode_test.go @@ -5,10 +5,11 @@ package ua import ( - "bytes" "reflect" "testing" "time" + + "github.com/stretchr/testify/require" ) type A struct { @@ -384,9 +385,7 @@ func TestCodec(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if reflect.ValueOf(tt.v).Kind() != reflect.Ptr { - t.Fatalf("%T is not a pointer", tt.v) - } + require.Equal(t, reflect.Ptr, reflect.ValueOf(tt.v).Kind(), "%T is not a pointer", tt.v) t.Run("decode", func(t *testing.T) { // create a new instance of the same type as tt.v @@ -394,22 +393,14 @@ func TestCodec(t *testing.T) { typ := reflect.ValueOf(tt.v).Type() v := reflect.New(typ.Elem()) - if _, err := Decode(tt.b, v.Interface()); err != nil { - t.Fatal(err) - } - - if got, want := v.Interface(), tt.v; !reflect.DeepEqual(got, want) { - t.Fatalf("got %#v, want %#v", got, want) - } + _, err := Decode(tt.b, v.Interface()) + require.NoError(t, err, "Decode failed") + require.Equal(t, tt.v, v.Interface(), "Decoded payload not equal") }) t.Run("encode", func(t *testing.T) { b, err := Encode(tt.v) - if err != nil { - t.Fatal(err) - } - if got, want := b, tt.b; !bytes.Equal(got, want) { - t.Fatalf("\ngot %#v\nwant %#v", got, want) - } + require.NoError(t, err, "Encode failed") + require.Equal(t, tt.b, b, "Encoded payload not equal") }) }) } @@ -426,7 +417,5 @@ func TestFailDecodeArray(t *testing.T) { } var a [2]int32 _, err := Decode(b, &a) - if err == nil { - t.Fatalf("was expecting error for tryig to decode a stream of bytes with length 3 into an array of size 2") - } + require.Error(t, err, "was expecting error for tryig to decode a stream of bytes with length 3 into an array of size 2") } diff --git a/ua/expanded_node_id_test.go b/ua/expanded_node_id_test.go index 827c2404..2f831320 100644 --- a/ua/expanded_node_id_test.go +++ b/ua/expanded_node_id_test.go @@ -6,10 +6,10 @@ package ua import ( "math" - "reflect" "testing" "github.com/gopcua/opcua/errors" + "github.com/stretchr/testify/require" ) func TestExpandedNodeID(t *testing.T) { @@ -101,12 +101,8 @@ func TestParseExpandedNodeID(t *testing.T) { for _, c := range cases { t.Run(c.s, func(t *testing.T) { n, err := ParseExpandedNodeID(c.s, c.ns) - if got, want := err, c.err; !errors.Equal(got, want) { - t.Fatalf("got error %v want %v", got, want) - } - if got, want := n, c.n; !reflect.DeepEqual(got, want) { - t.Fatalf("\ngot %#v\nwant %#v", got, want) - } + require.Equal(t, c.err, err, "Errors not equal") + require.Equal(t, c.n, n, "ExpandedNodeID not equal") }) } } diff --git a/ua/node_id_test.go b/ua/node_id_test.go index a4cc97b9..a76193ad 100644 --- a/ua/node_id_test.go +++ b/ua/node_id_test.go @@ -8,10 +8,10 @@ import ( "encoding/base64" "encoding/json" "math" - "reflect" "testing" "github.com/gopcua/opcua/errors" + "github.com/stretchr/testify/require" ) func TestNodeID(t *testing.T) { @@ -173,12 +173,8 @@ func TestParseNodeID(t *testing.T) { for _, c := range cases { t.Run(c.s, func(t *testing.T) { n, err := ParseNodeID(c.s) - if got, want := err, c.err; !errors.Equal(got, want) { - t.Fatalf("got error %v want %v", got, want) - } - if got, want := n, c.n; !reflect.DeepEqual(got, want) { - t.Fatalf("\ngot %#v\nwant %#v", got, want) - } + require.Equal(t, c.err, err, "Error not equal") + require.Equal(t, c.n, n, "Parsed NodeID not equal") }) } } @@ -197,9 +193,7 @@ func TestStringID(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { - if got, want := c.n.String(), c.s; got != want { - t.Fatalf("got %s want %s", got, want) - } + require.Equal(t, c.s, c.n.String()) }) } } @@ -266,22 +260,17 @@ func TestSetIntID(t *testing.T) { v := tt.n.IntID() // sanity check - if before, after := v, tt.v; before == after { - t.Fatalf("before == after: %d == %d", before, after) - } + require.NotEqual(t, tt.v, v, "before == after: %d == %d", v, tt.v) err := tt.n.SetIntID(tt.v) - if got, want := err, tt.err; !errors.Equal(got, want) { - t.Fatalf("got error %v want %v", got, want) - } + require.Equal(t, tt.err, err) + // if the test should fail and the error was correct // we need to stop here. if tt.err != nil { return } - if got, want := tt.n.IntID(), tt.v; got != want { - t.Fatalf("got value %d want %d", got, want) - } + require.Equal(t, tt.v, tt.n.IntID(), "IntID not equal") }) } } @@ -342,22 +331,17 @@ func TestSetStringID(t *testing.T) { v := tt.n.StringID() // sanity check - if before, after := v, tt.v; before == after { - t.Fatalf("before == after: %s == %s", before, after) - } + require.NotEqual(t, tt.v, v, "before == after: %s == %s", v, tt.v) err := tt.n.SetStringID(tt.v) - if got, want := err, tt.err; !errors.Equal(got, want) { - t.Fatalf("got error %q (%T) want %q (%T)", got, got, want, want) - } + require.Equal(t, tt.err, err) + // if the test should fail and the error was correct // we need to stop here. if tt.err != nil { return } - if got, want := tt.n.StringID(), tt.v; got != want { - t.Fatalf("got value %s want %s", got, want) - } + require.Equal(t, tt.v, tt.n.StringID()) }) } } @@ -404,17 +388,14 @@ func TestSetNamespace(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.n.SetNamespace(tt.v) - if got, want := err, tt.err; !errors.Equal(got, want) { - t.Fatalf("got error %v want %v", got, want) - } + require.Equal(t, tt.err, err) + // if the test should fail and the error was correct // we need to stop here. if tt.err != nil { return } - if got, want := tt.n.Namespace(), tt.v; got != want { - t.Fatalf("got value %d want %d", got, want) - } + require.Equal(t, tt.v, tt.n.Namespace()) }) } } @@ -422,64 +403,43 @@ func TestSetNamespace(t *testing.T) { func TestNodeIDJSON(t *testing.T) { t.Run("value", func(t *testing.T) { n, err := ParseNodeID(`ns=4;s=abc`) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + b, err := json.Marshal(n) - if err != nil { - t.Fatal(err) - } - if got, want := string(b), `"ns=4;s=abc"`; got != want { - t.Fatalf("got %s want %s", got, want) - } + require.NoError(t, err) + require.Equal(t, `"ns=4;s=abc"`, string(b)) + var nn NodeID - if err := json.Unmarshal(b, &nn); err != nil { - t.Fatal(err) - } - if got, want := nn.String(), n.String(); got != want { - t.Fatalf("got %s want %s", got, want) - } + err = json.Unmarshal(b, &nn) + require.NoError(t, err) + require.Equal(t, n.String(), nn.String(), "NodeIDs not equal") }) t.Run("nil", func(t *testing.T) { var n *NodeID b, err := json.Marshal(n) - if err != nil { - t.Fatal(err) - } - if got, want := string(b), "null"; got != want { - t.Fatalf("got %s want %s", got, want) - } + require.NoError(t, err) + require.Equal(t, "null", string(b)) }) type X struct{ N *NodeID } t.Run("struct", func(t *testing.T) { x := X{NewStringNodeID(4, "abc")} b, err := json.Marshal(x) - if err != nil { - t.Fatal(err) - } - if got, want := string(b), `{"N":"ns=4;s=abc"}`; got != want { - t.Fatalf("got %s want %s", got, want) - } + require.NoError(t, err) + require.Equal(t, `{"N":"ns=4;s=abc"}`, string(b)) }) t.Run("nil struct", func(t *testing.T) { var x X b, err := json.Marshal(x) - if err != nil { - t.Fatal(err) - } - if got, want := string(b), `{"N":null}`; got != want { - t.Fatalf("got %s want %s", got, want) - } + require.NoError(t, err) + require.Equal(t, `{"N":null}`, string(b)) + var xx X - if err := json.Unmarshal(b, &xx); err != nil { - t.Fatal(err) - } - if got, want := xx, x; !reflect.DeepEqual(got, want) { - t.Fatalf("got %s want %s", got, want) - } + err = json.Unmarshal(b, &xx) + require.NoError(t, err) + require.Equal(t, x, xx) }) } @@ -507,12 +467,8 @@ func TestNodeIDToString(t *testing.T) { for _, tt := range tests { t.Run(tt.s, func(t *testing.T) { n, err := ParseNodeID(tt.s) - if err != nil { - t.Fatal(err) - } - if got, want := n.String(), tt.want; got != want { - t.Fatalf("got %s want %s", got, want) - } + require.NoError(t, err) + require.Equal(t, tt.want, n.String()) }) } } @@ -578,9 +534,7 @@ func TestNewNodeIDFromExpandedNodeID(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewNodeIDFromExpandedNodeID(tt.args.id); !reflect.DeepEqual(got, tt.want) { - t.Errorf("NewNodeIDFromExpandedNodeID() = %#v, want %#v", got, tt.want) - } + require.Equal(t, tt.want, NewNodeIDFromExpandedNodeID(tt.args.id)) }) } } diff --git a/ua/variant_test.go b/ua/variant_test.go index 37df8b09..01f76eac 100644 --- a/ua/variant_test.go +++ b/ua/variant_test.go @@ -10,9 +10,7 @@ import ( "testing" "time" - "github.com/gopcua/opcua/errors" - - "github.com/pascaldekloe/goe/verify" + "github.com/stretchr/testify/require" ) func TestVariant(t *testing.T) { @@ -493,43 +491,25 @@ func TestVariant(t *testing.T) { func TestMustVariant(t *testing.T) { t.Run("int", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Fatalf("MustVariant(int) did not panic") - } - }() - MustVariant(int(5)) + require.Panics(t, func() { MustVariant(int(5)) }) }) t.Run("uint", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Fatalf("MustVariant(uint) did not panic") - } - }() - MustVariant(uint(5)) + require.Panics(t, func() { MustVariant(uint(5)) }) }) } func TestArray(t *testing.T) { t.Run("one-dimension", func(t *testing.T) { v := MustVariant([]uint32{1, 2, 3}) - if got, want := v.ArrayLength(), int32(3); got != want { - t.Fatalf("got length %d want %d", got, want) - } - if got, want := v.EncodingMask(), byte(TypeIDUint32|VariantArrayValues); got != want { - t.Fatalf("got mask %d want %d", got, want) - } - verify.Values(t, "", v.ArrayDimensions(), []int32{}) + require.Equal(t, int32(3), v.ArrayLength(), "ArrayLength not equal") + require.Equal(t, byte(TypeIDUint32|VariantArrayValues), v.EncodingMask(), "EncodingMask not equal") + require.Equal(t, []int32(nil), v.ArrayDimensions()) }) t.Run("multi-dimension", func(t *testing.T) { v := MustVariant([][]uint32{{1, 1}, {2, 2}, {3, 3}}) - if got, want := v.ArrayLength(), int32(6); got != want { - t.Fatalf("got length %d want %d", got, want) - } - if got, want := v.EncodingMask(), byte(TypeIDUint32|VariantArrayValues|VariantArrayDimensions); got != want { - t.Fatalf("got mask %d want %d", got, want) - } - verify.Values(t, "", v.ArrayDimensions(), []int32{3, 2}) + require.Equal(t, int32(6), v.ArrayLength(), "ArrayLength not equal") + require.Equal(t, byte(TypeIDUint32|VariantArrayValues|VariantArrayDimensions), v.EncodingMask(), "EncodingMask not equal") + require.Equal(t, []int32{3, 2}, v.ArrayDimensions()) }) t.Run("unbalanced", func(t *testing.T) { b := []byte{ @@ -549,9 +529,7 @@ func TestArray(t *testing.T) { } _, err := Decode(b, MustVariant([]uint32{0})) - if got, want := err, errUnbalancedSlice; !errors.Equal(got, want) { - t.Fatalf("got error %#v want %#v", got, want) - } + require.ErrorIs(t, err, errUnbalancedSlice) }) t.Run("length too big", func(t *testing.T) { b := []byte{ @@ -564,9 +542,7 @@ func TestArray(t *testing.T) { } _, err := Decode(b, MustVariant([]uint32{0})) - if got, want := err, StatusBadEncodingLimitsExceeded; !errors.Equal(got, want) { - t.Fatalf("got error %v want %v", err, StatusBadEncodingLimitsExceeded) - } + require.ErrorIs(t, err, StatusBadEncodingLimitsExceeded) }) t.Run("dimensions length negative", func(t *testing.T) { b := []byte{ @@ -585,9 +561,7 @@ func TestArray(t *testing.T) { } _, err := Decode(b, MustVariant([]uint32{0})) - if got, want := err, StatusBadEncodingLimitsExceeded; !errors.Equal(got, want) { - t.Fatalf("got error %#v want %#v", got, want) - } + require.ErrorIs(t, err, StatusBadEncodingLimitsExceeded) }) t.Run("dimensions negative", func(t *testing.T) { b := []byte{ @@ -606,9 +580,7 @@ func TestArray(t *testing.T) { } _, err := Decode(b, MustVariant([]uint32{0})) - if got, want := err, StatusBadEncodingLimitsExceeded; !errors.Equal(got, want) { - t.Fatalf("got error %#v want %#v", got, want) - } + require.ErrorIs(t, err, StatusBadEncodingLimitsExceeded) }) t.Run("dimensions zero", func(t *testing.T) { b := []byte{ @@ -625,9 +597,7 @@ func TestArray(t *testing.T) { } _, err := Decode(b, MustVariant([][]uint32{{}, {}})) - if got, want := err, StatusBadEncodingLimitsExceeded; !errors.Equal(got, want) { - t.Fatalf("got error %#v want %#v", got, want) - } + require.ErrorIs(t, err, StatusBadEncodingLimitsExceeded) }) } @@ -725,14 +695,12 @@ func TestSet(t *testing.T) { for _, tt := range tests { t.Run(fmt.Sprintf("%T", tt.v), func(t *testing.T) { va, err := NewVariant(tt.v) - if got, want := err, tt.err; got != want { - t.Fatalf("got error %v want %v", got, want) - } - verify.Values(t, "variant.mask", va.mask, tt.va.mask) - verify.Values(t, "variant.arrayLength", va.arrayLength, tt.va.arrayLength) - verify.Values(t, "variant.arrayDimensionsLength", va.arrayDimensionsLength, tt.va.arrayDimensionsLength) - verify.Values(t, "variant.arrayDimensions", va.arrayDimensions, tt.va.arrayDimensions) - verify.Values(t, "variant.value", va.value, tt.va.value) + require.Equal(t, tt.err, err) + require.Equal(t, tt.va.mask, va.mask) + require.Equal(t, tt.va.arrayLength, va.arrayLength) + require.Equal(t, tt.va.arrayDimensionsLength, va.arrayDimensionsLength) + require.Equal(t, tt.va.arrayDimensions, va.arrayDimensions) + require.Equal(t, tt.va.value, va.value) }) } } @@ -812,18 +780,10 @@ func TestSliceDim(t *testing.T) { for _, tt := range tests { t.Run(fmt.Sprintf("%T", tt.v), func(t *testing.T) { et, dim, len, err := sliceDim(reflect.ValueOf(tt.v)) - if got, want := err, tt.err; got != want { - t.Fatalf("got error %v want %v", got, want) - } - if got, want := et, tt.et; got != want { - t.Fatalf("got type %v want %v", got, want) - } - if got, want := dim, tt.dim; !reflect.DeepEqual(got, want) { - t.Fatalf("got dimensions %v want %v", got, want) - } - if got, want := len, tt.len; got != want { - t.Fatalf("got len %v want %v", got, want) - } + require.Equal(t, tt.err, err) + require.Equal(t, tt.et, et) + require.Equal(t, tt.dim, dim) + require.Equal(t, tt.len, len) }) } } @@ -832,17 +792,14 @@ func TestVariantUnsupportedType(t *testing.T) { tests := []interface{}{int(5), uint(5)} for _, v := range tests { t.Run(fmt.Sprintf("%T", v), func(t *testing.T) { - if _, err := NewVariant(v); err == nil { - t.Fatal("got nil want err") - } + _, err := NewVariant(v) + require.Error(t, err) }) } } func TestVariantValueMethod(t *testing.T) { - if got, want := MustVariant(int32(5)).Value().(int32), int32(5); got != want { - t.Fatalf("got %d want %d", got, want) - } + require.Equal(t, int32(5), MustVariant(int32(5)).Value().(int32)) } func TestVariantValueHelpers(t *testing.T) { @@ -1270,7 +1227,7 @@ func TestVariantValueHelpers(t *testing.T) { for i, tt := range tests { name := fmt.Sprintf("test-%d %T -> %T", i, tt.v, tt.want) t.Run(name, func(t *testing.T) { - verify.Values(t, "", tt.fn(MustVariant(tt.v)), tt.want) + require.Equal(t, tt.want, tt.fn(MustVariant(tt.v))) }) } } @@ -1283,7 +1240,5 @@ func TestDecodeInvalidType(t *testing.T) { v := &Variant{} _, err := v.Decode(b) - if got, want := err, errors.New("invalid type id: 32"); !errors.Equal(got, want) { - t.Fatalf("got error %s want %s", got, want) - } + require.EqualError(t, err, "opcua: invalid type id: 32") } diff --git a/uacp/codec_test.go b/uacp/codec_test.go index 360ab773..493063c8 100644 --- a/uacp/codec_test.go +++ b/uacp/codec_test.go @@ -11,7 +11,7 @@ import ( "testing" "github.com/gopcua/opcua/ua" - "github.com/pascaldekloe/goe/verify" + "github.com/stretchr/testify/require" ) // CodecTestCase describes a test case for a encoding and decoding an @@ -39,26 +39,23 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { case reflect.Slice: v = reflect.New(typ) // typ: []x, v: *[]x default: - t.Fatalf("%T is not a pointer or a slice", c.Struct) + require.Fail(t, "%T is not a pointer or a slice", c.Struct) } - if _, err := ua.Decode(c.Bytes, v.Interface()); err != nil { - t.Fatal(err) - } + _, err := ua.Decode(c.Bytes, v.Interface()) + require.NoError(t, err, "Decode failed") // if v is a *[]x we need to dereference it before comparing it. if typ.Kind() == reflect.Slice { v = v.Elem() } - verify.Values(t, "", v.Interface(), c.Struct) + require.Equal(t, c.Struct, v.Interface(), "Decoded payload not equal") }) t.Run("encode", func(t *testing.T) { b, err := ua.Encode(c.Struct) - if err != nil { - t.Fatal(err) - } - verify.Values(t, "", b, c.Bytes) + require.NoError(t, err, "Encode failed") + require.Equal(t, c.Bytes, b, "Encoded payload not equal") }) }) } diff --git a/uacp/conn_test.go b/uacp/conn_test.go index 2269bbe2..44adb241 100644 --- a/uacp/conn_test.go +++ b/uacp/conn_test.go @@ -11,16 +11,14 @@ import ( "time" "github.com/gopcua/opcua/errors" - "github.com/pascaldekloe/goe/verify" + "github.com/stretchr/testify/require" ) func TestConn(t *testing.T) { t.Run("server exists ", func(t *testing.T) { ep := "opc.tcp://127.0.0.1:4840/foo/bar" ln, err := Listen(ep, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer ln.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -45,9 +43,9 @@ func TestConn(t *testing.T) { select { case <-done: case err := <-acceptErr: - t.Fatalf("accept fail: %v", err) + require.Fail(t, "accept fail: %v", err) case <-time.After(time.Second): - t.Fatal("timed out") + require.Fail(t, "timed out") } }) @@ -67,9 +65,7 @@ func TestConn(t *testing.T) { func TestClientWrite(t *testing.T) { ep := "opc.tcp://127.0.0.1:4840/foo/bar" ln, err := Listen(ep, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Listen failed") defer ln.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -90,49 +86,40 @@ func TestClientWrite(t *testing.T) { }() cliConn, err := Dial(ctx, ep) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Dial failed") for { select { case _, ok := <-done: - if !ok { - t.Fatal("failed to setup secure channel") - } + require.True(t, ok, "failed to setup secure channel") goto NEXT case err := <-acceptErr: - t.Fatalf("accept fail: %v", err) + require.Fail(t, "accept fail: %v", err) case <-time.After(time.Second): - t.Fatal("timed out") + require.Fail(t, "timed out") } } NEXT: msg := &Message{Data: []byte{0xde, 0xad, 0xbe, 0xef}} - if err := cliConn.Send("MSGF", msg); err != nil { - t.Fatal(err) - } + err = cliConn.Send("MSGF", msg) + require.NoError(t, err, "Send failed") got, err := srvConn.Receive() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Receive failed") + got = got[hdrlen:] want, err := msg.Encode() - if err != nil { - t.Fatal(err) - } - verify.Values(t, "", got, want) + require.NoError(t, err, "Encode failed") + + require.Equal(t, want, got) } func TestServerWrite(t *testing.T) { ep := "opc.tcp://127.0.0.1:4840/foo/bar" ln, err := Listen(ep, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Listen failed") defer ln.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -153,35 +140,29 @@ func TestServerWrite(t *testing.T) { }() cliConn, err := Dial(ctx, ep) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Dial failed") for { select { case _, ok := <-done: - if !ok { - t.Fatal("failed to setup secure channel") - } + require.True(t, ok, "failed to setup secure channel") goto NEXT case err := <-acceptErr: - t.Fatalf("accept fail: %v", err) + require.Fail(t, "accept fail: %v", err) case <-time.After(time.Second): - t.Fatal("timed out") + require.Fail(t, "timed out") } } NEXT: want := []byte{0xde, 0xad, 0xbe, 0xef} - if _, err := srvConn.Write(want); err != nil { - t.Fatal(err) - } + _, err = srvConn.Write(want) + require.NoError(t, err, "Write failed") got := make([]byte, cliConn.ReceiveBufSize()) n, err := cliConn.Read(got) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Read failed") + got = got[:n] - verify.Values(t, "", got, want) + require.Equal(t, want, got) } diff --git a/uacp/endpoint_test.go b/uacp/endpoint_test.go index 4637371e..afe699ac 100644 --- a/uacp/endpoint_test.go +++ b/uacp/endpoint_test.go @@ -7,6 +7,8 @@ package uacp import ( "net" "testing" + + "github.com/stretchr/testify/require" ) func TestResolveEndpoint(t *testing.T) { @@ -58,19 +60,11 @@ func TestResolveEndpoint(t *testing.T) { } for _, c := range cases { - var errStr string network, addr, err := ResolveEndpoint(c.input) + require.Equal(t, c.network, network, "network not equal") + require.Equal(t, c.addr.String(), addr.String(), "addr not equal") if err != nil { - errStr = err.Error() - } - if got, want := network, c.network; got != want { - t.Fatalf("got network %q want %q", got, want) - } - if got, want := addr.String(), c.addr.String(); got != want { - t.Fatalf("got addr %q want %q", got, want) - } - if got, want := errStr, c.errStr; got != want { - t.Fatalf("got error %q want %q", got, want) + require.ErrorContains(t, err, c.errStr) } } } diff --git a/uapolicy/securitypolicy_test.go b/uapolicy/securitypolicy_test.go index e9618711..2a2b0121 100644 --- a/uapolicy/securitypolicy_test.go +++ b/uapolicy/securitypolicy_test.go @@ -5,7 +5,6 @@ package uapolicy import ( - "bytes" "crypto" "crypto/rand" "crypto/rsa" @@ -13,8 +12,7 @@ import ( "testing" "github.com/gopcua/opcua/ua" - - "github.com/pascaldekloe/goe/verify" + "github.com/stretchr/testify/require" ) func TestSupportedPolicies(t *testing.T) { @@ -24,33 +22,23 @@ func TestSupportedPolicies(t *testing.T) { want = append(want, k) } sort.Strings(want) - verify.Values(t, "", got, want) + require.Equal(t, want, got) } func TestGenerateKeysLength(t *testing.T) { localNonce := make([]byte, 32) remoteNonce := make([]byte, 32) _, err := rand.Read(localNonce) - if err != nil { - t.Fatalf("Could not generate local nonce") - } + require.NoError(t, err, "Could not generate local nonce") _, err = rand.Read(remoteNonce) - if err != nil { - t.Fatalf("Could not generate remote nonce") - } + require.NoError(t, err, "Could not generate remote nonce") hmac := &HMAC{Hash: crypto.SHA256, Secret: remoteNonce} keys := generateKeys(hmac, localNonce, 32, 32, 16) - if len(keys.signing) != 32 { - t.Errorf("Signing Key Invalid Length\n") - } - if len(keys.encryption) != 32 { - t.Errorf("Encryption Key Invalid Length\n") - } - if len(keys.iv) != 16 { - t.Errorf("Encryption IV Invalid Length\n") - } + require.Equal(t, 32, len(keys.signing), "Signing Key Invalid Length") + require.Equal(t, 32, len(keys.encryption), "Encryption Key Invalid Length") + require.Equal(t, 16, len(keys.iv), "Encryption IV Invalid Length") } func TestGenerateKeys(t *testing.T) { @@ -71,27 +59,15 @@ func TestGenerateKeys(t *testing.T) { localHmac := &HMAC{Hash: crypto.SHA1, Secret: localNonce} keys := generateKeys(localHmac, remoteNonce, 16, 16, 16) - if got, want := keys.signing, localKeys.signing; !bytes.Equal(got, want) { - t.Errorf("local signing key generation failed:\ngot %#v want %#v\n", got, want) - } - if got, want := keys.encryption, localKeys.encryption; !bytes.Equal(got, want) { - t.Errorf("local encryption key generation failed:\ngot %#v want %#v\n", got, want) - } - if got, want := keys.iv, localKeys.iv; !bytes.Equal(got, want) { - t.Errorf("local iv key generation failed:\ngot %#v want %#v\n", got, want) - } + require.Equal(t, localKeys.signing, keys.signing, "local signing key generation failed") + require.Equal(t, localKeys.encryption, keys.encryption, "local encryption key generation failed") + require.Equal(t, localKeys.iv, keys.iv, "local iv key generation failed") remoteHmac := &HMAC{Hash: crypto.SHA1, Secret: remoteNonce} keys = generateKeys(remoteHmac, localNonce, 16, 16, 16) - if got, want := keys.signing, remoteKeys.signing; !bytes.Equal(got, want) { - t.Errorf("remote signing key generation failed:\ngot %#v want %#v\n", got, want) - } - if got, want := keys.encryption, remoteKeys.encryption; !bytes.Equal(got, want) { - t.Errorf("remote encryption key generation failed:\ngot %#v want %#v\n", got, want) - } - if got, want := keys.iv, remoteKeys.iv; !bytes.Equal(got, want) { - t.Errorf("remote iv key generation failed:\ngot %#v want %#v\n", got, want) - } + require.Equal(t, remoteKeys.signing, keys.signing, "remote signing key generation failed") + require.Equal(t, remoteKeys.encryption, keys.encryption, "remote encryption key generation failed") + require.Equal(t, remoteKeys.iv, keys.iv, "remote iv key generation failed") } // Test all supported encryption algorithms. Because the majority of the algorithms @@ -102,9 +78,7 @@ func TestGenerateKeys(t *testing.T) { func TestEncryptionAlgorithms(t *testing.T) { payload := make([]byte, 5000) _, err := rand.Read(payload) - if err != nil { - t.Fatalf("could not generate random payload") - } + require.NoError(t, err, "could not generate random payload") payloadRef := make([]byte, len(payload)) copy(payloadRef, payload) @@ -113,20 +87,15 @@ func TestEncryptionAlgorithms(t *testing.T) { // This won't be the case forever and will be too small for future algorithms // and the test will need to be able to input keys of varying size localKey, err := generatePrivateKey(2048) - if err != nil { - t.Fatalf("Unable to generate local private key\n") - } + require.NoError(t, err, "Unable to generate local private key") + remoteKey, err := generatePrivateKey(2048) - if err != nil { - t.Fatalf("Unable to generate remote private key\n") - } + require.NoError(t, err, "Unable to generate remote private key") for uri, p := range policies { t.Run(uri, func(t *testing.T) { localAsymmetric, err := p.asymmetric(localKey, &remoteKey.PublicKey) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) makeNonce := func(n int) []byte { t.Helper() @@ -134,32 +103,23 @@ func TestEncryptionAlgorithms(t *testing.T) { return nil } b := make([]byte, n) - if _, err = rand.Read(b); err != nil { - t.Fatalf("could not generate nonce") - } + _, err = rand.Read(b) + require.NoError(t, err, "could not generate nonce") return b } nonceLength := localAsymmetric.NonceLength() localNonce, remoteNonce := makeNonce(nonceLength), makeNonce(nonceLength) - if nonceLength == 0 && uri != ua.SecurityPolicyURINone { - t.Fatalf("client nonce length zero") - } + require.False(t, nonceLength == 0 && uri != ua.SecurityPolicyURINone, "client nonce length zero") localSymmetric, err := p.symmetric(localNonce, remoteNonce) - if err != nil { - t.Fatalf("failed local Symmetric: %s", err) - } + require.NoError(t, err, "failed local Symmetric: %s", err) remoteSymmetric, err := p.symmetric(remoteNonce, localNonce) - if err != nil { - t.Fatalf("failed remote Symmetric: %s", err) - } + require.NoError(t, err, "failed remote Symmetric: %s", err) remoteAsymmetric, err := p.asymmetric(remoteKey, &localKey.PublicKey) - if err != nil { - t.Fatalf("failed remote Asymmetric: %s", err) - } + require.NoError(t, err, "failed remote Asymmetric: %s", err) // Symmetric Algorithm plaintext := make([]byte, len(payload)) @@ -173,64 +133,43 @@ func TestEncryptionAlgorithms(t *testing.T) { copy(paddedPlaintext, plaintext) symCiphertext, err := localSymmetric.Encrypt(paddedPlaintext) - if err != nil { - t.Fatalf("failed to encrypt Symmetric: %s", err) - } + require.NoError(t, err, "failed to encrypt Symmetric: %s", err) symDeciphered, err := remoteSymmetric.Decrypt(symCiphertext) - if err != nil { - t.Fatalf("failed to decrypt Symmetric: %s", err) - } + require.NoError(t, err, "failed to decrypt Symmetric: %s", err) + symDeciphered = symDeciphered[:len(symDeciphered)-padSize] // Trim off padding - if got, want := symDeciphered, plaintext; !bytes.Equal(got, want) { - t.Errorf("symmetric encryption failed:\ngot %#v want %#v\n", got, want) - } + require.Equal(t, plaintext, symDeciphered, "symmetric encryption failed") // Modify the plaintext and detect if the decrypted message changes; if it does, // our byte slices are referencing the same data and the previous test may have // been a false positive paddedPlaintext[4] = 0xff ^ paddedPlaintext[4] - if got, want := symDeciphered, payloadRef; !bytes.Equal(got, want) { - t.Errorf("symmetric input corruption detected:\ngot %#v want %#v\n", got, want) - } + require.Equal(t, payloadRef, symDeciphered, "symmetric input corruption detected") symSignature, err := localSymmetric.Signature(paddedPlaintext) - if err != nil { - t.Errorf("symmetric signature generation failed") - } + require.NoError(t, err, "symmetric signature generation failed") err = remoteSymmetric.VerifySignature(paddedPlaintext, symSignature) - if err != nil { - t.Errorf("symmetric signature validation failed") - } + require.NoError(t, err, "symmetric signature validation failed") // Asymmetric Algorithm asymCiphertext, err := localAsymmetric.Encrypt(plaintext) - if err != nil { - t.Fatalf("failed to encrypt Asymmetric: %s", err) - } + require.NoError(t, err, "failed to encrypt Asymmetric: %s", err) + asymDeciphered, err := remoteAsymmetric.Decrypt(asymCiphertext) - if err != nil { - t.Fatalf("failed to decrypt Asymmetric: %s", err) - } - if got, want := asymDeciphered, plaintext; !bytes.Equal(got, want) { - t.Errorf("asymmetric encryption failed:\ngot %#v want %#v\n", got, want) - } + require.NoError(t, err, "failed to decrypt Asymmetric: %s", err) + + require.Equal(t, plaintext, asymDeciphered, "asymmetric encryption failed") paddedPlaintext[4] = 0xff ^ paddedPlaintext[4] - if got, want := asymDeciphered, payloadRef; !bytes.Equal(got, want) { - t.Errorf("asymmetric input corruption detected:\ngot %#v want %#v\n", got, want) - } + require.Equal(t, payloadRef, asymDeciphered, "asymmetric input corruption detected") asymSignature, err := localAsymmetric.Signature(plaintext) - if err != nil { - t.Errorf("asymmetric signature generation failed\n") - } + require.NoError(t, err, "asymmetric signature generation failed") err = remoteAsymmetric.VerifySignature(plaintext, asymSignature) - if err != nil { - t.Errorf("asymmetric signature validation failed\n") - } + require.NoError(t, err, "asymmetric signature validation failed") }) } } @@ -238,17 +177,13 @@ func TestEncryptionAlgorithms(t *testing.T) { func TestMissingKey(t *testing.T) { payload := make([]byte, 5000) _, err := rand.Read(payload) - if err != nil { - t.Fatalf("could not generate random payload") - } + require.NoError(t, err, "could not generate random payload") payloadRef := make([]byte, len(payload)) copy(payloadRef, payload) key, err := generatePrivateKey(2048) - if err != nil { - t.Fatalf("Unable to generate private key\n") - } + require.NoError(t, err, "Unable to generate private key") for uri, p := range policies { if uri == ua.SecurityPolicyURINone { @@ -257,87 +192,58 @@ func TestMissingKey(t *testing.T) { t.Run(uri, func(t *testing.T) { encryptOnly, err := p.asymmetric(nil, &key.PublicKey) - if err != nil { - t.Fatalf("failed to create encrypt-only asymmetric algorithms: %s", err) - } + require.NoError(t, err, "failed to create encrypt-only asymmetric algorithms") decryptOnly, err := p.asymmetric(key, nil) - if err != nil { - t.Fatalf("failed to create decrypt-only asymmetric algorithms: %s", err) - } + require.NoError(t, err, "failed to create decrypt-only asymmetric algorithms") ciphertext, err := encryptOnly.Encrypt(payload) - if err != nil { - t.Fatalf("failed to encrypt with encrypt-only policy: %s", err) - } + require.NoError(t, err, "failed to encrypt with encrypt-only policy") signature, err := decryptOnly.Signature(payload) - if err != nil { - t.Fatalf("decrypt-only algorithm failed to generate signature: %s", err) - } + require.NoError(t, err, "decrypt-only algorithm failed to generate signature") err = encryptOnly.VerifySignature(payload, signature) - if err != nil { - t.Fatalf("failed to verify signature with encrypt-only algorithm: %s", err) - } + require.NoError(t, err, "failed to verify signature with encrypt-only algorithm") plaintext, err := decryptOnly.Decrypt(ciphertext) - if err != nil { - t.Fatalf("failed to decrypt with decrypt-only algorithm: %s", err) - } + require.NoError(t, err, "failed to decrypt with decrypt-only algorithm") - if got, want := plaintext, payloadRef; !bytes.Equal(got, want) { - t.Errorf("decryption failed:\ngot %#v want %#v\n", got, want) - } + require.Equal(t, payloadRef, plaintext, "decryption failed") _, err = encryptOnly.Decrypt(ciphertext) - if err == nil { - t.Fatal("encrypt-only algorithm decrypted block without error - should be impossible") - } + require.Error(t, err, "encrypt-only algorithm decrypted block without error - should be impossible") _, err = encryptOnly.Signature(payload) - if err == nil { - t.Fatalf("encrypt-only algorithm generated a signature without error - should be impossible") - } + require.Error(t, err, "encrypt-only algorithm generated a signature without error - should be impossible") _, err = decryptOnly.Encrypt(payload) - if err == nil { - t.Fatal("decrypt-only algorithm encrypted block without error - should be impossible") - } + require.Error(t, err, "decrypt-only algorithm encrypted block without error - should be impossible") err = decryptOnly.VerifySignature(payload, signature) - if err == nil { - t.Fatalf("decrypt-only algorithm verified a signature without error - should be impossible") - } - + require.Error(t, err, "decrypt-only algorithm verified a signature without error - should be impossible") }) } } func TestZeroStruct(t *testing.T) { - defer func() { - if r := recover(); r != nil { - t.Error("panicked while checking zero value of struct", r) - } - }() - ze := &EncryptionAlgorithm{} const payload string = "The quick brown fox jumps over the lazy dog." plaintext := []byte(payload) // Call all the methods and make sure they don't panic due to nil pointers - _ = ze.BlockSize() - _ = ze.PlaintextBlockSize() - _, _ = ze.Encrypt(plaintext) - _, _ = ze.Decrypt(plaintext) - _, _ = ze.Signature(plaintext) - _ = ze.VerifySignature(plaintext, plaintext) - _ = ze.NonceLength() - _ = ze.SignatureLength() - _ = ze.EncryptionURI() - _ = ze.SignatureURI() + require.NotPanics(t, func() { ze.BlockSize() }) + require.NotPanics(t, func() { ze.PlaintextBlockSize() }) + require.NotPanics(t, func() { ze.Encrypt(plaintext) }) + require.NotPanics(t, func() { ze.Decrypt(plaintext) }) + require.NotPanics(t, func() { ze.Signature(plaintext) }) + require.NotPanics(t, func() { ze.VerifySignature(plaintext, plaintext) }) + require.NotPanics(t, func() { ze.NonceLength() }) + require.NotPanics(t, func() { ze.SignatureLength() }) + require.NotPanics(t, func() { ze.EncryptionURI() }) + require.NotPanics(t, func() { ze.SignatureURI() }) } func generatePrivateKey(bitSize int) (*rsa.PrivateKey, error) { diff --git a/uasc/codec_test.go b/uasc/codec_test.go index 18f94aee..4575640e 100644 --- a/uasc/codec_test.go +++ b/uasc/codec_test.go @@ -11,7 +11,7 @@ import ( "testing" "github.com/gopcua/opcua/ua" - "github.com/pascaldekloe/goe/verify" + "github.com/stretchr/testify/require" ) // CodecTestCase describes a test case for a encoding and decoding an @@ -39,26 +39,23 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { case reflect.Slice: v = reflect.New(typ) // typ: []x, v: *[]x default: - t.Fatalf("%T is not a pointer or a slice", c.Struct) + require.Fail(t, "%T is not a pointer or a slice", c.Struct) } - if _, err := ua.Decode(c.Bytes, v.Interface()); err != nil { - t.Fatal(err) - } + _, err := ua.Decode(c.Bytes, v.Interface()) + require.NoError(t, err, "Decode failed") // if v is a *[]x we need to dereference it before comparing it. if typ.Kind() == reflect.Slice { v = v.Elem() } - verify.Values(t, "", v.Interface(), c.Struct) + require.Equal(t, c.Struct, v.Interface(), "Decoded payload not equal") }) t.Run("encode", func(t *testing.T) { b, err := ua.Encode(c.Struct) - if err != nil { - t.Fatal(err) - } - verify.Values(t, "", b, c.Bytes) + require.NoError(t, err, "Encode failed") + require.Equal(t, c.Bytes, b, "Encoded payload not equal") }) }) } diff --git a/uasc/secure_channel_test.go b/uasc/secure_channel_test.go index 0065e739..3918a34f 100644 --- a/uasc/secure_channel_test.go +++ b/uasc/secure_channel_test.go @@ -1,13 +1,11 @@ package uasc import ( - "bytes" "crypto/rsa" "crypto/x509" "encoding/pem" "fmt" "math" - "strings" "testing" "time" @@ -16,7 +14,7 @@ import ( "github.com/gopcua/opcua/ua" "github.com/gopcua/opcua/uacp" "github.com/gopcua/opcua/uapolicy" - "github.com/pascaldekloe/goe/verify" + "github.com/stretchr/testify/require" ) func TestNewRequestMessage(t *testing.T) { @@ -146,10 +144,8 @@ func TestNewRequestMessage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m, err := tt.sechan.activeInstance.newRequestMessage(tt.req, tt.sechan.nextRequestID(), tt.authToken, tt.timeout) - if err != nil { - t.Fatal(err) - } - verify.Values(t, "", m, tt.m) + require.NoError(t, err) + require.Equal(t, tt.m, m) }) } } @@ -159,21 +155,15 @@ func TestSignAndEncryptVerifyAndDecrypt(t *testing.T) { t.Helper() certPEM, keyPEM, err := uatest.GenerateCert("localhost", bits, 24*time.Hour) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) block, _ := pem.Decode(keyPEM) pk, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) certblock, _ := pem.Decode(certPEM) remoteX509Cert, err := x509.ParseCertificate(certblock.Bytes) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) remoteKey := remoteX509Cert.PublicKey.(*rsa.PublicKey) alg, _ := uapolicy.Asymmetric(uri, pk, remoteKey) @@ -293,23 +283,17 @@ func TestSignAndEncryptVerifyAndDecrypt(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cipher, err := tt.c.signAndEncrypt(tt.m, tt.b) - if err != nil { - t.Fatalf("error: message encrypt: %v", err) - } + require.NoError(t, err, "error: message encrypt") m := new(MessageChunk) - if _, err := m.Decode(cipher); err != nil { - t.Fatalf("error: message decode: %v", err) - } + _, err = m.Decode(cipher) + require.NoError(t, err, "error: message decode") + plain, err := tt.c.verifyAndDecrypt(m, cipher) - if err != nil { - t.Fatalf("error: message decrypt: %v", err) - } + require.NoError(t, err, "error: message decrypt") headerLength := 12 + m.AsymmetricSecurityHeader.Len() - if got, want := plain, tt.b[headerLength:]; !bytes.Equal(got, want) { - t.Fatalf("got bytes %v want %v", got, want) - } + require.Equal(t, tt.b[headerLength:], plain, "header not equal") }) } } @@ -317,39 +301,29 @@ func TestSignAndEncryptVerifyAndDecrypt(t *testing.T) { func TestNewSecureChannel(t *testing.T) { t.Run("no connection", func(t *testing.T) { _, err := NewSecureChannel("", nil, nil, nil) - errorContains(t, err, "no connection") + require.ErrorContains(t, err, "no connection") }) t.Run("no error channel", func(t *testing.T) { _, err := NewSecureChannel("", &uacp.Conn{}, nil, nil) - errorContains(t, err, "no secure channel config") + require.ErrorContains(t, err, "no secure channel config") }) t.Run("no config", func(t *testing.T) { _, err := NewSecureChannel("", &uacp.Conn{}, nil, make(chan error)) - errorContains(t, err, "no secure channel config") + require.ErrorContains(t, err, "no secure channel config") }) t.Run("uri none, mode not none", func(t *testing.T) { cfg := &Config{SecurityPolicyURI: ua.SecurityPolicyURINone, SecurityMode: ua.MessageSecurityModeSign} _, err := NewSecureChannel("", &uacp.Conn{}, cfg, make(chan error)) - errorContains(t, err, "invalid channel config: Security policy 'http://opcfoundation.org/UA/SecurityPolicy#None' cannot be used with 'MessageSecurityModeSign'") + require.ErrorContains(t, err, "invalid channel config: Security policy 'http://opcfoundation.org/UA/SecurityPolicy#None' cannot be used with 'MessageSecurityModeSign'") }) t.Run("uri not none, mode none", func(t *testing.T) { cfg := &Config{SecurityPolicyURI: ua.SecurityPolicyURIBasic256, SecurityMode: ua.MessageSecurityModeNone} _, err := NewSecureChannel("", &uacp.Conn{}, cfg, make(chan error)) - errorContains(t, err, "Security policy 'http://opcfoundation.org/UA/SecurityPolicy#Basic256' cannot be used with 'MessageSecurityModeNone'") + require.ErrorContains(t, err, "Security policy 'http://opcfoundation.org/UA/SecurityPolicy#Basic256' cannot be used with 'MessageSecurityModeNone'") }) t.Run("uri not none, local key missing", func(t *testing.T) { cfg := &Config{SecurityPolicyURI: ua.SecurityPolicyURIBasic256, SecurityMode: ua.MessageSecurityModeSign} _, err := NewSecureChannel("", &uacp.Conn{}, cfg, make(chan error)) - errorContains(t, err, "invalid channel config: Security policy 'http://opcfoundation.org/UA/SecurityPolicy#Basic256' requires a private key") + require.ErrorContains(t, err, "invalid channel config: Security policy 'http://opcfoundation.org/UA/SecurityPolicy#Basic256' requires a private key") }) } - -func errorContains(t *testing.T, err error, msg string) { - t.Helper() - if err == nil { - t.Fatal("expected an error but got nil") - } - if !strings.Contains(err.Error(), msg) { - t.Fatalf("error '%s' does not contain '%s'", err, msg) - } -}