Skip to content

Commit

Permalink
fix: read until '\n'
Browse files Browse the repository at this point in the history
Signed-off-by: knight42 <[email protected]>
  • Loading branch information
knight42 committed Jul 5, 2020
1 parent 9f058d1 commit 166beab
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 18 deletions.
22 changes: 13 additions & 9 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ func (s *Server) listAllModules(downConn net.Conn) error {
for name := range s.modules {
modules = append(modules, name)
}
timeout := s.WriteTimeout
s.reloadLock.RUnlock()

sort.Strings(modules)
Expand All @@ -102,7 +103,7 @@ func (s *Server) listAllModules(downConn net.Conn) error {
buf.WriteRune('\n')
}
buf.Write(RsyncdExit)
_, _ = s.writeWithTimeout(downConn, buf.Bytes())
_, _ = writeWithTimeout(downConn, buf.Bytes(), timeout)
return nil
}

Expand All @@ -113,7 +114,10 @@ func (s *Server) relay(ctx context.Context, downConn *net.TCPConn) error {
// nolint:staticcheck
defer s.bufPool.Put(buf)

n, err := s.readWithTimeout(downConn, buf)
writeTimeout := s.WriteTimeout
readTimeout := s.ReadTimeout

n, err := readLine(downConn, buf, readTimeout)
if err != nil {
return fmt.Errorf("read version from client: %w", err)
}
Expand All @@ -122,12 +126,12 @@ func (s *Server) relay(ctx context.Context, downConn *net.TCPConn) error {
return fmt.Errorf("unknown version from client: %s", data)
}

_, err = s.writeWithTimeout(downConn, RsyncdVersion)
_, err = writeWithTimeout(downConn, RsyncdVersion, writeTimeout)
if err != nil {
return fmt.Errorf("send version to client: %w", err)
}

n, err = s.readWithTimeout(downConn, buf)
n, err = readLine(downConn, buf, readTimeout)
if err != nil {
return fmt.Errorf("read module from client: %w", err)
}
Expand All @@ -146,8 +150,8 @@ func (s *Server) relay(ctx context.Context, downConn *net.TCPConn) error {
s.reloadLock.RUnlock()

if !ok {
_, _ = s.writeWithTimeout(downConn, []byte(fmt.Sprintf("unknown module: %s\n", moduleName)))
_, _ = s.writeWithTimeout(downConn, RsyncdExit)
_, _ = writeWithTimeout(downConn, []byte(fmt.Sprintf("unknown module: %s\n", moduleName)), writeTimeout)
_, _ = writeWithTimeout(downConn, RsyncdExit, writeTimeout)
return nil
}

Expand All @@ -158,12 +162,12 @@ func (s *Server) relay(ctx context.Context, downConn *net.TCPConn) error {
upConn := conn.(*net.TCPConn)
defer upConn.Close()

_, err = s.writeWithTimeout(upConn, RsyncdVersion)
_, err = writeWithTimeout(upConn, RsyncdVersion, writeTimeout)
if err != nil {
return fmt.Errorf("send version to upstream: %w", err)
}

n, err = s.readWithTimeout(upConn, buf)
n, err = readLine(upConn, buf, readTimeout)
if err != nil {
return fmt.Errorf("read version from upstream: %w", err)
}
Expand All @@ -172,7 +176,7 @@ func (s *Server) relay(ctx context.Context, downConn *net.TCPConn) error {
return fmt.Errorf("unknown version from upstream: %s", data)
}

_, err = s.writeWithTimeout(upConn, []byte(moduleName+"\n"))
_, err = writeWithTimeout(upConn, []byte(moduleName+"\n"), writeTimeout)
if err != nil {
return fmt.Errorf("send module to upstream: %w", err)
}
Expand Down
34 changes: 25 additions & 9 deletions pkg/server/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,34 @@ import (
"time"
)

func (s *Server) readWithTimeout(conn net.Conn, buf []byte) (n int, err error) {
if s.ReadTimeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(s.ReadTimeout))
func writeWithTimeout(conn net.Conn, buf []byte, timeout time.Duration) (n int, err error) {
if timeout > 0 {
_ = conn.SetWriteDeadline(time.Now().Add(timeout))
}
n, err = conn.Read(buf)
n, err = conn.Write(buf)
return
}

func (s *Server) writeWithTimeout(conn net.Conn, buf []byte) (n int, err error) {
if s.WriteTimeout > 0 {
_ = conn.SetWriteDeadline(time.Now().Add(s.ReadTimeout))
func readLine(conn net.Conn, buf []byte, timeout time.Duration) (n int, err error) {
// 这个只是特殊场景下的 readLine
// rsync 在握手过程中除了 protocol version 跟 module name 以外并不会传输其他数据,而这些数据又是以 '\n' 分割
// 所以可以直接尽力读满传进来的 buffer 直到读到 '\n' 为止
max := len(buf)
for {
if timeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(timeout))
}
var nr int
nr, err = conn.Read(buf[n:])
n += nr
if n > 0 && buf[n-1] == '\n' {
return n, nil
}
if n == max {
return n, nil
}
if err != nil {
return
}
}
n, err = conn.Write(buf)
return
}
67 changes: 67 additions & 0 deletions pkg/server/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package server

import (
"net"
"reflect"
"testing"
"time"
)

type fakeConn struct {
fragments [][]byte
}

func (c *fakeConn) Read(b []byte) (n int, err error) {
for _, frag := range c.fragments {
nw := copy(b[n:], frag)
n += nw
}
return
}

func (c *fakeConn) Write(b []byte) (n int, err error) {
panic("implement me")
}

func (c *fakeConn) Close() error {
panic("implement me")
}

func (c *fakeConn) LocalAddr() net.Addr {
panic("implement me")
}

func (c *fakeConn) RemoteAddr() net.Addr {
panic("implement me")
}

func (c *fakeConn) SetDeadline(t time.Time) error {
panic("implement me")
}

func (c *fakeConn) SetReadDeadline(t time.Time) error {
return nil
}

func (c *fakeConn) SetWriteDeadline(t time.Time) error {
panic("implement me")
}

func TestReadLine(t *testing.T) {
c := &fakeConn{fragments: [][]byte{
RsyncdVersionPrefix,
[]byte(" 31.0"),
{'\n'},
}}

buf := make([]byte, TCPBufferSize)
n, err := readLine(c, buf, time.Minute)
if err != nil {
t.Error(err)
}
got := buf[:n]
expected := []byte("@RSYNCD: 31.0\n")
if !reflect.DeepEqual(got, expected) {
t.Errorf("Unexpected data\nExpected: %s\nGot: %s\n", string(expected), string(got))
}
}

0 comments on commit 166beab

Please sign in to comment.