From 48b28c672d43b37b4c8cad7748a9c8be9596df38 Mon Sep 17 00:00:00 2001 From: jgould Date: Sat, 14 Jan 2023 17:43:33 -0800 Subject: [PATCH] uacp: add support for read timeouts This prevents io.ReadFull to hang. --- config.go | 12 +++++++++++- uacp/conn.go | 36 ++++++++++++++++++++++++++++++------ 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/config.go b/config.go index 0dd4aae2..6bbf8312 100644 --- a/config.go +++ b/config.go @@ -81,7 +81,7 @@ func NewDialer(cfg *Config) *uacp.Dialer { // ApplyConfig applies the config options to the default configuration. // todo(fs): Can we find a better name? // -// Note: Starting with v0.5 this function will will return an error. +// Note: Starting with v0.5 this function will return an error. func ApplyConfig(opts ...Option) *Config { cfg := &Config{ sechan: DefaultClientConfig(), @@ -501,6 +501,16 @@ func DialTimeout(d time.Duration) Option { } } +// ReadTimeout sets the timeout for every read operation. +func ReadTimeout(d time.Duration) Option { + return func(cfg *Config) { + initDialer(cfg) + cfg.dialer.ReadTimeout = d + } +} + } +} + // MaxMessageSize sets the maximum message size for the UACP handshake. func MaxMessageSize(n uint32) Option { return func(cfg *Config) { diff --git a/uacp/conn.go b/uacp/conn.go index a9eda06a..943a37e0 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -64,6 +64,10 @@ type Dialer struct { // ClientACK defines the connection parameters requested by the client. // Defaults to DefaultClientACK. ClientACK *Acknowledge + + // ReadTimeout sets a read timeout for reading a full response from the + // underlying network connection. ReadTimeout is ignored if it is <= 0. + ReadTimeout time.Duration } func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) { @@ -88,6 +92,7 @@ func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) { c.Close() return nil, err } + conn.readTimeout = d.ReadTimeout debug.Printf("uacp %d: start HEL/ACK handshake", conn.id) if err := conn.Handshake(ctx, endpoint); err != nil { @@ -174,7 +179,8 @@ type Conn struct { id uint32 ack *Acknowledge - closeOnce sync.Once + closeOnce sync.Once + readTimeout time.Duration } func NewConn(c *net.TCPConn, ack *Acknowledge) (*Conn, error) { @@ -351,15 +357,25 @@ const hdrlen = 8 // The size of b must be at least ReceiveBufSize. Otherwise, // the function returns an error. func (c *Conn) Receive() ([]byte, error) { - // TODO(kung-foo): allow user-specified buffer - // TODO(kung-foo): sync.Pool + // todo(kung-foo): allow user-specified buffer + // todo(kung-foo): sync.Pool b := make([]byte, c.ack.ReceiveBufSize) - if _, err := io.ReadFull(c, b[:hdrlen]); err != nil { + if c.readTimeout > 0 { + if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil { + return nil, errors.Errorf("uacp: failed to set read timeout: %w", err) + } + } + + n, err := c.Read(b[:hdrlen]) + if err != nil { // todo(fs): do not wrap this error since it hides io.EOF // todo(fs): use golang.org/x/xerrors return nil, err } + if n != hdrlen { + return nil, errors.Errorf("uacp: short read on header. got %d bytes, want %d ", n, hdrlen) + } var h Header if _, err := h.Decode(b[:hdrlen]); err != nil { @@ -370,18 +386,26 @@ func (c *Conn) Receive() ([]byte, error) { return nil, errors.Errorf("uacp: message too large: %d > %d bytes", h.MessageSize, c.ack.ReceiveBufSize) } - if _, err := io.ReadFull(c, b[hdrlen:h.MessageSize]); err != nil { + n, err = c.Read(b[hdrlen:h.MessageSize]) + if err != nil { // todo(fs): do not wrap this error since it hides io.EOF // todo(fs): use golang.org/x/xerrors return nil, err } + // clear the deadline + c.SetReadDeadline(time.Time{}) + + if uint32(n) != h.MessageSize-hdrlen { + return nil, fmt.Errorf("uacp %d: short read on message. got %d bytes, want %d", c.id, n, h.MessageSize-hdrlen) + } + debug.Printf("uacp %d: recv %s%c with %d bytes", c.id, h.MessageType, h.ChunkType, h.MessageSize) if h.MessageType == "ERR" { errf := new(Error) if _, err := errf.Decode(b[hdrlen:h.MessageSize]); err != nil { - return nil, errors.Errorf("uacp: failed to decode ERRF message: %s", err) + return nil, errors.Errorf("uacp: failed to decode ERRF message: %w", err) } return nil, errf }