Skip to content

Commit

Permalink
Merge branch 'sbruens/udp-split-serving' into sbruens/websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
sbruens committed Dec 13, 2024
2 parents e210fb0 + e0547f2 commit 3c1277c
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 118 deletions.
210 changes: 118 additions & 92 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ type UDPAssocationMetrics interface {
// Max UDP buffer size for the server code.
const serverUDPBufferSize = 64 * 1024

// Buffer pool used for reading UDP packets.
var readBufPool = slicepool.MakePool(serverUDPBufferSize)

// Wrapper for slog.Debug during UDP proxying.
func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) {
// This is an optimization to reduce unnecessary allocations due to an interaction
Expand Down Expand Up @@ -153,57 +156,100 @@ type AssocationHandleFunc func(assocation net.Conn)
// function for each association. It uses a NAT map to track active associations
// and handles their lifecycle.
func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics NATMetrics) {
// This goroutine continuously reads from clientConn and sends the received data
// to readCh. It uses a buffer pool (readBufPool) to efficiently manage buffers
// and minimize allocations. The LazySlice is sent along with the read event
// to allow the receiver to release the buffer back to the pool after processing.
readCh := make(chan readEvent, 10)
go func() {
for {
lazySlice := readBufPool.LazySlice()
buffer := lazySlice.Acquire()
n, addr, err := clientConn.ReadFrom(buffer)
if err != nil {
lazySlice.Release()
if errors.Is(err, net.ErrClosed) {
readCh <- readEvent{err: err}
return
}
slog.Warn("Failed to read from client. Continuing to listen.", "err", err)
continue
}
readCh <- readEvent{
poolSlice: lazySlice,
pkt: buffer[:n],
addr: addr,
}
}
}()

nm := newNATmap()
defer nm.Close()
buffer := make([]byte, serverUDPBufferSize)
// This loop handles events from closeCh (connection closures) and readCh
// (incoming data). It removes NAT entries for closed connections and processes
// incoming data packets. The loop also ensures that buffers acquired from
// the readBufPool are released back to the pool after processing is complete.
closeCh := make(chan net.Addr, 10)
for {
n, addr, err := clientConn.ReadFrom(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
select {
case addr := <-closeCh:
metrics.RemoveNATEntry()
nm.Del(addr)
case read := <-readCh:
if read.err != nil {
return
}
slog.Warn("Failed to read from client. Continuing to listen.", "err", err)
continue
}
pkt := buffer[:n]

// TODO: Include server address in the NAT key as well.
conn := nm.Get(addr.String())
if conn == nil {
conn = &natconn{
Conn: &packetConnWrapper{PacketConn: clientConn, raddr: addr},
readBufCh: make(chan []byte, 1),
bytesReadCh: make(chan int, 1),

poolSlice := read.poolSlice
pkt := read.pkt
addr := read.addr

// TODO: Include server address in the NAT key as well.
conn := nm.Get(addr.String())
if conn == nil {
conn = &natconn{
PacketConn: clientConn,
raddr: addr,
closeCh: closeCh,
doneCh: make(chan struct{}),
readBufCh: make(chan []byte, 1),
bytesReadCh: make(chan int, 1),
}
metrics.AddNATEntry()
nm.Add(addr, conn)
go handle(conn)
}
metrics.AddNATEntry()
deleteEntry := nm.Add(addr, conn)
go func(conn *natconn) {
defer func() {
conn.Close()
deleteEntry()
metrics.RemoveNATEntry()
}()
handle(conn)
}(conn)
}
readBuf, ok := <-conn.readBufCh
if !ok {
continue
readBuf, ok := <-conn.readBufCh
if !ok {
poolSlice.Release()
continue
}
copy(readBuf, pkt)
poolSlice.Release()
conn.bytesReadCh <- len(pkt)
}
copy(readBuf, pkt)
conn.bytesReadCh <- n
}
}

type readEvent struct {
poolSlice slicepool.LazySlice
pkt []byte
addr net.Addr
err error
}

// natconn adapts a [net.Conn] to provide a synchronized reading mechanism for NAT traversal.
//
// The application provides the buffer to `Read()` (BYOB: Bring Your Own Buffer!)
// which minimizes buffer allocations and copying.
type natconn struct {
net.Conn
net.PacketConn
raddr net.Addr
closeCh chan net.Addr
doneCh chan struct{}

// readBufCh provides a buffer to copy incoming packet data into.
readBufCh chan []byte
readBufCh chan []byte

// bytesReadCh is used to signal the availability of new data and carries
// the length of the received packet.
Expand All @@ -213,35 +259,48 @@ type natconn struct {
var _ net.Conn = (*natconn)(nil)

func (c *natconn) Read(p []byte) (int, error) {
c.readBufCh <- p
n, ok := <-c.bytesReadCh
if !ok {
select {
case <-c.doneCh:
c.closeCh <- c.raddr
return 0, net.ErrClosed
case c.readBufCh <- p:
n, ok := <-c.bytesReadCh
if !ok {
c.closeCh <- c.raddr
return 0, net.ErrClosed
}
return n, nil
}
return n, nil
}

func (c *natconn) Write(b []byte) (n int, err error) {
return c.PacketConn.WriteTo(b, c.raddr)
}

func (c *natconn) Close() error {
close(c.readBufCh)
close(c.doneCh)
close(c.bytesReadCh)
c.Conn.Close()
return nil
}

func (c *natconn) RemoteAddr() net.Addr {
return c.raddr
}

func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPAssocationMetrics) {
if connMetrics == nil {
connMetrics = &NoOpUDPAssocationMetrics{}
}

targetConn, err := h.targetConnFactory()
if err != nil {
slog.Error("UDP: failed to create target connection", slog.Any("err", err))
h.logger.Error("UDP: failed to create target connection", slog.Any("err", err))
return
}

cipherLazySlice := h.bufPool.LazySlice()
cipherBuf := cipherLazySlice.Acquire()
defer cipherLazySlice.Release()
cipherSlice := h.bufPool.LazySlice()
cipherBuf := cipherSlice.Acquire()
defer cipherSlice.Release()

textLazySlice := h.bufPool.LazySlice()

Expand All @@ -250,18 +309,18 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA
for {
clientProxyBytes, err := clientAssociation.Read(cipherBuf)
if errors.Is(err, net.ErrClosed) {
cipherLazySlice.Release()
cipherSlice.Release()
return
}
debugUDPAddr(h.logger, "Outbound packet.", clientAssociation.RemoteAddr(), slog.Int("bytes", clientProxyBytes))

connError := func() *onet.ConnectionError {
defer func() {
if r := recover(); r != nil {
slog.Error("Panic in UDP loop. Continuing to listen.", "err", r)
h.logger.Error("Panic in UDP loop. Continuing to listen.", "err", r)
debug.PrintStack()
}
slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAssociation.RemoteAddr().String()))
h.logger.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAssociation.RemoteAddr().String()))
}()

cipherData := cipherBuf[:clientProxyBytes]
Expand All @@ -285,9 +344,10 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA

connMetrics.AddAuthenticated(keyID)
go func() {
defer connMetrics.AddClosed()
timedCopy(clientAssociation, targetConn, cryptoKey, connMetrics, h.logger)
connMetrics.AddClosed()
targetConn.Close()
clientAssociation.Close()
}()

} else {
Expand Down Expand Up @@ -422,34 +482,25 @@ func (m *natmap) Get(key string) *natconn {
return m.keyConn[key]
}

func (m *natmap) set(key string, pc *natconn) {
m.Lock()
defer m.Unlock()

m.keyConn[key] = pc
return
}

func (m *natmap) del(key string) *natconn {
func (m *natmap) Del(addr net.Addr) net.PacketConn {
m.Lock()
defer m.Unlock()

entry, ok := m.keyConn[key]
entry, ok := m.keyConn[addr.String()]
if ok {
delete(m.keyConn, key)
delete(m.keyConn, addr.String())
return entry
}
return nil
}

// Add adds a new UDP NAT entry to the natmap and returns a closure to delete
// the entry.
func (m *natmap) Add(addr net.Addr, pc *natconn) func() {
key := addr.String()
m.set(key, pc)
return func() {
m.del(key)
}
func (m *natmap) Add(addr net.Addr, pc *natconn) {
m.Lock()
defer m.Unlock()

m.keyConn[addr.String()] = pc
}

func (m *natmap) Close() error {
Expand All @@ -466,31 +517,6 @@ func (m *natmap) Close() error {
return err
}

// packetConnWrapper wraps a [net.PacketConn] and provides a [net.Conn] interface
// with a given remote address.
type packetConnWrapper struct {
net.PacketConn
raddr net.Addr
}

var _ net.Conn = (*packetConnWrapper)(nil)

// ReadFrom reads data from the connection.
func (pcw *packetConnWrapper) Read(b []byte) (n int, err error) {
n, _, err = pcw.PacketConn.ReadFrom(b)
return
}

// WriteTo writes data to the connection.
func (pcw *packetConnWrapper) Write(b []byte) (n int, err error) {
return pcw.PacketConn.WriteTo(b, pcw.raddr)
}

// RemoteAddr returns the remote network address.
func (pcw *packetConnWrapper) RemoteAddr() net.Addr {
return pcw.raddr
}

// Get the maximum length of the shadowsocks address header by parsing
// and serializing an IPv6 address from the example range.
var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345"))
Expand Down
Loading

0 comments on commit 3c1277c

Please sign in to comment.