diff --git a/pkg/server/server.go b/pkg/server/server.go index 7c19df4..73b7b07 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -94,6 +94,7 @@ func (s *Server) listAllModules(downConn net.Conn) error { for name := range s.modules { modules = append(modules, name) } + timeout := s.WriteTimeout s.reloadLock.RUnlock() sort.Strings(modules) @@ -102,7 +103,7 @@ func (s *Server) listAllModules(downConn net.Conn) error { buf.WriteRune('\n') } buf.Write(RsyncdExit) - _, _ = s.writeWithTimeout(downConn, buf.Bytes()) + _, _ = writeWithTimeout(downConn, buf.Bytes(), timeout) return nil } @@ -113,7 +114,10 @@ func (s *Server) relay(ctx context.Context, downConn *net.TCPConn) error { // nolint:staticcheck defer s.bufPool.Put(buf) - n, err := s.readWithTimeout(downConn, buf) + writeTimeout := s.WriteTimeout + readTimeout := s.ReadTimeout + + n, err := readLine(downConn, buf, readTimeout) if err != nil { return fmt.Errorf("read version from client: %w", err) } @@ -122,12 +126,12 @@ func (s *Server) relay(ctx context.Context, downConn *net.TCPConn) error { return fmt.Errorf("unknown version from client: %s", data) } - _, err = s.writeWithTimeout(downConn, RsyncdVersion) + _, err = writeWithTimeout(downConn, RsyncdVersion, writeTimeout) if err != nil { return fmt.Errorf("send version to client: %w", err) } - n, err = s.readWithTimeout(downConn, buf) + n, err = readLine(downConn, buf, readTimeout) if err != nil { return fmt.Errorf("read module from client: %w", err) } @@ -146,8 +150,8 @@ func (s *Server) relay(ctx context.Context, downConn *net.TCPConn) error { s.reloadLock.RUnlock() if !ok { - _, _ = s.writeWithTimeout(downConn, []byte(fmt.Sprintf("unknown module: %s\n", moduleName))) - _, _ = s.writeWithTimeout(downConn, RsyncdExit) + _, _ = writeWithTimeout(downConn, []byte(fmt.Sprintf("unknown module: %s\n", moduleName)), writeTimeout) + _, _ = writeWithTimeout(downConn, RsyncdExit, writeTimeout) return nil } @@ -158,12 +162,12 @@ func (s *Server) relay(ctx context.Context, downConn *net.TCPConn) error { upConn := conn.(*net.TCPConn) defer upConn.Close() - _, err = s.writeWithTimeout(upConn, RsyncdVersion) + _, err = writeWithTimeout(upConn, RsyncdVersion, writeTimeout) if err != nil { return fmt.Errorf("send version to upstream: %w", err) } - n, err = s.readWithTimeout(upConn, buf) + n, err = readLine(upConn, buf, readTimeout) if err != nil { return fmt.Errorf("read version from upstream: %w", err) } @@ -172,7 +176,7 @@ func (s *Server) relay(ctx context.Context, downConn *net.TCPConn) error { return fmt.Errorf("unknown version from upstream: %s", data) } - _, err = s.writeWithTimeout(upConn, []byte(moduleName+"\n")) + _, err = writeWithTimeout(upConn, []byte(moduleName+"\n"), writeTimeout) if err != nil { return fmt.Errorf("send module to upstream: %w", err) } diff --git a/pkg/server/utils.go b/pkg/server/utils.go index a4a0a98..57f35a2 100644 --- a/pkg/server/utils.go +++ b/pkg/server/utils.go @@ -5,18 +5,34 @@ import ( "time" ) -func (s *Server) readWithTimeout(conn net.Conn, buf []byte) (n int, err error) { - if s.ReadTimeout > 0 { - _ = conn.SetReadDeadline(time.Now().Add(s.ReadTimeout)) +func writeWithTimeout(conn net.Conn, buf []byte, timeout time.Duration) (n int, err error) { + if timeout > 0 { + _ = conn.SetWriteDeadline(time.Now().Add(timeout)) } - n, err = conn.Read(buf) + n, err = conn.Write(buf) return } -func (s *Server) writeWithTimeout(conn net.Conn, buf []byte) (n int, err error) { - if s.WriteTimeout > 0 { - _ = conn.SetWriteDeadline(time.Now().Add(s.ReadTimeout)) +func readLine(conn net.Conn, buf []byte, timeout time.Duration) (n int, err error) { + // 这个只是特殊场景下的 readLine + // rsync 在握手过程中除了 protocol version 跟 module name 以外并不会传输其他数据,而这些数据又是以 '\n' 分割 + // 所以可以直接尽力读满传进来的 buffer 直到读到 '\n' 为止 + max := len(buf) + for { + if timeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + } + var nr int + nr, err = conn.Read(buf[n:]) + n += nr + if n > 0 && buf[n-1] == '\n' { + return n, nil + } + if n == max { + return n, nil + } + if err != nil { + return + } } - n, err = conn.Write(buf) - return } diff --git a/pkg/server/utils_test.go b/pkg/server/utils_test.go new file mode 100644 index 0000000..0b7edc7 --- /dev/null +++ b/pkg/server/utils_test.go @@ -0,0 +1,67 @@ +package server + +import ( + "net" + "reflect" + "testing" + "time" +) + +type fakeConn struct { + fragments [][]byte +} + +func (c *fakeConn) Read(b []byte) (n int, err error) { + for _, frag := range c.fragments { + nw := copy(b[n:], frag) + n += nw + } + return +} + +func (c *fakeConn) Write(b []byte) (n int, err error) { + panic("implement me") +} + +func (c *fakeConn) Close() error { + panic("implement me") +} + +func (c *fakeConn) LocalAddr() net.Addr { + panic("implement me") +} + +func (c *fakeConn) RemoteAddr() net.Addr { + panic("implement me") +} + +func (c *fakeConn) SetDeadline(t time.Time) error { + panic("implement me") +} + +func (c *fakeConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *fakeConn) SetWriteDeadline(t time.Time) error { + panic("implement me") +} + +func TestReadLine(t *testing.T) { + c := &fakeConn{fragments: [][]byte{ + RsyncdVersionPrefix, + []byte(" 31.0"), + {'\n'}, + }} + + buf := make([]byte, TCPBufferSize) + n, err := readLine(c, buf, time.Minute) + if err != nil { + t.Error(err) + } + got := buf[:n] + expected := []byte("@RSYNCD: 31.0\n") + if !reflect.DeepEqual(got, expected) { + t.Errorf("Unexpected data\nExpected: %s\nGot: %s\n", string(expected), string(got)) + } +}