diff --git a/service/udp.go b/service/udp.go index 87c5ba4b..6ca1cd2c 100644 --- a/service/udp.go +++ b/service/udp.go @@ -48,6 +48,9 @@ type UDPAssocationMetrics interface { // Max UDP buffer size for the server code. const serverUDPBufferSize = 64 * 1024 +// Buffer pool used for reading UDP packets. +var readBufPool = slicepool.MakePool(serverUDPBufferSize) + // Wrapper for slog.Debug during UDP proxying. func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) { // This is an optimization to reduce unnecessary allocations due to an interaction @@ -153,57 +156,100 @@ type AssocationHandleFunc func(assocation net.Conn) // function for each association. It uses a NAT map to track active associations // and handles their lifecycle. func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics NATMetrics) { + // This goroutine continuously reads from clientConn and sends the received data + // to readCh. It uses a buffer pool (readBufPool) to efficiently manage buffers + // and minimize allocations. The LazySlice is sent along with the read event + // to allow the receiver to release the buffer back to the pool after processing. + readCh := make(chan readEvent, 10) + go func() { + for { + lazySlice := readBufPool.LazySlice() + buffer := lazySlice.Acquire() + n, addr, err := clientConn.ReadFrom(buffer) + if err != nil { + lazySlice.Release() + if errors.Is(err, net.ErrClosed) { + readCh <- readEvent{err: err} + return + } + slog.Warn("Failed to read from client. Continuing to listen.", "err", err) + continue + } + readCh <- readEvent{ + poolSlice: lazySlice, + pkt: buffer[:n], + addr: addr, + } + } + }() + nm := newNATmap() defer nm.Close() - buffer := make([]byte, serverUDPBufferSize) + // This loop handles events from closeCh (connection closures) and readCh + // (incoming data). It removes NAT entries for closed connections and processes + // incoming data packets. The loop also ensures that buffers acquired from + // the readBufPool are released back to the pool after processing is complete. + closeCh := make(chan net.Addr, 10) for { - n, addr, err := clientConn.ReadFrom(buffer) - if err != nil { - if errors.Is(err, net.ErrClosed) { - break + select { + case addr := <-closeCh: + metrics.RemoveNATEntry() + nm.Del(addr) + case read := <-readCh: + if read.err != nil { + return } - slog.Warn("Failed to read from client. Continuing to listen.", "err", err) - continue - } - pkt := buffer[:n] - - // TODO: Include server address in the NAT key as well. - conn := nm.Get(addr.String()) - if conn == nil { - conn = &natconn{ - Conn: &packetConnWrapper{PacketConn: clientConn, raddr: addr}, - readBufCh: make(chan []byte, 1), - bytesReadCh: make(chan int, 1), + + poolSlice := read.poolSlice + pkt := read.pkt + addr := read.addr + + // TODO: Include server address in the NAT key as well. + conn := nm.Get(addr.String()) + if conn == nil { + conn = &natconn{ + PacketConn: clientConn, + raddr: addr, + closeCh: closeCh, + doneCh: make(chan struct{}), + readBufCh: make(chan []byte, 1), + bytesReadCh: make(chan int, 1), + } + metrics.AddNATEntry() + nm.Add(addr, conn) + go handle(conn) } - metrics.AddNATEntry() - deleteEntry := nm.Add(addr, conn) - go func(conn *natconn) { - defer func() { - conn.Close() - deleteEntry() - metrics.RemoveNATEntry() - }() - handle(conn) - }(conn) - } - readBuf, ok := <-conn.readBufCh - if !ok { - continue + readBuf, ok := <-conn.readBufCh + if !ok { + poolSlice.Release() + continue + } + copy(readBuf, pkt) + poolSlice.Release() + conn.bytesReadCh <- len(pkt) } - copy(readBuf, pkt) - conn.bytesReadCh <- n } } +type readEvent struct { + poolSlice slicepool.LazySlice + pkt []byte + addr net.Addr + err error +} + // natconn adapts a [net.Conn] to provide a synchronized reading mechanism for NAT traversal. // // The application provides the buffer to `Read()` (BYOB: Bring Your Own Buffer!) // which minimizes buffer allocations and copying. type natconn struct { - net.Conn + net.PacketConn + raddr net.Addr + closeCh chan net.Addr + doneCh chan struct{} // readBufCh provides a buffer to copy incoming packet data into. - readBufCh chan []byte + readBufCh chan []byte // bytesReadCh is used to signal the availability of new data and carries // the length of the received packet. @@ -213,21 +259,34 @@ type natconn struct { var _ net.Conn = (*natconn)(nil) func (c *natconn) Read(p []byte) (int, error) { - c.readBufCh <- p - n, ok := <-c.bytesReadCh - if !ok { + select { + case <-c.doneCh: + c.closeCh <- c.raddr return 0, net.ErrClosed + case c.readBufCh <- p: + n, ok := <-c.bytesReadCh + if !ok { + c.closeCh <- c.raddr + return 0, net.ErrClosed + } + return n, nil } - return n, nil +} + +func (c *natconn) Write(b []byte) (n int, err error) { + return c.PacketConn.WriteTo(b, c.raddr) } func (c *natconn) Close() error { - close(c.readBufCh) + close(c.doneCh) close(c.bytesReadCh) - c.Conn.Close() return nil } +func (c *natconn) RemoteAddr() net.Addr { + return c.raddr +} + func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPAssocationMetrics) { if connMetrics == nil { connMetrics = &NoOpUDPAssocationMetrics{} @@ -235,13 +294,13 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA targetConn, err := h.targetConnFactory() if err != nil { - slog.Error("UDP: failed to create target connection", slog.Any("err", err)) + h.logger.Error("UDP: failed to create target connection", slog.Any("err", err)) return } - cipherLazySlice := h.bufPool.LazySlice() - cipherBuf := cipherLazySlice.Acquire() - defer cipherLazySlice.Release() + cipherSlice := h.bufPool.LazySlice() + cipherBuf := cipherSlice.Acquire() + defer cipherSlice.Release() textLazySlice := h.bufPool.LazySlice() @@ -250,18 +309,18 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA for { clientProxyBytes, err := clientAssociation.Read(cipherBuf) if errors.Is(err, net.ErrClosed) { - cipherLazySlice.Release() + cipherSlice.Release() return } debugUDPAddr(h.logger, "Outbound packet.", clientAssociation.RemoteAddr(), slog.Int("bytes", clientProxyBytes)) - + connError := func() *onet.ConnectionError { defer func() { if r := recover(); r != nil { - slog.Error("Panic in UDP loop. Continuing to listen.", "err", r) + h.logger.Error("Panic in UDP loop. Continuing to listen.", "err", r) debug.PrintStack() } - slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAssociation.RemoteAddr().String())) + h.logger.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAssociation.RemoteAddr().String())) }() cipherData := cipherBuf[:clientProxyBytes] @@ -285,9 +344,10 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA connMetrics.AddAuthenticated(keyID) go func() { - defer connMetrics.AddClosed() timedCopy(clientAssociation, targetConn, cryptoKey, connMetrics, h.logger) + connMetrics.AddClosed() targetConn.Close() + clientAssociation.Close() }() } else { @@ -422,21 +482,13 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } -func (m *natmap) set(key string, pc *natconn) { - m.Lock() - defer m.Unlock() - - m.keyConn[key] = pc - return -} - -func (m *natmap) del(key string) *natconn { +func (m *natmap) Del(addr net.Addr) net.PacketConn { m.Lock() defer m.Unlock() - entry, ok := m.keyConn[key] + entry, ok := m.keyConn[addr.String()] if ok { - delete(m.keyConn, key) + delete(m.keyConn, addr.String()) return entry } return nil @@ -444,12 +496,11 @@ func (m *natmap) del(key string) *natconn { // Add adds a new UDP NAT entry to the natmap and returns a closure to delete // the entry. -func (m *natmap) Add(addr net.Addr, pc *natconn) func() { - key := addr.String() - m.set(key, pc) - return func() { - m.del(key) - } +func (m *natmap) Add(addr net.Addr, pc *natconn) { + m.Lock() + defer m.Unlock() + + m.keyConn[addr.String()] = pc } func (m *natmap) Close() error { @@ -466,31 +517,6 @@ func (m *natmap) Close() error { return err } -// packetConnWrapper wraps a [net.PacketConn] and provides a [net.Conn] interface -// with a given remote address. -type packetConnWrapper struct { - net.PacketConn - raddr net.Addr -} - -var _ net.Conn = (*packetConnWrapper)(nil) - -// ReadFrom reads data from the connection. -func (pcw *packetConnWrapper) Read(b []byte) (n int, err error) { - n, _, err = pcw.PacketConn.ReadFrom(b) - return -} - -// WriteTo writes data to the connection. -func (pcw *packetConnWrapper) Write(b []byte) (n int, err error) { - return pcw.PacketConn.WriteTo(b, pcw.raddr) -} - -// RemoteAddr returns the remote network address. -func (pcw *packetConnWrapper) RemoteAddr() net.Addr { - return pcw.raddr -} - // Get the maximum length of the shadowsocks address header by parsing // and serializing an IPv6 address from the example range. var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) diff --git a/service/udp_test.go b/service/udp_test.go index 160c9697..8a7eb043 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -195,6 +195,24 @@ func startTestHandler() (AssociationHandler, func(target net.Addr, payload []byt }, targetConn } +func TestNatconnCloseWhileReading(t *testing.T) { + nc := &natconn{ + PacketConn: makePacketConn(), + raddr: &clientAddr, + doneCh: make(chan struct{}), + readBufCh: make(chan []byte, 1), + bytesReadCh: make(chan int, 1), + } + go func() { + buf := make([]byte, 1024) + nc.Read(buf) + }() + + err := nc.Close() + + assert.NoError(t, err, "Close should not panic or return an error") +} + func TestAssociationHandler_Handle_IPFilter(t *testing.T) { t.Run("RequirePublicIP blocks localhost", func(t *testing.T) { handler, sendPayload, targetConn := startTestHandler() @@ -462,56 +480,56 @@ func TestTimedPacketConn(t *testing.T) { func TestNATMap(t *testing.T) { t.Run("Empty", func(t *testing.T) { - nat := newNATmap() - if nat.Get("foo") != nil { + nm := newNATmap() + if nm.Get("foo") != nil { t.Error("Expected nil value from empty NAT map") } }) t.Run("Add", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} conn1 := &natconn{} - nat.Add(addr1, conn1) - assert.Equal(t, conn1, nat.Get(addr1.String()), "Get should return the correct connection") + nm.Add(addr, conn1) + assert.Equal(t, conn1, nm.Get(addr.String()), "Get should return the correct connection") conn2 := &natconn{} - nat.Add(addr1, conn2) - assert.Equal(t, conn2, nat.Get(addr1.String()), "Adding with the same address should overwrite the entry") + nm.Add(addr, conn2) + assert.Equal(t, conn2, nm.Get(addr.String()), "Adding with the same address should overwrite the entry") }) t.Run("Get", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - conn1 := &natconn{} - nat.Add(addr1, conn1) + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + conn := &natconn{} + nm.Add(addr, conn) - assert.Equal(t, conn1, nat.Get(addr1.String()), "Get should return the correct connection for an existing address") + assert.Equal(t, conn, nm.Get(addr.String()), "Get should return the correct connection for an existing address") addr2 := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 5678} - assert.Nil(t, nat.Get(addr2.String()), "Get should return nil for a non-existent address") + assert.Nil(t, nm.Get(addr2.String()), "Get should return nil for a non-existent address") }) - t.Run("closure_deletes", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - conn1 := &natconn{} - deleteEntry := nat.Add(addr1, conn1) + t.Run("Del", func(t *testing.T) { + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + conn := &natconn{} + nm.Add(addr, conn) - deleteEntry() + nm.Del(addr) - assert.Nil(t, nat.Get(addr1.String()), "Get should return nil after deleting the entry") + assert.Nil(t, nm.Get(addr.String()), "Get should return nil after deleting the entry") }) t.Run("Close", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} pc := makePacketConn() - conn1 := &natconn{Conn: &packetConnWrapper{PacketConn: pc, raddr: addr1}} - nat.Add(addr1, conn1) + conn := &natconn{PacketConn: pc, raddr: addr} + nm.Add(addr, conn) - err := nat.Close() + err := nm.Close() assert.NoError(t, err, "Close should not return an error") // The underlying connection should be scheduled to close immediately.