From 0186dcb43fa0c17cb8b655d4cef71e3bfa33e701 Mon Sep 17 00:00:00 2001 From: werben Date: Tue, 19 Sep 2023 16:22:13 +0800 Subject: [PATCH] Upgrade MQTT to mochi-mqtt v2.4.0 --- mqtt/clients.go | 33 +- mqtt/clients_test.go | 60 +- mqtt/cmd/main.go | 9 +- mqtt/examples/auth/basic/main.go | 8 +- mqtt/examples/auth/encoded/main.go | 8 +- mqtt/examples/benchmark/main.go | 8 +- mqtt/examples/debug/main.go | 18 +- mqtt/examples/direct/main.go | 83 +++ mqtt/examples/hooks/main.go | 38 +- mqtt/examples/paho.testing/main.go | 8 +- mqtt/examples/persistence/badger/main.go | 9 +- mqtt/examples/persistence/bolt/main.go | 13 +- mqtt/examples/persistence/redis/main.go | 19 +- mqtt/examples/tcp/main.go | 8 +- mqtt/examples/tls/main.go | 8 +- mqtt/examples/websocket/main.go | 8 +- mqtt/hooks.go | 64 +- mqtt/hooks/auth/allow_all.go | 2 +- mqtt/hooks/auth/allow_all_test.go | 2 +- mqtt/hooks/auth/auth.go | 26 +- mqtt/hooks/auth/auth_test.go | 23 +- mqtt/hooks/auth/ledger.go | 38 +- mqtt/hooks/auth/ledger_test.go | 8 +- mqtt/hooks/debug/debug.go | 70 +- mqtt/hooks/storage/badger/badger.go | 70 +- mqtt/hooks/storage/badger/badger_test.go | 114 +-- mqtt/hooks/storage/bolt/bolt.go | 79 +-- mqtt/hooks/storage/bolt/bolt_test.go | 106 +-- mqtt/hooks/storage/redis/redis.go | 95 ++- mqtt/hooks/storage/redis/redis_test.go | 18 +- mqtt/hooks/storage/storage.go | 6 +- mqtt/hooks/storage/storage_test.go | 2 +- mqtt/hooks_test.go | 28 +- mqtt/inflight.go | 12 +- mqtt/inflight_test.go | 2 +- mqtt/listeners/http_healthcheck.go | 100 +++ mqtt/listeners/http_healthcheck_test.go | 143 ++++ mqtt/listeners/http_sysinfo.go | 49 +- mqtt/listeners/http_sysinfo_test.go | 10 +- mqtt/listeners/listeners.go | 18 +- mqtt/listeners/listeners_test.go | 7 +- mqtt/listeners/mock.go | 6 +- mqtt/listeners/mock_test.go | 6 +- mqtt/listeners/net.go | 92 +++ mqtt/listeners/net_test.go | 105 +++ mqtt/listeners/tcp.go | 20 +- mqtt/listeners/tcp_test.go | 34 +- mqtt/listeners/unixsock.go | 20 +- mqtt/listeners/unixsock_test.go | 12 +- mqtt/listeners/websocket.go | 21 +- mqtt/listeners/websocket_test.go | 20 +- mqtt/packets/codec.go | 2 +- mqtt/packets/codec_test.go | 2 +- mqtt/packets/codes.go | 3 +- mqtt/packets/codes_test.go | 4 +- mqtt/packets/fixedheader.go | 2 +- mqtt/packets/fixedheader_test.go | 2 +- mqtt/packets/packets.go | 48 +- mqtt/packets/packets_test.go | 6 +- mqtt/packets/properties.go | 6 +- mqtt/packets/properties_test.go | 2 +- mqtt/packets/tpackets.go | 76 +- mqtt/packets/tpackets_test.go | 2 +- mqtt/server.go | 265 ++++--- mqtt/server_test.go | 853 ++++++++++++++++++----- mqtt/system/system.go | 2 +- mqtt/topics.go | 155 +++- mqtt/topics_test.go | 227 +++++- 68 files changed, 2469 insertions(+), 954 deletions(-) create mode 100644 mqtt/examples/direct/main.go create mode 100644 mqtt/listeners/http_healthcheck.go create mode 100644 mqtt/listeners/http_healthcheck_test.go create mode 100644 mqtt/listeners/net.go create mode 100644 mqtt/listeners/net_test.go diff --git a/mqtt/clients.go b/mqtt/clients.go index 39ddbd4..3497df5 100644 --- a/mqtt/clients.go +++ b/mqtt/clients.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co, wind package mqtt @@ -104,13 +104,13 @@ func (cl *Clients) GetByListener(id string) []*Client { // Client contains information about a client known by the broker. type Client struct { - ops *ops // ops provides a reference to server ops. + Properties ClientProperties // client properties State ClientState // the operational state of the client. + Net ClientConnection // network connection state of the client ID string // the client id. - Net ClientConnection // network connection state of the clinet - Properties ClientProperties // client properties - InheritWay int // session inheritance way + ops *ops // ops provides a reference to server ops. sync.RWMutex // mutex + InheritWay int // session inheritance way } // ClientConnection contains the connection transport and metadata for the client. @@ -119,42 +119,42 @@ type ClientConnection struct { bconn *bufio.ReadWriter // a buffered net.Conn for reading packets Remote string // the remote address of the client Listener string // listener id of the client - Inline bool // client is an inline programmetic client + Inline bool // if true, the client is the built-in 'inline' embedded client } // ClientProperties contains the properties which define the client behaviour. type ClientProperties struct { - Username []byte - Will Will Props packets.Properties + Will Will + Username []byte ProtocolVersion byte Clean bool } // Will contains the last will and testament details for a client connection. type Will struct { - TopicName string // - Payload []byte // - User []packets.UserProperty // - + TopicName string // - Flag uint32 // 0,1 WillDelayInterval uint32 // - Qos byte // - Retain bool // - } -// State tracks the state of the client. +// ClientState tracks the state of the client. type ClientState struct { TopicAliases TopicAliases // a map of topic aliases stopCause atomic.Value // reason for stopping - open context.Context // indicate that the client is open for packet exchange - Subscriptions *Subscriptions // a map of the subscription filters a client maintains - outbound chan *packets.Packet // queue for pending outbound packets Inflight *Inflight // a map of in-flight qos messages - cancelOpen context.CancelFunc // cancel function for open context + Subscriptions *Subscriptions // a map of the subscription filters a client maintains disconnected int64 // the time the client disconnected in unix time, for calculating expiry + outbound chan *packets.Packet // queue for pending outbound packets endOnce sync.Once // only end once isTakenOver uint32 // used to identify orphaned clients packetID uint32 // the current highest packetID + open context.Context // indicate that the client is open for packet exchange + cancelOpen context.CancelFunc // cancel function for open context outboundQty int32 // number of messages currently in the outbound queue Keepalive uint16 // the number of seconds the connection can wait ServerKeepalive bool // keepalive was set by the server @@ -200,7 +200,8 @@ func (cl *Client) WriteLoop() { select { case pk := <-cl.State.outbound: if err := cl.WritePacket(*pk); err != nil { - cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet") + // TODO : Figure out what to do with error + cl.ops.log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk) } atomic.AddInt32(&cl.State.outboundQty, -1) case <-cl.State.open.Done(): @@ -318,7 +319,7 @@ func (cl *Client) ResendInflightMessages(force bool) error { return nil } -// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session. +// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session. func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 { deleted := []uint16{} for _, tk := range cl.State.Inflight.GetAll(false) { diff --git a/mqtt/clients_test.go b/mqtt/clients_test.go index a6748f1..957a9ba 100644 --- a/mqtt/clients_test.go +++ b/mqtt/clients_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt @@ -29,7 +29,7 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) { cl = newClient(w, &ops{ info: new(system.Info), hooks: new(Hooks), - log: &logger, + log: logger, options: &Options{ Capabilities: &Capabilities{ ReceiveMaximum: 10, @@ -263,7 +263,7 @@ func TestClientNextPacketIDOverflow(t *testing.T) { cl.State.Inflight.internal[uint16(i)] = packets.Packet{} } - cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1) + cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID - 1 i, err := cl.NextPacketID() require.NoError(t, err) require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i) @@ -303,7 +303,7 @@ func TestClientResendInflightMessages(t *testing.T) { err := cl.ResendInflightMessages(true) require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -315,7 +315,7 @@ func TestClientResendInflightMessages(t *testing.T) { func TestClientResendInflightMessagesWriteFailure(t *testing.T) { pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup) cl, r, _ := newTestClient() - r.Close() + _ = r.Close() cl.State.Inflight.Set(*pk1.Packet) require.Equal(t, 1, cl.State.Inflight.Len()) @@ -342,8 +342,8 @@ func TestClientReadFixedHeader(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write([]byte{packets.Connect << 4, 0x00}) - r.Close() + _, _ = r.Write([]byte{packets.Connect << 4, 0x00}) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -357,8 +357,8 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}) - r.Close() + _, _ = r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -372,8 +372,8 @@ func TestClientReadFixedHeaderPacketOversized(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -387,7 +387,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Close() + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -401,8 +401,8 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}) - r.Close() + _, _ = r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -414,7 +414,7 @@ func TestClientReadOK(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 18, // Fixed header 0, 5, // Topic Name - LSB+MSB 'a', '/', 'b', '/', 'c', // Topic Name @@ -424,7 +424,7 @@ func TestClientReadOK(t *testing.T) { 'd', '/', 'e', '/', 'f', // Topic Name 'y', 'e', 'a', 'h', // Payload }) - r.Close() + _ = r.Close() }() var pks []packets.Packet @@ -499,10 +499,10 @@ func TestClientReadFixedHeaderError(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 11, // Fixed header }) - r.Close() + _ = r.Close() }() cl.Net.bconn = nil @@ -516,13 +516,13 @@ func TestClientReadReadHandlerErr(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 11, // Fixed header 0, 5, // Topic Name - LSB+MSB 'd', '/', 'e', '/', 'f', // Topic Name 'y', 'e', 'a', 'h', // Payload }) - r.Close() + _ = r.Close() }() err := cl.Read(func(cl *Client, pk packets.Packet) error { @@ -536,13 +536,13 @@ func TestClientReadReadPacketOK(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 11, // Fixed header 0, 5, 'd', '/', 'e', '/', 'f', 'y', 'e', 'a', 'h', }) - r.Close() + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -573,7 +573,7 @@ func TestClientReadPacket(t *testing.T) { t.Run(tt.Desc, func(t *testing.T) { atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0) go func() { - r.Write(tt.RawBytes) + _, _ = r.Write(tt.RawBytes) }() fh := new(packets.FixedHeader) @@ -600,7 +600,7 @@ func TestClientReadPacket(t *testing.T) { func TestClientReadPacketInvalidTypeError(t *testing.T) { cl, _, _ := newTestClient() - cl.Net.Conn.Close() + _ = cl.Net.Conn.Close() _, err := cl.ReadPacket(&packets.FixedHeader{}) require.Error(t, err) require.Contains(t, err.Error(), "invalid packet type") @@ -624,7 +624,7 @@ func TestClientWritePacket(t *testing.T) { require.NoError(t, err, pkInfo, tt.Case, tt.Desc) time.Sleep(2 * time.Millisecond) - cl.Net.Conn.Close() + _ = cl.Net.Conn.Close() require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc) @@ -660,13 +660,13 @@ func TestClientReadPacketReadingError(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ 0, 11, // Fixed header 0, 5, 'd', '/', 'e', '/', 'f', 'y', 'e', 'a', 'h', }) - r.Close() + _ = r.Close() }() _, err := cl.ReadPacket(&packets.FixedHeader{ @@ -680,13 +680,13 @@ func TestClientReadPacketReadUnknown(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ 0, 11, // Fixed header 0, 5, 'd', '/', 'e', '/', 'f', 'y', 'e', 'a', 'h', }) - r.Close() + _ = r.Close() }() _, err := cl.ReadPacket(&packets.FixedHeader{ @@ -706,7 +706,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) { func TestClientWritePacketWriteError(t *testing.T) { cl, _, _ := newTestClient() - cl.Net.Conn.Close() + _ = cl.Net.Conn.Close() err := cl.WritePacket(*pkTable[1].Packet) require.Error(t, err) diff --git a/mqtt/cmd/main.go b/mqtt/cmd/main.go index 6435940..050fcc4 100644 --- a/mqtt/cmd/main.go +++ b/mqtt/cmd/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -68,7 +68,8 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") + } diff --git a/mqtt/examples/auth/basic/main.go b/mqtt/examples/auth/basic/main.go index 1946131..84be23f 100644 --- a/mqtt/examples/auth/basic/main.go +++ b/mqtt/examples/auth/basic/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -77,7 +77,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/auth/encoded/main.go b/mqtt/examples/auth/encoded/main.go index 5a708cc..8d0095f 100644 --- a/mqtt/examples/auth/encoded/main.go +++ b/mqtt/examples/auth/encoded/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -59,7 +59,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/benchmark/main.go b/mqtt/examples/benchmark/main.go index b30fe45..b1accd2 100644 --- a/mqtt/examples/benchmark/main.go +++ b/mqtt/examples/benchmark/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -45,7 +45,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/debug/main.go b/mqtt/examples/debug/main.go index 798d1bc..6c05f58 100644 --- a/mqtt/examples/debug/main.go +++ b/mqtt/examples/debug/main.go @@ -1,16 +1,16 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main import ( "log" + "log/slog" "os" "os/signal" "syscall" - "github.com/rs/zerolog" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/hooks/debug" @@ -27,8 +27,12 @@ func main() { }() server := mqtt.New(nil) - l := server.Log.Level(zerolog.DebugLevel) - server.Log = &l + + level := new(slog.LevelVar) + server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: level, + })) + level.Set(slog.LevelDebug) err := server.AddHook(new(debug.Hook), &debug.Options{ // ShowPacketData: true, @@ -56,7 +60,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/direct/main.go b/mqtt/examples/direct/main.go new file mode 100644 index 0000000..885a338 --- /dev/null +++ b/mqtt/examples/direct/main.go @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package main + +import ( + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" + + mqtt "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/mqtt/packets" +) + +func main() { + sigs := make(chan os.Signal, 1) + done := make(chan bool, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigs + done <- true + }() + + server := mqtt.New(&mqtt.Options{ + InlineClient: true, // you must enable inline client to use direct publishing and subscribing. + }) + _ = server.AddHook(new(auth.AllowHook), nil) + + // Start the server + go func() { + err := server.Serve() + if err != nil { + log.Fatal(err) + } + }() + + // Demonstration of using an inline client to directly subscribe to a topic and receive a message when + // that subscription is activated. The inline subscription method uses the same internal subscription logic + // as used for external (normal) clients. + go func() { + // Inline subscriptions can also receive retained messages on subscription. + _ = server.Publish("direct/retained", []byte("retained message"), true, 0) + _ = server.Publish("direct/alternate/retained", []byte("some other retained message"), true, 0) + + // Subscribe to a filter and handle any received messages via a callback function. + callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) { + server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload)) + } + server.Log.Info("inline client subscribing") + _ = server.Subscribe("direct/#", 1, callbackFn) + _ = server.Subscribe("direct/#", 2, callbackFn) + }() + + // There is a shorthand convenience function, Publish, for easily sending publish packets if you are not + // concerned with creating your own packets. If you want to have more control over your packets, you can + //directly inject a packet of any kind into the broker. See examples/hooks/main.go for usage. + go func() { + for range time.Tick(time.Second * 3) { + err := server.Publish("direct/publish", []byte("scheduled message"), false, 0) + if err != nil { + server.Log.Error("server.Publish", "error", err) + } + server.Log.Info("main.go issued direct message to direct/publish") + } + }() + + go func() { + time.Sleep(time.Second * 10) + // Unsubscribe from the same filter to stop receiving messages. + server.Log.Info("inline client unsubscribing") + _ = server.Unsubscribe("direct/#", 1) + }() + + <-done + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") +} diff --git a/mqtt/examples/hooks/main.go b/mqtt/examples/hooks/main.go index f4be075..f5835d7 100644 --- a/mqtt/examples/hooks/main.go +++ b/mqtt/examples/hooks/main.go @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main import ( "bytes" + "fmt" "log" "os" "os/signal" @@ -62,9 +63,9 @@ func main() { Payload: []byte("injected scheduled message"), }) if err != nil { - server.Log.Error().Err(err).Msg("server.InjectPacket") + server.Log.Error("server.InjectPacket", "error", err) } - server.Log.Info().Msgf("main.go injected packet to direct/publish") + server.Log.Info("main.go injected packet to direct/publish") } }() @@ -74,16 +75,16 @@ func main() { for range time.Tick(time.Second * 5) { err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0) if err != nil { - server.Log.Error().Err(err).Msg("server.Publish") + server.Log.Error("server.Publish", "error", err) } - server.Log.Info().Msgf("main.go issued direct message to direct/publish") + server.Log.Info("main.go issued direct message to direct/publish") } }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } type ExampleHook struct { @@ -106,39 +107,44 @@ func (h *ExampleHook) Provides(b byte) bool { } func (h *ExampleHook) Init(config any) error { - h.Log.Info().Msg("initialised") + h.Log.Info("initialised") return nil } func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error { - h.Log.Info().Str("client", cl.ID).Msgf("client connected") + h.Log.Info("client connected", "client", cl.ID) return nil } func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) { - h.Log.Info().Str("client", cl.ID).Bool("expire", expire).Err(err).Msg("client disconnected") + if err != nil { + h.Log.Info("client disconnected", "client", cl.ID, "expire", expire, "error", err) + } else { + h.Log.Info("client disconnected", "client", cl.ID, "expire", expire) + } + } func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { - h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msgf("subscribed qos=%v", reasonCodes) + h.Log.Info(fmt.Sprintf("subscribed qos=%v", reasonCodes), "client", cl.ID, "filters", pk.Filters) } func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { - h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msg("unsubscribed") + h.Log.Info("unsubscribed", "client", cl.ID, "filters", pk.Filters) } func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { - h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("received from client") + h.Log.Info("received from client", "client", cl.ID, "payload", string(pk.Payload)) pkx := pk if string(pk.Payload) == "hello" { pkx.Payload = []byte("hello world") - h.Log.Info().Str("client", cl.ID).Str("payload", string(pkx.Payload)).Msg("received modified packet from client") + h.Log.Info("received modified packet from client", "client", cl.ID, "payload", string(pkx.Payload)) } return pkx, nil } func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) { - h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("published to client") + h.Log.Info("published to client", "client", cl.ID, "payload", string(pk.Payload)) } diff --git a/mqtt/examples/paho.testing/main.go b/mqtt/examples/paho.testing/main.go index 03694f4..66ce193 100644 --- a/mqtt/examples/paho.testing/main.go +++ b/mqtt/examples/paho.testing/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -45,9 +45,9 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } type pahoAuthHook struct { diff --git a/mqtt/examples/persistence/badger/main.go b/mqtt/examples/persistence/badger/main.go index a27fd09..c3a6e38 100644 --- a/mqtt/examples/persistence/badger/main.go +++ b/mqtt/examples/persistence/badger/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -52,8 +52,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") - + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/persistence/bolt/main.go b/mqtt/examples/persistence/bolt/main.go index 2b2c973..685d59a 100644 --- a/mqtt/examples/persistence/bolt/main.go +++ b/mqtt/examples/persistence/bolt/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -30,12 +30,15 @@ func main() { server := mqtt.New(nil) _ = server.AddHook(new(auth.AllowHook), nil) - err := server.AddHook(new(bolt.Hook), bolt.Options{ + err := server.AddHook(new(bolt.Hook), &bolt.Options{ Path: "bolt.db", Options: &bbolt.Options{ Timeout: 500 * time.Millisecond, }, }) + if err != nil { + log.Fatal(err) + } tcp := listeners.NewTCP("t1", ":1883", nil) err = server.AddListener(tcp) @@ -51,7 +54,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/persistence/redis/main.go b/mqtt/examples/persistence/redis/main.go index 4d0e14b..53d8926 100644 --- a/mqtt/examples/persistence/redis/main.go +++ b/mqtt/examples/persistence/redis/main.go @@ -1,16 +1,16 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main import ( "log" + "log/slog" "os" "os/signal" "syscall" - "github.com/rs/zerolog" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/hooks/storage/redis" @@ -30,8 +30,12 @@ func main() { server := mqtt.New(nil) _ = server.AddHook(new(auth.AllowHook), nil) - l := server.Log.Level(zerolog.DebugLevel) - server.Log = &l + + level := new(slog.LevelVar) + server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: level, + })) + level.Set(slog.LevelDebug) err := server.AddHook(new(redis.Hook), &redis.Options{ Options: &rv8.Options{ @@ -58,8 +62,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") - + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/tcp/main.go b/mqtt/examples/tcp/main.go index 1060b95..8a51a70 100644 --- a/mqtt/examples/tcp/main.go +++ b/mqtt/examples/tcp/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -52,7 +52,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/tls/main.go b/mqtt/examples/tls/main.go index 52a4b62..837158e 100644 --- a/mqtt/examples/tls/main.go +++ b/mqtt/examples/tls/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -111,7 +111,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/websocket/main.go b/mqtt/examples/websocket/main.go index 1a48d1e..1757bcc 100644 --- a/mqtt/examples/websocket/main.go +++ b/mqtt/examples/websocket/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -41,7 +41,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/hooks.go b/mqtt/hooks.go index e6f8a8b..9746675 100644 --- a/mqtt/hooks.go +++ b/mqtt/hooks.go @@ -1,20 +1,19 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co -// SPDX-FileContributor: mochi-co, wind +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co, wind, thedevop, dgduncan package mqtt import ( "errors" "fmt" + "log/slog" "sync" "sync/atomic" "github.com/wind-c/comqtt/v2/mqtt/hooks/storage" "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/mqtt/system" - - "github.com/rs/zerolog" ) const ( @@ -74,7 +73,7 @@ type Hook interface { Provides(b byte) bool Init(config any) error Stop() error - SetOpts(l *zerolog.Logger, o *HookOptions) + SetOpts(l *slog.Logger, o *HookOptions) OnStarted() OnStopped() OnConnectAuthenticate(cl *Client, pk packets.Packet) bool @@ -125,11 +124,11 @@ type HookOptions struct { // Hooks is a slice of Hook interfaces to be called in sequence. type Hooks struct { - Log *zerolog.Logger // a logger for the hook (from the server) - internal atomic.Value // a slice of []Hook - wg sync.WaitGroup // a waitgroup for syncing hook shutdown - qty int64 // the number of hooks in use - sync.Mutex // a mutex for locking when adding hooks + Log *slog.Logger // a logger for the hook (from the server) + internal atomic.Value // a slice of []Hook + wg sync.WaitGroup // a waitgroup for syncing hook shutdown + qty int64 // the number of hooks in use + sync.Mutex // a mutex for locking when adding hooks } // Len returns the number of hooks added. @@ -187,9 +186,9 @@ func (h *Hooks) GetAll() []Hook { func (h *Hooks) Stop() { go func() { for _, hook := range h.GetAll() { - h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook") + h.Log.Info("stopping hook", "hook", hook.ID()) if err := hook.Stop(); err != nil { - h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook") + h.Log.Debug("problem stopping hook", "error", err, "hook", hook.ID()) } h.wg.Done() @@ -274,7 +273,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, if hook.Provides(OnPacketRead) { npk, err := hook.OnPacketRead(cl, pkx) if err != nil && errors.Is(err, packets.ErrRejectPacket) { - h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("packet rejected") + h.Log.Debug("packet rejected", "hook", hook.ID(), "packet", pkx) return pk, err } else if err != nil { continue @@ -402,10 +401,16 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er npk, err := hook.OnPublish(cl, pkx) if err != nil { if errors.Is(err, packets.ErrRejectPacket) { - h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected") + h.Log.Debug("publish packet rejected", + "error", err, + "hook", hook.ID(), + "packet", pkx) return pk, err } - h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet error") + h.Log.Error("publish packet error", + "error", err, + "hook", hook.ID(), + "packet", pkx) return pk, err } pkx = npk @@ -504,7 +509,10 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will { if hook.Provides(OnWill) { mlwt, err := hook.OnWill(cl, will) if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("will", will).Msg("parse will error") + h.Log.Error("parse will error", + "error", err, + "hook", hook.ID(), + "will", will) continue } will = mlwt @@ -548,7 +556,7 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) { if hook.Provides(StoredClients) { v, err := hook.StoredClients() if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients") + h.Log.Error("failed to load clients", "error", err, "hook", hook.ID()) return v, err } @@ -568,7 +576,7 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) { if hook.Provides(StoredSubscriptions) { v, err := hook.StoredSubscriptions() if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load subscriptions") + h.Log.Error("failed to load subscriptions", "error", err, "hook", hook.ID()) return v, err } @@ -588,7 +596,7 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) { if hook.Provides(StoredInflightMessages) { v, err := hook.StoredInflightMessages() if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load inflight messages") + h.Log.Error("failed to load inflight messages", "error", err, "hook", hook.ID()) return v, err } @@ -608,7 +616,7 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) { if hook.Provides(StoredRetainedMessages) { v, err := hook.StoredRetainedMessages() if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load retained messages") + h.Log.Error("failed to load retained messages", "error", err, "hook", hook.ID()) return v, err } @@ -627,7 +635,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) { if hook.Provides(StoredSysInfo) { v, err := hook.StoredSysInfo() if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load $SYS info") + h.Log.Error("failed to load $SYS info", "error", err, "hook", hook.ID()) return v, err } @@ -646,7 +654,7 @@ func (h *Hooks) StoredClientByCid(cid string) (v storage.Client, err error) { if hook.Provides(StoredClientByCid) { v, err := hook.StoredClientByCid(cid) if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients") + h.Log.Error("failed to load clients", "error", err, "hook", hook.ID()) return v, err } @@ -665,7 +673,7 @@ func (h *Hooks) StoredSubscriptionsByCid(cid string) (v []storage.Subscription, if hook.Provides(StoredSubscriptionsByCid) { v, err := hook.StoredSubscriptionsByCid(cid) if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to get subscriptions") + h.Log.Error("failed to get subscriptions", "error", err, "hook", hook.ID()) return v, err } @@ -684,7 +692,7 @@ func (h *Hooks) StoredInflightMessagesByCid(cid string) (v []storage.Message, er if hook.Provides(StoredInflightMessagesByCid) { v, err := hook.StoredInflightMessagesByCid(cid) if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to get inflight messages") + h.Log.Error("failed to get inflight messages", "error", err, "hook", hook.ID()) return v, err } @@ -703,7 +711,7 @@ func (h *Hooks) StoredRetainedMessageByTopic(topic string) (v storage.Message, e if hook.Provides(StoredRetainedMessageByTopic) { v, err := hook.StoredRetainedMessageByTopic(topic) if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to get retained message") + h.Log.Error("failed to get retained message", "error", err, "hook", hook.ID()) return v, err } @@ -752,7 +760,7 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool { // all hooks. type HookBase struct { Hook - Log *zerolog.Logger + Log *slog.Logger Opts *HookOptions } @@ -775,12 +783,12 @@ func (h *HookBase) Init(config any) error { // SetOpts is called by the server to propagate internal values and generally should // not be called manually. -func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) { +func (h *HookBase) SetOpts(l *slog.Logger, opts *HookOptions) { h.Log = l h.Opts = opts } -// Stop is called to gracefully shutdown the hook. +// Stop is called to gracefully shut down the hook. func (h *HookBase) Stop() error { return nil } diff --git a/mqtt/hooks/auth/allow_all.go b/mqtt/hooks/auth/allow_all.go index dd059ce..1e8e71c 100644 --- a/mqtt/hooks/auth/allow_all.go +++ b/mqtt/hooks/auth/allow_all.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth diff --git a/mqtt/hooks/auth/allow_all_test.go b/mqtt/hooks/auth/allow_all_test.go index 4815365..90fd48e 100644 --- a/mqtt/hooks/auth/allow_all_test.go +++ b/mqtt/hooks/auth/allow_all_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth diff --git a/mqtt/hooks/auth/auth.go b/mqtt/hooks/auth/auth.go index a093d2b..c4eae8d 100644 --- a/mqtt/hooks/auth/auth.go +++ b/mqtt/hooks/auth/auth.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth @@ -67,10 +67,9 @@ func (h *Hook) Init(config any) error { } } - h.Log.Info(). - Int("authentication", len(h.ledger.Auth)). - Int("acl", len(h.ledger.ACL)). - Msg("loaded auth rules") + h.Log.Info("loaded auth rules", + "authentication", len(h.ledger.Auth), + "acl", len(h.ledger.ACL)) return nil } @@ -82,11 +81,9 @@ func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { return true } - h.Log.Info(). - Str("username", string(pk.Connect.Username)). - Str("remote", cl.Net.Remote). - Msg("client failed authentication check") - + h.Log.Info("client failed authentication check", + "username", string(pk.Connect.Username), + "remote", cl.Net.Remote) return false } @@ -97,11 +94,10 @@ func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { return true } - h.Log.Debug(). - Str("client", cl.ID). - Str("username", string(cl.Properties.Username)). - Str("topic", topic). - Msg("client failed allowed ACL check") + h.Log.Debug("client failed allowed ACL check", + "client", cl.ID, + "username", string(cl.Properties.Username), + "topic", topic) return false } diff --git a/mqtt/hooks/auth/auth_test.go b/mqtt/hooks/auth/auth_test.go index 60059e5..ee6bd09 100644 --- a/mqtt/hooks/auth/auth_test.go +++ b/mqtt/hooks/auth/auth_test.go @@ -1,20 +1,19 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth import ( + "log/slog" "os" "testing" - - "github.com/rs/zerolog" "github.com/stretchr/testify/require" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/packets" ) -var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) +var logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) // func teardown(t *testing.T, path string, h *Hook) { // h.Stop() @@ -34,7 +33,7 @@ func TestBasicProvides(t *testing.T) { func TestBasicInitBadConfig(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(map[string]any{}) require.Error(t, err) @@ -42,7 +41,7 @@ func TestBasicInitBadConfig(t *testing.T) { func TestBasicInitDefaultConfig(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) @@ -50,7 +49,7 @@ func TestBasicInitDefaultConfig(t *testing.T) { func TestBasicInitWithLedgerPointer(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) ln := &Ledger{ Auth: []AuthRule{ @@ -79,7 +78,7 @@ func TestBasicInitWithLedgerPointer(t *testing.T) { func TestBasicInitWithLedgerJSON(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.Nil(t, h.ledger) err := h.Init(&Options{ @@ -93,7 +92,7 @@ func TestBasicInitWithLedgerJSON(t *testing.T) { func TestBasicInitWithLedgerYAML(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.Nil(t, h.ledger) err := h.Init(&Options{ @@ -107,7 +106,7 @@ func TestBasicInitWithLedgerYAML(t *testing.T) { func TestBasicInitWithLedgerBadDAta(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.Nil(t, h.ledger) err := h.Init(&Options{ @@ -119,7 +118,7 @@ func TestBasicInitWithLedgerBadDAta(t *testing.T) { func TestOnConnectAuthenticate(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) ln := new(Ledger) ln.Auth = checkLedger.Auth @@ -158,7 +157,7 @@ func TestOnConnectAuthenticate(t *testing.T) { func TestOnACL(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) ln := new(Ledger) ln.Auth = checkLedger.Auth diff --git a/mqtt/hooks/auth/ledger.go b/mqtt/hooks/auth/ledger.go index e47d4ce..0a85490 100644 --- a/mqtt/hooks/auth/ledger.go +++ b/mqtt/hooks/auth/ledger.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth @@ -124,8 +124,8 @@ func (r RString) Matches(a string) bool { } // FilterMatches returns true if a filter matches a topic rule. -func (f RString) FilterMatches(a string) bool { - _, ok := MatchTopic(string(f), a) +func (r RString) FilterMatches(a string) bool { + _, ok := MatchTopic(string(r), a) return ok } @@ -205,7 +205,7 @@ func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) { } // ACLOk returns true if the rules indicate the user is allowed to read or write to -// a specific filter or topic respectively, based on the write bool. +// a specific filter or topic respectively, based on the `write` bool. func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) { // If the users map is set, always check for a predefined user first instead // of iterating through global rules. @@ -233,17 +233,31 @@ func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok boo return n, true } - for filter, access := range rule.Filters { - if filter.FilterMatches(topic) { - if !write && (access == ReadOnly || access == ReadWrite) { - return n, true - } else if write && (access == WriteOnly || access == ReadWrite) { - return n, true - } else { - return n, false + if write { + for filter, access := range rule.Filters { + if access == WriteOnly || access == ReadWrite { + if filter.FilterMatches(topic) { + return n, true + } } } } + + if !write { + for filter, access := range rule.Filters { + if access == ReadOnly || access == ReadWrite { + if filter.FilterMatches(topic) { + return n, true + } + } + } + } + + for filter := range rule.Filters { + if filter.FilterMatches(topic) { + return n, false + } + } } } diff --git a/mqtt/hooks/auth/ledger_test.go b/mqtt/hooks/auth/ledger_test.go index 6ac8ac9..18004d1 100644 --- a/mqtt/hooks/auth/ledger_test.go +++ b/mqtt/hooks/auth/ledger_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth @@ -561,17 +561,17 @@ func TestLedgerUpdate(t *testing.T) { }, } - new := &Ledger{ + n := &Ledger{ Auth: AuthRules{ {Remote: "127.0.0.1", Allow: true}, {Remote: "192.168.*", Allow: true}, }, } - old.Update(new) + old.Update(n) require.Len(t, old.Auth, 2) require.Equal(t, RString("192.168.*"), old.Auth[1].Remote) - require.NotSame(t, new, old) + require.NotSame(t, n, old) } func TestLedgerToJSON(t *testing.T) { diff --git a/mqtt/hooks/debug/debug.go b/mqtt/hooks/debug/debug.go index f8c5f86..03921ea 100644 --- a/mqtt/hooks/debug/debug.go +++ b/mqtt/hooks/debug/debug.go @@ -1,17 +1,17 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package debug import ( + "fmt" + "log/slog" "strings" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/storage" "github.com/wind-c/comqtt/v2/mqtt/packets" - - "github.com/rs/zerolog" ) // Options contains configuration settings for the debug output. @@ -25,7 +25,7 @@ type Options struct { type Hook struct { mqtt.HookBase config *Options - Log *zerolog.Logger + Log *slog.Logger } // ID returns the ID of the hook. @@ -54,25 +54,25 @@ func (h *Hook) Init(config any) error { } // SetOpts is called when the hook receives inheritable server parameters. -func (h *Hook) SetOpts(l *zerolog.Logger, opts *mqtt.HookOptions) { +func (h *Hook) SetOpts(l *slog.Logger, opts *mqtt.HookOptions) { h.Log = l - h.Log.Debug().Interface("opts", opts).Str("method", "SetOpts").Send() + h.Log.Debug("", "method", "SetOpts") } // Stop is called when the hook is stopped. func (h *Hook) Stop() error { - h.Log.Debug().Str("method", "Stop").Send() + h.Log.Debug("", "method", "Stop") return nil } // OnStarted is called when the server starts. func (h *Hook) OnStarted() { - h.Log.Debug().Str("method", "OnStarted").Send() + h.Log.Debug("", "method", "OnStarted") } // OnStopped is called when the server stops. func (h *Hook) OnStopped() { - h.Log.Debug().Str("method", "OnStopped").Send() + h.Log.Debug("", "method", "OnStopped") } // OnPacketRead is called when a new packet is received from a client. @@ -81,8 +81,7 @@ func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, return pk, nil } - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID) - + h.Log.Debug(fmt.Sprintf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk)) return pk, nil } @@ -92,85 +91,72 @@ func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) { return } - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID) + h.Log.Debug(fmt.Sprintf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk)) } // OnRetainMessage is called when a published message is retained (or retain deleted/modified). func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("retained message on topic") + h.Log.Debug("retained message on topic", "m", h.packetMeta(pk)) } // OnQosPublish is called when a publish packet with Qos is issued to a subscriber. func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight out") + h.Log.Debug("inflight out", "m", h.packetMeta(pk)) } // OnQosComplete is called when the Qos flow for a message has been completed. func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight complete") + h.Log.Debug("inflight complete", "m", h.packetMeta(pk)) } // OnQosDropped is called the Qos flow for a message expires. func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight dropped") + h.Log.Debug("inflight dropped", "m", h.packetMeta(pk)) } -// OnLWTSent is called when a will message has been issued from a disconnecting client. +// OnLWTSent is called when a Will Message has been issued from a disconnecting client. func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) { - h.Log.Debug().Str("method", "OnLWTSent").Str("client", cl.ID).Msg("sent lwt for client") + h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID) } // OnRetainedExpired is called when the server clears expired retained messages. func (h *Hook) OnRetainedExpired(filter string) { - h.Log.Debug().Str("method", "OnRetainedExpired").Str("topic", filter).Msg("retained message expired") + h.Log.Debug("retained message expired", "method", "OnRetainedExpired", "topic", filter) } // OnClientExpired is called when the server clears an expired client. func (h *Hook) OnClientExpired(cl *mqtt.Client) { - h.Log.Debug().Str("method", "OnClientExpired").Str("client", cl.ID).Msg("client session expired") + h.Log.Debug("client session expired", "method", "OnClientExpired", "client", cl.ID) } // StoredClients is called when the server restores clients from a store. func (h *Hook) StoredClients() (v []storage.Client, err error) { - h.Log.Debug(). - Str("method", "StoredClients"). - Send() + h.Log.Debug("", "method", "StoredClients") return v, nil } -// StoredClients is called when the server restores subscriptions from a store. +// StoredSubscriptions is called when the server restores subscriptions from a store. func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { - h.Log.Debug(). - Str("method", "StoredSubscriptions"). - Send() - + h.Log.Debug("", "method", "StoredSubscriptions") return v, nil } -// StoredClients is called when the server restores retained messages from a store. +// StoredRetainedMessages is called when the server restores retained messages from a store. func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { - h.Log.Debug(). - Str("method", "StoredRetainedMessages"). - Send() - + h.Log.Debug("", "method", "StoredRetainedMessages") return v, nil } -// StoredClients is called when the server restores inflight messages from a store. +// StoredInflightMessages is called when the server restores inflight messages from a store. func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { - h.Log.Debug(). - Str("method", "StoredInflightMessages"). - Send() - + h.Log.Debug("", "method", "StoredInflightMessages") return v, nil } -// StoredClients is called when the server restores system info from a store. +// StoredSysInfo is called when the server restores system info from a store. func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { - h.Log.Debug(). - Str("method", "StoredClients"). - Send() + h.Log.Debug("", "method", "StoredSysInfo") return v, nil } diff --git a/mqtt/hooks/storage/badger/badger.go b/mqtt/hooks/storage/badger/badger.go index 701d844..c14d935 100644 --- a/mqtt/hooks/storage/badger/badger.go +++ b/mqtt/hooks/storage/badger/badger.go @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co -// SPDX-FileContributor: mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co, gsagula package badger @@ -127,8 +127,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } -// OnWillSent is called when a client sends a will message and the will message is removed -// from the client record. +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } @@ -136,7 +135,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { // updateClient writes the client data to the store. func (h *Hook) updateClient(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -165,14 +164,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) { err := h.db.Upsert(in.ID, in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert client data") + h.Log.Error("failed to upsert client data", "error", err, "data", in) } } // OnDisconnect removes a client from the store if their session has expired. func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -188,14 +187,14 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { err := h.db.Delete(clientKey(cl), new(storage.Client)) if err != nil { - h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data") + h.Log.Error("failed to delete client data", "error", err, "data", clientKey(cl)) } } // OnSubscribed adds one or more client subscriptions to the store. func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -217,7 +216,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by err := h.db.Upsert(in.ID, in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert subscription data") + h.Log.Error("failed to upsert subscription data", "error", err, "data", in) } } } @@ -225,14 +224,14 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by // OnUnsubscribed removes one or more client subscriptions from the store. func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } for i := 0; i < len(pk.Filters); i++ { err := h.db.Delete(subscriptionKey(cl, pk.Filters[i].Filter), new(storage.Subscription)) if err != nil { - h.Log.Error().Err(err).Interface("data", subscriptionKey(cl, pk.Filters[i].Filter)).Msg("failed to delete subscription data") + h.Log.Error("failed to delete subscription data", "error", err, "data", subscriptionKey(cl, pk.Filters[i].Filter)) } } } @@ -240,14 +239,14 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [] // OnRetainMessage adds a retained message for a topic to the store. func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } if r == -1 { err := h.db.Delete(retainedKey(pk.TopicName), new(storage.Message)) if err != nil { - h.Log.Error().Err(err).Interface("data", retainedKey(pk.TopicName)).Msg("failed to delete retained message data") + h.Log.Error("failed to delete retained message data", "error", err, "data", retainedKey(pk.TopicName)) } return @@ -276,14 +275,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { err := h.db.Upsert(in.ID, in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert retained message data") + h.Log.Error("failed to upsert retained message data", "error", err, "data", in) } } // OnQosPublish adds or updates an inflight message in the store. func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -312,27 +311,27 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese err := h.db.Upsert(in.ID, in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert qos inflight data") + h.Log.Error("failed to upsert qos inflight data", "error", err, "data", in) } } // OnQosComplete removes a resolved inflight message from the store. func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.Delete(inflightKey(cl, pk), new(storage.Message)) if err != nil { - h.Log.Error().Err(err).Interface("data", inflightKey(cl, pk)).Msg("failed to delete inflight message data") + h.Log.Error("failed to delete inflight message data", "error", err, "data", inflightKey(cl, pk)) } } // OnQosDropped removes a dropped inflight message from the store. func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) } h.OnQosComplete(cl, pk) @@ -341,7 +340,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { // OnSysInfoTick stores the latest system info in the store. func (h *Hook) OnSysInfoTick(sys *system.Info) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -353,40 +352,40 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { err := h.db.Upsert(in.ID, in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert $SYS data") + h.Log.Error("failed to upsert $SYS data", "error", err, "data", in) } } // OnRetainedExpired deletes expired retained messages from the store. func (h *Hook) OnRetainedExpired(filter string) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.Delete(retainedKey(filter), new(storage.Message)) if err != nil { - h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data") + h.Log.Error("failed to delete expired retained message data", "error", err, "id", retainedKey(filter)) } } // OnClientExpired deleted expired clients from the store. func (h *Hook) OnClientExpired(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.Delete(clientKey(cl), new(storage.Client)) if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data") + h.Log.Error("failed to delete expired client data", "error", err, "id", clientKey(cl)) } } // StoredClients returns all stored clients from the store. func (h *Hook) StoredClients() (v []storage.Client, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -401,7 +400,7 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) { // StoredSubscriptions returns all stored subscriptions from the store. func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -416,7 +415,7 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { // StoredRetainedMessages returns all stored retained messages from the store. func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -431,7 +430,7 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { // StoredInflightMessages returns all stored inflight messages from the store. func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -446,7 +445,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { // StoredSysInfo returns the system info from the store. func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -460,20 +459,21 @@ func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { // Errorf satisfies the badger interface for an error logger. func (h *Hook) Errorf(m string, v ...interface{}) { - h.Log.Error().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) + h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) + } // Warningf satisfies the badger interface for a warning logger. func (h *Hook) Warningf(m string, v ...interface{}) { - h.Log.Warn().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) + h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) } // Infof satisfies the badger interface for an info logger. func (h *Hook) Infof(m string, v ...interface{}) { - h.Log.Info().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) + h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) } // Debugf satisfies the badger interface for a debug logger. func (h *Hook) Debugf(m string, v ...interface{}) { - h.Log.Debug().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) + h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) } diff --git a/mqtt/hooks/storage/badger/badger_test.go b/mqtt/hooks/storage/badger/badger_test.go index a657797..aaa9da6 100644 --- a/mqtt/hooks/storage/badger/badger_test.go +++ b/mqtt/hooks/storage/badger/badger_test.go @@ -1,16 +1,16 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package badger import ( + "log/slog" "os" "strings" "testing" "time" - "github.com/rs/zerolog" "github.com/stretchr/testify/require" "github.com/timshannon/badgerhold" "github.com/wind-c/comqtt/v2/mqtt" @@ -20,7 +20,7 @@ import ( ) var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) client = &mqtt.Client{ ID: "test", @@ -38,8 +38,8 @@ var ( ) func teardown(t *testing.T, path string, h *Hook) { - h.Stop() - h.db.Badger().Close() + _ = h.Stop() + _ = h.db.Badger().Close() err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1)) require.NoError(t, err) } @@ -95,7 +95,7 @@ func TestProvides(t *testing.T) { func TestInitBadConfig(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(map[string]any{}) require.Error(t, err) @@ -103,7 +103,7 @@ func TestInitBadConfig(t *testing.T) { func TestInitUseDefaults(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -113,7 +113,7 @@ func TestInitUseDefaults(t *testing.T) { func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -146,7 +146,7 @@ func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { func TestOnClientExpired(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -170,13 +170,13 @@ func TestOnClientExpired(t *testing.T) { func TestOnClientExpiredNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnClientExpired(client) } func TestOnClientExpiredClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -185,13 +185,13 @@ func TestOnClientExpiredClosedDB(t *testing.T) { func TestOnSessionEstablishedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnSessionEstablished(client, packets.Packet{}) } func TestOnSessionEstablishedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -200,7 +200,7 @@ func TestOnSessionEstablishedClosedDB(t *testing.T) { func TestOnWillSent(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -219,13 +219,13 @@ func TestOnWillSent(t *testing.T) { func TestOnDisconnectNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnDisconnect(client, nil, false) } func TestOnDisconnectClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -234,7 +234,7 @@ func TestOnDisconnectClosedDB(t *testing.T) { func TestOnDisconnectSessionTakenOver(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) @@ -257,12 +257,12 @@ func TestOnDisconnectSessionTakenOver(t *testing.T) { func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) - h.OnSubscribed(client, pkf, []byte{0}, nil) + h.OnSubscribed(client, pkf, []byte{0}) r := new(storage.Subscription) err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r) @@ -271,7 +271,7 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { require.Equal(t, pkf.Filters[0].Filter, r.Filter) require.Equal(t, byte(0), r.Qos) - h.OnUnsubscribed(client, pkf, nil, nil) + h.OnUnsubscribed(client, pkf) err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r) require.Error(t, err) require.Equal(t, badgerhold.ErrNotFound, err) @@ -279,37 +279,37 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { func TestOnSubscribedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) - h.OnSubscribed(client, pkf, []byte{0}, nil) + h.SetOpts(logger, nil) + h.OnSubscribed(client, pkf, []byte{0}) } func TestOnSubscribedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) - h.OnSubscribed(client, pkf, []byte{0}, nil) + h.OnSubscribed(client, pkf, []byte{0}) } func TestOnUnsubscribedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) - h.OnUnsubscribed(client, pkf, nil, nil) + h.SetOpts(logger, nil) + h.OnUnsubscribed(client, pkf) } func TestOnUnsubscribedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) - h.OnUnsubscribed(client, pkf, nil, nil) + h.OnUnsubscribed(client, pkf) } func TestOnRetainMessageThenUnset(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -344,7 +344,7 @@ func TestOnRetainMessageThenUnset(t *testing.T) { func TestOnRetainedExpired(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -371,13 +371,13 @@ func TestOnRetainedExpired(t *testing.T) { func TestOnRetainExpiredNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnRetainedExpired("a/b/c") } func TestOnRetainExpiredClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -386,13 +386,13 @@ func TestOnRetainExpiredClosedDB(t *testing.T) { func TestOnRetainMessageNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnRetainMessage(client, packets.Packet{}, 0) } func TestOnRetainMessageClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -401,7 +401,7 @@ func TestOnRetainMessageClosedDB(t *testing.T) { func TestOnQosPublishThenQOSComplete(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -436,13 +436,13 @@ func TestOnQosPublishThenQOSComplete(t *testing.T) { func TestOnQosPublishNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) } func TestOnQosPublishClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -451,13 +451,13 @@ func TestOnQosPublishClosedDB(t *testing.T) { func TestOnQosCompleteNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosComplete(client, packets.Packet{}) } func TestOnQosCompleteClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -466,13 +466,13 @@ func TestOnQosCompleteClosedDB(t *testing.T) { func TestOnQosDroppedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosDropped(client, packets.Packet{}) } func TestOnSysInfoTick(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -494,13 +494,13 @@ func TestOnSysInfoTick(t *testing.T) { func TestOnSysInfoTickNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnSysInfoTick(new(system.Info)) } func TestOnSysInfoTickClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -509,7 +509,7 @@ func TestOnSysInfoTickClosedDB(t *testing.T) { func TestStoredClients(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -534,7 +534,7 @@ func TestStoredClients(t *testing.T) { func TestStoredClientsNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredClients() require.Empty(t, v) require.NoError(t, err) @@ -542,7 +542,7 @@ func TestStoredClientsNoDB(t *testing.T) { func TestStoredSubscriptions(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -567,7 +567,7 @@ func TestStoredSubscriptions(t *testing.T) { func TestStoredSubscriptionsNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredSubscriptions() require.Empty(t, v) require.NoError(t, err) @@ -575,7 +575,7 @@ func TestStoredSubscriptionsNoDB(t *testing.T) { func TestStoredRetainedMessages(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -603,7 +603,7 @@ func TestStoredRetainedMessages(t *testing.T) { func TestStoredRetainedMessagesNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredRetainedMessages() require.Empty(t, v) require.NoError(t, err) @@ -611,7 +611,7 @@ func TestStoredRetainedMessagesNoDB(t *testing.T) { func TestStoredInflightMessages(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -639,7 +639,7 @@ func TestStoredInflightMessages(t *testing.T) { func TestStoredInflightMessagesNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredInflightMessages() require.Empty(t, v) require.NoError(t, err) @@ -647,7 +647,7 @@ func TestStoredInflightMessagesNoDB(t *testing.T) { func TestStoredSysInfo(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -669,7 +669,7 @@ func TestStoredSysInfo(t *testing.T) { func TestStoredSysInfoNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredSysInfo() require.Empty(t, v) require.NoError(t, err) @@ -678,27 +678,27 @@ func TestStoredSysInfoNoDB(t *testing.T) { func TestErrorf(t *testing.T) { // coverage: one day check log hook h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.Errorf("test", 1, 2, 3) } func TestWarningf(t *testing.T) { // coverage: one day check log hook h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.Warningf("test", 1, 2, 3) } func TestInfof(t *testing.T) { // coverage: one day check log hook h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.Infof("test", 1, 2, 3) } func TestDebugf(t *testing.T) { // coverage: one day check log hook h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.Debugf("test", 1, 2, 3) } diff --git a/mqtt/hooks/storage/bolt/bolt.go b/mqtt/hooks/storage/bolt/bolt.go index 1e42f3e..374e493 100644 --- a/mqtt/hooks/storage/bolt/bolt.go +++ b/mqtt/hooks/storage/bolt/bolt.go @@ -1,7 +1,8 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co -// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead. + +// Package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead. package bolt import ( @@ -132,8 +133,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } -// OnWillSent is called when a client sends a will message and the will message is removed -// from the client record. +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } @@ -141,7 +141,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { // updateClient writes the client data to the store. func (h *Hook) updateClient(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -169,14 +169,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) { } err := h.db.Save(in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to save client data") + h.Log.Error("failed to save client data", "error", err, "data", in) } } // OnDisconnect removes a client from the store if they were using a clean session. func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -190,14 +190,14 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)}) if err != nil && !errors.Is(err, storm.ErrNotFound) { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client") + h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl)) } } // OnSubscribed adds one or more client subscriptions to the store. func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -219,10 +219,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by err := h.db.Save(in) if err != nil { - h.Log.Error().Err(err). - Str("client", cl.ID). - Interface("data", in). - Msg("failed to save subscription data") + h.Log.Error("failed to save subscription data", "error", err, "client", cl.ID, "data", in) } } } @@ -230,7 +227,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by // OnUnsubscribed removes one or more client subscriptions from the store. func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -239,9 +236,7 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [] ID: subscriptionKey(cl, pk.Filters[i].Filter), }) if err != nil { - h.Log.Error().Err(err). - Str("id", subscriptionKey(cl, pk.Filters[i].Filter)). - Msg("failed to delete client") + h.Log.Error("failed to delete client", "error", err, "id", subscriptionKey(cl, pk.Filters[i].Filter)) } } } @@ -249,7 +244,7 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [] // OnRetainMessage adds a retained message for a topic to the store. func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -258,9 +253,7 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { ID: retainedKey(pk.TopicName), }) if err != nil { - h.Log.Error().Err(err). - Str("id", retainedKey(pk.TopicName)). - Msg("failed to delete retained publish") + h.Log.Error("failed to delete retained publish", "error", err, "id", retainedKey(pk.TopicName)) } return } @@ -287,17 +280,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { } err := h.db.Save(in) if err != nil { - h.Log.Error().Err(err). - Str("client", cl.ID). - Interface("data", in). - Msg("failed to save retained publish data") + h.Log.Error("failed to save retained publish data", "error", err, "client", cl.ID, "data", in) } } // OnQosPublish adds or updates an inflight message in the store. func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -325,17 +315,14 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese err := h.db.Save(in) if err != nil { - h.Log.Error().Err(err). - Str("client", cl.ID). - Interface("data", in). - Msg("failed to save qos inflight data") + h.Log.Error("failed to save qos inflight data", "error", err, "client", cl.ID, "data", in) } } // OnQosComplete removes a resolved inflight message from the store. func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -343,16 +330,14 @@ func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { ID: inflightKey(cl, pk), }) if err != nil { - h.Log.Error().Err(err). - Str("id", inflightKey(cl, pk)). - Msg("failed to delete inflight data") + h.Log.Error("failed to delete inflight data", "error", err, "id", inflightKey(cl, pk)) } } // OnQosDropped removes a dropped inflight message from the store. func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) } h.OnQosComplete(cl, pk) @@ -361,7 +346,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { // OnSysInfoTick stores the latest system info in the store. func (h *Hook) OnSysInfoTick(sys *system.Info) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -373,41 +358,39 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { err := h.db.Save(in) if err != nil { - h.Log.Error().Err(err). - Interface("data", in). - Msg("failed to save $SYS data") + h.Log.Error("failed to save $SYS data", "error", err, "data", in) } } // OnRetainedExpired deletes expired retained messages from the store. func (h *Hook) OnRetainedExpired(filter string) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil { - h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish") + h.Log.Error("failed to delete retained publish", "error", err, "id", retainedKey(filter)) } } // OnClientExpired deleted expired clients from the store. func (h *Hook) OnClientExpired(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)}) if err != nil && !errors.Is(err, storm.ErrNotFound) { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") + h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl)) } } // StoredClients returns all stored clients from the store. func (h *Hook) StoredClients() (v []storage.Client, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -422,7 +405,7 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) { // StoredSubscriptions returns all stored subscriptions from the store. func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -437,7 +420,7 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { // StoredRetainedMessages returns all stored retained messages from the store. func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -452,7 +435,7 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { // StoredInflightMessages returns all stored inflight messages from the store. func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -467,7 +450,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { // StoredSysInfo returns the system info from the store. func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } diff --git a/mqtt/hooks/storage/bolt/bolt_test.go b/mqtt/hooks/storage/bolt/bolt_test.go index f9175a4..4d7a8f7 100644 --- a/mqtt/hooks/storage/bolt/bolt_test.go +++ b/mqtt/hooks/storage/bolt/bolt_test.go @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package bolt import ( + "log/slog" "os" "testing" "time" @@ -15,12 +16,11 @@ import ( "github.com/wind-c/comqtt/v2/mqtt/system" "github.com/asdine/storm/v3" - "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) client = &mqtt.Client{ ID: "test", @@ -38,8 +38,8 @@ var ( ) func teardown(t *testing.T, path string, h *Hook) { - h.Stop() - err := os.RemoveAll(path) + _ = h.Stop() + err := os.Remove(path) require.NoError(t, err) } @@ -94,7 +94,7 @@ func TestProvides(t *testing.T) { func TestInitBadConfig(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(map[string]any{}) require.Error(t, err) @@ -102,7 +102,7 @@ func TestInitBadConfig(t *testing.T) { func TestInitUseDefaults(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -113,7 +113,7 @@ func TestInitUseDefaults(t *testing.T) { func TestInitBadPath(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(&Options{ Path: "..", }) @@ -122,7 +122,7 @@ func TestInitBadPath(t *testing.T) { func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -155,13 +155,13 @@ func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { func TestOnSessionEstablishedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnSessionEstablished(client, packets.Packet{}) } func TestOnSessionEstablishedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -170,7 +170,7 @@ func TestOnSessionEstablishedClosedDB(t *testing.T) { func TestOnWillSent(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -189,7 +189,7 @@ func TestOnWillSent(t *testing.T) { func TestOnClientExpired(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -213,7 +213,7 @@ func TestOnClientExpired(t *testing.T) { func TestOnClientExpiredClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -222,19 +222,19 @@ func TestOnClientExpiredClosedDB(t *testing.T) { func TestOnClientExpiredNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnClientExpired(client) } func TestOnDisconnectNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnDisconnect(client, nil, false) } func TestOnDisconnectClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -243,7 +243,7 @@ func TestOnDisconnectClosedDB(t *testing.T) { func TestOnDisconnectSessionTakenOver(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) @@ -266,7 +266,7 @@ func TestOnDisconnectSessionTakenOver(t *testing.T) { func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -288,13 +288,13 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { func TestOnSubscribedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnSubscribed(client, pkf, []byte{0}, nil) } func TestOnSubscribedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -303,13 +303,13 @@ func TestOnSubscribedClosedDB(t *testing.T) { func TestOnUnsubscribedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnUnsubscribed(client, pkf, nil, nil) } func TestOnUnsubscribedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -318,7 +318,7 @@ func TestOnUnsubscribedClosedDB(t *testing.T) { func TestOnRetainMessageThenUnset(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -353,7 +353,7 @@ func TestOnRetainMessageThenUnset(t *testing.T) { func TestOnRetainedExpired(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -380,7 +380,7 @@ func TestOnRetainedExpired(t *testing.T) { func TestOnRetainedExpiredClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -389,19 +389,19 @@ func TestOnRetainedExpiredClosedDB(t *testing.T) { func TestOnRetainedExpiredNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnRetainedExpired("a/b/c") } func TestOnRetainMessageNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnRetainMessage(client, packets.Packet{}, 0) } func TestOnRetainMessageClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -410,7 +410,7 @@ func TestOnRetainMessageClosedDB(t *testing.T) { func TestOnQosPublishThenQOSComplete(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -445,13 +445,13 @@ func TestOnQosPublishThenQOSComplete(t *testing.T) { func TestOnQosPublishNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) } func TestOnQosPublishClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -460,13 +460,13 @@ func TestOnQosPublishClosedDB(t *testing.T) { func TestOnQosCompleteNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosComplete(client, packets.Packet{}) } func TestOnQosCompleteClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -475,13 +475,13 @@ func TestOnQosCompleteClosedDB(t *testing.T) { func TestOnQosDroppedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosDropped(client, packets.Packet{}) } func TestOnSysInfoTick(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -503,13 +503,13 @@ func TestOnSysInfoTick(t *testing.T) { func TestOnSysInfoTickNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnSysInfoTick(new(system.Info)) } func TestOnSysInfoTickClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -518,7 +518,7 @@ func TestOnSysInfoTickClosedDB(t *testing.T) { func TestStoredClients(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -543,7 +543,7 @@ func TestStoredClients(t *testing.T) { func TestStoredClientsNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredClients() require.Empty(t, v) require.NoError(t, err) @@ -551,7 +551,7 @@ func TestStoredClientsNoDB(t *testing.T) { func TestStoredClientsClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -562,7 +562,7 @@ func TestStoredClientsClosedDB(t *testing.T) { func TestStoredSubscriptions(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -587,7 +587,7 @@ func TestStoredSubscriptions(t *testing.T) { func TestStoredSubscriptionsNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredSubscriptions() require.Empty(t, v) require.NoError(t, err) @@ -595,7 +595,7 @@ func TestStoredSubscriptionsNoDB(t *testing.T) { func TestStoredSubscriptionsClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -606,7 +606,7 @@ func TestStoredSubscriptionsClosedDB(t *testing.T) { func TestStoredRetainedMessages(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -634,7 +634,7 @@ func TestStoredRetainedMessages(t *testing.T) { func TestStoredRetainedMessagesNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredRetainedMessages() require.Empty(t, v) require.NoError(t, err) @@ -642,7 +642,7 @@ func TestStoredRetainedMessagesNoDB(t *testing.T) { func TestStoredRetainedMessagesClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -653,7 +653,7 @@ func TestStoredRetainedMessagesClosedDB(t *testing.T) { func TestStoredInflightMessages(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -681,7 +681,7 @@ func TestStoredInflightMessages(t *testing.T) { func TestStoredInflightMessagesNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredInflightMessages() require.Empty(t, v) require.NoError(t, err) @@ -689,7 +689,7 @@ func TestStoredInflightMessagesNoDB(t *testing.T) { func TestStoredInflightMessagesClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -700,7 +700,7 @@ func TestStoredInflightMessagesClosedDB(t *testing.T) { func TestStoredSysInfo(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -722,7 +722,7 @@ func TestStoredSysInfo(t *testing.T) { func TestStoredSysInfoNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredSysInfo() require.Empty(t, v) require.NoError(t, err) @@ -730,7 +730,7 @@ func TestStoredSysInfoNoDB(t *testing.T) { func TestStoredSysInfoClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) diff --git a/mqtt/hooks/storage/redis/redis.go b/mqtt/hooks/storage/redis/redis.go index 49241a1..4ecb025 100644 --- a/mqtt/hooks/storage/redis/redis.go +++ b/mqtt/hooks/storage/redis/redis.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package redis @@ -103,7 +103,7 @@ func (h *Hook) Init(config any) error { } h.ctx = context.Background() - h.ctx.Deadline() + if config == nil { config = &Options{ Options: &redis.Options{ @@ -117,12 +117,11 @@ func (h *Hook) Init(config any) error { h.config.HPrefix = defaultHPrefix } - h.Log.Info(). - Str("address", h.config.Options.Addr). - Str("username", h.config.Options.Username). - Int("password-len", len(h.config.Options.Password)). - Int("db", h.config.Options.DB). - Msg("connecting to redis service") + h.Log.Info("connecting to redis service", + "address", h.config.Options.Addr, + "username", h.config.Options.Username, + "password-len", len(h.config.Options.Password), + "db", h.config.Options.DB) h.db = redis.NewClient(h.config.Options) _, err := h.db.Ping(context.Background()).Result() @@ -130,14 +129,15 @@ func (h *Hook) Init(config any) error { return fmt.Errorf("failed to ping service: %w", err) } - h.Log.Info().Msg("connected to redis service") + h.Log.Info("connected to redis service") return nil } // Stop closes the redis connection. func (h *Hook) Stop() error { - h.Log.Info().Msg("disconnecting from redis service") + h.Log.Info("disconnecting from redis service") + return h.db.Close() } @@ -146,8 +146,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } -// OnWillSent is called when a client sends a will message and the will message is removed -// from the client record. +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } @@ -155,7 +154,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { // updateClient writes the client data to the store. func (h *Hook) updateClient(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -184,14 +183,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) { err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err() if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data") + h.Log.Error("failed to hset client data", "error", err, "data", in) } } // OnDisconnect removes a client from the store if they were using a clean session. func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -205,14 +204,14 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client") + h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl)) } } // OnSubscribed adds one or more client subscriptions to the store. func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -234,7 +233,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err() if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data") + h.Log.Error("failed to hset subscription data", "error", err, "data", in) } } } @@ -242,14 +241,14 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by // OnUnsubscribed removes one or more client subscriptions from the store. func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } for i := 0; i < len(pk.Filters); i++ { err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data") + h.Log.Error("failed to delete subscription data", "error", err, "id", clientKey(cl)) } } } @@ -257,14 +256,14 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [] // OnRetainMessage adds a retained message for a topic to the store. func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } if r == -1 { err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data") + h.Log.Error("failed to delete retained message data", "error", err, "id", retainedKey(pk.TopicName)) } return @@ -293,14 +292,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err() if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data") + h.Log.Error("failed to hset retained message data", "error", err, "data", in) } } // OnQosPublish adds or updates an inflight message in the store. func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -328,27 +327,27 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err() if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data") + h.Log.Error("failed to hset qos inflight message data", "error", err, "data", in) } } // OnQosComplete removes a resolved inflight message from the store. func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data") + h.Log.Error("failed to delete qos inflight message data", "error", err, "id", inflightKey(cl, pk)) } } // OnQosDropped removes a dropped inflight message from the store. func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) } h.OnQosComplete(cl, pk) @@ -357,7 +356,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { // OnSysInfoTick stores the latest system info in the store. func (h *Hook) OnSysInfoTick(sys *system.Info) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -369,53 +368,53 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err() if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data") + h.Log.Error("failed to hset server info data", "error", err, "data", in) } } // OnRetainedExpired deletes expired retained messages from the store. func (h *Hook) OnRetainedExpired(filter string) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data") + h.Log.Error("failed to delete expired retained message", "error", err, "id", retainedKey(filter)) } } // OnClientExpired deleted expired clients from the store. func (h *Hook) OnClientExpired(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") + h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl)) } } // StoredClients returns all stored clients from the store. func (h *Hook) StoredClients() (v []storage.Client, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result() if err != nil && !errors.Is(err, redis.Nil) { - h.Log.Error().Err(err).Msg("failed to HGetAll client data") + h.Log.Error("failed to HGetAll client data", "error", err) return } for _, row := range rows { var d storage.Client if err = d.UnmarshalBinary([]byte(row)); err != nil { - h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data") + h.Log.Error("failed to unmarshal client data", "error", err, "data", row) } v = append(v, d) @@ -427,20 +426,20 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) { // StoredSubscriptions returns all stored subscriptions from the store. func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result() if err != nil && !errors.Is(err, redis.Nil) { - h.Log.Error().Err(err).Msg("failed to HGetAll subscription data") + h.Log.Error("failed to HGetAll subscription data", "error", err) return } for _, row := range rows { var d storage.Subscription if err = d.UnmarshalBinary([]byte(row)); err != nil { - h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data") + h.Log.Error("failed to unmarshal subscription data", "error", err, "data", row) } v = append(v, d) @@ -452,20 +451,20 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { // StoredRetainedMessages returns all stored retained messages from the store. func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result() if err != nil && !errors.Is(err, redis.Nil) { - h.Log.Error().Err(err).Msg("failed to HGetAll retained message data") + h.Log.Error("failed to HGetAll retained message data", "error", err) return } for _, row := range rows { var d storage.Message if err = d.UnmarshalBinary([]byte(row)); err != nil { - h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data") + h.Log.Error("failed to unmarshal retained message data", "error", err, "data", row) } v = append(v, d) @@ -477,20 +476,20 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { // StoredInflightMessages returns all stored inflight messages from the store. func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result() if err != nil && !errors.Is(err, redis.Nil) { - h.Log.Error().Err(err).Msg("failed to HGetAll inflight message data") + h.Log.Error("failed to HGetAll inflight message data", "error", err) return } for _, row := range rows { var d storage.Message if err = d.UnmarshalBinary([]byte(row)); err != nil { - h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data") + h.Log.Error("failed to unmarshal inflight message data", "error", err, "data", row) } v = append(v, d) @@ -502,7 +501,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { // StoredSysInfo returns the system info from the store. func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -512,7 +511,7 @@ func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { } if err = v.UnmarshalBinary([]byte(row)); err != nil { - h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data") + h.Log.Error("failed to unmarshal sys info data", "error", err, "data", row) } return v, nil diff --git a/mqtt/hooks/storage/redis/redis_test.go b/mqtt/hooks/storage/redis/redis_test.go index 0226806..9a41af4 100644 --- a/mqtt/hooks/storage/redis/redis_test.go +++ b/mqtt/hooks/storage/redis/redis_test.go @@ -5,6 +5,7 @@ package redis import ( + "log/slog" "os" "sort" "testing" @@ -17,12 +18,11 @@ import ( miniredis "github.com/alicebob/miniredis/v2" redis "github.com/go-redis/redis/v8" - "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) client = &mqtt.Client{ ID: "test", @@ -41,7 +41,7 @@ var ( func newHook(t *testing.T, addr string) *Hook { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(&Options{ Options: &redis.Options{ @@ -87,13 +87,13 @@ func TestSysInfoKey(t *testing.T) { func TestID(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.Equal(t, "redis-db", h.ID()) } func TestProvides(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.True(t, h.Provides(mqtt.OnSessionEstablished)) require.True(t, h.Provides(mqtt.OnDisconnect)) require.True(t, h.Provides(mqtt.OnSubscribed)) @@ -116,7 +116,7 @@ func TestHKey(t *testing.T) { s := miniredis.RunT(t) defer s.Close() h := newHook(t, s.Addr()) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.Equal(t, defaultHPrefix+"test", h.hKey("test")) } @@ -126,7 +126,7 @@ func TestInitUseDefaults(t *testing.T) { defer s.Close() h := newHook(t, defaultAddr) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h) @@ -137,7 +137,7 @@ func TestInitUseDefaults(t *testing.T) { func TestInitBadConfig(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(map[string]any{}) require.Error(t, err) @@ -145,7 +145,7 @@ func TestInitBadConfig(t *testing.T) { func TestInitBadAddr(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(&Options{ Options: &redis.Options{ Addr: "abc:123", diff --git a/mqtt/hooks/storage/storage.go b/mqtt/hooks/storage/storage.go index 8cabf01..b7f6ce4 100644 --- a/mqtt/hooks/storage/storage.go +++ b/mqtt/hooks/storage/storage.go @@ -25,7 +25,7 @@ var ( ErrDBFileNotOpen = errors.New("db file not open") ) -// Client is a storable representation of an mqtt client. +// Client is a storable representation of an MQTT client. type Client struct { Will ClientWill `json:"will,omitempty"` // will topic and payload data if applicable Properties ClientProperties `json:"properties,omitempty"` // the connect properties for the client @@ -55,9 +55,9 @@ type ClientProperties struct { // ClientWill contains a will message for a client, and limited mqtt v5 properties. type ClientWill struct { - TopicName string `json:"topicName,omitempty"` Payload []byte `json:"payload,omitempty"` User []packets.UserProperty `json:"user,omitempty"` + TopicName string `json:"topicName,omitempty"` Flag uint32 `json:"flag,omitempty"` WillDelayInterval uint32 `json:"willDelayInterval,omitempty"` Qos byte `json:"qos,omitempty"` @@ -147,7 +147,7 @@ func (d *Message) ToPacket() packets.Packet { return pk } -// Subscription is a storable representation of an mqtt subscription. +// Subscription is a storable representation of an MQTT subscription. type Subscription struct { T string `json:"t,omitempty"` ID string `json:"id,omitempty" storm:"id"` diff --git a/mqtt/hooks/storage/storage_test.go b/mqtt/hooks/storage/storage_test.go index daee469..1cdc346 100644 --- a/mqtt/hooks/storage/storage_test.go +++ b/mqtt/hooks/storage/storage_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package storage diff --git a/mqtt/hooks_test.go b/mqtt/hooks_test.go index 35b1fde..94ed049 100644 --- a/mqtt/hooks_test.go +++ b/mqtt/hooks_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt @@ -215,7 +215,7 @@ func TestHooksAddInitFailure(t *testing.T) { func TestHooksStop(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger err := h.Add(new(HookBase), nil) require.NoError(t, err) @@ -334,7 +334,7 @@ func TestHooksOnUnsubscribe(t *testing.T) { func TestHooksOnPublish(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -360,7 +360,7 @@ func TestHooksOnPublish(t *testing.T) { func TestHooksOnPacketRead(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -386,7 +386,7 @@ func TestHooksOnPacketRead(t *testing.T) { func TestHooksOnAuthPacket(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -404,7 +404,7 @@ func TestHooksOnAuthPacket(t *testing.T) { func TestHooksOnConnect(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -420,7 +420,7 @@ func TestHooksOnConnect(t *testing.T) { func TestHooksOnPacketEncode(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -432,7 +432,7 @@ func TestHooksOnPacketEncode(t *testing.T) { func TestHooksOnLWT(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -449,7 +449,7 @@ func TestHooksOnLWT(t *testing.T) { func TestHooksStoredClients(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger v, err := h.StoredClients() require.NoError(t, err) @@ -471,7 +471,7 @@ func TestHooksStoredClients(t *testing.T) { func TestHooksStoredSubscriptions(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger v, err := h.StoredSubscriptions() require.NoError(t, err) @@ -493,7 +493,7 @@ func TestHooksStoredSubscriptions(t *testing.T) { func TestHooksStoredRetainedMessages(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger v, err := h.StoredRetainedMessages() require.NoError(t, err) @@ -515,7 +515,7 @@ func TestHooksStoredRetainedMessages(t *testing.T) { func TestHooksStoredInflightMessages(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger v, err := h.StoredInflightMessages() require.NoError(t, err) @@ -537,7 +537,7 @@ func TestHooksStoredInflightMessages(t *testing.T) { func TestHooksStoredSysInfo(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger v, err := h.StoredSysInfo() require.NoError(t, err) @@ -575,7 +575,7 @@ func TestHookBaseInit(t *testing.T) { func TestHookBaseSetOpts(t *testing.T) { h := new(HookBase) - h.SetOpts(&logger, new(HookOptions)) + h.SetOpts(logger, new(HookOptions)) require.NotNil(t, h.Log) require.NotNil(t, h.Opts) } diff --git a/mqtt/inflight.go b/mqtt/inflight.go index 51e77fc..b94b2fd 100644 --- a/mqtt/inflight.go +++ b/mqtt/inflight.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt @@ -14,12 +14,12 @@ import ( // Inflight is a map of InflightMessage keyed on packet id. type Inflight struct { - internal map[uint16]packets.Packet // internal contains the inflight packets sync.RWMutex - receiveQuota int32 // remaining inbound qos quota for flow control - sendQuota int32 // remaining outbound qos quota for flow control - maximumReceiveQuota int32 // maximum allowed receive quota - maximumSendQuota int32 // maximum allowed send quota + internal map[uint16]packets.Packet // internal contains the inflight packets + receiveQuota int32 // remaining inbound qos quota for flow control + sendQuota int32 // remaining outbound qos quota for flow control + maximumReceiveQuota int32 // maximum allowed receive quota + maximumSendQuota int32 // maximum allowed send quota } // NewInflights returns a new instance of an Inflight packets map. diff --git a/mqtt/inflight_test.go b/mqtt/inflight_test.go index 9b55d12..8de6e65 100644 --- a/mqtt/inflight_test.go +++ b/mqtt/inflight_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt diff --git a/mqtt/listeners/http_healthcheck.go b/mqtt/listeners/http_healthcheck.go new file mode 100644 index 0000000..a82e2e3 --- /dev/null +++ b/mqtt/listeners/http_healthcheck.go @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: Derek Duncan + +package listeners + +import ( + "context" + "log/slog" + "net/http" + "sync" + "sync/atomic" + "time" +) + +// HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint. +type HTTPHealthCheck struct { + sync.RWMutex + id string // the internal id of the listener + address string // the network address to bind to + config *Config // configuration values for the listener + listen *http.Server // the http server + end uint32 // ensure the close methods are only called once +} + +// NewHTTPHealthCheck initialises and returns a new HTTP listener, listening on an address. +func NewHTTPHealthCheck(id, address string, config *Config) *HTTPHealthCheck { + if config == nil { + config = new(Config) + } + return &HTTPHealthCheck{ + id: id, + address: address, + config: config, + } +} + +// ID returns the id of the listener. +func (l *HTTPHealthCheck) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *HTTPHealthCheck) Address() string { + return l.address +} + +// Protocol returns the address of the listener. +func (l *HTTPHealthCheck) Protocol() string { + if l.listen != nil && l.listen.TLSConfig != nil { + return "https" + } + + return "http" +} + +// Init initializes the listener. +func (l *HTTPHealthCheck) Init(_ *slog.Logger) error { + mux := http.NewServeMux() + mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + } + }) + l.listen = &http.Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + Addr: l.address, + Handler: mux, + } + + if l.config.TLSConfig != nil { + l.listen.TLSConfig = l.config.TLSConfig + } + + return nil +} + +// Serve starts listening for new connections and serving responses. +func (l *HTTPHealthCheck) Serve(establish EstablishFn) { + if l.listen.TLSConfig != nil { + _ = l.listen.ListenAndServeTLS("", "") + } else { + _ = l.listen.ListenAndServe() + } +} + +// Close closes the listener and any client connections. +func (l *HTTPHealthCheck) Close(closeClients CloseFn) { + l.Lock() + defer l.Unlock() + + if atomic.CompareAndSwapUint32(&l.end, 0, 1) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = l.listen.Shutdown(ctx) + } + + closeClients(l.id) +} diff --git a/mqtt/listeners/http_healthcheck_test.go b/mqtt/listeners/http_healthcheck_test.go new file mode 100644 index 0000000..1c753c1 --- /dev/null +++ b/mqtt/listeners/http_healthcheck_test.go @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: Derek Duncan + +package listeners + +import ( + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewHTTPHealthCheck(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + require.Equal(t, "healthcheck", l.id) + require.Equal(t, testAddr, l.address) +} + +func TestHTTPHealthCheckID(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + require.Equal(t, "healthcheck", l.ID()) +} + +func TestHTTPHealthCheckAddress(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + require.Equal(t, testAddr, l.Address()) +} + +func TestHTTPHealthCheckProtocol(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + require.Equal(t, "http", l.Protocol()) +} + +func TestHTTPHealthCheckTLSProtocol(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{ + TLSConfig: tlsConfigBasic, + }) + + _ = l.Init(logger) + require.Equal(t, "https", l.Protocol()) +} + +func TestHTTPHealthCheckInit(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + err := l.Init(logger) + require.NoError(t, err) + + require.NotNil(t, l.listen) + require.Equal(t, testAddr, l.listen.Addr) +} + +func TestHTTPHealthCheckServeAndClose(t *testing.T) { + // setup http stats listener + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + // call healthcheck + resp, err := http.Get("http://localhost" + testAddr + "/healthcheck") + require.NoError(t, err) + require.NotNil(t, resp) + + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + // ensure listening is closed + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.Equal(t, true, closed) + + _, err = http.Get("http://localhost/healthcheck" + testAddr + "/healthcheck") + require.Error(t, err) + <-o +} + +func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) { + // setup http stats listener + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + // make disallowed method type http request + resp, err := http.Post("http://localhost"+testAddr+"/healthcheck", "application/json", http.NoBody) + require.NoError(t, err) + require.NotNil(t, resp) + + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + // ensure listening is closed + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.Equal(t, true, closed) + + _, err = http.Post("http://localhost/healthcheck"+testAddr+"/healthcheck", "application/json", http.NoBody) + require.Error(t, err) + <-o +} + +func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{ + TLSConfig: tlsConfigBasic, + }) + + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + l.Close(MockCloser) +} diff --git a/mqtt/listeners/http_sysinfo.go b/mqtt/listeners/http_sysinfo.go index 0454f32..63d7c28 100644 --- a/mqtt/listeners/http_sysinfo.go +++ b/mqtt/listeners/http_sysinfo.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -8,26 +8,24 @@ import ( "context" "encoding/json" "io" + "log/slog" "net/http" "sync" "sync/atomic" "time" "github.com/wind-c/comqtt/v2/mqtt/system" - - "github.com/rs/zerolog" ) // HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint. type HTTPStats struct { sync.RWMutex - id string // the internal id of the listener - address string // the network address to bind to - config *Config // configuration values for the listener - listen *http.Server // the http server - log *zerolog.Logger // server logger - sysInfo *system.Info // pointers to the server data - end uint32 // ensure the close methods are only called once + id string // the internal id of the listener + address string // the network address to bind to + config *Config // configuration values for the listener + listen *http.Server // the http server + sysInfo *system.Info // pointers to the server data + end uint32 // ensure the close methods are only called once handlers Handlers } @@ -48,7 +46,15 @@ func NewHTTP(id, address string, config *Config, sysInfo *system.Info, handlers // NewHTTPStats initialises and returns a new HTTP listener, listening on an address. func NewHTTPStats(id, address string, config *Config, sysInfo *system.Info) *HTTPStats { - return NewHTTP(id, address, config, sysInfo, nil) + if config == nil { + config = new(Config) + } + return &HTTPStats{ + id: id, + address: address, + sysInfo: sysInfo, + config: config, + } } // ID returns the id of the listener. @@ -71,16 +77,9 @@ func (l *HTTPStats) Protocol() string { } // Init initializes the listener. -func (l *HTTPStats) Init(log *zerolog.Logger) error { - l.log = log - +func (l *HTTPStats) Init(_ *slog.Logger) error { mux := http.NewServeMux() - mux.HandleFunc("/mqtt/stats", l.jsonHandler) - - for path, handler := range l.handlers { - mux.HandleFunc(path, handler) - } - + mux.HandleFunc("/", l.jsonHandler) l.listen = &http.Server{ ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, @@ -98,9 +97,9 @@ func (l *HTTPStats) Init(log *zerolog.Logger) error { // Serve starts listening for new connections and serving responses. func (l *HTTPStats) Serve(establish EstablishFn) { if l.listen.TLSConfig != nil { - l.listen.ListenAndServeTLS("", "") + _ = l.listen.ListenAndServeTLS("", "") } else { - l.listen.ListenAndServe() + _ = l.listen.ListenAndServe() } } @@ -112,7 +111,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) { if atomic.CompareAndSwapUint32(&l.end, 0, 1) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - l.listen.Shutdown(ctx) + _ = l.listen.Shutdown(ctx) } closeClients(l.id) @@ -124,8 +123,8 @@ func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) { out, err := json.MarshalIndent(info, "", "\t") if err != nil { - io.WriteString(w, err.Error()) + _, _ = io.WriteString(w, err.Error()) } - w.Write(out) + _, _ = w.Write(out) } diff --git a/mqtt/listeners/http_sysinfo_test.go b/mqtt/listeners/http_sysinfo_test.go index 34f29fc..76dbe3c 100644 --- a/mqtt/listeners/http_sysinfo_test.go +++ b/mqtt/listeners/http_sysinfo_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -42,14 +42,14 @@ func TestHTTPStatsTLSProtocol(t *testing.T) { TLSConfig: tlsConfigBasic, }, nil) - l.Init(nil) + _ = l.Init(logger) require.Equal(t, "https", l.Protocol()) } func TestHTTPStatsInit(t *testing.T) { sysInfo := new(system.Info) l := NewHTTPStats("t1", testAddr, nil, sysInfo) - err := l.Init(nil) + err := l.Init(logger) require.NoError(t, err) require.NotNil(t, l.sysInfo) @@ -65,7 +65,7 @@ func TestHTTPStatsServeAndClose(t *testing.T) { // setup http stats listener l := NewHTTPStats("t1", testAddr, nil, sysInfo) - err := l.Init(nil) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -113,7 +113,7 @@ func TestHTTPStatsServeTLSAndClose(t *testing.T) { TLSConfig: tlsConfigBasic, }, sysInfo) - err := l.Init(nil) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) diff --git a/mqtt/listeners/listeners.go b/mqtt/listeners/listeners.go index 0dd8f15..429f497 100644 --- a/mqtt/listeners/listeners.go +++ b/mqtt/listeners/listeners.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -9,7 +9,7 @@ import ( "net" "sync" - "github.com/rs/zerolog" + "log/slog" ) // Config contains configuration values for a listener. @@ -22,18 +22,18 @@ type Config struct { // EstablishFn is a callback function for establishing new clients. type EstablishFn func(id string, c net.Conn) error -// CloseFunc is a callback function for closing all listener clients. +// CloseFn is a callback function for closing all listener clients. type CloseFn func(id string) // Listener is an interface for network listeners. A network listener listens // for incoming client connections and adds them to the server. type Listener interface { - Init(*zerolog.Logger) error // open the network address - Serve(EstablishFn) // starting actively listening for new connections - ID() string // return the id of the listener - Address() string // the address of the listener - Protocol() string // the protocol in use by the listener - Close(CloseFn) // stop and close the listener + Init(*slog.Logger) error // open the network address + Serve(EstablishFn) // starting actively listening for new connections + ID() string // return the id of the listener + Address() string // the address of the listener + Protocol() string // the protocol in use by the listener + Close(CloseFn) // stop and close the listener } // Listeners contains the network listeners for the broker. diff --git a/mqtt/listeners/listeners_test.go b/mqtt/listeners/listeners_test.go index aabc9f2..63e841c 100644 --- a/mqtt/listeners/listeners_test.go +++ b/mqtt/listeners/listeners_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -11,14 +11,15 @@ import ( "testing" "time" - "github.com/rs/zerolog" + "log/slog" + "github.com/stretchr/testify/require" ) const testAddr = ":22222" var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) testCertificate = []byte(`-----BEGIN CERTIFICATE----- MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB diff --git a/mqtt/listeners/mock.go b/mqtt/listeners/mock.go index 778c8e5..826f80c 100644 --- a/mqtt/listeners/mock.go +++ b/mqtt/listeners/mock.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -9,7 +9,7 @@ import ( "net" "sync" - "github.com/rs/zerolog" + "log/slog" ) // MockEstablisher is a function signature which can be used in testing. @@ -53,7 +53,7 @@ func (l *MockListener) Serve(establisher EstablishFn) { } // Init initializes the listener. -func (l *MockListener) Init(log *zerolog.Logger) error { +func (l *MockListener) Init(log *slog.Logger) error { if l.ErrListen { return fmt.Errorf("listen failure") } diff --git a/mqtt/listeners/mock_test.go b/mqtt/listeners/mock_test.go index c2170ce..46aa922 100644 --- a/mqtt/listeners/mock_test.go +++ b/mqtt/listeners/mock_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -16,7 +16,7 @@ func TestMockEstablisher(t *testing.T) { _, w := net.Pipe() err := MockEstablisher("t1", w) require.NoError(t, err) - w.Close() + _ = w.Close() } func TestNewMockListener(t *testing.T) { @@ -86,7 +86,7 @@ func TestMockListenerServe(t *testing.T) { require.Equal(t, true, closed) <-o - mocked.Init(nil) + _ = mocked.Init(nil) } func TestMockListenerClose(t *testing.T) { diff --git a/mqtt/listeners/net.go b/mqtt/listeners/net.go new file mode 100644 index 0000000..fa4ef3d --- /dev/null +++ b/mqtt/listeners/net.go @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: Jeroen Rinzema + +package listeners + +import ( + "net" + "sync" + "sync/atomic" + + "log/slog" +) + +// Net is a listener for establishing client connections on basic TCP protocol. +type Net struct { // [MQTT-4.2.0-1] + mu sync.Mutex + listener net.Listener // a net.Listener which will listen for new clients + id string // the internal id of the listener + log *slog.Logger // server logger + end uint32 // ensure the close methods are only called once +} + +// NewNet initialises and returns a listener serving incoming connections on the given net.Listener +func NewNet(id string, listener net.Listener) *Net { + return &Net{ + id: id, + listener: listener, + } +} + +// ID returns the id of the listener. +func (l *Net) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *Net) Address() string { + return l.listener.Addr().String() +} + +// Protocol returns the network of the listener. +func (l *Net) Protocol() string { + return l.listener.Addr().Network() +} + +// Init initializes the listener. +func (l *Net) Init(log *slog.Logger) error { + l.log = log + return nil +} + +// Serve starts waiting for new TCP connections, and calls the establish +// connection callback for any received. +func (l *Net) Serve(establish EstablishFn) { + for { + if atomic.LoadUint32(&l.end) == 1 { + return + } + + conn, err := l.listener.Accept() + if err != nil { + return + } + + if atomic.LoadUint32(&l.end) == 0 { + go func() { + err = establish(l.id, conn) + if err != nil { + l.log.Warn("", "error", err) + } + }() + } + } +} + +// Close closes the listener and any client connections. +func (l *Net) Close(closeClients CloseFn) { + l.mu.Lock() + defer l.mu.Unlock() + + if atomic.CompareAndSwapUint32(&l.end, 0, 1) { + closeClients(l.id) + } + + if l.listener != nil { + err := l.listener.Close() + if err != nil { + return + } + } +} diff --git a/mqtt/listeners/net_test.go b/mqtt/listeners/net_test.go new file mode 100644 index 0000000..14a1ad6 --- /dev/null +++ b/mqtt/listeners/net_test.go @@ -0,0 +1,105 @@ +package listeners + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewNet(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, "t1", l.id) +} + +func TestNetID(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, "t1", l.ID()) +} + +func TestNetAddress(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, n.Addr().String(), l.Address()) +} + +func TestNetProtocol(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, "tcp", l.Protocol()) +} + +func TestNetInit(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + err = l.Init(logger) + l.Close(MockCloser) + require.NoError(t, err) +} + +func TestNetServeAndClose(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + err = l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.True(t, closed) + <-o + + l.Close(MockCloser) // coverage: close closed + l.Serve(MockEstablisher) // coverage: serve closed +} + +func TestNetEstablishThenEnd(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + err = l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + established := make(chan bool) + go func() { + l.Serve(func(id string, c net.Conn) error { + established <- true + return errors.New("ending") // return an error to exit immediately + }) + o <- true + }() + + time.Sleep(time.Millisecond) + _, _ = net.Dial("tcp", n.Addr().String()) + require.Equal(t, true, <-established) + l.Close(MockCloser) + <-o +} diff --git a/mqtt/listeners/tcp.go b/mqtt/listeners/tcp.go index ca25a67..1682734 100644 --- a/mqtt/listeners/tcp.go +++ b/mqtt/listeners/tcp.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -10,18 +10,18 @@ import ( "sync" "sync/atomic" - "github.com/rs/zerolog" + "log/slog" ) // TCP is a listener for establishing client connections on basic TCP protocol. type TCP struct { // [MQTT-4.2.0-1] sync.RWMutex - id string // the internal id of the listener - address string // the network address to bind to - listen net.Listener // a net.Listener which will listen for new clients - config *Config // configuration values for the listener - log *zerolog.Logger // server logger - end uint32 // ensure the close methods are only called once + id string // the internal id of the listener + address string // the network address to bind to + listen net.Listener // a net.Listener which will listen for new clients + config *Config // configuration values for the listener + log *slog.Logger // server logger + end uint32 // ensure the close methods are only called once } // NewTCP initialises and returns a new TCP listener, listening on an address. @@ -53,7 +53,7 @@ func (l *TCP) Protocol() string { } // Init initializes the listener. -func (l *TCP) Init(log *zerolog.Logger) error { +func (l *TCP) Init(log *slog.Logger) error { l.log = log var err error @@ -83,7 +83,7 @@ func (l *TCP) Serve(establish EstablishFn) { go func() { err = establish(l.id, conn) if err != nil { - l.log.Warn().Err(err).Send() + l.log.Warn("", "error", err) } }() } diff --git a/mqtt/listeners/tcp_test.go b/mqtt/listeners/tcp_test.go index 6e577ed..636c8ab 100644 --- a/mqtt/listeners/tcp_test.go +++ b/mqtt/listeners/tcp_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -35,35 +35,33 @@ func TestTCPProtocol(t *testing.T) { } func TestTCPProtocolTLS(t *testing.T) { - // pick a random port: - l := NewTCP("t1", ":0", &Config{ + l := NewTCP("t1", testAddr, &Config{ TLSConfig: tlsConfigBasic, }) - err := l.Init(&logger) - require.NoError(t, err) + + _ = l.Init(logger) + defer l.listen.Close() require.Equal(t, "tcp", l.Protocol()) - err = l.listen.Close() - require.NoError(t, err) } func TestTCPInit(t *testing.T) { - l := NewTCP("t1", ":0", nil) - err := l.Init(&logger) + l := NewTCP("t1", testAddr, nil) + err := l.Init(logger) l.Close(MockCloser) require.NoError(t, err) - l2 := NewTCP("t2", ":0", &Config{ + l2 := NewTCP("t2", testAddr, &Config{ TLSConfig: tlsConfigBasic, }) - err = l2.Init(&logger) + err = l2.Init(logger) l2.Close(MockCloser) require.NoError(t, err) require.NotNil(t, l2.config.TLSConfig) } func TestTCPServeAndClose(t *testing.T) { - l := NewTCP("t1", ":0", nil) - err := l.Init(&logger) + l := NewTCP("t1", testAddr, nil) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -87,10 +85,10 @@ func TestTCPServeAndClose(t *testing.T) { } func TestTCPServeTLSAndClose(t *testing.T) { - l := NewTCP("t1", ":0", &Config{ + l := NewTCP("t1", testAddr, &Config{ TLSConfig: tlsConfigBasic, }) - err := l.Init(&logger) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -111,8 +109,8 @@ func TestTCPServeTLSAndClose(t *testing.T) { } func TestTCPEstablishThenEnd(t *testing.T) { - l := NewTCP("t1", ":0", nil) - err := l.Init(&logger) + l := NewTCP("t1", testAddr, nil) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -126,7 +124,7 @@ func TestTCPEstablishThenEnd(t *testing.T) { }() time.Sleep(time.Millisecond) - net.Dial("tcp", l.listen.Addr().String()) + _, _ = net.Dial("tcp", l.listen.Addr().String()) require.Equal(t, true, <-established) l.Close(MockCloser) <-o diff --git a/mqtt/listeners/unixsock.go b/mqtt/listeners/unixsock.go index 1ceaf99..5892fc9 100644 --- a/mqtt/listeners/unixsock.go +++ b/mqtt/listeners/unixsock.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: jason@zgwit.com package listeners @@ -10,17 +10,17 @@ import ( "sync" "sync/atomic" - "github.com/rs/zerolog" + "log/slog" ) // UnixSock is a listener for establishing client connections on basic UnixSock protocol. type UnixSock struct { sync.RWMutex - id string // the internal id of the listener. - address string // the network address to bind to. - listen net.Listener // a net.Listener which will listen for new clients. - log *zerolog.Logger // server logger - end uint32 // ensure the close methods are only called once. + id string // the internal id of the listener. + address string // the network address to bind to. + listen net.Listener // a net.Listener which will listen for new clients. + log *slog.Logger // server logger + end uint32 // ensure the close methods are only called once. } // NewUnixSock initialises and returns a new UnixSock listener, listening on an address. @@ -47,11 +47,11 @@ func (l *UnixSock) Protocol() string { } // Init initializes the listener. -func (l *UnixSock) Init(log *zerolog.Logger) error { +func (l *UnixSock) Init(log *slog.Logger) error { l.log = log var err error - _ = os.RemoveAll(l.address) + _ = os.Remove(l.address) l.listen, err = net.Listen("unix", l.address) return err } @@ -73,7 +73,7 @@ func (l *UnixSock) Serve(establish EstablishFn) { go func() { err = establish(l.id, conn) if err != nil { - l.log.Warn().Err(err).Send() + l.log.Warn("", "error", err) } }() } diff --git a/mqtt/listeners/unixsock_test.go b/mqtt/listeners/unixsock_test.go index d09f776..06ce24d 100644 --- a/mqtt/listeners/unixsock_test.go +++ b/mqtt/listeners/unixsock_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: jason@zgwit.com package listeners @@ -38,19 +38,19 @@ func TestUnixSockProtocol(t *testing.T) { func TestUnixSockInit(t *testing.T) { l := NewUnixSock("t1", testUnixAddr) - err := l.Init(&logger) + err := l.Init(logger) l.Close(MockCloser) require.NoError(t, err) l2 := NewUnixSock("t2", testUnixAddr) - err = l2.Init(&logger) + err = l2.Init(logger) l2.Close(MockCloser) require.NoError(t, err) } func TestUnixSockServeAndClose(t *testing.T) { l := NewUnixSock("t1", testUnixAddr) - err := l.Init(&logger) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -75,7 +75,7 @@ func TestUnixSockServeAndClose(t *testing.T) { func TestUnixSockEstablishThenEnd(t *testing.T) { l := NewUnixSock("t1", testUnixAddr) - err := l.Init(&logger) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -89,7 +89,7 @@ func TestUnixSockEstablishThenEnd(t *testing.T) { }() time.Sleep(time.Millisecond) - net.Dial("unix", l.listen.Addr().String()) + _, _ = net.Dial("unix", l.listen.Addr().String()) require.Equal(t, true, <-established) l.Close(MockCloser) <-o diff --git a/mqtt/listeners/websocket.go b/mqtt/listeners/websocket.go index 4e1f4d8..50715fc 100644 --- a/mqtt/listeners/websocket.go +++ b/mqtt/listeners/websocket.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -14,8 +14,9 @@ import ( "sync/atomic" "time" + "log/slog" + "github.com/gorilla/websocket" - "github.com/rs/zerolog" ) var ( @@ -29,8 +30,8 @@ type Websocket struct { // [MQTT-4.2.0-1] id string // the internal id of the listener address string // the network address to bind to config *Config // configuration values for the listener - listen *http.Server // an http server for serving websocket connections - log *zerolog.Logger // server logger + listen *http.Server // a http server for serving websocket connections + log *slog.Logger // server logger establish EstablishFn // the server's establish connection handler upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection. end uint32 // ensure the close methods are only called once @@ -75,7 +76,7 @@ func (l *Websocket) Protocol() string { } // Init initializes the listener. -func (l *Websocket) Init(log *zerolog.Logger) error { +func (l *Websocket) Init(log *slog.Logger) error { l.log = log mux := http.NewServeMux() @@ -101,7 +102,7 @@ func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) { err = l.establish(l.id, &wsConn{Conn: c.UnderlyingConn(), c: c}) if err != nil { - l.log.Warn().Err(err).Send() + l.log.Warn("", "error", err) } } @@ -111,9 +112,9 @@ func (l *Websocket) Serve(establish EstablishFn) { l.establish = establish if l.listen.TLSConfig != nil { - l.listen.ListenAndServeTLS("", "") + _ = l.listen.ListenAndServeTLS("", "") } else { - l.listen.ListenAndServe() + _ = l.listen.ListenAndServe() } } @@ -125,7 +126,7 @@ func (l *Websocket) Close(closeClients CloseFn) { if atomic.CompareAndSwapUint32(&l.end, 0, 1) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - l.listen.Shutdown(ctx) + _ = l.listen.Shutdown(ctx) } closeClients(l.id) @@ -136,7 +137,7 @@ type wsConn struct { net.Conn c *websocket.Conn - // reader for the current message (may be nil) + // reader for the current message (can be nil) r io.Reader } diff --git a/mqtt/listeners/websocket_test.go b/mqtt/listeners/websocket_test.go index ee91f81..a2db1bb 100644 --- a/mqtt/listeners/websocket_test.go +++ b/mqtt/listeners/websocket_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -37,24 +37,24 @@ func TestWebsocketProtocol(t *testing.T) { require.Equal(t, "ws", l.Protocol()) } -func TestWebsocketProtocoTLS(t *testing.T) { +func TestWebsocketProtocolTLS(t *testing.T) { l := NewWebsocket("t1", testAddr, &Config{ TLSConfig: tlsConfigBasic, }) require.Equal(t, "wss", l.Protocol()) } -func TestWebsockeInit(t *testing.T) { +func TestWebsocketInit(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) require.Nil(t, l.listen) - err := l.Init(nil) + err := l.Init(logger) require.NoError(t, err) require.NotNil(t, l.listen) } func TestWebsocketServeAndClose(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) - l.Init(nil) + _ = l.Init(logger) o := make(chan bool) go func(o chan bool) { @@ -77,7 +77,7 @@ func TestWebsocketServeTLSAndClose(t *testing.T) { l := NewWebsocket("t1", testAddr, &Config{ TLSConfig: tlsConfigBasic, }) - err := l.Init(nil) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -96,7 +96,7 @@ func TestWebsocketServeTLSAndClose(t *testing.T) { func TestWebsocketUpgrade(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) - l.Init(nil) + _ = l.Init(logger) e := make(chan bool) l.establish = func(id string, c net.Conn) error { @@ -110,12 +110,12 @@ func TestWebsocketUpgrade(t *testing.T) { require.Equal(t, true, <-e) s.Close() - ws.Close() + _ = ws.Close() } func TestWebsocketConnectionReads(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) - l.Init(nil) + _ = l.Init(nil) recv := make(chan []byte) l.establish = func(id string, c net.Conn) error { @@ -151,5 +151,5 @@ func TestWebsocketConnectionReads(t *testing.T) { require.Equal(t, pkt, got) s.Close() - ws.Close() + _ = ws.Close() } diff --git a/mqtt/packets/codec.go b/mqtt/packets/codec.go index 029cfa7..152d777 100644 --- a/mqtt/packets/codec.go +++ b/mqtt/packets/codec.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/packets/codec_test.go b/mqtt/packets/codec_test.go index 8b10126..9129721 100644 --- a/mqtt/packets/codec_test.go +++ b/mqtt/packets/codec_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/packets/codes.go b/mqtt/packets/codes.go index 7e314de..5af1b74 100644 --- a/mqtt/packets/codes.go +++ b/mqtt/packets/codes.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -126,6 +126,7 @@ var ( ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"} ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"} ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"} + ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."} // MQTTv3 specific bytes. Err3UnsupportedProtocolVersion = Code{Code: 0x01} diff --git a/mqtt/packets/codes_test.go b/mqtt/packets/codes_test.go index 694f47e..aed8e57 100644 --- a/mqtt/packets/codes_test.go +++ b/mqtt/packets/codes_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -19,7 +19,7 @@ func TestCodesString(t *testing.T) { require.Equal(t, "test", c.String()) } -func TestCodesErrorr(t *testing.T) { +func TestCodesError(t *testing.T) { c := Code{ Reason: "error", Code: 0x1, diff --git a/mqtt/packets/fixedheader.go b/mqtt/packets/fixedheader.go index ddf68ca..eb20451 100644 --- a/mqtt/packets/fixedheader.go +++ b/mqtt/packets/fixedheader.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/packets/fixedheader_test.go b/mqtt/packets/fixedheader_test.go index 8f7acf4..fe8c497 100644 --- a/mqtt/packets/fixedheader_test.go +++ b/mqtt/packets/fixedheader_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/packets/packets.go b/mqtt/packets/packets.go index e53fe12..ff5930b 100644 --- a/mqtt/packets/packets.go +++ b/mqtt/packets/packets.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -14,7 +14,7 @@ import ( "sync" ) -// All of the valid packet types and their packet identifier. +// All valid packet types and their packet identifiers. const ( Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets. Connect // 1 @@ -37,9 +37,9 @@ const ( var ( // ErrNoValidPacketAvailable indicates the packet type byte provided does not exist in the mqtt specification. - ErrNoValidPacketAvailable error = errors.New("no valid packet available") + ErrNoValidPacketAvailable = errors.New("no valid packet available") - // PacketNames is a map of packet bytes to human readable names, for easier debugging. + // PacketNames is a map of packet bytes to human-readable names, for easier debugging. PacketNames = map[byte]string{ 0: "Reserved", 1: "Connect", @@ -272,28 +272,28 @@ func (s Subscription) Merge(n Subscription) Subscription { } // encode encodes a subscription and properties into bytes. -func (p Subscription) encode() byte { +func (s Subscription) encode() byte { var flag byte - flag |= p.Qos + flag |= s.Qos - if p.NoLocal { + if s.NoLocal { flag |= 1 << 2 } - if p.RetainAsPublished { + if s.RetainAsPublished { flag |= 1 << 3 } - flag |= p.RetainHandling << 4 + flag |= s.RetainHandling << 4 return flag } // decode decodes subscription bytes into a subscription struct. -func (p *Subscription) decode(b byte) { - p.Qos = b & 3 // byte - p.NoLocal = 1&(b>>2) > 0 // bool - p.RetainAsPublished = 1&(b>>3) > 0 // bool - p.RetainHandling = 3 & (b >> 4) // byte +func (s *Subscription) decode(b byte) { + s.Qos = b & 3 // byte + s.NoLocal = 1&(b>>2) > 0 // bool + s.RetainAsPublished = 1&(b>>3) > 0 // bool + s.RetainHandling = 3 & (b >> 4) // byte } // ConnectEncode encodes a connect packet. @@ -343,7 +343,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -505,7 +505,7 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -548,7 +548,7 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -619,7 +619,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -707,7 +707,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -844,7 +844,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -901,7 +901,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -996,7 +996,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -1049,7 +1049,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -1109,7 +1109,7 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } diff --git a/mqtt/packets/packets_test.go b/mqtt/packets/packets_test.go index c08ff10..1e18f1f 100644 --- a/mqtt/packets/packets_test.go +++ b/mqtt/packets/packets_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -150,7 +150,7 @@ func TestPacketEncode(t *testing.T) { } pk := new(Packet) - copier.Copy(pk, wanted.Packet) + _ = copier.Copy(pk, wanted.Packet) require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc) pk.Mods.AllowResponseInfo = true @@ -218,7 +218,7 @@ func TestPacketDecode(t *testing.T) { pk := &Packet{FixedHeader: FixedHeader{Type: pkt}} pk.Mods.AllowResponseInfo = true - pk.FixedHeader.Decode(wanted.RawBytes[0]) + _ = pk.FixedHeader.Decode(wanted.RawBytes[0]) if len(wanted.RawBytes) > 0 { pk.FixedHeader.Remaining = int(wanted.RawBytes[1]) } diff --git a/mqtt/packets/properties.go b/mqtt/packets/properties.go index ea77e2b..1fc02fd 100644 --- a/mqtt/packets/properties.go +++ b/mqtt/packets/properties.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -77,7 +77,7 @@ type UserProperty struct { // [MQTT-1.5.7-1] Val string `json:"v"` } -// Properties contains all of the mqtt v5 properties available for a packet. +// Properties contains all mqtt v5 properties available for a packet. // Some properties have valid values of 0 or not-present. In this case, we opt for // property flags to indicate the usage of property. // Refer to mqtt v5 2.2.2.2 Property spec for more information. @@ -355,7 +355,7 @@ func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) { } encodeLength(b, int64(buf.Len())) - buf.WriteTo(b) // [MQTT-3.1.3-10] + _, _ = buf.WriteTo(b) // [MQTT-3.1.3-10] } // Decode decodes property bytes into a properties struct. diff --git a/mqtt/packets/properties_test.go b/mqtt/packets/properties_test.go index 8d326ba..b0a2f10 100644 --- a/mqtt/packets/properties_test.go +++ b/mqtt/packets/properties_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/packets/tpackets.go b/mqtt/packets/tpackets.go index 8c21dbd..267721e 100644 --- a/mqtt/packets/tpackets.go +++ b/mqtt/packets/tpackets.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -40,7 +40,6 @@ const ( TConnectMqtt5 TConnectMqtt5LWT TConnectClean - TConnectCleanLWT TConnectUserPass TConnectUserPassLWT TConnectMalProtocolName @@ -61,7 +60,6 @@ const ( TConnectInvalidProtocolVersion2 TConnectInvalidReservedBit TConnectInvalidClientIDTooLong - TConnectInvalidPasswordNoUsername TConnectInvalidFlagNoUsername TConnectInvalidFlagNoPassword TConnectInvalidUsernameNoFlag @@ -131,12 +129,14 @@ const ( TPublishSpecDenySysTopic TPuback TPubackMqtt5 + TPubackMqtt5NotAuthorized TPubackMalPacketID TPubackMalProperties TPubackUnexpectedError TPubrec TPubrecMqtt5 TPubrecMqtt5IDInUse + TPubrecMqtt5NotAuthorized TPubrecMalPacketID TPubrecMalProperties TPubrecMalReasonCode @@ -184,7 +184,6 @@ const ( TUnsubscribe TUnsubscribeMany TUnsubscribeMqtt5 - TUnsubscribeDropProperties TUnsubscribeMalPacketID TUnsubscribeMalTopicName TUnsubscribeMalProperties @@ -202,7 +201,6 @@ const ( TDisconnect TDisconnectTakeover TDisconnectMqtt5 - TDisconnectNormalMqtt5 TDisconnectSecondConnect TDisconnectReceiveMaximum TDisconnectDropProperties @@ -2274,6 +2272,40 @@ var TPacketData = map[byte]TPacketCases{ }, }, }, + { + Case: TPubackMqtt5NotAuthorized, + Desc: "QOS 1 publish not authorized mqtt5", + Primary: true, + RawBytes: []byte{ + Puback << 4, 37, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrNotAuthorized.Code, // Reason Code + 33, // Properties Length + 31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u', + 't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Puback, + Remaining: 31, + }, + PacketID: 7, + ReasonCode: ErrNotAuthorized.Code, + Properties: Properties{ + ReasonString: ErrNotAuthorized.Reason, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, { Case: TPubackUnexpectedError, Desc: "unexpected error", @@ -2412,6 +2444,40 @@ var TPacketData = map[byte]TPacketCases{ }, }, }, + { + Case: TPubrecMqtt5NotAuthorized, + Desc: "QOS 2 publish not authorized mqtt5", + Primary: true, + RawBytes: []byte{ + Pubrec << 4, 37, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrNotAuthorized.Code, // Reason Code + 33, // Properties Length + 31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u', + 't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Pubrec, + Remaining: 31, + }, + PacketID: 7, + ReasonCode: ErrNotAuthorized.Code, + Properties: Properties{ + ReasonString: ErrNotAuthorized.Reason, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, { Case: TPubrecMalReasonCode, Desc: "malformed reason code", diff --git a/mqtt/packets/tpackets_test.go b/mqtt/packets/tpackets_test.go index c50bb55..8114207 100644 --- a/mqtt/packets/tpackets_test.go +++ b/mqtt/packets/tpackets_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/server.go b/mqtt/server.go index 30fa937..f74b1e0 100644 --- a/mqtt/server.go +++ b/mqtt/server.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co, wind // package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility. @@ -22,12 +22,14 @@ import ( "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/mqtt/system" - "github.com/rs/zerolog" + "log/slog" ) const ( - Version = "2.3.0" // the current server version. + Version = "2.4.0" // the current server version. defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes + LocalListener = "local" + InlineClientId = "inline" ) var ( @@ -36,7 +38,7 @@ var ( MaximumSessionExpiryInterval: math.MaxUint32, // maximum number of seconds to keep disconnected sessions MaximumMessageExpiryInterval: 60 * 60 * 24, // maximum message expiry if message expiry is 0 or over ReceiveMaximum: 1024, // maximum number of concurrent qos messages per client - MaximumQos: 2, // maxmimum qos value available to clients + MaximumQos: 2, // maximum qos value available to clients RetainAvailable: 1, // retain messages is available MaximumPacketSize: 0, // no maximum packet size TopicAliasMaximum: math.MaxUint16, // maximum topic alias value @@ -47,15 +49,16 @@ var ( MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client } - ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists. - ErrConnectionClosed = errors.New("connection not open") // connection is closed + ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists + ErrConnectionClosed = errors.New("connection not open") // connection is closed + ErrInlineClientNotEnabled = errors.New("please set Options.InlineClient=true to use this feature") // inline client is not enabled by default ) // Capabilities indicates the capabilities and features provided by the server. type Capabilities struct { MaximumMessageExpiryInterval int64 `yaml:"maximum-message-expiry-interval"` - MaximumSessionExpiryInterval uint32 `yaml:"maximum-session-expiry-interval"` MaximumClientWritesPending int32 `yaml:"maximum-client-writes-pending"` + MaximumSessionExpiryInterval uint32 `yaml:"maximum-session-expiry-interval"` MaximumPacketSize uint32 `yaml:"maximum-packet-size"` maximumPacketID uint32 // unexported, used for testing only ReceiveMaximum uint16 `yaml:"receive-maximum"` @@ -85,36 +88,44 @@ type Options struct { // server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024 Capabilities *Capabilities + // ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer. + ClientNetWriteBufferSize int + + // ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer. + ClientNetReadBufferSize int + // Logger specifies a custom configured implementation of zerolog to override // the servers default logger configuration. If you wish to change the log level, // of the default logger, you can do so by setting // server := mqtt.New(nil) - // l := server.Log.Level(zerolog.DebugLevel) - // server.Log = &l - Logger *zerolog.Logger - - // ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer. - ClientNetWriteBufferSize int `yaml:"client-write-buffer-size"` - - // ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer. - ClientNetReadBufferSize int `yaml:"client-read-buffer-size"` + // level := new(slog.LevelVar) + // server.Slog = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + // Level: level, + // })) + // level.Set(slog.LevelDebug) + Logger *slog.Logger // SysTopicResendInterval specifies the interval between $SYS topic updates in seconds. - SysTopicResendInterval int64 `yaml:"sys-topic-resend-interval"` + SysTopicResendInterval int64 + + // Enable Inline client to allow direct subscribing and publishing from the parent codebase, + // with negligible performance difference (disabled by default to prevent confusion in statistics). + InlineClient bool } // Server is an MQTT broker server. It should be created with server.New() // in order to ensure all the internal fields are correctly populated. type Server struct { - Options *Options // configurable server options - Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections - Clients *Clients // clients known to the broker - Topics *TopicsIndex // an index of topic filter subscriptions and retained messages - Info *system.Info // values about the server commonly known as $SYS topics - loop *loop // loop contains tickers for the system event loop - done chan bool // indicate that the server is ending - Log *zerolog.Logger // minimal no-alloc logger - hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage. + Options *Options // configurable server options + Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections + Clients *Clients // clients known to the broker + Topics *TopicsIndex // an index of topic filter subscriptions and retained messages + Info *system.Info // values about the server commonly known as $SYS topics + loop *loop // loop contains tickers for the system event loop + done chan bool // indicate that the server is ending + Log *slog.Logger // minimal no-alloc logger + hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage + inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish } // loop contains interval tickers for the system events loop. @@ -123,16 +134,16 @@ type loop struct { clientExpiry *time.Ticker // interval ticker for cleaning expired clients inflightExpiry *time.Ticker // interval ticker for cleaning up expired inflight messages retainedExpiry *time.Ticker // interval ticker for cleaning retained messages - willDelaySend *time.Ticker // interval ticker for sending will messages with a delay + willDelaySend *time.Ticker // interval ticker for sending Will Messages with a delay willDelayed *packets.Packets // activate LWT packets which will be sent after a delay } // ops contains server values which can be propagated to other structs. type ops struct { - options *Options // a pointer to the server options and capabilities, for referencing in clients - info *system.Info // pointers to server system info - hooks *Hooks // pointer to the server hooks - log *zerolog.Logger // a structured logger for the client + options *Options // a pointer to the server options and capabilities, for referencing in clients + info *system.Info // pointers to server system info + hooks *Hooks // pointer to the server hooks + log *slog.Logger // a structured logger for the client } // New returns a new instance of comqtt broker. Optional parameters @@ -168,6 +179,11 @@ func New(opts *Options) *Server { }, } + if s.Options.InlineClient { + s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true) + s.Clients.Add(s.inlineClient) + } + return s } @@ -192,8 +208,8 @@ func (o *Options) ensureDefaults() { } if o.Logger == nil { - log := zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.InfoLevel).Output(zerolog.ConsoleWriter{Out: os.Stderr}) - o.Logger = &log + log := slog.New(slog.NewTextHandler(os.Stdout, nil)) + o.Logger = log } } @@ -227,12 +243,12 @@ func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) // AddHook attaches a new Hook to the server. Ideally, this should be called // before the server is started with s.Serve(). func (s *Server) AddHook(hook Hook, config any) error { - nl := s.Log.With().Str("hook", hook.ID()).Logger() - hook.SetOpts(&nl, &HookOptions{ + nl := s.Log.With("hook", hook.ID()) + hook.SetOpts(nl, &HookOptions{ Capabilities: s.Options.Capabilities, }) - s.Log.Info().Str("hook", hook.ID()).Msg("added hook") + s.Log.Info("added hook", "hook", hook.ID()) return s.hooks.Add(hook, config) } @@ -242,23 +258,23 @@ func (s *Server) AddListener(l listeners.Listener) error { return ErrListenerIDExists } - nl := s.Log.With().Str("listener", l.ID()).Logger() - err := l.Init(&nl) + nl := s.Log.With(slog.String("listener", l.ID())) + err := l.Init(nl) if err != nil { return err } s.Listeners.Add(l) - s.Log.Info().Str("id", l.ID()).Str("protocol", l.Protocol()).Str("address", l.Address()).Msg("attached listener") + s.Log.Info("attached listener", "id", l.ID(), "protocol", l.Protocol(), "address", l.Address()) return nil } // Serve starts the event loops responsible for establishing client connections // on all attached listeners, publishing the system topics, and starting all hooks. func (s *Server) Serve() error { - //s.Log.Info().Str("version", Version).Msg("comqtt starting") - defer s.Log.Info().Msg("comqtt server started") + //s.Log.Info("version", Version).Msg("comqtt starting") + defer s.Log.Info("comqtt server started") if s.hooks.Provides( StoredClients, @@ -283,8 +299,8 @@ func (s *Server) Serve() error { // eventLoop loops forever, running various server housekeeping methods at different intervals. func (s *Server) eventLoop() { - s.Log.Debug().Msg("system event loop started") - defer s.Log.Debug().Msg("system event loop halted") + s.Log.Debug("system event loop started") + defer s.Log.Debug("system event loop halted") for { select { @@ -375,8 +391,8 @@ func (s *Server) attachClient(cl *Client, listener string) error { } else { cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10] } + s.Log.Debug("client disconnected", "error", err, "client", cl.ID, "remote", cl.Net.Remote, "listener", listener) - s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected") expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) s.hooks.OnDisconnect(cl, err, expire) @@ -418,10 +434,10 @@ func (s *Server) receivePacket(cl *Client, pk packets.Packet) error { if code, ok := err.(packets.Code); ok && cl.Properties.ProtocolVersion == 5 && code.Code >= packets.ErrUnspecifiedError.Code { - s.DisconnectClient(cl, code) + _ = s.DisconnectClient(cl, code) } - s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("pk", pk).Msg("error processing packet") + s.Log.Warn("error processing packet", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "pk", pk) return err } @@ -456,7 +472,7 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code { // session is abandoned. func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok { - s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3] + _ = s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3] if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4] s.UnsubscribeClient(existing) existing.ClearInflights(math.MaxInt64, 0) @@ -487,10 +503,8 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { // from increasing memory usage by inflights + subs * client-id. s.UnsubscribeClient(existing) existing.ClearInflights(math.MaxInt64, 0) - s.Log.Debug().Str("client", cl.ID). - Str("old_remote", existing.Net.Remote). - Str("new_remote", cl.Net.Remote). - Msg("session taken over") + + s.Log.Debug("session taken over", "client", cl.ID, "old_remote", existing.Net.Remote, "new_remote", cl.Net.Remote) cl.InheritWay = InheritWayLocal return true // [MQTT-3.2.2-3] @@ -676,13 +690,16 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error { }) } -// Publish publishes a publish packet into the broker as if it were sent from the speicfied client. +// Publish publishes a publish packet into the broker as if it were sent from the specified client. // This is a convenience function which wraps InjectPacket. As such, this method can publish packets // to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the // outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete). func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error { - cl := s.NewClient(nil, "local", "inline", true) - return s.InjectPacket(cl, packets.Packet{ + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + + return s.InjectPacket(s.inlineClient, packets.Packet{ FixedHeader: packets.FixedHeader{ Type: packets.Publish, Qos: qos, @@ -694,6 +711,75 @@ func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) er }) } +// Subscribe adds an inline subscription for the specified topic filter and subscription identifier +// with the provided handler function. +func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubFn) error { + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + + if handler == nil { + return packets.ErrInlineSubscriptionHandlerInvalid + } + + if !IsValidFilter(filter, false) { + return packets.ErrTopicFilterInvalid + } + + subscription := packets.Subscription{ + Identifier: subscriptionId, + Filter: filter, + } + + pk := s.hooks.OnSubscribe(s.inlineClient, packets.Packet{ // subscribe like a normal client. + Origin: s.inlineClient.ID, + FixedHeader: packets.FixedHeader{Type: packets.Subscribe}, + Filters: packets.Subscriptions{subscription}, + }) + + inlineSubscription := InlineSubscription{ + Subscription: subscription, + Handler: handler, + } + + _, count := s.Topics.InlineSubscribe(inlineSubscription) + s.hooks.OnSubscribed(s.inlineClient, pk, []byte{packets.CodeSuccess.Code}, []int{count}) + + // Handling retained messages. + for _, pkv := range s.Topics.Messages(filter) { // [MQTT-3.8.4-4] + handler(s.inlineClient, inlineSubscription.Subscription, pkv) + } + return nil +} + +// Unsubscribe removes an inline subscription for the specified subscription and topic filter. +// It allows you to unsubscribe a specific subscription from the internal subscription +// associated with the given topic filter. +func (s *Server) Unsubscribe(filter string, subscriptionId int) error { + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + + if !IsValidFilter(filter, false) { + return packets.ErrTopicFilterInvalid + } + + pk := s.hooks.OnUnsubscribe(s.inlineClient, packets.Packet{ + Origin: s.inlineClient.ID, + FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe}, + Filters: packets.Subscriptions{ + { + Identifier: subscriptionId, + Filter: filter, + }, + }, + }) + + _, count := s.Topics.InlineUnsubscribe(subscriptionId, filter) + s.hooks.OnUnsubscribed(s.inlineClient, pk, []byte{packets.CodeSuccess.Code}, []int{count}) + return nil +} + // InjectPacket injects a packet into the broker as if it were sent from the specified client. // InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks. func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error { @@ -723,7 +809,21 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { } if !cl.Net.Inline && !s.hooks.OnACLCheck(cl, pk.TopicName, true) { - return nil + if pk.FixedHeader.Qos == 0 { + return nil + } + + if cl.Properties.ProtocolVersion != 5 { + return s.DisconnectClient(cl, packets.ErrNotAuthorized) + } + + ackType := packets.Puback + if pk.FixedHeader.Qos == 2 { + ackType = packets.Pubrec + } + + ack := s.buildAck(pk.PacketID, ackType, 0, pk.Properties, packets.ErrNotAuthorized) + return cl.WritePacket(ack) } pk.Origin = cl.ID @@ -746,7 +846,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { } if pk.FixedHeader.Qos > s.Options.Capabilities.MaximumQos { - pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce Qos based on server max qos capability + pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce qos based on server max qos capability } pkx, err := s.hooks.OnPublish(cl, pk) @@ -768,7 +868,10 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { s.retainMessage(cl, pk) } - if pk.FixedHeader.Qos == 0 { + // If it's inlineClient, it can't handle PUBREC and PUBREL. + // When it publishes a package with a qos > 0, the server treats + // the package as qos=0, and the client receives it as qos=1 or 2. + if pk.FixedHeader.Qos == 0 || cl.Net.Inline { s.PublishToSubscribers(pk) s.hooks.OnPublished(cl, pk) return nil @@ -841,11 +944,15 @@ func (s *Server) PublishToSubscribers(pk packets.Packet) { subscribers.MergeSharedSelected() } + for _, inlineSubscription := range subscribers.InlineSubscriptions { + inlineSubscription.Handler(s.inlineClient, inlineSubscription.Subscription, pk) + } + for id, subs := range subscribers.Subscriptions { if cl, ok := s.Clients.Get(id); ok { _, err := s.publishToClient(cl, subs, pk) if err != nil { - s.Log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet") + s.Log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk) } } } @@ -857,6 +964,9 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet } out := pk.Copy(false) + if !s.hooks.OnACLCheck(cl, pk.TopicName, false) { + return out, packets.ErrNotAuthorized + } if !sub.FwdRetainedFlag && ((cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished) || cl.Properties.ProtocolVersion < 5) { // ![MQTT-3.3.1-13] [v3 MQTT-3.3.1-9] out.FixedHeader.Retain = false // [MQTT-3.3.1-12] } @@ -892,7 +1002,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet i, err := cl.NextPacketID() // [MQTT-4.3.2-1] [MQTT-4.3.3-1] if err != nil { s.hooks.OnPacketIDExhausted(cl, pk) - s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Msg("packet ids exhausted") + s.Log.Warn("packet ids exhausted", "error", err, "client", cl.ID, "listener", cl.Net.Listener) return out, packets.ErrQuotaExceeded } @@ -943,7 +1053,7 @@ func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, e for _, pkv := range s.Topics.Messages(sub.Filter) { // [MQTT-3.8.4-4] _, err := s.publishToClient(cl, sub, pkv) if err != nil { - s.Log.Debug().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message") + s.Log.Debug("failed to publish retained message", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "packet", pkv) continue } s.hooks.OnRetainPublished(cl, pkv) @@ -1182,12 +1292,20 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error { func (s *Server) UnsubscribeClient(cl *Client) { i := 0 filterMap := cl.State.Subscriptions.GetAll() + + for k := range filterMap { + cl.State.Subscriptions.Delete(k) + } + + if atomic.LoadUint32(&cl.State.isTakenOver) == 1 { + return + } + length := len(filterMap) filters := make([]packets.Subscription, length) reasonCodes := make([]byte, length) counts := make([]int, length) // An array of the number of subscribers for the same filter for k, v := range filterMap { - cl.State.Subscriptions.Delete(k) q, count := s.Topics.Unsubscribe(k, cl.ID) if q { atomic.AddInt64(&s.Info.Subscriptions, -1) @@ -1200,7 +1318,7 @@ func (s *Server) UnsubscribeClient(cl *Client) { i++ } - s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters}, reasonCodes, counts) + s.hooks.OnUnsubscribed(cl, packets.Packet{FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe}, Filters: filters}, reasonCodes, counts) } // processAuth processes an Auth packet. @@ -1318,7 +1436,7 @@ func (s *Server) Close() error { s.hooks.OnStopped() s.hooks.Stop() - s.Log.Info().Msg("comqtt server stopped") + s.Log.Info("comqtt server stopped") return nil } @@ -1326,7 +1444,7 @@ func (s *Server) Close() error { func (s *Server) closeListenerClients(listener string) { clients := s.Clients.GetByListener(listener) for _, cl := range clients { - s.DisconnectClient(cl, packets.ErrServerShuttingDown) + _ = s.DisconnectClient(cl, packets.ErrServerShuttingDown) } } @@ -1377,9 +1495,7 @@ func (s *Server) readStore() error { return fmt.Errorf("failed to load clients; %w", err) } s.loadClients(clients) - s.Log.Debug(). - Int("len", len(clients)). - Msg("loaded clients from store") + s.Log.Debug("loaded clients from store", "len", len(clients)) } if s.hooks.Provides(StoredSubscriptions) { @@ -1388,9 +1504,7 @@ func (s *Server) readStore() error { return fmt.Errorf("load subscriptions; %w", err) } s.loadSubscriptions(subs) - s.Log.Debug(). - Int("len", len(subs)). - Msg("loaded subscriptions from store") + s.Log.Debug("loaded subscriptions from store", "len", len(subs)) } if s.hooks.Provides(StoredInflightMessages) { @@ -1399,9 +1513,7 @@ func (s *Server) readStore() error { return fmt.Errorf("load inflight; %w", err) } s.loadInflight(inflight) - s.Log.Debug(). - Int("len", len(inflight)). - Msg("loaded inflights from store") + s.Log.Debug("loaded inflights from store", "len", len(inflight)) } if s.hooks.Provides(StoredRetainedMessages) { @@ -1410,9 +1522,7 @@ func (s *Server) readStore() error { return fmt.Errorf("load retained; %w", err) } s.loadRetained(retained) - s.Log.Debug(). - Int("len", len(retained)). - Msg("loaded retained messages from store") + s.Log.Debug("loaded retained messages from store", "len", len(retained)) } if s.hooks.Provides(StoredSysInfo) { @@ -1421,8 +1531,7 @@ func (s *Server) readStore() error { return fmt.Errorf("load server info; %w", err) } s.loadServerInfo(sysInfo.Info) - s.Log.Debug(). - Msg("loaded $SYS info from store") + s.Log.Debug("loaded $SYS info from store") } return nil diff --git a/mqtt/server_test.go b/mqtt/server_test.go index 3234911..3171bb0 100644 --- a/mqtt/server_test.go +++ b/mqtt/server_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt @@ -8,9 +8,10 @@ import ( "bytes" "encoding/binary" "io" + "log/slog" "net" - "os" "strconv" + "sync" "sync/atomic" "testing" "time" @@ -20,11 +21,10 @@ import ( "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/mqtt/system" - "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) -var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) +var logger = slog.New(slog.NewTextHandler(io.Discard, nil)) type ProtocolTest []struct { protocolVersion byte @@ -37,6 +37,11 @@ type AllowHook struct { HookBase } +func (h *AllowHook) SetOpts(l *slog.Logger, opts *HookOptions) { + h.Log = l + h.Opts = opts +} + func (h *AllowHook) ID() string { return "allow-all-auth" } @@ -48,11 +53,36 @@ func (h *AllowHook) Provides(b byte) bool { func (h *AllowHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return true } func (h *AllowHook) OnACLCheck(cl *Client, topic string, write bool) bool { return true } +type DenyHook struct { + HookBase +} + +func (h *DenyHook) SetOpts(l *slog.Logger, opts *HookOptions) { + h.Log = l + h.Opts = opts +} + +func (h *DenyHook) ID() string { + return "deny-all-auth" +} + +func (h *DenyHook) Provides(b byte) bool { + return bytes.Contains([]byte{OnConnectAuthenticate, OnACLCheck}, []byte{b}) +} + +func (h *DenyHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return false } +func (h *DenyHook) OnACLCheck(cl *Client, topic string, write bool) bool { return false } + type DelayHook struct { HookBase DisconnectDelay time.Duration } +func (h *DelayHook) SetOpts(l *slog.Logger, opts *HookOptions) { + h.Log = l + h.Opts = opts +} + func (h *DelayHook) ID() string { return "delay-hook" } @@ -69,12 +99,24 @@ func newServer() *Server { cc := *DefaultServerCapabilities cc.MaximumMessageExpiryInterval = 0 cc.ReceiveMaximum = 0 + s := New(&Options{ + Logger: logger, + Capabilities: &cc, + }) + _ = s.AddHook(new(AllowHook), nil) + return s +} +func newServerWithInlineClient() *Server { + cc := *DefaultServerCapabilities + cc.MaximumMessageExpiryInterval = 0 + cc.ReceiveMaximum = 0 s := New(&Options{ - Logger: &logger, + Logger: logger, Capabilities: &cc, + InlineClient: true, }) - s.AddHook(new(AllowHook), nil) + _ = s.AddHook(new(AllowHook), nil) return s } @@ -106,6 +148,16 @@ func TestNew(t *testing.T) { require.NotNil(t, s.hooks) require.NotNil(t, s.hooks.Log) require.NotNil(t, s.done) + require.Nil(t, s.inlineClient) + require.Equal(t, 0, s.Clients.Len()) +} + +func TestNewWithInlineClient(t *testing.T) { + s := New(&Options{ + InlineClient: true, + }) + require.NotNil(t, s.inlineClient) + require.Equal(t, 1, s.Clients.Len()) } func TestNewNilOpts(t *testing.T) { @@ -116,7 +168,7 @@ func TestNewNilOpts(t *testing.T) { func TestServerNewClient(t *testing.T) { s := New(nil) - s.Log = &logger + s.Log = logger r, _ := net.Pipe() cl := s.NewClient(r, "testing", "test", false) @@ -143,7 +195,8 @@ func TestServerNewClientInline(t *testing.T) { func TestServerAddHook(t *testing.T) { s := New(nil) - s.Log = &logger + + s.Log = logger require.NotNil(t, s) require.Equal(t, int64(0), s.hooks.Len()) @@ -247,8 +300,8 @@ func TestServerReadConnectionPacket(t *testing.T) { }() go func() { - r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _ = r.Close() }() require.Equal(t, *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet, <-o) @@ -268,8 +321,8 @@ func TestServerReadConnectionPacketBadFixedHeader(t *testing.T) { }() go func() { - r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalFixedHeader).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalFixedHeader).RawBytes) + _ = r.Close() }() err := <-o @@ -285,8 +338,8 @@ func TestServerReadConnectionPacketBadPacketType(t *testing.T) { s.Clients.Add(cl) go func() { - r.Write(packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes) + _ = r.Close() }() _, err := s.readConnectionPacket(cl) @@ -302,8 +355,8 @@ func TestServerReadConnectionPacketBadPacket(t *testing.T) { s.Clients.Add(cl) go func() { - r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalProtocolName).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalProtocolName).RawBytes) + _ = r.Close() }() _, err := s.readConnectionPacket(cl) @@ -322,8 +375,8 @@ func TestEstablishConnection(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack @@ -336,14 +389,18 @@ func TestEstablishConnection(t *testing.T) { err := <-o require.NoError(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect - } + + // Todo: + // s.Clients is already empty here. Is it necessary to check v.StopCause()? + + // for _, v := range s.Clients.GetAll() { + // require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect + // } require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() // client must be deleted on session close if Clean = true _, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet.Connect.ClientIdentifier) @@ -361,15 +418,15 @@ func TestEstablishConnectionAckFailure(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Close() + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _ = w.Close() }() err := <-o require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) - r.Close() + _ = r.Close() } func TestEstablishConnectionReadError(t *testing.T) { @@ -383,8 +440,8 @@ func TestEstablishConnectionReadError(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).RawBytes) - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) // second connect error + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) // second connect error }() // receive the connack @@ -397,9 +454,11 @@ func TestEstablishConnectionReadError(t *testing.T) { err := <-o require.Error(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect - } + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect ret := <-recv require.Equal(t, append( @@ -408,8 +467,8 @@ func TestEstablishConnectionReadError(t *testing.T) { ret, ) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() } func TestEstablishConnectionInheritExisting(t *testing.T) { @@ -432,9 +491,9 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) time.Sleep(time.Millisecond) // we want to receive the queued inflight, so we need to wait a moment before sending the disconnect. - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the disconnect session takeover @@ -455,9 +514,11 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { err := <-o require.NoError(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect - } + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect connackPlusPacket := append( packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes, @@ -467,8 +528,8 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectTakeover).RawBytes, <-takeover) time.Sleep(time.Microsecond * 100) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) require.True(t, ok) @@ -478,12 +539,12 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { require.Empty(t, cl.State.Subscriptions.GetAll()) } -// See https://github.com/mochi-co/mqtt/issues/173 +// See https://github.com/mochi-mqtt/server/issues/173 func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { s := newServer() d := new(DelayHook) d.DisconnectDelay = time.Millisecond * 200 - s.AddHook(d, nil) + _ = s.AddHook(d, nil) defer s.Close() // Clean session, 0 session expiry interval @@ -508,7 +569,7 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { o1 <- err }() go func() { - w1.Write(cl1RawBytes) + _, _ = w1.Write(cl1RawBytes) }() // receive the first connack @@ -537,7 +598,7 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { go func() { x := packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes[:] x[19] = '.' // differentiate username bytes in debugging - w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes) + _, _ = w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes) }() // receive the second connack @@ -565,7 +626,7 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { require.NotEmpty(t, clp2.State.Subscriptions.GetAll()) require.Empty(t, clp1.State.Subscriptions.GetAll()) - w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) require.NoError(t, <-o2) } @@ -588,7 +649,7 @@ func TestEstablishConnectionResentPendingInflightsError(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) }() go func() { @@ -623,8 +684,8 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the disconnect @@ -645,15 +706,17 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { err := <-o require.NoError(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect - } + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) require.True(t, ok) @@ -662,7 +725,7 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { func TestEstablishConnectionBadAuthentication(t *testing.T) { s := New(&Options{ - Logger: &logger, + Logger: logger, }) defer s.Close() @@ -673,8 +736,8 @@ func TestEstablishConnectionBadAuthentication(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack @@ -690,13 +753,13 @@ func TestEstablishConnectionBadAuthentication(t *testing.T) { require.ErrorIs(t, err, packets.ErrBadUsernameOrPassword) require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackBadUsernamePasswordNoSession).RawBytes, <-recv) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() } func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) { s := New(&Options{ - Logger: &logger, + Logger: logger, }) defer s.Close() @@ -707,15 +770,15 @@ func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Close() + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _ = w.Close() }() err := <-o require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionInvalidConnect(t *testing.T) { @@ -728,8 +791,8 @@ func TestServerEstablishConnectionInvalidConnect(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack @@ -745,10 +808,10 @@ func TestServerEstablishConnectionInvalidConnect(t *testing.T) { require.ErrorIs(t, packets.ErrProtocolViolationReservedBit, err) require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackProtocolViolationNoSession).RawBytes, <-recv) - r.Close() + _ = r.Close() } -// See https://github.com/mochi-co/mqtt/issues/178 +// See https://github.com/mochi-mqtt/server/issues/178 func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) { s := newServer() @@ -759,8 +822,8 @@ func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectZeroByteUsername).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectZeroByteUsername).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack error @@ -772,7 +835,7 @@ func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) { err := <-o require.NoError(t, err) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionInvalidConnectAckFailure(t *testing.T) { @@ -785,15 +848,15 @@ func TestServerEstablishConnectionInvalidConnectAckFailure(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) - w.Close() + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) + _ = w.Close() }() err := <-o require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionBadPacket(t *testing.T) { @@ -806,15 +869,15 @@ func TestServerEstablishConnectionBadPacket(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnackBadProtocolVersion).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnackBadProtocolVersion).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() err := <-o require.Error(t, err) require.ErrorIs(t, err, packets.ErrProtocolViolationRequireFirstConnect) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionOnConnectError(t *testing.T) { @@ -831,14 +894,14 @@ func TestServerEstablishConnectionOnConnectError(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) }() err = <-o require.Error(t, err) require.ErrorIs(t, err, errTestHook) - r.Close() + _ = r.Close() } func TestServerSendConnack(t *testing.T) { @@ -852,7 +915,7 @@ func TestServerSendConnack(t *testing.T) { go func() { err := s.SendConnack(cl, packets.CodeSuccess, true, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -867,7 +930,7 @@ func TestServerSendConnackFailureReason(t *testing.T) { go func() { err := s.SendConnack(cl, packets.ErrUnspecifiedError, true, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -884,7 +947,7 @@ func TestServerSendConnackWithServerKeepalive(t *testing.T) { go func() { err := s.SendConnack(cl, packets.CodeSuccess, true, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -963,7 +1026,7 @@ func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) { go func() { err := s.SendConnack(cl, packets.CodeSuccess, false, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1042,7 +1105,7 @@ func TestServerProcessPacketPingreq(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1071,7 +1134,7 @@ func TestServerProcessPacketPublishInvalid(t *testing.T) { func TestInjectPacketPublishAndReceive(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -1096,17 +1159,18 @@ func TestInjectPacketPublishAndReceive(t *testing.T) { go func() { err := s.InjectPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) require.NoError(t, err) - w1.Close() + _ = w1.Close() time.Sleep(time.Millisecond * 10) - w2.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) } -func TestServerDirectPublishAndReceive(t *testing.T) { - s := newServer() - s.Serve() +func TestServerPublishAndReceive(t *testing.T) { + s := newServerWithInlineClient() + + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -1132,14 +1196,22 @@ func TestServerDirectPublishAndReceive(t *testing.T) { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos) require.NoError(t, err) - w1.Close() + _ = w1.Close() time.Sleep(time.Millisecond * 10) - w2.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) } +func TestServerPublishNoInlineClient(t *testing.T) { + s := newServer() + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + func TestInjectPacketError(t *testing.T) { s := newServer() defer s.Close() @@ -1164,7 +1236,7 @@ func TestInjectPacketPublishInvalidTopic(t *testing.T) { func TestServerProcessPacketPublishAndReceive(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -1190,8 +1262,8 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) { err := s.processPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) require.NoError(t, err) time.Sleep(time.Millisecond * 10) - w1.Close() - w2.Close() + _ = w1.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) @@ -1256,7 +1328,7 @@ func TestServerProcessPacketAndNextImmediate(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1266,15 +1338,15 @@ func TestServerProcessPacketAndNextImmediate(t *testing.T) { require.Equal(t, int32(4), cl.State.Inflight.sendQuota) } -func TestServerProcessPacketPublishAckFailure(t *testing.T) { +func TestServerProcessPublishAckFailure(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() cl, _, w := newTestClient() s.Clients.Add(cl) - w.Close() + _ = w.Close() err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) @@ -1291,7 +1363,7 @@ func TestServerProcessPublishOnPublishAckErrorRWError(t *testing.T) { cl, _, w := newTestClient() cl.Properties.ProtocolVersion = 5 s.Clients.Add(cl) - w.Close() + _ = w.Close() err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.Error(t, err) @@ -1305,7 +1377,7 @@ func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) { hook.err = packets.ErrPayloadFormatInvalid err := s.AddHook(hook, nil) require.NoError(t, err) - s.Serve() + _ = s.Serve() defer s.Close() cl, r, w := newTestClient() @@ -1315,7 +1387,7 @@ func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1330,7 +1402,7 @@ func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) { hook.err = packets.CodeSuccessIgnore err := s.AddHook(hook, nil) require.NoError(t, err) - s.Serve() + _ = s.Serve() defer s.Close() cl, r, w := newTestClient() @@ -1354,8 +1426,8 @@ func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() - w2.Close() + _ = w.Close() + _ = w2.Close() }() buf, err := io.ReadAll(r) @@ -1367,7 +1439,7 @@ func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) { func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() cl, r, w := newTestClient() @@ -1379,7 +1451,7 @@ func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.Error(t, err) require.ErrorIs(t, err, packets.ErrReceiveMaximum) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1389,22 +1461,107 @@ func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { func TestServerProcessPublishInvalidTopic(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() cl, _, _ := newTestClient() err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishSpecDenySysTopic).Packet) - require.NoError(t, err) // $SYS topics should be ignored? + require.NoError(t, err) // $SYS Topics should be ignored? } func TestServerProcessPublishACLCheckDeny(t *testing.T) { - s := New(&Options{ - Logger: &logger, - }) - s.Serve() - defer s.Close() - cl, _, _ := newTestClient() - err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) - require.NoError(t, err) // ACL check fails silently + tt := []struct { + name string + protocolVersion byte + pk packets.Packet + expectErr error + expectReponse []byte + expectDisconnect bool + }{ + { + name: "v4_QOS0", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet, + expectErr: nil, + expectReponse: nil, + expectDisconnect: false, + }, + { + name: "v4_QOS1", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet, + expectErr: packets.ErrNotAuthorized, + expectReponse: nil, + expectDisconnect: true, + }, + { + name: "v4_QOS2", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet, + expectErr: packets.ErrNotAuthorized, + expectReponse: nil, + expectDisconnect: true, + }, + { + name: "v5_QOS0", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet, + expectErr: nil, + expectReponse: nil, + expectDisconnect: false, + }, + { + name: "v5_QOS1", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Mqtt5).Packet, + expectErr: nil, + expectReponse: packets.TPacketData[packets.Puback].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes, + expectDisconnect: false, + }, + { + name: "v5_QOS2", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet, + expectErr: nil, + expectReponse: packets.TPacketData[packets.Pubrec].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes, + expectDisconnect: false, + }, + } + + for _, tx := range tt { + t.Run(tx.name, func(t *testing.T) { + cc := *DefaultServerCapabilities + s := New(&Options{ + Logger: logger, + Capabilities: &cc, + }) + _ = s.AddHook(new(DenyHook), nil) + _ = s.Serve() + defer s.Close() + + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = tx.protocolVersion + s.Clients.Add(cl) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + err := s.processPublish(cl, tx.pk) + require.ErrorIs(t, err, tx.expectErr) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + + if tx.expectReponse != nil { + require.Equal(t, tx.expectReponse, buf) + } + + require.Equal(t, tx.expectDisconnect, cl.Closed()) + wg.Wait() + }) + } } func TestServerProcessPublishOnMessageRecvRejected(t *testing.T) { @@ -1417,7 +1574,7 @@ func TestServerProcessPublishOnMessageRecvRejected(t *testing.T) { err := s.AddHook(hook, nil) require.NoError(t, err) - s.Serve() + _ = s.Serve() defer s.Close() cl, _, _ := newTestClient() err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) @@ -1431,7 +1588,7 @@ func TestServerProcessPacketPublishQos0(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1448,7 +1605,7 @@ func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1467,7 +1624,7 @@ func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1483,7 +1640,7 @@ func TestServerProcessPacketPublishQos1(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1498,7 +1655,7 @@ func TestServerProcessPacketPublishQos2(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1514,7 +1671,7 @@ func TestServerProcessPacketPublishDowngradeQos(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1534,7 +1691,7 @@ func TestPublishToSubscribersSelfNoLocal(t *testing.T) { pkx.Origin = cl.ID s.PublishToSubscribers(pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1592,9 +1749,9 @@ func TestPublishToSubscribers(t *testing.T) { go func() { s.PublishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) - w1.Close() - w2.Close() - w3.Close() + _ = w1.Close() + _ = w2.Close() + _ = w3.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-cl1Recv) @@ -1636,7 +1793,7 @@ func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) { pkx.Created = time.Now().Unix() - 30 s.PublishToSubscribers(pkx) time.Sleep(time.Millisecond) - w1.Close() + _ = w1.Close() }() b := <-cl1Recv @@ -1660,7 +1817,7 @@ func TestPublishToSubscribersIdentifiers(t *testing.T) { go func() { s.PublishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1685,7 +1842,7 @@ func TestPublishToSubscribersPkIgnore(t *testing.T) { pk.Ignore = true s.PublishToSubscribers(pk) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1712,9 +1869,9 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.FixedHeader.Qos = 2 - s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx) time.Sleep(time.Microsecond * 100) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1741,9 +1898,9 @@ func TestPublishToClientSubscriptionDowngradeQos(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.FixedHeader.Qos = 2 - s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, pkx) time.Sleep(time.Microsecond * 100) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1763,7 +1920,7 @@ func TestPublishToClientExceedClientWritesPending(t *testing.T) { cl := newClient(w, &ops{ info: new(system.Info), hooks: new(Hooks), - log: &logger, + log: logger, options: &Options{ Capabilities: &Capabilities{ MaximumClientWritesPending: 3, @@ -1792,10 +1949,10 @@ func TestPublishToClientServerTopicAlias(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet - s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) - s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1848,6 +2005,19 @@ func TestPublishToClientExhaustedPacketID(t *testing.T) { require.ErrorIs(t, err, packets.ErrQuotaExceeded) } +func TestPublishToClientACLNotAuthorized(t *testing.T) { + s := New(&Options{ + Logger: logger, + }) + err := s.AddHook(new(DenyHook), nil) + require.NoError(t, err) + cl, _, _ := newTestClient() + + _, err = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrNotAuthorized) +} + func TestPublishToClientNoConn(t *testing.T) { s := newServer() cl, _, _ := newTestClient() @@ -1875,10 +2045,10 @@ func TestProcessPublishWithTopicAlias(t *testing.T) { pkx.Properties.SubscriptionIdentifier = []int{} // must not contain from client to server pkx.TopicName = "" pkx.Properties.TopicAlias = 1 - s.processPacket(cl2, pkx) + _ = s.processPacket(cl2, pkx) time.Sleep(time.Millisecond) - w2.Close() - w.Close() + _ = w2.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1898,12 +2068,12 @@ func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) { // coverage: subscriber publish errors are non-returnable // can we hook into zerolog ? - r.Close() + _ = r.Close() pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.PacketID = 0 s.PublishToSubscribers(pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() } func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { @@ -1920,12 +2090,12 @@ func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { // coverage: subscriber publish errors are non-returnable // can we hook into zerolog ? - r.Close() + _ = r.Close() pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.PacketID = 0 s.PublishToSubscribers(pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() } func TestPublishToSubscribersNoConnection(t *testing.T) { @@ -1938,10 +2108,10 @@ func TestPublishToSubscribersNoConnection(t *testing.T) { // coverage: subscriber publish errors are non-returnable // can we hook into zerolog ? - r.Close() + _ = r.Close() s.PublishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() } func TestPublishRetainedToClient(t *testing.T) { @@ -1959,7 +2129,7 @@ func TestPublishRetainedToClient(t *testing.T) { go func() { s.publishRetainedToClient(cl, packets.Subscription{Filter: "a/b/c"}, false) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1979,7 +2149,7 @@ func TestPublishRetainedToClientIsShared(t *testing.T) { go func() { s.publishRetainedToClient(cl, sub, false) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2000,7 +2170,7 @@ func TestPublishRetainedToClientError(t *testing.T) { retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) require.Equal(t, int64(1), retained) - w.Close() + _ = w.Close() s.publishRetainedToClient(cl, sub, false) } @@ -2103,7 +2273,7 @@ func TestServerProcessPacketPubrec(t *testing.T) { err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).RawBytes, <-recv) @@ -2131,7 +2301,7 @@ func TestServerProcessPacketPubrecNoPacketID(t *testing.T) { pk := *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet // not sending properties err := s.processPacket(cl, pk) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrelMqtt5AckNoPacket).RawBytes, <-recv) @@ -2181,7 +2351,7 @@ func TestServerProcessPacketPubrel(t *testing.T) { err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) @@ -2210,7 +2380,7 @@ func TestServerProcessPacketPubrelNoPacketID(t *testing.T) { pk := *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet // not sending properties err := s.processPacket(cl, pk) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, packets.TPacketData[packets.Pubcomp].Get(packets.TPubcompMqtt5AckNoPacket).RawBytes, <-recv) @@ -2322,7 +2492,7 @@ func TestServerProcessInboundQos2Flow(t *testing.T) { err := s.processPacket(cl, *tx.in.Packet) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, tx.out.RawBytes, <-recv) if i == 0 { @@ -2403,7 +2573,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) { } time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() if i != 2 { require.Equal(t, tx.out.RawBytes, <-recv) @@ -2426,7 +2596,7 @@ func TestServerProcessPacketSubscribe(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2445,7 +2615,7 @@ func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, pkx) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2471,7 +2641,7 @@ func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidFilter).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2487,7 +2657,7 @@ func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2507,7 +2677,7 @@ func TestServerProcessSubscribeWithRetain(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2528,7 +2698,7 @@ func TestServerProcessSubscribeDowngradeQos(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2550,7 +2720,7 @@ func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2571,7 +2741,7 @@ func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2592,7 +2762,7 @@ func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2606,7 +2776,7 @@ func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) { func TestServerProcessSubscribeNoConnection(t *testing.T) { s := newServer() cl, r, _ := newTestClient() - r.Close() + _ = r.Close() err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) @@ -2614,16 +2784,16 @@ func TestServerProcessSubscribeNoConnection(t *testing.T) { func TestServerProcessSubscribeACLCheckDeny(t *testing.T) { s := New(&Options{ - Logger: &logger, + Logger: logger, }) - s.Serve() + _ = s.Serve() cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 go func() { err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2633,9 +2803,9 @@ func TestServerProcessSubscribeACLCheckDeny(t *testing.T) { func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) { s := New(&Options{ - Logger: &logger, + Logger: logger, }) - s.Serve() + _ = s.Serve() s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 @@ -2643,7 +2813,7 @@ func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) { go func() { err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2660,7 +2830,7 @@ func TestServerProcessSubscribeErrorDowngrade(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2676,7 +2846,7 @@ func TestServerProcessPacketUnsubscribe(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2693,7 +2863,7 @@ func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2729,7 +2899,7 @@ func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) { err := s.receivePacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet) require.Error(t, err) require.ErrorIs(t, err, packets.ErrProtocolViolationZeroNonZeroExpiry) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2744,7 +2914,7 @@ func TestServerRecievePacketDisconnectClient(t *testing.T) { go func() { err := s.DisconnectClient(cl, packets.CodeDisconnect) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2789,7 +2959,7 @@ func TestServerProcessPacketAuth(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2823,7 +2993,7 @@ func TestServerProcessPacketAuthFailure(t *testing.T) { func TestServerSendLWT(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -2853,8 +3023,8 @@ func TestServerSendLWT(t *testing.T) { go func() { s.sendLWT(sender) time.Sleep(time.Millisecond * 10) - w1.Close() - w2.Close() + _ = w1.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) @@ -2862,7 +3032,7 @@ func TestServerSendLWT(t *testing.T) { func TestServerSendLWTRetain(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -2893,8 +3063,8 @@ func TestServerSendLWTRetain(t *testing.T) { go func() { s.sendLWT(sender) time.Sleep(time.Millisecond * 10) - w1.Close() - w2.Close() + _ = w1.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) @@ -2930,7 +3100,7 @@ func TestServerSendLWTDelayed(t *testing.T) { s.sendDelayedLWT(time.Now().Unix()) require.Equal(t, 0, s.loop.willDelayed.Len()) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() recv := make(chan []byte) @@ -2946,7 +3116,7 @@ func TestServerSendLWTDelayed(t *testing.T) { func TestServerReadStore(t *testing.T) { s := newServer() hook := new(modifiedHookBase) - s.AddHook(hook, nil) + _ = s.AddHook(hook, nil) hook.failAt = 1 // clients err := s.readStore() @@ -3007,6 +3177,7 @@ func TestServerLoadInflightMessages(t *testing.T) { {ID: "zen"}, {ID: "mochi-co"}, }) + require.Equal(t, 3, s.Clients.Len()) v := []storage.Message{ @@ -3051,7 +3222,7 @@ func TestServerClose(t *testing.T) { s := newServer() hook := new(modifiedHookBase) - s.AddHook(hook, nil) + _ = s.AddHook(hook, nil) cl, r, _ := newTestClient() cl.Net.Listener = "t1" @@ -3060,7 +3231,7 @@ func TestServerClose(t *testing.T) { err := s.AddListener(listeners.NewMockListener("t1", ":1882")) require.NoError(t, err) - s.Serve() + _ = s.Serve() // receive the disconnect recv := make(chan []byte) @@ -3077,7 +3248,7 @@ func TestServerClose(t *testing.T) { require.Equal(t, true, ok) require.Equal(t, true, listener.(*listeners.MockListener).IsServing()) - s.Close() + _ = s.Close() time.Sleep(time.Millisecond) require.Equal(t, false, listener.(*listeners.MockListener).IsServing()) require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectShuttingDown).RawBytes, <-recv) @@ -3166,7 +3337,6 @@ func TestServerClearExpiredClients(t *testing.T) { require.Equal(t, 4, s.Clients.Len()) s.clearExpiredClients(n) - require.Equal(t, 2, s.Clients.Len()) } @@ -3186,3 +3356,308 @@ func TestAtomicItoa(t *testing.T) { ip := &i require.Equal(t, "22", AtomicItoa(ip)) } + +func TestServerSubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {} + + s := newServerWithInlineClient() + require.NotNil(t, s) + + tt := []struct { + desc string + filter string + identifier int + handler InlineSubFn + expect error + }{ + { + desc: "subscribe", + filter: "a/b/c", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "re-subscribe", + filter: "a/b/c", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "subscribe d/e/f", + filter: "d/e/f", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "re-subscribe d/e/f by different identifier", + filter: "d/e/f", + identifier: 2, + handler: handler, + expect: nil, + }, + { + desc: "subscribe different handler", + filter: "a/b/c", + identifier: 1, + handler: func(cl *Client, sub packets.Subscription, pk packets.Packet) {}, + expect: nil, + }, + { + desc: "subscribe $SYS/info", + filter: "$SYS/info", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "subscribe invalid ###", + filter: "###", + identifier: 1, + handler: handler, + expect: packets.ErrTopicFilterInvalid, + }, + { + desc: "subscribe invalid handler", + filter: "a/b/c", + identifier: 1, + handler: nil, + expect: packets.ErrInlineSubscriptionHandlerInvalid, + }, + } + + for _, tx := range tt { + t.Run(tx.desc, func(t *testing.T) { + require.Equal(t, tx.expect, s.Subscribe(tx.filter, tx.identifier, tx.handler)) + }) + } +} + +func TestServerSubscribeNoInlineClient(t *testing.T) { + s := newServer() + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {}) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + +func TestServerUnsubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + s := newServerWithInlineClient() + err := s.Subscribe("a/b/c", 1, handler) + require.Nil(t, err) + + err = s.Subscribe("d/e/f", 1, handler) + require.Nil(t, err) + + err = s.Subscribe("d/e/f", 2, handler) + require.Nil(t, err) + + err = s.Unsubscribe("a/b/c", 1) + require.Nil(t, err) + + err = s.Unsubscribe("d/e/f", 1) + require.Nil(t, err) + + err = s.Unsubscribe("d/e/f", 2) + require.Nil(t, err) + + err = s.Unsubscribe("not/exist", 1) + require.Nil(t, err) + + err = s.Unsubscribe("#/#/invalid", 1) + require.Equal(t, packets.ErrTopicFilterInvalid, err) +} + +func TestServerUnsubscribeNoInlineClient(t *testing.T) { + s := newServer() + err := s.Unsubscribe("a/b/c", 1) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + +func TestPublishToInlineSubscriber(t *testing.T) { + s := newServerWithInlineClient() + finishCh := make(chan bool) + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.PublishToSubscribers(pkx) + }() + + require.Equal(t, true, <-finishCh) +} + +func TestPublishToInlineSubscribersDifferentFilter(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("mochi mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "z/e/n", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.PublishToSubscribers(pkx) + + pkx = *packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet + s.PublishToSubscribers(pkx) + }() + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestPublishToInlineSubscribersDifferentIdentifier(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 2, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.PublishToSubscribers(pkx) + }() + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestServerSubscribeWithRetain(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 1 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + require.Equal(t, true, <-finishCh) +} + +func TestServerSubscribeWithRetainDifferentFilter(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + retained = s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("mochi mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "z/e/n", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestServerSubscribeWithRetainDifferentIdentifier(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 2, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} diff --git a/mqtt/system/system.go b/mqtt/system/system.go index 647ae00..2ed47d0 100644 --- a/mqtt/system/system.go +++ b/mqtt/system/system.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package system diff --git a/mqtt/topics.go b/mqtt/topics.go index 0cc80d5..f9d122e 100644 --- a/mqtt/topics.go +++ b/mqtt/topics.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt @@ -186,6 +186,65 @@ func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscriptio return m } +// InlineSubFn is the signature for a callback function which will be called +// when an inline client receives a message on a topic it is subscribed to. +// The sub argument contains information about the subscription that was matched for any filters. +type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet) + +// InlineSubscriptions represents a map of internal subscriptions keyed on client. +type InlineSubscriptions struct { + internal map[int]InlineSubscription + sync.RWMutex +} + +// NewInlineSubscriptions returns a new instance of InlineSubscriptions. +func NewInlineSubscriptions() *InlineSubscriptions { + return &InlineSubscriptions{ + internal: map[int]InlineSubscription{}, + } +} + +// Add adds a new internal subscription for a client id. +func (s *InlineSubscriptions) Add(val InlineSubscription) { + s.Lock() + defer s.Unlock() + s.internal[val.Identifier] = val +} + +// GetAll returns all internal subscriptions. +func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription { + s.RLock() + defer s.RUnlock() + m := map[int]InlineSubscription{} + for k, v := range s.internal { + m[k] = v + } + return m +} + +// Get returns an internal subscription for a client id. +func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) { + s.RLock() + defer s.RUnlock() + val, ok = s.internal[id] + return val, ok +} + +// Len returns the number of internal subscriptions. +func (s *InlineSubscriptions) Len() int { + s.RLock() + defer s.RUnlock() + val := len(s.internal) + return val +} + +// Delete removes an internal subscription by the client id. +func (s *InlineSubscriptions) Delete(id int) { + s.Lock() + defer s.Unlock() + delete(s.internal, id) +} + // Subscriptions is a map of subscriptions keyed on client. type Subscriptions struct { internal map[string]packets.Subscription @@ -244,11 +303,17 @@ func (s *Subscriptions) Delete(id string) { // ClientSubscriptions is a map of aggregated subscriptions for a client. type ClientSubscriptions map[string]packets.Subscription +type InlineSubscription struct { + packets.Subscription + Handler InlineSubFn +} + // Subscribers contains the shared and non-shared subscribers matching a topic. type Subscribers struct { - Shared map[string]map[string]packets.Subscription - SharedSelected map[string]packets.Subscription - Subscriptions map[string]packets.Subscription + Shared map[string]map[string]packets.Subscription + SharedSelected map[string]packets.Subscription + Subscriptions map[string]packets.Subscription + InlineSubscriptions map[int]InlineSubscription } // SelectShared returns one subscriber for each shared subscription group. @@ -298,6 +363,39 @@ func NewTopicsIndex() *TopicsIndex { } } +// InlineSubscribe adds a new internal subscription for a topic filter, returning +// true if the subscription was new. +func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) (bool, int) { + x.root.Lock() + defer x.root.Unlock() + + var existed bool + n := x.set(subscription.Filter, 0) + _, existed = n.inlineSubscriptions.Get(subscription.Identifier) + n.inlineSubscriptions.Add(subscription) + + return !existed, n.inlineSubscriptions.Len() +} + +// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client, +// returning true if the subscription existed. +func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) (bool, int) { + x.root.Lock() + defer x.root.Unlock() + + particle := x.seek(filter, 0) + if particle == nil { + return false, 0 + } + + particle.inlineSubscriptions.Delete(id) + + if particle.inlineSubscriptions.Len() == 0 { + x.trim(particle) + } + return true, particle.inlineSubscriptions.Len() +} + // Subscribe adds a new subscription for a client to a topic filter, returning // true if the subscription was new. func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) (bool, int) { @@ -490,9 +588,10 @@ func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []pack // their subscription ids and highest qos. func (x *TopicsIndex) Subscribers(topic string) *Subscribers { return x.scanSubscribers(topic, 0, nil, &Subscribers{ - Shared: map[string]map[string]packets.Subscription{}, - SharedSelected: map[string]packets.Subscription{}, - Subscriptions: map[string]packets.Subscription{}, + Shared: map[string]map[string]packets.Subscription{}, + SharedSelected: map[string]packets.Subscription{}, + Subscriptions: map[string]packets.Subscription{}, + InlineSubscriptions: map[int]InlineSubscription{}, }) } @@ -514,10 +613,12 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su } else { x.gatherSubscriptions(topic, particle, subs) x.gatherSharedSubscriptions(particle, subs) + x.gatherInlineSubscriptions(particle, subs) if wild := particle.particles.get("#"); wild != nil && partKey != "+" { x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2 x.gatherSharedSubscriptions(wild, subs) + x.gatherInlineSubscriptions(particle, subs) } } } @@ -526,6 +627,7 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su if particle := n.particles.get("#"); particle != nil { x.gatherSubscriptions(topic, particle, subs) x.gatherSharedSubscriptions(particle, subs) + x.gatherInlineSubscriptions(particle, subs) } return subs @@ -568,6 +670,17 @@ func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscr } } +// gatherSharedSubscriptions gathers all inline subscriptions for a particle. +func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) { + if subs.InlineSubscriptions == nil { + subs.InlineSubscriptions = map[int]InlineSubscription{} + } + + for id, inline := range particle.inlineSubscriptions.GetAll() { + subs.InlineSubscriptions[id] = inline + } +} + // isolateParticle extracts a particle between d / and d+1 / without allocations. func isolateParticle(filter string, d int) (particle string, hasNext bool) { var next, end int @@ -598,7 +711,7 @@ func IsSharedFilter(filter string) bool { // IsValidFilter returns true if the filter is valid. func IsValidFilter(filter string, forPublish bool) bool { - if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publihs. + if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publish. return false // [MQTT-4.7.3-1] } @@ -639,23 +752,25 @@ func IsValidFilter(filter string, forPublish bool) bool { // particle is a child node on the tree. type particle struct { - parent *particle // a pointer to the parent of the particle - subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address - shared *SharedSubscriptions // a map of shared subscriptions keyed on group name - key string // the key of the particle - retainPath string // path of a retained message - particles particles // a map of child particles - sync.Mutex // mutex for when making changes to the particle + key string // the key of the particle + parent *particle // a pointer to the parent of the particle + particles particles // a map of child particles + subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address + shared *SharedSubscriptions // a map of shared subscriptions keyed on group name + inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle + retainPath string // path of a retained message + sync.Mutex // mutex for when making changes to the particle } // newParticle returns a pointer to a new instance of particle. func newParticle(key string, parent *particle) *particle { return &particle{ - key: key, - parent: parent, - particles: newParticles(), - subscriptions: NewSubscriptions(), - shared: NewSharedSubscriptions(), + key: key, + parent: parent, + particles: newParticles(), + subscriptions: NewSubscriptions(), + shared: NewSharedSubscriptions(), + inlineSubscriptions: NewInlineSubscriptions(), } } diff --git a/mqtt/topics_test.go b/mqtt/topics_test.go index 9de4bfe..8a5c5dc 100644 --- a/mqtt/topics_test.go +++ b/mqtt/topics_test.go @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt import ( + "fmt" "testing" "github.com/stretchr/testify/require" @@ -857,3 +858,227 @@ func TestNewTopicAliases(t *testing.T) { require.NotNil(t, a.Outbound) require.Equal(t, uint16(5), a.Outbound.maximum) } + +func TestNewInlineSubscriptions(t *testing.T) { + subscriptions := NewInlineSubscriptions() + require.NotNil(t, subscriptions) + require.NotNil(t, subscriptions.internal) + require.Equal(t, 0, subscriptions.Len()) +} + +func TestInlineSubscriptionAdd(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + sub, ok := subscriptions.Get(1) + require.True(t, ok) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler)) +} + +func TestInlineSubscriptionGet(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + sub, ok := subscriptions.Get(1) + require.True(t, ok) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler)) + + _, ok = subscriptions.Get(999) + require.False(t, ok) +} + +func TestInlineSubscriptionsGetAll(t *testing.T) { + subscriptions := NewInlineSubscriptions() + + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 3}, + }) + + allSubs := subscriptions.GetAll() + require.Len(t, allSubs, 3) + require.Contains(t, allSubs, 1) + require.Contains(t, allSubs, 2) + require.Contains(t, allSubs, 3) +} + +func TestInlineSubscriptionDelete(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + subscriptions.Delete(1) + _, ok := subscriptions.Get(1) + require.False(t, ok) + require.Empty(t, subscriptions.GetAll()) + require.Zero(t, subscriptions.Len()) +} + +func TestInlineSubscribe(t *testing.T) { + + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + tt := []struct { + desc string + filter string + subscription InlineSubscription + wasNew bool + count int + }{ + { + desc: "subscribe", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}}, + wasNew: true, + count: 1, + }, + { + desc: "subscribe existed", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}}, + wasNew: false, + count: 1, + }, + { + desc: "subscribe different identifier", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}}, + wasNew: true, + count: 2, + }, + { + desc: "subscribe case sensitive didnt exist", + filter: "A/B/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "A/B/c", Identifier: 1}}, + wasNew: true, + count: 1, + }, + { + desc: "wildcard+ sub", + filter: "d/+", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/+", Identifier: 1}}, + wasNew: true, + count: 1, + }, + { + desc: "wildcard# sub", + filter: "d/e/#", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/#", Identifier: 1}}, + wasNew: true, + count: 1, + }, + } + + index := NewTopicsIndex() + for _, tx := range tt { + t.Run(tx.desc, func(t *testing.T) { + exist, count := index.InlineSubscribe(tx.subscription) + require.Equal(t, tx.wasNew, exist) + require.Equal(t, tx.count, count) + }) + } + + final := index.root.particles.get("a").particles.get("b").particles.get("c") + require.NotNil(t, final) +} + +func TestInlineUnsubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + index := NewTopicsIndex() + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}}) + sub, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index = NewTopicsIndex() + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}}) + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 2}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(2) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/+/d", Identifier: 1}}) + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "#", Identifier: 1}}) + sub, exists = index.root.particles.get("#").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + ok, count := index.InlineUnsubscribe(1, "a/b/c/d") + require.Equal(t, 0, count) + require.True(t, ok) + require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c")) + + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + ok, _ = index.InlineUnsubscribe(1, "d/e/f") + require.Equal(t, 0, count) + require.True(t, ok) + require.NotNil(t, index.root.particles.get("d").particles.get("e").particles.get("f")) + + ok, _ = index.InlineUnsubscribe(1, "not/exist") + require.Equal(t, 0, count) + require.False(t, ok) +}