diff --git a/cluster/rest/entity.go b/cluster/rest/entity.go new file mode 100644 index 0000000..44527da --- /dev/null +++ b/cluster/rest/entity.go @@ -0,0 +1,12 @@ +package rest + +type result struct { + Url string `json:"url"` + Data string `json:"data"` + Err string `json:"err"` +} + +type node struct { + Name string `json:"name"` + Addr string `json:"addr"` +} diff --git a/cluster/rest/http.go b/cluster/rest/http.go new file mode 100644 index 0000000..f5c8e00 --- /dev/null +++ b/cluster/rest/http.go @@ -0,0 +1,85 @@ +package rest + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "sync" + "time" +) + +const ( + HttpGet = "GET" + HttpPost = "POST" + HttpDelete = "DELETE" + Timeout = 3 * time.Second +) + +func fetch(ctx context.Context, method string, url string, data []byte) result { + rs := result{Url: url} + var req *http.Request + var err error + var body io.Reader + if data != nil { + body = bytes.NewBuffer(data) + } + if ctx != nil { + req, err = http.NewRequestWithContext(ctx, method, url, body) + } else { + req, err = http.NewRequest(method, url, body) + } + if err != nil { + rs.Err = err.Error() + return rs + } + + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + if ctxErr := ctx.Err(); errors.Is(ctxErr, context.DeadlineExceeded) { + rs.Err = "Request timeout" + } else { + rs.Err = err.Error() + } + return rs + } + defer resp.Body.Close() + if data, err := io.ReadAll(resp.Body); err != nil { + rs.Err = err.Error() + } else { + if resp.StatusCode == http.StatusOK { + rs.Data = string(data) + } else { + rs.Err = string(data) + } + } + return rs +} + +func fetchM(method string, urls []string, body []byte) []result { + ctx, cancel := context.WithTimeout(context.Background(), Timeout) + defer cancel() + + var wg sync.WaitGroup + ch := make(chan result, len(urls)) + + wg.Add(len(urls)) + for _, url := range urls { + go func(url string) { + defer wg.Done() + ch <- fetch(ctx, method, url, body) + }(url) + } + + wg.Wait() + close(ch) + results := make([]result, len(urls)) + index := 0 + for rs := range ch { + results[index] = rs + index++ + } + return results +} diff --git a/cluster/rest/rest.go b/cluster/rest/rest.go new file mode 100644 index 0000000..f138ddb --- /dev/null +++ b/cluster/rest/rest.go @@ -0,0 +1,177 @@ +package rest + +import ( + "encoding/json" + "fmt" + cs "github.com/wind-c/comqtt/v2/cluster" + "github.com/wind-c/comqtt/v2/cluster/discovery" + rt "github.com/wind-c/comqtt/v2/mqtt/rest" + "net/http" + "net/netip" + "strings" +) + +type rest struct { + agent *cs.Agent +} + +func New(agent *cs.Agent) *rest { + return &rest{ + agent: agent, + } +} + +func (s *rest) GenHandlers() map[string]rt.Handler { + return map[string]rt.Handler{ + "GET /api/v1/node/config": s.viewConfig, + "DELETE /api/v1/node/{name}": s.leave, + "GET /api/v1/cluster/nodes": s.getNodes, + "POST /api/v1/cluster/nodes": s.join, + "POST /api/v1/cluster/peers": s.addRaftPeer, + "DELETE /api/v1/cluster/peers/{name}": s.removeRaftPeer, + "GET /api/v1/cluster/stat/online": s.getOnlineCount, + "GET /api/v1/cluster/clients/{id}": s.getClient, + "POST /api/v1/cluster/blacklist/{id}": s.kickClient, + "DELETE /api/v1/cluster/blacklist/{id}": s.blanchClient, + } +} + +// viewConfig return the configuration parameters of this node +// GET api/v1/node/config +func (s *rest) viewConfig(w http.ResponseWriter, r *http.Request) { + rt.Ok(w, s.agent.Config) +} + +// getMembers return all nodes in the cluster +// GET api/v1/cluster/nodes +func (s *rest) getNodes(w http.ResponseWriter, r *http.Request) { + rt.Ok(w, s.agent.GetMemberList()) +} + +// join add a node to the cluster +// POST api/v1/cluster/nodes +func (s *rest) join(w http.ResponseWriter, r *http.Request) { + var n node + if err := json.NewDecoder(r.Body).Decode(&n); err != nil { + rt.Error(w, http.StatusBadRequest, err.Error()) + return + } + n.Name = strings.TrimSpace(n.Name) + n.Addr = strings.TrimSpace(n.Addr) + if n.Name == "" || n.Addr == "" { + rt.Error(w, http.StatusBadRequest, "name and addr cannot be empty") + return + } + if _, err := netip.ParseAddrPort(n.Addr); err != nil { + rt.Error(w, http.StatusBadRequest, "invalid address") + return + } + + if err := s.agent.Join(n.Name, n.Addr); err != nil { + rt.Error(w, http.StatusInternalServerError, err.Error()) + } else { + rt.Ok(w, n) + } +} + +// leave local node gracefully exits the cluster +// DELETE api/v1/node/{name} +func (s *rest) leave(w http.ResponseWriter, r *http.Request) { + name := r.PathValue("name") + localName := s.agent.GetLocalName() + if name != localName { + rt.Error(w, http.StatusBadRequest, fmt.Sprintf("cannot remove not local node %s", localName)) + return + } + + if err := s.agent.Leave(); err != nil { + rt.Error(w, http.StatusInternalServerError, err.Error()) + return + } else { + rt.Ok(w, name) + } +} + +// addRaftPeer add peer to raft cluster +// POST api/v1/cluster/peers +func (s *rest) addRaftPeer(w http.ResponseWriter, r *http.Request) { + var p node + if err := json.NewDecoder(r.Body).Decode(&p); err != nil { + rt.Error(w, http.StatusBadRequest, err.Error()) + return + } + p.Name = strings.TrimSpace(p.Name) + p.Addr = strings.TrimSpace(p.Addr) + if p.Name == "" || p.Addr == "" { + rt.Error(w, http.StatusBadRequest, "name and addr cannot be empty") + return + } + if _, err := netip.ParseAddrPort(p.Addr); err != nil { + rt.Error(w, http.StatusBadRequest, "invalid address") + return + } + + s.agent.AddRaftPeer(p.Name, p.Addr) + rt.Ok(w, p) +} + +// removeRaftPeer remove peer from raft cluster +// DELETE api/v1/cluster/peers/{name} +func (s *rest) removeRaftPeer(w http.ResponseWriter, r *http.Request) { + name := r.PathValue("name") + if strings.TrimSpace(name) == "" { + rt.Error(w, http.StatusBadRequest, "name cannot be empty") + return + } + + s.agent.RemoveRaftPeer(name) + rt.Ok(w, name) +} + +// getOnlineCount return online number from all nodes in the cluster +// GET api/v1/cluster/stat/online +func (s *rest) getOnlineCount(w http.ResponseWriter, r *http.Request) { + path := rt.MqttGetOnlinePath + urls := genUrls(s.agent.GetMemberList(), path) + rs := fetchM(HttpGet, urls, nil) + rt.Ok(w, rs) +} + +// getClient return a client information, search from all nodes in the cluster +// GET api/v1/cluster/clients/{id} +func (s *rest) getClient(w http.ResponseWriter, r *http.Request) { + cid := r.PathValue("id") + path := strings.Replace(rt.MqttGetClientPath, "{id}", cid, 1) + urls := genUrls(s.agent.GetMemberList(), path) + rs := fetchM(HttpGet, urls, nil) + rt.Ok(w, rs) +} + +// kickClient add it to the blacklist on all nodes in the cluster +// POST api/v1/cluster/blacklist/{id} +func (s *rest) kickClient(w http.ResponseWriter, r *http.Request) { + cid := r.PathValue("id") + path := strings.Replace(rt.MqttAddBlacklistPath, "{id}", cid, 1) + urls := genUrls(s.agent.GetMemberList(), path) + rs := fetchM(HttpPost, urls, nil) + rt.Ok(w, rs) +} + +// blanchClient remove from the blacklist on all nodes in the cluster +// DELETE api/v1/cluster/blacklist/{id} +func (s *rest) blanchClient(w http.ResponseWriter, r *http.Request) { + cid := r.PathValue("id") + path := strings.Replace(rt.MqttDelBlacklistPath, "{id}", cid, 1) + urls := genUrls(s.agent.GetMemberList(), path) + rs := fetchM(HttpDelete, urls, nil) + rt.Ok(w, rs) +} + +// genUrls generate urls +func genUrls(ms []discovery.Member, path string) []string { + urls := make([]string, len(ms)) + for i, m := range ms { + urls[i] = "http://" + m.Addr + ":8080" + path + } + return urls +} diff --git a/cmd/cluster/main.go b/cmd/cluster/main.go index a250897..c57edf8 100644 --- a/cmd/cluster/main.go +++ b/cmd/cluster/main.go @@ -6,10 +6,10 @@ package main import ( "context" - "encoding/json" "flag" "fmt" - "io" + csRt "github.com/wind-c/comqtt/v2/cluster/rest" + "maps" "net" "net/http" _ "net/http/pprof" @@ -27,6 +27,7 @@ import ( mqtt "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/listeners" + mqttRt "github.com/wind-c/comqtt/v2/mqtt/rest" "github.com/wind-c/comqtt/v2/plugin" hauth "github.com/wind-c/comqtt/v2/plugin/auth/http" mauth "github.com/wind-c/comqtt/v2/plugin/auth/mysql" @@ -139,11 +140,10 @@ func realMain(ctx context.Context) error { onError(server.AddListener(ws), "add websocket listener") // add http listener - handles := make(map[string]func(http.ResponseWriter, *http.Request), 1) - handles["/cluster/conf"] = ConfHandler - handles["/cluster/ms"] = MsHandler - handles["/cluster/peer/"] = PeerHandler //for test peer join and leave - http := listeners.NewHTTP("stats", cfg.Mqtt.HTTP, nil, server.Info, handles) + csHls := csRt.New(agent).GenHandlers() + mqHls := mqttRt.New(server).GenHandlers() + maps.Copy(csHls, mqHls) + http := listeners.NewHTTP("stats", cfg.Mqtt.HTTP, nil, csHls) onError(server.AddListener(http), "add http listener") errCh := make(chan error, 1) @@ -248,79 +248,3 @@ func onError(err error, msg string) { os.Exit(1) } } - -func ConfHandler(w http.ResponseWriter, req *http.Request) { - body, err := json.MarshalIndent(agent.Config, "", "\t") - if err != nil { - io.WriteString(w, err.Error()) - return - } - - w.Write(body) -} - -func MsHandler(w http.ResponseWriter, r *http.Request) { - body, err := json.MarshalIndent(agent.GetMemberList(), "", "\t") - if err != nil { - io.WriteString(w, err.Error()) - return - } - - w.Write(body) -} - -func PeerHandler(w http.ResponseWriter, r *http.Request) { - key := strings.SplitN(r.RequestURI, "/", 4)[3] - defer r.Body.Close() - switch r.Method { - case http.MethodPut: - //val, err := io.ReadAll(r.Body) - //if err != nil { - // //logger.Errorf("[http] failed to read on PUT: %v", err) - // http.Error(w, "Failed to PUT", http.StatusBadRequest) - // return - //} - - //agent.Propose(key, string(val)) - w.WriteHeader(http.StatusNoContent) - case http.MethodGet: - if val := agent.GetValue(key); len(val) > 0 { - w.Write([]byte(strings.Join(val, ","))) - } else { - http.Error(w, "Failed to GET", http.StatusNotFound) - } - case http.MethodPost: - addr, err := io.ReadAll(r.Body) - if err != nil { - //logger.Errorf("[http] failed to read on POST: %v", err) - http.Error(w, "Failed to POST", http.StatusBadRequest) - return - } - - nodeId, err := strconv.ParseUint(key, 0, 64) - if err != nil { - //logger.Errorf("[http] failed to convert ID for conf change: %v", err) - http.Error(w, "Failed to POST", http.StatusBadRequest) - return - } - - agent.AddRaftPeer(fmt.Sprint(nodeId), string(addr)) - w.WriteHeader(http.StatusNoContent) - case http.MethodDelete: - nodeId, err := strconv.ParseUint(key, 0, 64) - if err != nil { - //logger.Errorf("[http] failed to convert ID for conf change: %v", err) - http.Error(w, "Failed to POST", http.StatusBadRequest) - return - } - - agent.RemoveRaftPeer(fmt.Sprint(nodeId)) - w.WriteHeader(http.StatusNoContent) - default: - w.Header().Add("Allow", http.MethodPut) - w.Header().Add("Allow", http.MethodGet) - w.Header().Add("Allow", http.MethodPost) - w.Header().Add("Allow", http.MethodDelete) - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } -} diff --git a/cmd/config/node1.yml b/cmd/config/node1.yml index ae351bb..9ea66d1 100644 --- a/cmd/config/node1.yml +++ b/cmd/config/node1.yml @@ -1,5 +1,5 @@ storage-way: 3 #Storage way optional items:0 memory、1 bolt、2 badger、3 redis;Only redis can be used in cluster mode. -bridge-way: 1 #Bridge way optional items:0 disable、1 kafka +bridge-way: 0 #Bridge way optional items:0 disable、1 kafka bridge-path: ./config/bridge-kafka.yml #The bridge config file path pprof-enable: false #Whether to enable the performance analysis tool http://ip:6060 @@ -40,7 +40,7 @@ mqtt: client-write-buffer-size: 1024 #It is the number of individual workers and queues to initialize. client-read-buffer-size: 1024 #It is the size of the queue per worker. sys-topic-resend-interval: 1 #It specifies the interval between $SYS topic updates in seconds. - inline-client: false #Whether to enable the inline client. + inline-client: true #Whether to enable the inline client. capabilities: compatibilities: obscure-not-authorized: false #Return unspecified errors instead of not authorized diff --git a/cmd/config/node2.yml b/cmd/config/node2.yml index 3af9fe9..ab0d088 100644 --- a/cmd/config/node2.yml +++ b/cmd/config/node2.yml @@ -40,7 +40,7 @@ mqtt: client-write-buffer-size: 1024 #It is the number of individual workers and queues to initialize. client-read-buffer-size: 1024 #It is the size of the queue per worker. sys-topic-resend-interval: 1 #It specifies the interval between $SYS topic updates in seconds. - inline-client: false #Whether to enable the inline client. + inline-client: true #Whether to enable the inline client. capabilities: compatibilities: obscure-not-authorized: false #Return unspecified errors instead of not authorized diff --git a/cmd/config/node3.yml b/cmd/config/node3.yml index 4e28b25..891bd74 100644 --- a/cmd/config/node3.yml +++ b/cmd/config/node3.yml @@ -40,7 +40,7 @@ mqtt: client-write-buffer-size: 1024 #It is the number of individual workers and queues to initialize. client-read-buffer-size: 1024 #It is the size of the queue per worker. sys-topic-resend-interval: 1 #It specifies the interval between $SYS topic updates in seconds. - inline-client: false #Whether to enable the inline client. + inline-client: true #Whether to enable the inline client. capabilities: compatibilities: obscure-not-authorized: false #Return unspecified errors instead of not authorized diff --git a/cmd/config/single.yml b/cmd/config/single.yml index ca42eac..96b279e 100644 --- a/cmd/config/single.yml +++ b/cmd/config/single.yml @@ -21,7 +21,7 @@ mqtt: client-write-buffer-size: 1024 #It is the number of individual workers and queues to initialize. client-read-buffer-size: 1024 #It is the size of the queue per worker. sys-topic-resend-interval: 1 #It specifies the interval between $SYS topic updates in seconds. - inline-client: false #Whether to enable the inline client. + inline-client: true #Whether to enable the inline client. capabilities: compatibilities: obscure-not-authorized: false #Return unspecified errors instead of not authorized diff --git a/cmd/single/main.go b/cmd/single/main.go index 1674b30..51d0f9f 100644 --- a/cmd/single/main.go +++ b/cmd/single/main.go @@ -23,6 +23,7 @@ import ( "github.com/wind-c/comqtt/v2/mqtt/hooks/storage/bolt" "github.com/wind-c/comqtt/v2/mqtt/hooks/storage/redis" "github.com/wind-c/comqtt/v2/mqtt/listeners" + "github.com/wind-c/comqtt/v2/mqtt/rest" "github.com/wind-c/comqtt/v2/plugin" hauth "github.com/wind-c/comqtt/v2/plugin/auth/http" mauth "github.com/wind-c/comqtt/v2/plugin/auth/mysql" @@ -107,7 +108,7 @@ func realMain(ctx context.Context) error { onError(server.AddListener(ws), "add websocket listener") // add http listener - http := listeners.NewHTTPStats("stats", cfg.Mqtt.HTTP, nil, server.Info) + http := listeners.NewHTTP("stats", cfg.Mqtt.HTTP, nil, rest.New(server).GenHandlers()) onError(server.AddListener(http), "add http listener") errCh := make(chan error, 1) diff --git a/go.mod b/go.mod index 5715f94..aaa6064 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/wind-c/comqtt/v2 -go 1.21 +go 1.22 require ( github.com/alicebob/miniredis/v2 v2.32.1 @@ -33,7 +33,7 @@ require ( go.etcd.io/etcd/server/v3 v3.5.13 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 - golang.org/x/crypto v0.22.0 + golang.org/x/crypto v0.25.0 google.golang.org/grpc v1.63.0 gopkg.in/h2non/gock.v1 v1.1.2 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -85,13 +85,13 @@ require ( go.etcd.io/etcd/api/v3 v3.5.13 // indirect go.etcd.io/etcd/pkg/v3 v3.5.13 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/mod v0.17.0 // indirect - golang.org/x/net v0.24.0 // indirect + golang.org/x/mod v0.19.0 // indirect + golang.org/x/net v0.27.0 // indirect golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.19.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/sys v0.22.0 // indirect + golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.20.0 // indirect + golang.org/x/tools v0.23.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/protobuf v1.33.0 // indirect ) diff --git a/go.sum b/go.sum index 9268be3..4545f60 100644 --- a/go.sum +++ b/go.sum @@ -370,9 +370,11 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= -golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -381,8 +383,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.19.0 h1:fEdghXQSo20giMthA7cd28ZC+jts4amQ3YMXiP5oMQ8= +golang.org/x/mod v0.19.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -402,8 +404,8 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= -golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -453,8 +455,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= -golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -469,8 +471,8 @@ golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -486,8 +488,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.20.0 h1:hz/CVckiOxybQvFw6h7b/q80NTr9IUQb4s1IIzW7KNY= -golang.org/x/tools v0.20.0/go.mod h1:WvitBU7JJf6A4jOdg4S1tviW9bhUxkgeCui/0JHctQg= +golang.org/x/tools v0.23.0 h1:SGsXPZ+2l4JsgaCKkx+FQ9YZ5XEtA1GZYuoDjenLjvg= +golang.org/x/tools v0.23.0/go.mod h1:pnu6ufv6vQkll6szChhK3C3L/ruaIv5eBeztNG8wtsI= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/mqtt/listeners/http_sysinfo.go b/mqtt/listeners/http_sysinfo.go index 63d7c28..9dd6df1 100644 --- a/mqtt/listeners/http_sysinfo.go +++ b/mqtt/listeners/http_sysinfo.go @@ -20,25 +20,24 @@ import ( // HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint. type HTTPStats struct { sync.RWMutex - id string // the internal id of the listener - address string // the network address to bind to - config *Config // configuration values for the listener - listen *http.Server // the http server - sysInfo *system.Info // pointers to the server data - end uint32 // ensure the close methods are only called once - handlers Handlers + id string // the internal id of the listener + address string // the network address to bind to + config *Config // configuration values for the listener + listen *http.Server // the http server + sysInfo *system.Info // pointers to the server data + end uint32 // ensure the close methods are only called once + handlers map[string]Handler } -type Handlers map[string]func(http.ResponseWriter, *http.Request) +type Handler = func(http.ResponseWriter, *http.Request) -func NewHTTP(id, address string, config *Config, sysInfo *system.Info, handlers Handlers) *HTTPStats { +func NewHTTP(id, address string, config *Config, handlers map[string]Handler) *HTTPStats { if config == nil { config = new(Config) } return &HTTPStats{ id: id, address: address, - sysInfo: sysInfo, config: config, handlers: handlers, } @@ -79,7 +78,14 @@ func (l *HTTPStats) Protocol() string { // Init initializes the listener. func (l *HTTPStats) Init(_ *slog.Logger) error { mux := http.NewServeMux() - mux.HandleFunc("/", l.jsonHandler) + if len(l.handlers) > 0 { + for path, handler := range l.handlers { + mux.HandleFunc(path, handler) + } + } else { + mux.HandleFunc("/", l.jsonHandler) + } + l.listen = &http.Server{ ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, diff --git a/mqtt/rest/entity.go b/mqtt/rest/entity.go new file mode 100644 index 0000000..a962577 --- /dev/null +++ b/mqtt/rest/entity.go @@ -0,0 +1,51 @@ +package rest + +import ( + "github.com/wind-c/comqtt/v2/mqtt" +) + +type client struct { + ID string `json:"id"` + IP string `json:"ip"` + Online bool `json:"online"` + Username string `json:"username"` + TopicFilters []string `json:"topic_filters"` + ProtocolVersion byte `json:"protocol_version"` + SessionClean bool `json:"session_clean"` + WillTopicName string `json:"will_topic_name"` + WillPayload string `json:"will_payload"` + WillRetain bool `json:"will_retain"` + InflightCount int `json:"inflight_count"` +} + +func genClient(cl *mqtt.Client) client { + filters := make([]string, 0, len(cl.State.Subscriptions.GetAll())) + for k := range cl.State.Subscriptions.GetAll() { + filters = append(filters, k) + } + + nc := client{ + ID: cl.ID, + IP: cl.Net.Remote, + Online: !cl.Closed(), + Username: string(cl.Properties.Username), + TopicFilters: filters, + ProtocolVersion: cl.Properties.ProtocolVersion, + SessionClean: cl.Properties.Clean, + WillTopicName: cl.Properties.Will.TopicName, + WillRetain: cl.Properties.Will.Retain, + InflightCount: cl.State.Inflight.Len(), + } + if cl.Properties.Will.Payload != nil { + nc.WillPayload = string(cl.Properties.Will.Payload) + } + + return nc +} + +type message struct { + TopicName string `json:"topic_name"` + Payload string `json:"payload"` + Retain bool `json:"retain"` + Qos byte `json:"qos"` +} diff --git a/mqtt/rest/http.go b/mqtt/rest/http.go new file mode 100644 index 0000000..4b35823 --- /dev/null +++ b/mqtt/rest/http.go @@ -0,0 +1,22 @@ +package rest + +import ( + "encoding/json" + "net/http" +) + +func Ok(w http.ResponseWriter, data any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(data); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } +} + +func Error(w http.ResponseWriter, code int, err string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + if e := json.NewEncoder(w).Encode(err); e != nil { + w.WriteHeader(http.StatusInternalServerError) + } +} diff --git a/mqtt/rest/rest.go b/mqtt/rest/rest.go new file mode 100644 index 0000000..0230bf8 --- /dev/null +++ b/mqtt/rest/rest.go @@ -0,0 +1,126 @@ +package rest + +import ( + "encoding/json" + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/mqtt/packets" + "net/http" + "slices" +) + +const ( + MqttGetOverallPath = "/api/v1/mqtt/stat/overall" + MqttGetOnlinePath = "/api/v1/mqtt/stat/online" + MqttGetClientPath = "/api/v1/mqtt/clients/{id}" + MqttGetBlacklistPath = "/api/v1/mqtt/blacklist" + MqttAddBlacklistPath = "/api/v1/mqtt/blacklist/{id}" + MqttDelBlacklistPath = "/api/v1/mqtt/blacklist/{id}" + MqttPublishMessagePath = "/api/v1/mqtt/message" +) + +type Handler = func(http.ResponseWriter, *http.Request) + +type Rest struct { + server *mqtt.Server +} + +func New(server *mqtt.Server) *Rest { + return &Rest{ + server: server, + } +} + +func (s *Rest) GenHandlers() map[string]Handler { + return map[string]Handler{ + "GET " + MqttGetOverallPath: s.getOverallInfo, + "GET " + MqttGetOnlinePath: s.getOnlineCount, + "GET " + MqttGetClientPath: s.getClient, + "GET " + MqttGetBlacklistPath: s.blacklist, + "POST " + MqttAddBlacklistPath: s.kickClient, + "DELETE " + MqttDelBlacklistPath: s.blanchClient, + "POST " + MqttPublishMessagePath: s.publishMessage, + } +} + +// getOverallInfo return server info +// GET api/v1/mqtt/stat/overall +func (s *Rest) getOverallInfo(w http.ResponseWriter, r *http.Request) { + Ok(w, s.server.Info) +} + +// viewConfig return the configuration parameters of broker +// GET api/v1/mqtt/config +func (s *Rest) viewConfig(w http.ResponseWriter, r *http.Request) { + Ok(w, s.server.Options) +} + +// getOnlineCount return online number +// GET api/v1/mqtt/stat/online +func (s *Rest) getOnlineCount(w http.ResponseWriter, r *http.Request) { + count := s.server.Info.ClientsConnected + Ok(w, count) +} + +// getClient return a client information +// GET api/v1/mqtt/clients/{id} +func (s *Rest) getClient(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if cl, ol := s.server.Clients.Get(id); ol { + Ok(w, genClient(cl)) + } else { + Error(w, http.StatusNotFound, "client not found") + } +} + +// publishMessage a message +// POST api/v1/mqtt/message +func (s *Rest) publishMessage(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + var msg message + if err := json.NewDecoder(r.Body).Decode(&msg); err != nil { + Error(w, http.StatusBadRequest, err.Error()) + return + } + + if err := s.server.Publish(msg.TopicName, []byte(msg.Payload), msg.Retain, msg.Qos); err != nil { + Error(w, http.StatusInternalServerError, err.Error()) + } else { + Ok(w, msg) + } +} + +// kickClient disconnect the client and add it to the blacklist +// POST api/v1/mqtt/blacklist/{id} +func (s *Rest) kickClient(w http.ResponseWriter, r *http.Request) { + cid := r.PathValue("id") + if !slices.Contains(s.server.Blacklist, cid) { + s.server.Blacklist = append(s.server.Blacklist, cid) + } + if cl, ol := s.server.Clients.Get(cid); ol { + s.server.DisconnectClient(cl, packets.ErrNotAuthorized) + Ok(w, cid) + } else { + Error(w, http.StatusNotFound, "client not found") + } +} + +// blanchClient remove from the blacklist +// DELETE api/v1/mqtt/blacklist/{id} +func (s *Rest) blanchClient(w http.ResponseWriter, r *http.Request) { + cid := r.PathValue("id") + if slices.Contains(s.server.Blacklist, cid) { + slices.DeleteFunc(s.server.Blacklist, func(s string) bool { return s == cid }) + Ok(w, cid) + } +} + +// blacklist return to the blacklist +// GET api/v1/mqtt/blacklist +func (s *Rest) blacklist(w http.ResponseWriter, r *http.Request) { + if s.server.Blacklist == nil { + Error(w, http.StatusNotFound, "blacklist not found") + } else { + Ok(w, s.server.Blacklist) + } +} diff --git a/mqtt/server.go b/mqtt/server.go index 98c209d..056c2d7 100644 --- a/mqtt/server.go +++ b/mqtt/server.go @@ -12,6 +12,7 @@ import ( "net" "os" "runtime" + "slices" "sort" "strconv" "strings" @@ -127,6 +128,7 @@ type Server struct { Log *slog.Logger // minimal no-alloc logger hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish + Blacklist []string // blacklist of client id } // loop contains interval tickers for the system events loop. @@ -337,6 +339,10 @@ func (s *Server) attachClient(cl *Client, listener string) error { } cl.ParseConnect(listener, pk) + if slices.Contains(s.Blacklist, cl.ID) { + return fmt.Errorf("blacklisted client: %s", cl.ID) + } + code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2] if code != packets.CodeSuccess { if err := s.SendConnack(cl, code, false, nil); err != nil {