diff --git a/port-forwarder/builder.go b/port-forwarder/builder.go index f88615264..906069e9a 100644 --- a/port-forwarder/builder.go +++ b/port-forwarder/builder.go @@ -152,6 +152,13 @@ func (s *PortForwardingService) ReloadConfigAndApplyChanges( return err } + return s.ApplyChangesByNewFwdList(&pflNew) +} + +func (s *PortForwardingService) ApplyChangesByNewFwdList( + pflNew *PortForwardingList, +) error { + to_be_closed := []string{} for old := range s.configPortForwardings { _, corresponding_new_exists := pflNew.configPortForwardings[old] diff --git a/port-forwarder/port_forwarder_tcp_test.go b/port-forwarder/port_forwarder_tcp_test.go index f22633198..2f86c37a4 100644 --- a/port-forwarder/port_forwarder_tcp_test.go +++ b/port-forwarder/port_forwarder_tcp_test.go @@ -27,17 +27,36 @@ func doTestTcpCommunication( assert.Equal(t, data_sent, buf[:n]) } +func doTestTcpCommunicationFail( + t *testing.T, + msg string, + senderConn net.Conn, + receiverConn net.Conn, +) { + data_sent := []byte(msg) + n, err := senderConn.Write(data_sent) + if err != nil { + return + } + assert.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + + buf := make([]byte, 100) + _, err = receiverConn.Read(buf) + assert.NotNil(t, err) +} + func TestTcpInOut2Clients(t *testing.T) { l := logrus.New() - server, client := service.CreateTwoConnectedServices() + server, client := service.CreateTwoConnectedServices(4247) defer client.Close() defer server.Close() server_pf, err := createPortForwarderFromConfigString(l, server, ` port_forwarding: inbound: - - listen_port: 4499 - dial_address: 127.0.0.1:5599 + - listen_port: 4495 + dial_address: 127.0.0.1:5595 protocols: [tcp] `) assert.Nil(t, err) @@ -47,17 +66,17 @@ port_forwarding: client_pf, err := createPortForwarderFromConfigString(l, client, ` port_forwarding: outbound: - - listen_address: 127.0.0.1:3399 - dial_address: 10.0.0.1:4499 + - listen_address: 127.0.0.1:3395 + dial_address: 10.0.0.1:4495 protocols: [tcp] `) assert.Nil(t, err) assert.Len(t, client_pf.portForwardings, 1) - client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3399") + client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3395") assert.Nil(t, err) - server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5599") + server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5595") assert.Nil(t, err) server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) @@ -87,3 +106,183 @@ port_forwarding: client2_conn, client2_server_side_conn) } + +func TestTcpInOut1ClientConfigReload(t *testing.T) { + l := logrus.New() + server, client := service.CreateTwoConnectedServices(4246) + defer client.Close() + defer server.Close() + + server_pf, err := createPortForwarderFromConfigString(l, server, ` +port_forwarding: + inbound: + - listen_port: 4497 + dial_address: 127.0.0.1:5597 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + client_pf, err := createPortForwarderFromConfigString(l, client, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3397 + dial_address: 10.0.0.1:4497 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, client_pf.portForwardings, 1) + + client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3397") + assert.Nil(t, err) + server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5597") + assert.Nil(t, err) + + server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) + assert.Nil(t, err) + client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) + assert.Nil(t, err) + client1_server_side_conn, err := server_listen_conn.Accept() + assert.Nil(t, err) + + doTestTcpCommunication(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunication(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) + doTestTcpCommunication(t, "Hello from server third side!", + client1_server_side_conn, client1_conn) + + doTestTcpCommunication(t, "Hello from client one side AGAIN!", + client1_conn, client1_server_side_conn) + + new_server_fwd_list, err := loadPortFwdConfigFromString(l, ` +port_forwarding: + inbound: + - listen_port: 4496 + dial_address: 127.0.0.1:5596 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + new_client_fwd_list, err := loadPortFwdConfigFromString(l, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3396 + dial_address: 10.0.0.1:4496 + protocols: [tcp] +`) + assert.Nil(t, err) + + err = client_pf.ApplyChangesByNewFwdList(new_client_fwd_list) + assert.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) + + err = server_pf.ApplyChangesByNewFwdList(new_server_fwd_list) + assert.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) +} + +func TestTcpInOut1ClientConfigReload_inverseCloseOrder(t *testing.T) { + l := logrus.New() + server, client := service.CreateTwoConnectedServices(4245) + defer client.Close() + defer server.Close() + + server_pf, err := createPortForwarderFromConfigString(l, server, ` +port_forwarding: + inbound: + - listen_port: 4499 + dial_address: 127.0.0.1:5599 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + client_pf, err := createPortForwarderFromConfigString(l, client, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 10.0.0.1:4499 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, client_pf.portForwardings, 1) + + client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3399") + assert.Nil(t, err) + server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5599") + assert.Nil(t, err) + + server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) + assert.Nil(t, err) + client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) + assert.Nil(t, err) + client1_server_side_conn, err := server_listen_conn.Accept() + assert.Nil(t, err) + + doTestTcpCommunication(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunication(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) + doTestTcpCommunication(t, "Hello from server third side!", + client1_server_side_conn, client1_conn) + + doTestTcpCommunication(t, "Hello from client one side AGAIN!", + client1_conn, client1_server_side_conn) + + new_server_fwd_list, err := loadPortFwdConfigFromString(l, ` +port_forwarding: + inbound: + - listen_port: 4498 + dial_address: 127.0.0.1:5598 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + new_client_fwd_list, err := loadPortFwdConfigFromString(l, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3398 + dial_address: 10.0.0.1:4498 + protocols: [tcp] +`) + assert.Nil(t, err) + + err = server_pf.ApplyChangesByNewFwdList(new_server_fwd_list) + assert.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) + + err = client_pf.ApplyChangesByNewFwdList(new_client_fwd_list) + assert.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) +} diff --git a/port-forwarder/port_forwarder_udp_test.go b/port-forwarder/port_forwarder_udp_test.go index 1768f8808..27ea5800d 100644 --- a/port-forwarder/port_forwarder_udp_test.go +++ b/port-forwarder/port_forwarder_udp_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" ) -func createPortForwarderFromConfigString(l *logrus.Logger, srv *service.Service, configStr string) (*PortForwardingService, error) { +func loadPortFwdConfigFromString(l *logrus.Logger, configStr string) (*PortForwardingList, error) { c := config.NewC(l) err := c.LoadString(configStr) if err != nil { @@ -23,7 +23,17 @@ func createPortForwarderFromConfigString(l *logrus.Logger, srv *service.Service, return nil, err } - pf, err := ConstructFromInitialFwdList(srv, l, &fwd_list) + return &fwd_list, nil +} + +func createPortForwarderFromConfigString(l *logrus.Logger, srv *service.Service, configStr string) (*PortForwardingService, error) { + + fwd_list, err := loadPortFwdConfigFromString(l, configStr) + if err != nil { + return nil, err + } + + pf, err := ConstructFromInitialFwdList(srv, l, fwd_list) if err != nil { return nil, err } @@ -64,7 +74,7 @@ func doTestUdpCommunication( func TestUdpInOut2Clients(t *testing.T) { l := logrus.New() - server, client := service.CreateTwoConnectedServices() + server, client := service.CreateTwoConnectedServices(4244) defer client.Close() defer server.Close() diff --git a/service/service_test.go b/service/service_test.go index da327af42..b9098c34f 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -10,7 +10,7 @@ import ( ) func TestService(t *testing.T) { - a, b := CreateTwoConnectedServices() + a, b := CreateTwoConnectedServices(4243) ln, err := a.Listen("tcp", ":1234") if err != nil { diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go index 01d226694..28661865b 100644 --- a/service/service_testhelpers.go +++ b/service/service_testhelpers.go @@ -1,6 +1,7 @@ package service import ( + "fmt" "net/netip" "time" @@ -75,7 +76,7 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, return s } -func CreateTwoConnectedServices() (*Service, *Service) { +func CreateTwoConnectedServices(port int) (*Service, *Service) { ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, @@ -84,12 +85,12 @@ func CreateTwoConnectedServices() (*Service, *Service) { }, "listen": m{ "host": "0.0.0.0", - "port": 4243, + "port": port, }, }) b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ "static_host_map": m{ - "10.0.0.1": []string{"localhost:4243"}, + "10.0.0.1": []string{fmt.Sprintf("localhost:%d", port)}, }, "lighthouse": m{ "hosts": []string{"10.0.0.1"},