diff --git a/TODO b/TODO index 36e6c83..2b94512 100644 --- a/TODO +++ b/TODO @@ -1,7 +1,7 @@ - [x] Hook API - [x] Unit Tests - [x] Code Report -- [ ] Better Coverage -- [ ] Benchmark +- [x] Better Coverage +- [x] Benchmark - [x] Register arguments: maxConn, maxServeNum, maxServeTime - [x] Active user status \ No newline at end of file diff --git a/client.go b/client.go index eb69fff..ea8b728 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "sort" "sync" @@ -126,7 +127,12 @@ func register(conn net.Conn, info RegisterInfo, ctrl ControlInfo) error { // SetLogger Set customized logrus logger for the inner client. func (c *ProxyLiteClient) SetLogger(logger *log.Logger) { - c.logger = logger + if logger == nil { + c.logger = log.New() + c.logger.SetOutput(io.Discard) + } else { + c.logger = logger + } } // RegisterInnerService Register inner server to proxy server's outer port. diff --git a/client_test.go b/client_test.go deleted file mode 100644 index ae7819f..0000000 --- a/client_test.go +++ /dev/null @@ -1 +0,0 @@ -package proxylite diff --git a/context.go b/context.go index f6c7252..4f56fd3 100644 --- a/context.go +++ b/context.go @@ -28,7 +28,7 @@ func makeContext(tn *tunnel, user *net.Conn, data []byte, kvMap *sync.Map) *Cont } func (ctx *Context) AbortTunnel() error { - if ctx.tn.innerConn == nil { + if ctx.tn == nil || ctx.tn.innerConn == nil { return errors.New("cannot abort service because inner connection not exists") } return (*ctx.tn.innerConn).Close() diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..b1b2b40 --- /dev/null +++ b/context_test.go @@ -0,0 +1,37 @@ +package proxylite + +import "testing" + +func TestContextNil(t *testing.T) { + data := []byte{'a'} + ctx := makeContext(nil, nil, data, nil) + if ctx.AbortTunnel() == nil { + t.Error("AbortTunnel") + } + if ctx.AbortUser() == nil { + t.Error("AbortUser") + } + if ctx.DataBuffer() == nil { + t.Error("DataBuffer") + } + + blankInfo := ServiceInfo{} + if ctx.ServiceInfo() != blankInfo { + t.Error("ServiceInfo") + } + if ctx.UserLocalAddress() != nil { + t.Error("UserLocalAddress") + } + if ctx.UserRemoteAddress() != nil { + t.Error("UserRemoteAddress") + } + if ctx.InnerLocalConn() != nil { + t.Error("InnerLocalConn") + } + if ctx.InnerRemoteConn() != nil { + t.Error("InnerRemoteConn") + } + if _, ok := ctx.GetValue("k"); ok { + t.Error("GetValue") + } +} diff --git a/proxylite_bench_test.go b/proxylite_bench_test.go new file mode 100644 index 0000000..6890412 --- /dev/null +++ b/proxylite_bench_test.go @@ -0,0 +1,95 @@ +package proxylite + +import ( + "net" + "strings" + "testing" + "time" +) + +var msg = strings.Repeat("hello", 100) + +func init() { + if err := newTcpEchoServer(":9969", len(msg)); err != nil { + panic(err) + } +} + +func BenchmarkProxylite(b *testing.B) { + proxyServer := NewProxyLiteServer() + proxyServer.SetLogger(nil) + proxyServer.AddPort(9968, 9968) + go func() { + if err := proxyServer.Run(":9967"); err != nil { + panic(err) + } + }() + time.Sleep(time.Millisecond * 10) + + innerClient := NewProxyLiteClient(":9967") + innerClient.SetLogger(nil) + + cancelFunc, done, err := innerClient.RegisterInnerService( + RegisterInfo{ + OuterPort: 9968, + InnerAddr: ":9969", + Name: "Echo", + Message: "TCP Echo Server", + }, + ControlInfo{}, + ) + if err != nil { + panic(err) + } + defer func() { + cancelFunc() + <-done + proxyServer.Stop() + }() + time.Sleep(time.Millisecond * 10) + + user, err := net.Dial("tcp", ":9968") + if err != nil { + b.Fatal(err) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + err := written(user, []byte(msg), len(msg)) + if err != nil { + b.Error("write 1, ", err) + } + _, err = readn(user, len(msg)) + if err != nil { + b.Error("read 1, ", err) + } + } + b.StopTimer() + user.Close() + time.Sleep(time.Millisecond * 100) + select { + case <-done: + b.Error("unexpected quit") + default: + } +} + +func BenchmarkNoProxy(b *testing.B) { + user, err := net.Dial("tcp", ":9969") + if err != nil { + b.Fail() + } + b.StartTimer() + for i := 0; i < b.N; i++ { + err := written(user, []byte(msg), len(msg)) + if err != nil { + b.Error("write 1, ", err) + } + _, err = readn(user, len(msg)) + if err != nil { + b.Error("read 1, ", err) + } + } + b.StopTimer() + user.Close() +} diff --git a/proxylite_test.go b/proxylite_test.go index e97ea63..c445c4a 100644 --- a/proxylite_test.go +++ b/proxylite_test.go @@ -75,13 +75,15 @@ func init() { } func TestBasicUsage(t *testing.T) { - logger := logrus.New() logger.Level = logrus.FatalLevel - proxyServer := NewProxyLiteServer() + proxyServer := NewProxyLiteServer([2]int{9900, 9910}) proxyServer.SetLogger(logger) proxyServer.AddPort(9968, 9968) + if proxyServer.AddPort(99968, 99968) == true { + t.Error("AddPort") + } go func() { if err := proxyServer.Run(":9967"); err != nil { panic(err) @@ -91,6 +93,16 @@ func TestBasicUsage(t *testing.T) { innerClient := NewProxyLiteClient(":9967") innerClient.SetLogger(logger) + + ports, ok := innerClient.AvailablePorts() + if !ok || len(ports) != 12 { + t.Errorf("AvailablePorts") + } + port, ok := innerClient.AnyPort() + if !ok || port < 9900 || (port > 9910 && port != 9968) { + t.Errorf("AnyPort") + } + cancelFunc, done, err := innerClient.RegisterInnerService( RegisterInfo{ OuterPort: 9968, @@ -136,6 +148,11 @@ func TestBasicUsage(t *testing.T) { t.Error("unexpected quit") default: } + + services, err := innerClient.ActiveServices() + if err != nil || len(services) != 1 { + t.Error("ActiveServices") + } } func TestCancel(t *testing.T) { @@ -198,7 +215,7 @@ func TestCancel(t *testing.T) { select { case <-done: default: - t.Error("unexpected quit") + t.Error("unexpected continue") } <-done } @@ -272,7 +289,7 @@ func TestMultiplex(t *testing.T) { time.Sleep(time.Millisecond * 100) select { case <-done: - t.Error("unexpected quit") + t.Error("unexpected continue") default: } @@ -359,7 +376,7 @@ func TestMultiplexMaxTimeControl(t *testing.T) { case <-done: // must done default: - t.Error("unexpected quit") + t.Error("unexpected continue") } } @@ -743,3 +760,159 @@ func TestHookContext(t *testing.T) { default: } } + +func TestHookContextAbortUser(t *testing.T) { + + logger := logrus.New() + logger.Level = logrus.FatalLevel + + proxyServer := NewProxyLiteServer() + proxyServer.SetLogger(logger) + proxyServer.AddPort(9968, 9968) + + cnt := 0 + + proxyServer.OnForwardTunnelToUser(func(ctx *Context) { + cnt++ + if cnt == 2 { + ctx.AbortUser() + } + }) + + go func() { + if err := proxyServer.Run(":9967"); err != nil { + panic(err) + } + }() + time.Sleep(time.Millisecond * 10) + + innerClient := NewProxyLiteClient(":9967") + innerClient.SetLogger(logger) + cancelFunc, done, err := innerClient.RegisterInnerService( + RegisterInfo{ + OuterPort: 9968, + InnerAddr: ":9966", + Name: "Echo", + Message: "TCP Echo Server", + }, + ControlInfo{}, + ) + if err != nil { + panic(err) + } + defer func() { + cancelFunc() + <-done + time.Sleep(time.Millisecond * 10) // wait close + proxyServer.Stop() + if cnt != 2 { + t.Error("hook does not work well") + } + }() + time.Sleep(time.Millisecond * 10) + + user, err := net.Dial("tcp", ":9968") + if err != nil { + t.Fatal(err) + } + msg := "hello123" + var data []byte + for i := 0; i < 10; i++ { + err := written(user, []byte(msg), 8) + if err != nil { + break + } + data, err = readn(user, 8) + if err != nil { + break + } + if string(data) != msg { + break + } + } + user.Close() + time.Sleep(time.Millisecond * 100) + select { + case <-done: + t.Error("unexpected quit") + default: + } +} + +func TestHookContextAbortTunnel(t *testing.T) { + + logger := logrus.New() + logger.Level = logrus.FatalLevel + + proxyServer := NewProxyLiteServer() + proxyServer.SetLogger(logger) + proxyServer.AddPort(9968, 9968) + + cnt := 0 + + proxyServer.OnForwardTunnelToUser(func(ctx *Context) { + cnt++ + if cnt == 2 { + ctx.AbortTunnel() + } + }) + + go func() { + if err := proxyServer.Run(":9967"); err != nil { + panic(err) + } + }() + time.Sleep(time.Millisecond * 10) + + innerClient := NewProxyLiteClient(":9967") + innerClient.SetLogger(logger) + cancelFunc, done, err := innerClient.RegisterInnerService( + RegisterInfo{ + OuterPort: 9968, + InnerAddr: ":9966", + Name: "Echo", + Message: "TCP Echo Server", + }, + ControlInfo{}, + ) + if err != nil { + panic(err) + } + defer func() { + cancelFunc() + <-done + time.Sleep(time.Millisecond * 10) // wait close + proxyServer.Stop() + if cnt != 2 { + t.Error("hook does not work well") + } + }() + time.Sleep(time.Millisecond * 10) + + user, err := net.Dial("tcp", ":9968") + if err != nil { + t.Fatal(err) + } + msg := "hello123" + var data []byte + for i := 0; i < 10; i++ { + err := written(user, []byte(msg), 8) + if err != nil { + break + } + data, err = readn(user, 8) + if err != nil { + break + } + if string(data) != msg { + break + } + } + user.Close() + time.Sleep(time.Millisecond * 100) + select { + case <-done: + default: + t.Error("unexpected continue") + } +} diff --git a/server.go b/server.go index 04ccd63..cf96b46 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net" "sync" "sync/atomic" @@ -64,11 +65,6 @@ func (s *ProxyLiteServer) AddPort(from, to int) bool { return true } -// SetLogger Set customized logrus logger for the server. -func (s *ProxyLiteServer) SetLogger(logger *log.Logger) { - s.logger = logger -} - // Run Run the server and let it listen on given address. func (s *ProxyLiteServer) Run(addr string) error { var err error @@ -489,3 +485,13 @@ func (s *ProxyLiteServer) OnForwardTunnelToUser(f HookFunc) { func (s *ProxyLiteServer) OnForwardUserToTunnel(f HookFunc) { s.onForwardUserToTunnel = f } + +// SetLogger Set customized logrus logger for the server. +func (s *ProxyLiteServer) SetLogger(logger *log.Logger) { + if logger == nil { + s.logger = log.New() + s.logger.SetOutput(io.Discard) + } else { + s.logger = logger + } +}