From bbfe721db664ea5fb438f7acc1487c7817967925 Mon Sep 17 00:00:00 2001 From: 1998-felix Date: Wed, 24 Jul 2024 09:21:17 +0300 Subject: [PATCH] feat: add udp and dtls proxy Signed-off-by: 1998-felix --- go.mod | 1 + go.sum | 2 + pkg/coap/coap.go | 465 +++++++++++++++++++---------------------------- pkg/tls/tls.go | 39 ++-- 4 files changed, 217 insertions(+), 290 deletions(-) diff --git a/go.mod b/go.mod index b1ffb23..d8f535b 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( require ( github.com/dsnet/golib/memfile v1.0.0 // indirect + github.com/dustin/go-coap v0.0.0-20190908170653-752e0f79981e github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/pion/logging v0.2.2 // indirect diff --git a/go.sum b/go.sum index 86a1242..a36f6a8 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dsnet/golib/memfile v1.0.0 h1:J9pUspY2bDCbF9o+YGwcf3uG6MdyITfh/Fk3/CaEiFs= github.com/dsnet/golib/memfile v1.0.0/go.mod h1:tXGNW9q3RwvWt1VV2qrRKlSSz0npnh12yftCSCy2T64= +github.com/dustin/go-coap v0.0.0-20190908170653-752e0f79981e h1:oppjHFVTardH+VyOD32F9uBtgT5Wd/qVqEGcwj389Lc= +github.com/dustin/go-coap v0.0.0-20190908170653-752e0f79981e/go.mod h1:as2rZ2aojRzZF8bGx1bPAn1yi9ICG6LwkiPOj6PBtjc= github.com/eclipse/paho.mqtt.golang v1.4.3 h1:2kwcUGn8seMUfWndX0hGbvH8r7crgcJguQNCyp70xik= github.com/eclipse/paho.mqtt.golang v1.4.3/go.mod h1:CSYvoAlsMkhYOXh/oKyxa8EcBci6dVkLCbo5tTC1RIE= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/pkg/coap/coap.go b/pkg/coap/coap.go index 3569d5d..36bd795 100644 --- a/pkg/coap/coap.go +++ b/pkg/coap/coap.go @@ -4,29 +4,35 @@ package coap import ( - "bytes" "context" - "errors" "fmt" "log/slog" + "net" + "sync" + "time" "github.com/absmach/mproxy" "github.com/absmach/mproxy/pkg/session" - "github.com/plgd-dev/go-coap/v3/dtls" - dtlsServer "github.com/plgd-dev/go-coap/v3/dtls/server" - "github.com/plgd-dev/go-coap/v3/message" - "github.com/plgd-dev/go-coap/v3/message/codes" - "github.com/plgd-dev/go-coap/v3/message/pool" - "github.com/plgd-dev/go-coap/v3/mux" - "github.com/plgd-dev/go-coap/v3/net" - "github.com/plgd-dev/go-coap/v3/options" - "github.com/plgd-dev/go-coap/v3/udp" - udpServer "github.com/plgd-dev/go-coap/v3/udp/server" + mptls "github.com/absmach/mproxy/pkg/tls" + gocoap "github.com/dustin/go-coap" + "github.com/pion/dtls/v2" + "golang.org/x/sync/errgroup" ) -const startObserve uint32 = 0 +const ( + bufferSize uint64 = 1280 + startObserve uint32 = 0 +) -var errUnsupportedMethod = errors.New("unsupported CoAP method") +var ( + ConnMap = make(map[string]*Conn) + mutex sync.Mutex +) + +type Conn struct { + clientAddr *net.UDPAddr + serverConn *net.UDPConn +} type Proxy struct { config mproxy.Config @@ -34,30 +40,6 @@ type Proxy struct { logger *slog.Logger } -type udpNilMonitor struct{} - -func (u *udpNilMonitor) UDPServerApply(cfg *udpServer.Config) { - cfg.CreateInactivityMonitor = nil -} - -func NewUDPNilMonitor() udpServer.Option { - return &udpNilMonitor{} -} - -var _ udpServer.Option = (*udpNilMonitor)(nil) - -type dtlsNilMonitor struct{} - -func (d *dtlsNilMonitor) DTLSServerApply(cfg *dtlsServer.Config) { - cfg.CreateInactivityMonitor = nil -} - -func NewDTLSNilMonitor() dtlsServer.Option { - return &dtlsNilMonitor{} -} - -var _ udpServer.Option = (*udpNilMonitor)(nil) - func NewProxy(config mproxy.Config, handler session.Handler, logger *slog.Logger) *Proxy { return &Proxy{ config: config, @@ -66,306 +48,241 @@ func NewProxy(config mproxy.Config, handler session.Handler, logger *slog.Logger } } -func sendErrorMessage(cc mux.Conn, token []byte, err error, code codes.Code) error { - m := cc.AcquireMessage(cc.Context()) - defer cc.ReleaseMessage(m) - m.SetCode(code) - m.SetBody(bytes.NewReader(([]byte)(err.Error()))) - m.SetToken(token) - m.SetContentFormat(message.TextPlain) - return cc.WriteMessage(m) +func (p *Proxy) proxy(ctx context.Context, l *net.UDPConn) { + buffer := make([]byte, 1024) + for { + select { + case <-ctx.Done(): + return + default: + n, clientAddr, err := l.ReadFromUDP(buffer) + if err != nil { + return + } + mutex.Lock() + conn, ok := ConnMap[clientAddr.String()] + if !ok { + conn, err = p.newConn(clientAddr) + if err != nil { + p.logger.Error("Failed to create new connection", slog.Any("error", err)) + mutex.Unlock() + return + } + ConnMap[clientAddr.String()] = conn + go p.down(l, conn) + } + mutex.Unlock() + p.up(conn, buffer[:n]) + } + } } -func (p *Proxy) postUpstream(cc mux.Conn, req *mux.Message, token []byte) error { - outbound, err := udp.Dial(p.config.Target) +func (p *Proxy) Listen(ctx context.Context) error { + addr, err := net.ResolveUDPAddr("udp", p.config.Address) if err != nil { + p.logger.Error("Failed to resolve UDP address", slog.Any("error", err)) return err } - defer outbound.Close() + g, ctx := errgroup.WithContext(ctx) + switch { + case p.config.DTLSConfig != nil: + l, err := dtls.Listen("udp", addr, p.config.DTLSConfig) + if err != nil { + return err + } + defer l.Close() - path, err := req.Options().Path() - if err != nil { - return err - } + g.Go(func() error { + p.accept(ctx, l) + return nil + }) + + g.Go(func() error { + <-ctx.Done() + return l.Close() + }) - format := message.TextPlain - if req.HasOption(message.ContentFormat) { - format, err = req.ContentFormat() + default: + l, err := net.ListenUDP("udp", addr) if err != nil { return err } - } + defer l.Close() - pm, err := outbound.Post(cc.Context(), path, format, req.Body(), req.Options()...) - if err != nil { - return err + g.Go(func() error { + p.proxy(ctx, l) + return nil + }) + + g.Go(func() error { + <-ctx.Done() + return l.Close() + }) } - pm.SetToken(token) - return cc.WriteMessage(pm) -} -func (p *Proxy) getUpstream(cc mux.Conn, req *mux.Message, token []byte) error { - path, err := req.Options().Path() - if err != nil { - return err + status := mptls.SecurityStatus(p.config.DTLSConfig) + p.logger.Info(fmt.Sprintf("COAP proxy server started at %s with %s", p.config.Address, status)) + + if err := g.Wait(); err != nil { + p.logger.Info(fmt.Sprintf("COAP proxy server at %s exiting with errors", p.config.Address), slog.String("error", err.Error())) + } else { + p.logger.Info(fmt.Sprintf("COAP proxy server at %s exiting...", p.config.Address)) } + return nil +} - outbound, err := udp.Dial(p.config.Target) +func (p *Proxy) newConn(clientAddr *net.UDPAddr) (*Conn, error) { + conn := new(Conn) + conn.clientAddr = clientAddr + addr, err := net.ResolveUDPAddr("udp", p.config.Target) if err != nil { - return err + return nil, err } - defer outbound.Close() - pm, err := outbound.Get(cc.Context(), path, req.Options()...) + t, err := net.DialUDP("udp", nil, addr) if err != nil { - return err + return nil, err } - pm.SetToken(token) - return cc.WriteMessage(pm) + conn.serverConn = t + return conn, nil } -func (p *Proxy) observeUpstream(ctx context.Context, cc mux.Conn, opts []message.Option, token []byte, path string) { - outbound, err := udp.Dial(p.config.Target) +func (p *Proxy) up(conn *Conn, buffer []byte) { + p.handleCoAPMessage(buffer) + _, err := conn.serverConn.Write(buffer) if err != nil { - if err := sendErrorMessage(cc, token, err, codes.BadGateway); err != nil { - p.logger.Error(fmt.Sprintf("cannot send error response: %v", err)) - } - } - defer outbound.Close() - doneObserving := make(chan struct{}) - - pm := outbound.AcquireMessage(outbound.Context()) - defer outbound.ReleaseMessage(pm) - pm.SetToken(token) - pm.SetCode(codes.GET) - for _, opt := range opts { - pm.SetOptionBytes(opt.ID, opt.Value) - } - if err := pm.SetPath(path); err != nil { - if err := sendErrorMessage(cc, token, err, codes.BadOption); err != nil { - p.logger.Error(fmt.Sprintf("cannot send error response: %v", err)) - } return } +} - obs, err := outbound.DoObserve(pm, func(req *pool.Message) { - req.SetToken(token) - if err := cc.WriteMessage(req); err != nil { - if err := sendErrorMessage(cc, token, err, codes.BadGateway); err != nil { - p.logger.Error(err.Error()) - } - p.logger.Error(err.Error()) +func (p *Proxy) down(l *net.UDPConn, conn *Conn) { + buffer := make([]byte, bufferSize) + for { + err := conn.serverConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + if err != nil { + return } - if req.Code() == codes.NotFound { - close(doneObserving) + n, err := conn.serverConn.Read(buffer) + if err != nil { + p.close(conn) + return } - }) - if err != nil { - if err := sendErrorMessage(cc, token, err, codes.BadGateway); err != nil { - p.logger.Error(fmt.Sprintf("cannot send error response: %v", err)) + _, err = l.WriteToUDP(buffer[:n], conn.clientAddr) + if err != nil { + return } } +} - select { - case <-doneObserving: - if err := obs.Cancel(ctx); err != nil { - p.logger.Error(fmt.Sprintf("failed to cancel observation:%v", err)) - } - case <-ctx.Done(): - return - } +func (p *Proxy) close(conn *Conn) { + mutex.Lock() + defer mutex.Unlock() + delete(ConnMap, conn.clientAddr.String()) + conn.serverConn.Close() } -func (p *Proxy) CancelObservation(cc mux.Conn, opts []message.Option, token []byte, path string) error { - outbound, err := udp.Dial(p.config.Target) - if err != nil { - if err := sendErrorMessage(cc, token, err, codes.BadGateway); err != nil { - p.logger.Error(fmt.Sprintf("cannot send error response: %v", err)) - } - } - defer outbound.Close() - - pm := outbound.AcquireMessage(outbound.Context()) - defer outbound.ReleaseMessage(pm) - pm.SetToken(token) - pm.SetCode(codes.GET) - for _, opt := range opts { - pm.SetOptionBytes(opt.ID, opt.Value) - } - if err := pm.SetPath(path); err != nil { - if err := sendErrorMessage(cc, token, err, codes.BadOption); err != nil { - p.logger.Error(fmt.Sprintf("cannot send error response: %v", err)) +func (p *Proxy) accept(ctx context.Context, l net.Listener) { + for { + select { + case <-ctx.Done(): + return + default: + conn, err := l.Accept() + if err != nil { + p.logger.Warn("Accept error " + err.Error()) + continue + } + p.logger.Info("Accepted new client") + go p.handleDTLS(conn) } - return err } - if err := outbound.WriteMessage(pm); err != nil { - return err - } - pm.SetCode(codes.Content) - return cc.WriteMessage(pm) } -func (p *Proxy) handler(w mux.ResponseWriter, r *mux.Message) { - tok, err := r.Options().GetBytes(message.URIQuery) +func (p *Proxy) handleDTLS(inbound net.Conn) { + outboundAddr, err := net.ResolveUDPAddr("udp", p.config.Address) if err != nil { - if err := sendErrorMessage(w.Conn(), r.Token(), err, codes.Unauthorized); err != nil { - p.logger.Error(err.Error()) - } return } - ctx := session.NewContext(r.Context(), &session.Session{Password: tok}) - if err := p.session.AuthConnect(ctx); err != nil { - if err := sendErrorMessage(w.Conn(), r.Token(), err, codes.Unauthorized); err != nil { - p.logger.Error(err.Error()) - } - return - } - path, err := r.Options().Path() + + outbound, err := net.DialUDP("udp", nil, outboundAddr) if err != nil { - if err := sendErrorMessage(w.Conn(), r.Token(), err, codes.BadOption); err != nil { - p.logger.Error(err.Error()) - } + p.logger.Error("Cannot connect to remote broker " + p.config.Address + " due to: " + err.Error()) return } - switch r.Code() { - case codes.GET: - p.handleGet(ctx, path, w.Conn(), r.Token(), r) - case codes.POST: - body, err := r.ReadBody() + go p.dtlsUp(outbound, inbound) + go p.dtlsDown(inbound, outbound) +} + +func (p *Proxy) dtlsUp(outbound *net.UDPConn, inbound net.Conn) { + buffer := make([]byte, bufferSize) + for { + n, err := inbound.Read(buffer) if err != nil { - if err := sendErrorMessage(w.Conn(), r.Token(), err, codes.BadRequest); err != nil { - p.logger.Error(err.Error()) - } return } - p.handlePost(ctx, w.Conn(), body, r.Token(), path, r) - default: - if err := sendErrorMessage(w.Conn(), r.Token(), errUnsupportedMethod, codes.MethodNotAllowed); err != nil { - p.logger.Error(err.Error()) + p.handleCoAPMessage(buffer[:n]) + + _, err = outbound.Write(buffer[:n]) + if err != nil { + slog.Error("Failed to write to server", slog.Any("err", err)) } } } -func (p *Proxy) handleGet(ctx context.Context, path string, con mux.Conn, token []byte, r *mux.Message) { - if err := p.session.AuthSubscribe(ctx, &[]string{path}); err != nil { - if err := sendErrorMessage(con, token, err, codes.Unauthorized); err != nil { - p.logger.Error(err.Error()) - } - return - } - if err := p.session.Subscribe(ctx, &[]string{path}); err != nil { - if err := sendErrorMessage(con, token, err, codes.Unauthorized); err != nil { - p.logger.Error(err.Error()) - } - return - } - switch { - case r.HasOption(message.Observe): - obs, err := r.Options().Observe() +func (p *Proxy) dtlsDown(inbound net.Conn, outbound *net.UDPConn) { + buffer := make([]byte, bufferSize) + for { + err := outbound.SetReadDeadline(time.Now().Add(1 * time.Minute)) if err != nil { - if err := sendErrorMessage(con, r.Token(), err, codes.BadRequest); err != nil { - p.logger.Error(err.Error()) - } return } - switch obs { - case startObserve: - go p.observeUpstream(ctx, con, r.Options(), token, path) - default: - if err := p.CancelObservation(con, r.Options(), token, path); err != nil { - p.logger.Error(fmt.Sprintf("error performing cancel observation: %v\n", err)) - if err := sendErrorMessage(con, token, err, codes.BadGateway); err != nil { - p.logger.Error(err.Error()) - } - return - } + n, err := outbound.Read(buffer) + defer outbound.Close() + if err != nil { + return } - default: - if err := p.getUpstream(con, r, token); err != nil { - p.logger.Error(fmt.Sprintf("error performing get: %v\n", err)) - if err := sendErrorMessage(con, token, err, codes.BadGateway); err != nil { - p.logger.Error(err.Error()) - } + + _, err = inbound.Write(buffer[:n]) + defer inbound.Close() + if err != nil { return } } } -func (p *Proxy) handlePost(ctx context.Context, con mux.Conn, body, token []byte, path string, r *mux.Message) { - if err := p.session.AuthPublish(ctx, &path, &body); err != nil { - if err := sendErrorMessage(con, token, err, codes.Unauthorized); err != nil { - p.logger.Error(err.Error()) - } +func (p *Proxy) handleCoAPMessage(buffer []byte) { + msg, err := gocoap.ParseMessage(buffer) + if err != nil { + p.logger.Error("Failed to parse message", slog.Any("error", err)) return } - if err := p.session.Publish(ctx, &path, &body); err != nil { - if err := sendErrorMessage(con, token, err, codes.BadRequest); err != nil { - p.logger.Error(err.Error()) + + token := msg.Token + path := msg.Path() + ctx := session.NewContext(context.Background(), &session.Session{Password: token}) + + switch msg.Code { + case gocoap.POST: + if err := p.session.AuthConnect(ctx); err != nil { + return } - return - } - if err := p.postUpstream(con, r, token); err != nil { - p.logger.Debug(fmt.Sprintf("error performing post: %v\n", err)) - if err := sendErrorMessage(con, token, err, codes.BadGateway); err != nil { - p.logger.Error(err.Error()) + if err := p.session.AuthPublish(ctx, &path[0], &msg.Payload); err != nil { + return } - return - } -} - -func (p *Proxy) Listen(ctx context.Context) error { - if p.config.DTLSConfig != nil { - l, err := net.NewDTLSListener("udp", p.config.Address, p.config.DTLSConfig) - if err != nil { - return err + if err := p.session.Publish(ctx, &path[0], &msg.Payload); err != nil { + return } - defer l.Close() - - p.logger.Info(fmt.Sprintf("CoAP proxy server started on port %s with DTLS", p.config.Address)) - var dialOpts []dtlsServer.Option - dialOpts = append(dialOpts, options.WithMux(mux.HandlerFunc(p.handler)), NewDTLSNilMonitor()) - - s := dtls.NewServer(dialOpts...) - - errCh := make(chan error) - go func() { - errCh <- s.Serve(l) - }() - - select { - case <-ctx.Done(): - p.logger.Info(fmt.Sprintf("CoAP proxy server on port %s with DTLS exiting ...", p.config.Address)) - l.Close() - case err := <-errCh: - p.logger.Error(fmt.Sprintf("CoAP proxy server on port %s with DTLS exiting with errors: %s", p.config.Address, err.Error())) - return err + case gocoap.GET: + if err := p.session.AuthConnect(ctx); err != nil { + return + } + if msg.Option(gocoap.Observe) == startObserve { + if err := p.session.AuthSubscribe(ctx, &path); err != nil { + return + } + if err := p.session.Subscribe(ctx, &path); err != nil { + return + } } - return nil - } - l, err := net.NewListenUDP("udp", p.config.Address) - if err != nil { - return err - } - defer l.Close() - - p.logger.Info(fmt.Sprintf("CoAP proxy server started at %s without DTLS", p.config.Address)) - var dialOpts []udpServer.Option - dialOpts = append(dialOpts, options.WithMux(mux.HandlerFunc(p.handler)), NewUDPNilMonitor()) - - s := udp.NewServer(dialOpts...) - - errCh := make(chan error) - go func() { - errCh <- s.Serve(l) - }() - - select { - case <-ctx.Done(): - p.logger.Info(fmt.Sprintf("CoAP proxy server on port %s without DTLS exiting ...", p.config.Address)) - l.Close() - case err := <-errCh: - p.logger.Error(fmt.Sprintf("CoAP proxy server on port %s without DTLS exiting with errors: %s", p.config.Address, err.Error())) - return err } - return nil } diff --git a/pkg/tls/tls.go b/pkg/tls/tls.go index b930ccd..ff7da54 100644 --- a/pkg/tls/tls.go +++ b/pkg/tls/tls.go @@ -105,6 +105,29 @@ func LoadTLSConfig[sc TLSConfig](c *Config, s sc) (sc, error) { } } +// SecurityStatus returns log message from TLS config. +func SecurityStatus[sc TLSConfig](s sc) string { + if s == nil { + return "no TLS" + } + switch c := any(s).(type) { + case *tls.Config: + ret := "TLS" + // It is possible to establish TLS with client certificates only. + if c.Certificates == nil || len(c.Certificates) == 0 { + ret = "no server certificates" + } + if c.ClientCAs != nil { + ret += " and " + c.ClientAuth.String() + } + return ret + case *dtls.Config: + return "DTLS" + default: + return "no TLS" + } +} + // ClientCert returns client certificate. func ClientCert(conn net.Conn) (x509.Certificate, error) { switch connVal := conn.(type) { @@ -126,22 +149,6 @@ func ClientCert(conn net.Conn) (x509.Certificate, error) { } } -// SecurityStatus returns log message from TLS config. -func SecurityStatus(c *tls.Config) string { - if c == nil { - return "no TLS" - } - ret := "TLS" - // It is possible to establish TLS with client certificates only. - if c.Certificates == nil || len(c.Certificates) == 0 { - ret = "no server certificates" - } - if c.ClientCAs != nil { - ret += " and " + c.ClientAuth.String() - } - return ret -} - func loadCertFile(certFile string) ([]byte, error) { if certFile != "" { return os.ReadFile(certFile)