Skip to content

Commit

Permalink
Modularize connection handling (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored Mar 15, 2024
1 parent de41d3e commit b9cb68e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 32 deletions.
9 changes: 7 additions & 2 deletions service/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ func (c *measuredConn) Write(b []byte) (int, error) {
return n, err
}

func (c *measuredConn) ReadFrom(r io.Reader) (int64, error) {
n, err := io.Copy(c.StreamConn, r)
func (c *measuredConn) ReadFrom(r io.Reader) (n int64, err error) {
if rf, ok := c.StreamConn.(io.ReaderFrom); ok {
// Prefer ReadFrom if we are calling ReadFrom. Otherwise io.Copy will try WriteTo first.
n, err = rf.ReadFrom(r)
} else {
n, err = io.Copy(c.StreamConn, r)
}
*c.writeCount += n
return n, err
}
Expand Down
97 changes: 67 additions & 30 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,25 +239,22 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
logger.Debugf("Done with status %v, duration %v", status, connDuration)
}

func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
// Set a deadline to receive the address to the target.
clientConn.SetReadDeadline(time.Now().Add(h.readTimeout))

// 1. Find the cipher and acess key id.
func (h *tcpHandler) authenticate(clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, transport.StreamConn, *onet.ConnectionError) {
// TODO(fortuna): Offer alternative transports.
// Find the cipher and acess key id.
cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), h.ciphers)
h.m.AddTCPCipherSearch(keyErr == nil, timeToCipher)
if keyErr != nil {
logger.Debugf("Failed to find a valid cipher after reading %v bytes: %v", proxyMetrics.ClientProxy, keyErr)
const status = "ERR_CIPHER"
h.absorbProbe(listenerPort, clientConn, status, proxyMetrics)
return "", onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
return "", nil, onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
}
var id string
if cipherEntry != nil {
id = cipherEntry.ID
}

// 2. Check if the connection is a replay.
// Check if the connection is a replay.
isServerSalt := cipherEntry.SaltGenerator.IsServerSalt(clientSalt)
// Only check the cache if findAccessKey succeeded and the salt is unrecognized.
if isServerSalt || !h.replayCache.Add(cipherEntry.ID, clientSalt) {
Expand All @@ -267,38 +264,39 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
} else {
status = "ERR_REPLAY_CLIENT"
}
h.absorbProbe(listenerPort, clientConn, status, proxyMetrics)
logger.Debugf(status+": %v sent %d bytes", clientConn.RemoteAddr(), proxyMetrics.ClientProxy)
return id, onet.NewConnectionError(status, "Replay detected", nil)
return id, nil, onet.NewConnectionError(status, "Replay detected", nil)
}

// 3. Read target address and dial it.
ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey)
tgtAddr, err := socks.ReadAddr(ssr)
ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey)
ssw.SetSaltGenerator(cipherEntry.SaltGenerator)
return id, transport.WrapConn(clientConn, ssr, ssw), nil
}

// Clear the deadline for the target address
clientConn.SetReadDeadline(time.Time{})
func getProxyRequest(clientConn transport.StreamConn) (string, error) {
// TODO(fortuna): Use Shadowsocks proxy, HTTP CONNECT or SOCKS5 based on first byte:
// case 1, 3 or 4: Shadowsocks (address type)
// case 5: SOCKS5 (protocol version)
// case "C": HTTP CONNECT (first char of method)
tgtAddr, err := socks.ReadAddr(clientConn)
if err != nil {
// Drain to prevent a close on cipher error.
io.Copy(io.Discard, clientConn)
return id, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err)
return "", err
}
tgtConn, dialErr := h.dialer.DialStream(ctx, tgtAddr.String())
return tgtAddr.String(), nil
}

func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError {
tgtConn, dialErr := dialer.DialStream(ctx, tgtAddr)
if dialErr != nil {
// We don't drain so dial errors and invalid addresses are communicated quickly.
return id, ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target")
return ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target")
}
tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
defer tgtConn.Close()

// 4. Bridge the client and target connections
logger.Debugf("proxy %s <-> %s", clientConn.RemoteAddr().String(), tgtConn.RemoteAddr().String())
ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey)
ssw.SetSaltGenerator(cipherEntry.SaltGenerator)

fromClientErrCh := make(chan error)
go func() {
_, fromClientErr := ssr.WriteTo(tgtConn)
_, fromClientErr := io.Copy(tgtConn, clientConn)
if fromClientErr != nil {
// Drain to prevent a close in the case of a cipher error.
io.Copy(io.Discard, clientConn)
Expand All @@ -310,19 +308,58 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
tgtConn.CloseWrite()
fromClientErrCh <- fromClientErr
}()
_, fromTargetErr := ssw.ReadFrom(tgtConn)
_, fromTargetErr := io.Copy(clientConn, tgtConn)
// Send FIN to client.
clientConn.CloseWrite()
tgtConn.CloseRead()

fromClientErr := <-fromClientErrCh
if fromClientErr != nil {
return id, onet.NewConnectionError("ERR_RELAY_CLIENT", "Failed to relay traffic from client", fromClientErr)
return onet.NewConnectionError("ERR_RELAY_CLIENT", "Failed to relay traffic from client", fromClientErr)
}
if fromTargetErr != nil {
return id, onet.NewConnectionError("ERR_RELAY_TARGET", "Failed to relay traffic from target", fromTargetErr)
return onet.NewConnectionError("ERR_RELAY_TARGET", "Failed to relay traffic from target", fromTargetErr)
}
return id, nil
return nil
}

func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
// Set a deadline to receive the address to the target.
readDeadline := time.Now().Add(h.readTimeout)
if deadline, ok := ctx.Deadline(); ok {
outerConn.SetDeadline(deadline)
if deadline.Before(readDeadline) {
readDeadline = deadline
}
}
outerConn.SetReadDeadline(readDeadline)

id, innerConn, authErr := h.authenticate(outerConn, proxyMetrics)
if authErr != nil {
// Drain to protect against probing attacks.
h.absorbProbe(listenerPort, outerConn, authErr.Status, proxyMetrics)
return id, authErr
}

// Read target address and dial it.
tgtAddr, err := getProxyRequest(innerConn)
// Clear the deadline for the target address
outerConn.SetReadDeadline(time.Time{})
if err != nil {
// Drain to prevent a close on cipher error.
io.Copy(io.Discard, outerConn)
return id, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err)
}

dialer := transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) {
tgtConn, err := h.dialer.DialStream(ctx, tgtAddr)
if err != nil {
return nil, err
}
tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
return tgtConn, nil
})
return id, proxyConnection(ctx, dialer, tgtAddr, innerConn)
}

// Keep the connection open until we hit the authentication deadline to protect against probing attacks
Expand Down

0 comments on commit b9cb68e

Please sign in to comment.