From 783fd928985bd7871016611d264ed7a0604dc751 Mon Sep 17 00:00:00 2001 From: Keisuke Kanao Date: Thu, 23 May 2024 16:36:53 +0900 Subject: [PATCH] [BLOCK-2310] Modified to catch panic from json.Marshal function --- eosws/marshal_test.go | 34 ++++++++++++++++++++++++++++++++++ eosws/wsconn.go | 15 +++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 eosws/marshal_test.go diff --git a/eosws/marshal_test.go b/eosws/marshal_test.go new file mode 100644 index 00000000..e23e40de --- /dev/null +++ b/eosws/marshal_test.go @@ -0,0 +1,34 @@ +package eosws + +import ( + "fmt" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/dfuse-io/dfuse-eosio/eosws/wsmsg" +) + +// TestMarshalOutgoingMessager is a code snippet extracted from eosws.(*WSConn).Emit() function. +// Here, SetReqID() is called to set an invalid string, because another function for OutgoingMessager interface, SetType() is called +// and the result is validated in Emit(). +func TestMarshalOutgoingMessager(t *testing.T) { + var err error + defer func() { + if r := recover(); r != nil { + switch v := r.(type) { + case error: + err = fmt.Errorf("unexpected error marshalling message: %w", v) + case string, fmt.Stringer: + err = fmt.Errorf("unexpected error marshalling message: %s", v) + default: + err = fmt.Errorf("unexpected error marshalling message: %v", v) + } + fmt.Printf("%s\n", err) + assert.Error(t, err) + } + }() + var msg wsmsg.OutgoingMessager + msg.SetReqID("") + _, err = json.Marshal(msg) +} diff --git a/eosws/wsconn.go b/eosws/wsconn.go index 3f3983ff..4c751293 100644 --- a/eosws/wsconn.go +++ b/eosws/wsconn.go @@ -307,6 +307,21 @@ func (ws *WSConn) handleMessage(rawMsg []byte) { } func (ws *WSConn) Emit(ctx context.Context, msg wsmsg.OutgoingMessager) { + var err error + defer func() { + if r := recover(); r != nil { + switch v := r.(type) { + case error: + err = fmt.Errorf("unexpected error marshalling message: %w", v) + case string, fmt.Stringer: + err = fmt.Errorf("unexpected error marshalling message: %s", v) + default: + err = fmt.Errorf("unexpected error marshalling message: %v", v) + } + ws.Shutdown(err) + } + }() + zlogger := logging.Logger(ctx, zlog) msgType, err := wsmsg.GetType(msg)