Skip to content

Commit

Permalink
fix(mqtt): resume subscription in reconnect (#3199)
Browse files Browse the repository at this point in the history
Signed-off-by: Jiyong Huang <[email protected]>
  • Loading branch information
ngjaying authored Sep 14, 2024
1 parent 61e984a commit 921d230
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 10 deletions.
25 changes: 22 additions & 3 deletions internal/io/mqtt/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type Connection struct {
status atomic.Value
scHandler api.StatusChangeHandler
conf *ConnectionConfig
// key is the topic. Each topic will have only one connector
subscriptions map[string]*subscriptionInfo
}

type ConnectionConfig struct {
Expand All @@ -53,16 +55,23 @@ type ConnectionConfig struct {
tls *tls.Config
}

type subscriptionInfo struct {
Qos byte
Handler pahoMqtt.MessageHandler
}

func CreateConnection(_ api.StreamContext) modules.Connection {
return &Connection{}
return &Connection{
subscriptions: make(map[string]*subscriptionInfo),
}
}

func (conn *Connection) Provision(ctx api.StreamContext, conId string, props map[string]any) error {
c, err := ValidateConfig(props)
if err != nil {
return err
}
opts := pahoMqtt.NewClientOptions().AddBroker(c.Server).SetProtocolVersion(c.pversion).SetAutoReconnect(true).SetMaxReconnectInterval(time.Minute)
opts := pahoMqtt.NewClientOptions().AddBroker(c.Server).SetProtocolVersion(c.pversion).SetAutoReconnect(true).SetMaxReconnectInterval(connection.DefaultMaxInterval)

opts = opts.SetTLSConfig(c.tls)

Expand All @@ -72,7 +81,6 @@ func (conn *Connection) Provision(ctx api.StreamContext, conId string, props map
if c.Password != "" {
opts = opts.SetPassword(c.Password)
}
opts = opts.SetClientID(c.ClientId).SetAutoReconnect(true).SetResumeSubs(true).SetMaxReconnectInterval(connection.DefaultMaxInterval)

conn.status.Store(modules.ConnectionStatus{Status: api.ConnectionConnecting})
opts.OnConnect = conn.onConnect
Expand Down Expand Up @@ -119,6 +127,12 @@ func (conn *Connection) onConnect(_ pahoMqtt.Client) {
conn.scHandler(api.ConnectionConnected, "")
}
conn.logger.Infof("The connection to mqtt broker is established")
for topic, info := range conn.subscriptions {
err := conn.Subscribe(topic, info.Qos, info.Handler)
if err != nil { // should never happen. If happens because of connection, it will retry later
conn.logger.Errorf("Failed to subscribe topic %s: %v", topic, err)
}
}
}

func (conn *Connection) onConnectLost(_ pahoMqtt.Client, err error) {
Expand All @@ -143,6 +157,7 @@ func (conn *Connection) DetachSub(ctx api.StreamContext, props map[string]any) {
if err != nil {
return
}
delete(conn.subscriptions, topic)
conn.Client.Unsubscribe(topic)
}

Expand Down Expand Up @@ -173,6 +188,10 @@ func (conn *Connection) Publish(topic string, qos byte, retained bool, payload a
}

func (conn *Connection) Subscribe(topic string, qos byte, callback pahoMqtt.MessageHandler) error {
conn.subscriptions[topic] = &subscriptionInfo{
Qos: qos,
Handler: callback,
}
token := conn.Client.Subscribe(topic, qos, callback)
return handleToken(token)
}
Expand Down
42 changes: 36 additions & 6 deletions internal/io/mqtt/source_sink_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,39 @@
package mqtt

import (
"fmt"
"testing"
"time"

"github.com/lf-edge/ekuiper/contract/v2/api"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/listeners"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/lf-edge/ekuiper/v2/internal/conf"
"github.com/lf-edge/ekuiper/v2/internal/pkg/store"
"github.com/lf-edge/ekuiper/v2/internal/testx"
"github.com/lf-edge/ekuiper/v2/internal/topo/topotest/mockclock"
"github.com/lf-edge/ekuiper/v2/pkg/connection"
"github.com/lf-edge/ekuiper/v2/pkg/mock"
"github.com/lf-edge/ekuiper/v2/pkg/model"
)

func TestSourceSink(t *testing.T) {
url, cancel, err := testx.InitBroker("TestSourceSink")
func TestSourceSinkRecon(t *testing.T) {
// Create the new MQTT Server.
server := mqtt.New(nil)
// Allow all connections.
_ = server.AddHook(new(auth.AllowHook), nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP(listeners.Config{ID: "testcon", Address: ":2883"})
err := server.AddListener(tcp)
require.NoError(t, err)
defer func() {
cancel()
go func() {
err = server.Serve()
fmt.Println(err)
}()
url := tcp.Address()
dataDir, err := conf.GetDataLoc()
require.NoError(t, err)
require.NoError(t, store.SetupDefault(dataDir))
Expand Down Expand Up @@ -74,7 +86,25 @@ func TestSourceSink(t *testing.T) {
"qos": 0,
"topic": "demo",
}, result, func() {
err := mock.RunBytesSinkCollect(sk, data, map[string]any{
err := mock.RunBytesSinkCollect(sk, data[:1], map[string]any{
"server": url,
"topic": "demo",
"qos": 0,
"retained": false,
})
assert.NoError(t, err)
err = server.Close()
tcp.Close(nil)
assert.NoError(t, err)
go func() {
tcp := listeners.NewTCP(listeners.Config{Address: url})
err := server.AddListener(tcp)
require.NoError(t, err)
err = server.Serve()
require.NoError(t, err)
}()
time.Sleep(time.Millisecond * 100)
err = mock.RunBytesSinkCollect(sk, data[1:], map[string]any{
"server": url,
"topic": "demo",
"qos": 0,
Expand Down
2 changes: 1 addition & 1 deletion pkg/mock/test_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func TestSourceConnectorCompare(t *testing.T, r api.Source, props map[string]any
assert.NoError(t, err)
}()

ticker := time.After(60000 * time.Second)
ticker := time.After(60 * time.Second)
finished := make(chan struct{})
go func() {
wg.Wait()
Expand Down

0 comments on commit 921d230

Please sign in to comment.