diff --git a/config.go b/config.go index 6bbf8312..76e494ec 100644 --- a/config.go +++ b/config.go @@ -508,6 +508,12 @@ func ReadTimeout(d time.Duration) Option { cfg.dialer.ReadTimeout = d } } + +// WriteTimeout sets the timeout for every write operation. +func WriteTimeout(d time.Duration) Option { + return func(cfg *Config) { + initDialer(cfg) + cfg.dialer.WriteTimeout = d } } diff --git a/uacp/conn.go b/uacp/conn.go index 943a37e0..9e9a5cbd 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -68,6 +68,10 @@ type Dialer struct { // ReadTimeout sets a read timeout for reading a full response from the // underlying network connection. ReadTimeout is ignored if it is <= 0. ReadTimeout time.Duration + + // WriteTimeout sets a write timeout for sending a request on the + // underlying network connection. WriteTimeout is ignored if it is <= 0. + WriteTimeout time.Duration } func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) { @@ -93,6 +97,7 @@ func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) { return nil, err } conn.readTimeout = d.ReadTimeout + conn.writeTimeout = d.WriteTimeout debug.Printf("uacp %d: start HEL/ACK handshake", conn.id) if err := conn.Handshake(ctx, endpoint); err != nil { @@ -179,8 +184,9 @@ type Conn struct { id uint32 ack *Acknowledge - closeOnce sync.Once + closeOnce sync.Once readTimeout time.Duration + writeTimeout time.Duration } func NewConn(c *net.TCPConn, ack *Acknowledge) (*Conn, error) { @@ -419,7 +425,7 @@ func (c *Conn) Send(typ string, msg interface{}) error { body, err := ua.Encode(msg) if err != nil { - return errors.Errorf("encode msg failed: %s", err) + return errors.Errorf("encode msg failed: %w", err) } h := Header{ @@ -437,12 +443,21 @@ func (c *Conn) Send(typ string, msg interface{}) error { return errors.Errorf("encode hdr failed: %s", err) } + if c.writeTimeout > 0 { + if err := c.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil { + return errors.Errorf("failed to set write timeout: %w", err) + } + } + b := append(hdr, body...) if _, err := c.Write(b); err != nil { return errors.Errorf("write failed: %s", err) } debug.Printf("uacp %d: sent %s with %d bytes", c.id, typ, len(b)) + // clear the deadline + c.SetWriteDeadline(time.Time{}) + return nil }