diff --git a/forward_test.go b/forward_test.go index 6aa07cd..b8c1e8c 100644 --- a/forward_test.go +++ b/forward_test.go @@ -323,3 +323,45 @@ func Test_L2L_0(t *testing.T) { time.Sleep(time.Second) as.Equal(birdge.ConnNum(), 0) } + +func Test_L2L_1(t *testing.T) { + as := assert.New(t, true) + + ll := new(L2L) + ll.MaxConn(0) + ll.KeptIdeConn(1) + defer ll.Close() + + birdge, err := ll.Transport(addra, addrb) + as.NotError(err) + defer birdge.Close() + + addra.Local = ll.alisten.Addr() + addrb.Local = ll.blisten.Addr() + + go func() { + defer birdge.Close() + + conna, err := net.Dial(addra.Network, addra.Local.String()) + as.NotError(err) + defer conna.Close() + + connb, err := net.Dial(addrb.Network, addrb.Local.String()) + as.NotError(err) + defer connb.Close() + + // 发送 + b := make([]byte, 10) + rand.Reader.Read(b) + conna.Write(b) + + connb.SetReadDeadline(time.Now().Add(time.Second)) + p := make([]byte, 10) + n, err := connb.Read(p) + as.NotError(err).Equal(b[:n], p[:n]) + }() + + birdge.Swap() + time.Sleep(time.Second) + as.Equal(birdge.ConnNum(), 0) +} diff --git a/l2l.go b/l2l.go index 4ab8a0d..fa33e8a 100644 --- a/l2l.go +++ b/l2l.go @@ -74,7 +74,7 @@ func (T *L2LSwap) Swap() error { // T.ll.logf("%s 池中读取连接错误: %s", T.ll.blisten.Addr().String(), err) atomic.AddInt32(&T.ll.currUseConn, -2) // 重新进池 - err = T.ll.acp.Put(conna, conna.LocalAddr()) + err = T.ll.acp.Put(conna, T.ll.alisten.Addr()) if err != nil { // T.ll.logf("%s 连接加入 %s 池中读取连接错误: %s", conna.RemoteAddr().String(), conna.LocalAddr().String(), err) conna.Close() @@ -227,11 +227,11 @@ func (T *L2L) bufConn(l net.Listener, cp *vconnpool.ConnPool, verify *func(net.C } tempDelay = 0 - go T.examineConn(conn, verify, cp) + go T.examineConn(conn, l.Addr(), verify, cp) } } -func (T *L2L) examineConn(conn net.Conn, verify *func(net.Conn) bool, cp *vconnpool.ConnPool) { +func (T *L2L) examineConn(conn net.Conn, addr net.Addr, verify *func(net.Conn) bool, cp *vconnpool.ConnPool) { // 连接最大限制,正在使用+池中空闲 if cp.MaxConn != 0 && T.currUseConns()+cp.ConnNum() >= cp.MaxConn { // T.logf("%s 池中数量达到最大 %s 连接不能入池", conn.LocalAddr().String(), conn.RemoteAddr().String()) @@ -245,7 +245,7 @@ func (T *L2L) examineConn(conn net.Conn, verify *func(net.Conn) bool, cp *vconnp return } - if err := cp.Put(conn, conn.LocalAddr()); err != nil { + if err := cp.Put(conn, addr); err != nil { // 池中受最大连接限制,无法加入池中。 // T.logf("%s 连接加入 %s 池中读取连接错误: %s", conn.RemoteAddr().String(), conn.LocalAddr().String(), err) conn.Close()