Skip to content

Commit

Permalink
add semaphore instead workers
Browse files Browse the repository at this point in the history
  • Loading branch information
z0rr0 committed Sep 19, 2024
1 parent 5c5fad9 commit c264e44
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 85 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG GOLANG_VERSION="1.22.1"
ARG GOLANG_VERSION="1.23.1"

FROM golang:${GOLANG_VERSION}-alpine as builder
ARG LDFLAGS
Expand All @@ -7,7 +7,7 @@ COPY . .
RUN echo "LDFLAGS = $LDFLAGS"
RUN GOOS=linux go build -ldflags "$LDFLAGS" -o ./gsocks5

FROM alpine:3.19
FROM alpine:3.20
LABEL org.opencontainers.image.authors="[email protected]" \
org.opencontainers.image.url="https://hub.docker.com/r/z0rr0/gsocks5" \
org.opencontainers.image.documentation="https://github.com/z0rr0/gsocks5" \
Expand Down
12 changes: 6 additions & 6 deletions args/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ const (
mintPort = 1
maxPort = 65535

minConcurrent = 1
maxConcurrent = 10_000
minConcurrent uint64 = 1
maxConcurrent uint64 = 1_000_000
)

// IsFile checks that the value is a file.
Expand Down Expand Up @@ -45,8 +45,8 @@ func IsPort(value string, result *uint16) error {
}

// IsConcurrent checks that the value is a valid number of concurrent connections.
func IsConcurrent(value string, result *int) error {
integer, err := strconv.Atoi(value)
func IsConcurrent(value string, result *uint32) error {
integer, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return err
}
Expand All @@ -55,7 +55,7 @@ func IsConcurrent(value string, result *int) error {
return fmt.Errorf("value is out of range")
}

*result = integer
*result = uint32(integer)
return nil
}

Expand All @@ -65,7 +65,7 @@ func PortDescription(value uint16) string {
}

// ConcurrentDescription returns a description of the concurrent argument.
func ConcurrentDescription(value int) string {
func ConcurrentDescription(value uint32) string {
return fmt.Sprintf(
"number of concurrent connections in range [%d, %d] (default %d)",
minConcurrent, maxConcurrent, value,
Expand Down
12 changes: 6 additions & 6 deletions args/args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,27 +89,27 @@ func TestIsConcurrent(t *testing.T) {
testCases := []struct {
name string
value string
want int
want uint64
wantErr bool
}{
{name: "ValidConcurrent", value: "100", want: 100},
{name: "TooLowConcurrent", value: "0", wantErr: true},
{name: "TooHighConcurrent", value: "100001", wantErr: true},
{name: "TooHighConcurrent", value: "4294967296", wantErr: true},
{name: "NonNumericConcurrent", value: "abc", wantErr: true},
}

for i := range testCases {
tc := testCases[i]

t.Run(tc.name, func(t *testing.T) {
var result int
var result uint32
err := IsConcurrent(tc.value, &result)

if (err != nil) != tc.wantErr {
t.Errorf("IsConcurrent() error = %v, wantErr %v", err, tc.wantErr)
return
}
if result != tc.want {
if uint64(result) != tc.want {
t.Errorf("IsConcurrent() = %v, want %v", result, tc.want)
}
})
Expand All @@ -126,8 +126,8 @@ func TestPortDescription(t *testing.T) {
}

func TestConcurrentDescription(t *testing.T) {
result := ConcurrentDescription(100)
expected := "number of concurrent connections in range [1, 10000] (default 100)"
result := ConcurrentDescription(10000)
expected := "number of concurrent connections in range [1, 1000000] (default 10000)"

if result != expected {
t.Errorf("ConcurrentDescription() = %v, want %v", result, expected)
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module github.com/z0rr0/gsocks5

go 1.22
go 1.23

require (
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
golang.org/x/net v0.22.0
golang.org/x/net v0.29.0
)
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
2 changes: 1 addition & 1 deletion gsocks5.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func main() {
host string
version bool
debugMode bool
connections = 100
connections uint32 = 1024
port uint16 = 1080
timeout = 3 * time.Minute
timeoutDNS = 5 * time.Second
Expand Down
108 changes: 59 additions & 49 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ type Server struct {
// Params is a start parameters for the server.
type Params struct {
Addr string
Connections int
Done chan struct{}
Connections uint32
Done chan struct{} // only for testing
Sigint chan os.Signal
Timeout time.Duration
setReady sync.Once
Expand Down Expand Up @@ -59,49 +59,34 @@ func (s *Server) ListenAndServe(p *Params) error {
}()

done := make(chan struct{})
connections, err := s.listen(ctx, p, done)
connections, semaphore, err := s.listen(ctx, p, done)
if err != nil {
return err
}

s.logDebug.Printf("listener started on %s", p.Addr)
p.Ready()
s.startWorkers(p, connections)
go s.start(p, connections, semaphore)

return s.waitClose(p, done)
}

// accept accepts a new connection.
func (s *Server) accept(listener net.Listener, p *Params) (net.Conn, error) {
conn, err := listener.Accept()
if err != nil {
return nil, fmt.Errorf("failed to accept connection: %w", err)
}

if p.Timeout > 0 {
if err = conn.SetReadDeadline(time.Now().Add(p.Timeout)); err != nil {
return nil, fmt.Errorf("failed to set deadline for connection: %w", err)
}
}

s.logDebug.Printf("accepted connection from %s with timeout %v", conn.RemoteAddr().String(), p.Timeout)
return conn, nil
}

// listen starts goroutine to accept incoming connections and sends them to a returned channel.
func (s *Server) listen(ctx context.Context, p *Params, done chan<- struct{}) (<-chan net.Conn, error) {
func (s *Server) listen(ctx context.Context, p *Params, done chan<- struct{}) (<-chan net.Conn, <-chan struct{}, error) {
var lc net.ListenConfig

listener, err := lc.Listen(ctx, "tcp", p.Addr)
if err != nil {
return nil, fmt.Errorf("failed to listen on %s: %w", p.Addr, err)
return nil, nil, fmt.Errorf("failed to listen on %s: %w", p.Addr, err)
}

p.listener = listener // to close it later
connections := make(chan net.Conn)
semaphore := make(chan struct{}, p.Connections)

go func() {
for {
semaphore <- struct{}{} // limit connections
if conn, e := s.accept(listener, p); e != nil {
if errors.Is(e, net.ErrClosed) {
break
Expand All @@ -114,38 +99,63 @@ func (s *Server) listen(ctx context.Context, p *Params, done chan<- struct{}) (<

s.logDebug.Printf("listener stopped")
close(connections) // finish workers
close(semaphore) // no new incoming connections
close(done)
}()

return connections, nil
return connections, semaphore, nil
}

// startWorkers starts workers to handle incoming connections.
func (s *Server) startWorkers(p *Params, connections <-chan net.Conn) {
for i := 0; i < p.Connections; i++ {
go func() {
var (
client string
err error
t time.Time
)

for conn := range connections {
p.wg.Add(1)

t = time.Now()
client = conn.RemoteAddr().String()
s.logDebug.Printf("accepted connection from %s", client)

if err = s.S.ServeConn(conn); err != nil {
s.logInfo.Printf("failed to serve connection from client %q: %v", client, err)
} else {
s.logDebug.Printf("connection served from %s during %v", client, time.Since(t))
}
// accept accepts a new connection.
func (s *Server) accept(listener net.Listener, p *Params) (net.Conn, error) {
conn, err := listener.Accept()
if err != nil {
return nil, fmt.Errorf("failed to accept connection: %w", err)
}

if p.Timeout > 0 {
if err = conn.SetReadDeadline(time.Now().Add(p.Timeout)); err != nil {
return nil, fmt.Errorf("failed to set deadline for connection: %w", err)
}
}

s.logDebug.Printf("accepted connection from %s with timeout %v", conn.RemoteAddr().String(), p.Timeout)
return conn, nil
}

p.wg.Done()
// start starts workers to handle incoming connections.
func (s *Server) start(p *Params, connections <-chan net.Conn, semaphore <-chan struct{}) {
for conn := range connections {
go s.handle(p, conn, semaphore)
}
s.logInfo.Printf("finished connections handling cycle")
}

func (s *Server) handle(p *Params, conn net.Conn, semaphore <-chan struct{}) {
var (
t = time.Now()
client = conn.RemoteAddr().String()
err error
)
p.wg.Add(1)

s.logDebug.Printf("accepted connection from %s", client)
defer func() {
<-semaphore // release the limitation
if closeErr := conn.Close(); closeErr != nil {
if !errors.Is(closeErr, net.ErrClosed) {
s.logInfo.Printf("failed to close connection from client %q: %v", client, closeErr)
} else {
s.logDebug.Printf("connection from %s is closed", client)
}
}()
}
p.wg.Done()
}()

if err = s.S.ServeConn(conn); err != nil {
s.logInfo.Printf("failed to serve connection from client %q: %v", client, err)
} else {
s.logDebug.Printf("connection served from %s during %v", client, time.Since(t))
}
}

Expand All @@ -161,7 +171,7 @@ func (s *Server) waitClose(p *Params, done <-chan struct{}) error {

<-done // wait for listener accept was stopped
p.wg.Wait() // wait for all connections to be handled

s.logInfo.Printf("all connections are handled")

return nil
}
71 changes: 56 additions & 15 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ import (
"golang.org/x/net/proxy"
)

var logger = log.New(os.Stdout, "[test] ", log.LstdFlags|log.Lshortfile)
const timeout = 2 * time.Second

var logger = log.New(os.Stdout, "[test] ", log.LstdFlags|log.Lshortfile|log.Lmicroseconds)

func run(t *testing.T, s *Server, i, port int, isErr bool) (string, chan os.Signal) {
params := &Params{
Addr: net.JoinHostPort("localhost", strconv.Itoa(port)),
Connections: 1,
Done: make(chan struct{}),
Sigint: make(chan os.Signal),
Timeout: time.Second,
Timeout: timeout,
}

go func() {
Expand All @@ -35,19 +37,55 @@ func run(t *testing.T, s *Server, i, port int, isErr bool) (string, chan os.Sign
return params.Addr, params.Sigint
}

type testHost struct {
host string
port int
close bool
}

func TestNew(t *testing.T) {
cases := []struct {
name string
port int
host string
err bool
name string
port int
hosts []testHost
err bool
}{
{name: "empty", port: 1080, host: "github.com:443"},
{name: "one", port: 1080, hosts: []testHost{{host: "github.com", port: 443}}},
{
name: "two",
port: 1080,
hosts: []testHost{
{host: "github.com", port: 443},
{host: "leetcode.com", port: 443, close: true},
},
},
{
name: "three",
port: 1080,
hosts: []testHost{
{host: "github.com", port: 443, close: true},
{host: "leetcode.com", port: 443, close: true},
{host: "leetcode.com", port: 80},
},
},
{
name: "many",
port: 1080,
hosts: []testHost{
{host: "github.com", port: 443, close: true},
{host: "github.com", port: 80},
{host: "leetcode.com", port: 443, close: true},
{host: "leetcode.com", port: 80},
},
},
{name: "badPort", port: 131072, err: true},
}
for i, c := range cases {
t.Run(c.name, func(tt *testing.T) {
cfg := &socks5.Config{Logger: logger}
var (
conn net.Conn
cfg = &socks5.Config{Logger: logger}
)
s, err := New(cfg, logger, logger)
if err != nil {
tt.Errorf("case [%d] %s: unexpected error: %v", i, c.name, err)
Expand All @@ -65,15 +103,18 @@ func TestNew(t *testing.T) {
tt.Errorf("case [%d] %s: unexpected error: %v", i, c.name, err)
}

conn, err := dialer.Dial("tcp", c.host)
if err != nil {
tt.Fatalf("set connection, case [%d] %s: %v", i, c.name, err)
}
for _, h := range c.hosts {
conn, err = dialer.Dial("tcp", net.JoinHostPort(h.host, strconv.Itoa(h.port)))
if err != nil {
tt.Errorf("case [%d] %s: %v", i, c.name, err)
}

if err = conn.Close(); err != nil {
tt.Errorf("close connection, case [%d] %s: %v", i, c.name, err)
if h.close {
if err = conn.Close(); err != nil {
tt.Errorf("case [%d] %s: %v", i, c.name, err)
}
}
}

sigint <- os.Interrupt
})
}
Expand Down

0 comments on commit c264e44

Please sign in to comment.