Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom UDP connection identifiers #252

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Adrian Cable <[email protected]>
Atsushi Watanabe <[email protected]>
backkem <[email protected]>
cnderrauber <[email protected]>
Dan Mangum <[email protected]>
Hugo Arregui <[email protected]>
Jeremiah Millay <[email protected]>
Jozef Kralik <[email protected]>
Expand Down
26 changes: 24 additions & 2 deletions udp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type listener struct {
doneCh chan struct{}
doneOnce sync.Once
acceptFilter func([]byte) bool
connIDFn func(net.Addr, []byte) string

connLock sync.Mutex
conns map[string]*Conn
Expand Down Expand Up @@ -122,6 +123,11 @@ type ListenConfig struct {
// AcceptFilter determines whether the new conn should be made for
// the incoming packet. If not set, any packet creates new conn.
AcceptFilter func([]byte) bool

// ConnIDFn defines a custom connection identifier that will be used
// instead of the remote address. This allows for custom routing of
// received packets
ConnIDFn func(net.Addr, []byte) string
}

// Listen creates a new listener based on the ListenConfig.
Expand All @@ -141,6 +147,7 @@ func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener
conns: make(map[string]*Conn),
doneCh: make(chan struct{}),
acceptFilter: lc.AcceptFilter,
connIDFn: lc.ConnIDFn,
connWG: &sync.WaitGroup{},
readDoneCh: make(chan struct{}),
}
Expand Down Expand Up @@ -195,7 +202,22 @@ func (l *listener) readLoop() {
func (l *listener) getConn(raddr net.Addr, buf []byte) (*Conn, bool, error) {
l.connLock.Lock()
defer l.connLock.Unlock()
conn, ok := l.conns[raddr.String()]
connID := raddr.String()
if l.connIDFn != nil {
connID = l.connIDFn(raddr, buf)
// If we have a connection associated with this connection ID,
// make sure that the remote address matches.
if conn, ok := l.conns[connID]; ok && conn.RemoteAddr() != raddr {
conn.rAddr = raddr
}
Comment on lines +208 to +212
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am removing this as we should defer updating the remote address of the connection to the conn itself to allow for necessary checks, such as those required for peer address update in the DTLS Connection ID RFC: https://datatracker.ietf.org/doc/html/rfc9146#peer-address-update

However, this is tricky to do today because listener.Accept() just gives us back a net.Conn, which doesn't allow for modifying the internal state of the udp.Conn.

// If we have a connection associated with this remote address,
// update it to use the new ID.
if conn, ok := l.conns[raddr.String()]; ok {
l.conns[connID] = conn
delete(l.conns, raddr.String())
}
}
conn, ok := l.conns[connID]
if !ok {
if isAccepting, ok := l.accepting.Load().(bool); !isAccepting || !ok {
return nil, false, ErrClosedListener
Expand All @@ -208,7 +230,7 @@ func (l *listener) getConn(raddr net.Addr, buf []byte) (*Conn, bool, error) {
conn = l.newConn(raddr)
select {
case l.acceptCh <- conn:
l.conns[raddr.String()] = conn
l.conns[connID] = conn
default:
return nil, false, ErrListenQueueExceeded
}
Expand Down
205 changes: 196 additions & 9 deletions udp/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package udp

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -79,8 +80,7 @@ func stressDuplex(t *testing.T) {
MsgCount: 1, // Can't rely on UDP message order in CI
}

err = test.StressDuplex(ca, cb, opt)
if err != nil {
if err := test.StressDuplex(ca, cb, opt); err != nil {
t.Fatal(err)
}
}
Expand All @@ -99,14 +99,12 @@ func TestListenerCloseTimeout(t *testing.T) {
t.Fatal(err)
}

err = listener.Close()
if err != nil {
if err := listener.Close(); err != nil {
t.Fatal(err)
}

// Close client after server closes to cleanup
err = ca.Close()
if err != nil {
if err := ca.Close(); err != nil {
t.Fatal(err)
}
}
Expand Down Expand Up @@ -147,7 +145,7 @@ func TestListenerCloseUnaccepted(t *testing.T) {
time.Sleep(100 * time.Millisecond) // Wait all packets being processed by readLoop

// Unaccepted connections must be closed by listener.Close()
if err = listener.Close(); err != nil {
if err := listener.Close(); err != nil {
t.Fatal(err)
}
}
Expand Down Expand Up @@ -311,14 +309,203 @@ func TestListenerConcurrent(t *testing.T) {
}()

time.Sleep(100 * time.Millisecond) // Last Accept should be discarded
err = listener.Close()
if err != nil {
if err := listener.Close(); err != nil {
t.Fatal(err)
}

wg.Wait()
}

func TestListenerCustomConnID(t *testing.T) {
const helloPayload = "hello"
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

// Check for leaking routines
report := test.CheckRoutines(t)
defer report()

type pkt struct {
ID int
Payload string
}
network, addr := getConfig()
listener, err := (&ListenConfig{
ConnIDFn: func(raddr net.Addr, buf []byte) string {
p := &pkt{}
if err := json.Unmarshal(buf, p); err != nil {
t.Fatal(err)
}
if p.Payload == helloPayload {
return raddr.String()
}
return fmt.Sprint(p.ID)
},
}).Listen(network, addr)
if err != nil {
t.Fatal(err)
}

clientWg := sync.WaitGroup{}
var readFirst [5]chan struct{}
for i := range readFirst {
readFirst[i] = make(chan struct{})
}
var readSecond [5]chan struct{}
for i := range readSecond {
readSecond[i] = make(chan struct{})
}
serverWg := sync.WaitGroup{}
clientMap := map[string]struct{}{}
var clientMapMu sync.Mutex
for i := 0; i < 5; i++ {
serverWg.Add(1)
go func() {
defer serverWg.Done()
conn, err := listener.Accept()
if err != nil {
t.Error(err)
}
buf := make([]byte, 40)
n, rErr := conn.Read(buf)
if rErr != nil {
t.Error(err)
}
p := &pkt{}
if uErr := json.Unmarshal(buf[:n], p); uErr != nil {
t.Error(err)
}
// First message should be a hello and custom connection
// ID function will use remote address as identifier.
// Connection ID is extracted to signal that we are
// ready for the second message.
if p.Payload != helloPayload {
t.Error("Expected hello message")
}
connID := p.ID
close(readFirst[connID])
n, err = conn.Read(buf)
if err != nil {
t.Error(err)
}
if err := json.Unmarshal(buf[:n], p); err != nil {
t.Error(err)
}
// Second message should be a set and custom connection
// function will update the connection ID from remote
// address to the supplied ID.
if p.Payload != "set" {
t.Error("Expected set message")
}
if p.ID != connID {
t.Errorf("Expected connection ID %d, but got %d", connID, p.ID)
}
close(readSecond[connID])
for j := 0; j < 4; j++ {
n, err := conn.Read(buf)
if err != nil {
t.Error(err)
}
p := &pkt{}
if err := json.Unmarshal(buf[:n], p); err != nil {
t.Error(err)
}
if p.ID != connID {
t.Errorf("Expected connection ID %d, but got %d", connID, p.ID)
}
// Ensure we only ever receive one message from
// a given client.
clientMapMu.Lock()
if _, ok := clientMap[p.Payload]; ok {
t.Errorf("Multiple messages from single client %s", p.Payload)
}
clientMap[p.Payload] = struct{}{}
clientMapMu.Unlock()
}
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
}

for i := 0; i < 5; i++ {
clientWg.Add(1)
go func(connID int) {
defer clientWg.Done()
conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr))
if dErr != nil {
t.Error(dErr)
}
hbuf, err := json.Marshal(&pkt{
ID: connID,
Payload: helloPayload,
})
if err != nil {
t.Error(err)
}
if _, wErr := conn.Write(hbuf); wErr != nil {
t.Error(wErr)
}
// Ensure that the first message, which does not include
// a connection ID is read before sending additional
// messages.
<-readFirst[connID]
// Send a message to update the connection ID from the
// remote address to the provided ID.
buf, err := json.Marshal(&pkt{
ID: connID,
Payload: "set",
})
if err != nil {
t.Error(err)
}
if _, wErr := conn.Write(buf); wErr != nil {
t.Error(wErr)
}
if cErr := conn.Close(); cErr != nil {
t.Error(cErr)
}
}(i)
}

// Spawn 20 clients sending on 5 connections.
for i := 1; i <= 20; i++ {
clientWg.Add(1)
go func(connID int) {
defer clientWg.Done()
// Ensure that we are using a connection ID for packet
// routing prior to sending any messages.
<-readSecond[connID]
conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr))
if dErr != nil {
t.Error(dErr)
}
buf, err := json.Marshal(&pkt{
ID: connID,
Payload: conn.LocalAddr().String(),
})
if err != nil {
t.Error(err)
}
if _, wErr := conn.Write(buf); wErr != nil {
t.Error(wErr)
}
if cErr := conn.Close(); cErr != nil {
t.Error(cErr)
}
}(i % 5)
}

// Wait for clients to exit.
clientWg.Wait()
// Wait for servers to exit.
serverWg.Wait()
if err := listener.Close(); err != nil {
t.Fatal(err)
}
}

func pipe() (net.Listener, net.Conn, *net.UDPConn, error) {
// Start listening
network, addr := getConfig()
Expand Down