From 54ad74be0deeb56903658995a08b79ce25750ef6 Mon Sep 17 00:00:00 2001 From: Matthew Sandoval Date: Mon, 28 Aug 2023 02:30:36 -0700 Subject: [PATCH] Unit tests for Ping.go inside netceptor (#810) Create interface for Ping & Packetconn and add unit tests to cover netceptor/ping.go Increase ping.go coverage from 0% to 92.5 --- pkg/netceptor/conn.go | 12 +- pkg/netceptor/mock_netceptor/packetconn.go | 208 +++++++++++ pkg/netceptor/mock_netceptor/ping.go | 161 +++++++++ pkg/netceptor/packetconn.go | 84 ++++- pkg/netceptor/ping.go | 35 +- pkg/netceptor/ping_test.go | 381 +++++++++++++++++++++ pkg/services/ip_router.go | 2 +- pkg/services/udp_proxy.go | 6 +- 8 files changed, 855 insertions(+), 34 deletions(-) create mode 100644 pkg/netceptor/mock_netceptor/packetconn.go create mode 100644 pkg/netceptor/mock_netceptor/ping.go create mode 100644 pkg/netceptor/ping_test.go diff --git a/pkg/netceptor/conn.go b/pkg/netceptor/conn.go index 1f2ec1397..0792c3ad5 100644 --- a/pkg/netceptor/conn.go +++ b/pkg/netceptor/conn.go @@ -36,7 +36,7 @@ type acceptResult struct { // Listener implements the net.Listener interface via the Receptor network. type Listener struct { s *Netceptor - pc *PacketConn + pc PacketConner ql quic.Listener acceptChan chan *acceptResult doneChan chan struct{} @@ -261,7 +261,7 @@ func (li *Listener) Addr() net.Addr { // Conn implements the net.Conn interface via the Receptor network. type Conn struct { s *Netceptor - pc *PacketConn + pc PacketConner qc quic.Connection qs quic.Stream doneChan chan struct{} @@ -380,7 +380,7 @@ func (s *Netceptor) DialContext(ctx context.Context, node string, service string // monitorUnreachable receives unreachable messages from the underlying PacketConn, and ends the connection // if the remote service has gone away. -func monitorUnreachable(pc *PacketConn, doneChan chan struct{}, remoteAddr Addr, cancel context.CancelFunc) { +func monitorUnreachable(pc PacketConner, doneChan chan struct{}, remoteAddr Addr, cancel context.CancelFunc) { msgCh := pc.SubscribeUnreachable(doneChan) if msgCh == nil { cancel() @@ -390,7 +390,7 @@ func monitorUnreachable(pc *PacketConn, doneChan chan struct{}, remoteAddr Addr, // read from channel until closed for msg := range msgCh { if msg.Problem == ProblemServiceUnknown && msg.ToNode == remoteAddr.node && msg.ToService == remoteAddr.service { - pc.s.Logger.Warning("remote service %s to node %s is unreachable", msg.ToService, msg.ToNode) + pc.GetLogger().Warning("remote service %s to node %s is unreachable", msg.ToService, msg.ToNode) cancel() } } @@ -421,11 +421,11 @@ func (c *Conn) Close() error { } func (c *Conn) CloseConnection() error { - c.pc.cancel() + c.pc.Cancel() c.doneOnce.Do(func() { close(c.doneChan) }) - c.s.Logger.Debug("closing connection from service %s to %s", c.pc.localService, c.RemoteAddr().String()) + c.s.Logger.Debug("closing connection from service %s to %s", c.pc.GetLocalService(), c.RemoteAddr().String()) return c.qc.CloseWithError(0, "normal close") } diff --git a/pkg/netceptor/mock_netceptor/packetconn.go b/pkg/netceptor/mock_netceptor/packetconn.go new file mode 100644 index 000000000..db461a8c5 --- /dev/null +++ b/pkg/netceptor/mock_netceptor/packetconn.go @@ -0,0 +1,208 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: pkg/netceptor/packetconn.go + +// Package mock_netceptor is a generated GoMock package. +package mock_netceptor + +import ( + context "context" + net "net" + reflect "reflect" + time "time" + + logger "github.com/ansible/receptor/pkg/logger" + netceptor "github.com/ansible/receptor/pkg/netceptor" + gomock "github.com/golang/mock/gomock" +) + +// MockPacketConner is a mock of PacketConner interface. +type MockPacketConner struct { + ctrl *gomock.Controller + recorder *MockPacketConnerMockRecorder +} + +// MockPacketConnerMockRecorder is the mock recorder for MockPacketConner. +type MockPacketConnerMockRecorder struct { + mock *MockPacketConner +} + +// NewMockPacketConner creates a new mock instance. +func NewMockPacketConner(ctrl *gomock.Controller) *MockPacketConner { + mock := &MockPacketConner{ctrl: ctrl} + mock.recorder = &MockPacketConnerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPacketConner) EXPECT() *MockPacketConnerMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockPacketConner) Cancel() *context.CancelFunc { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(*context.CancelFunc) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockPacketConnerMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockPacketConner)(nil).Cancel)) +} + +// Close mocks base method. +func (m *MockPacketConner) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockPacketConnerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketConner)(nil).Close)) +} + +// GetLocalService mocks base method. +func (m *MockPacketConner) GetLocalService() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLocalService") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetLocalService indicates an expected call of GetLocalService. +func (mr *MockPacketConnerMockRecorder) GetLocalService() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLocalService", reflect.TypeOf((*MockPacketConner)(nil).GetLocalService)) +} + +// GetLogger mocks base method. +func (m *MockPacketConner) GetLogger() *logger.ReceptorLogger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLogger") + ret0, _ := ret[0].(*logger.ReceptorLogger) + return ret0 +} + +// GetLogger indicates an expected call of GetLogger. +func (mr *MockPacketConnerMockRecorder) GetLogger() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogger", reflect.TypeOf((*MockPacketConner)(nil).GetLogger)) +} + +// LocalAddr mocks base method. +func (m *MockPacketConner) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockPacketConnerMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockPacketConner)(nil).LocalAddr)) +} + +// ReadFrom mocks base method. +func (m *MockPacketConner) ReadFrom(p []byte) (int, net.Addr, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadFrom", p) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(net.Addr) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ReadFrom indicates an expected call of ReadFrom. +func (mr *MockPacketConnerMockRecorder) ReadFrom(p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFrom", reflect.TypeOf((*MockPacketConner)(nil).ReadFrom), p) +} + +// SetDeadline mocks base method. +func (m *MockPacketConner) SetDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline. +func (mr *MockPacketConnerMockRecorder) SetDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPacketConner)(nil).SetDeadline), t) +} + +// SetHopsToLive mocks base method. +func (m *MockPacketConner) SetHopsToLive(hopsToLive byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetHopsToLive", hopsToLive) +} + +// SetHopsToLive indicates an expected call of SetHopsToLive. +func (mr *MockPacketConnerMockRecorder) SetHopsToLive(hopsToLive interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHopsToLive", reflect.TypeOf((*MockPacketConner)(nil).SetHopsToLive), hopsToLive) +} + +// SetReadDeadline mocks base method. +func (m *MockPacketConner) SetReadDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockPacketConnerMockRecorder) SetReadDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPacketConner)(nil).SetReadDeadline), t) +} + +// SetWriteDeadline mocks base method. +func (m *MockPacketConner) SetWriteDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline. +func (mr *MockPacketConnerMockRecorder) SetWriteDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPacketConner)(nil).SetWriteDeadline), t) +} + +// SubscribeUnreachable mocks base method. +func (m *MockPacketConner) SubscribeUnreachable(doneChan chan struct{}) chan netceptor.UnreachableNotification { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubscribeUnreachable", doneChan) + ret0, _ := ret[0].(chan netceptor.UnreachableNotification) + return ret0 +} + +// SubscribeUnreachable indicates an expected call of SubscribeUnreachable. +func (mr *MockPacketConnerMockRecorder) SubscribeUnreachable(doneChan interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeUnreachable", reflect.TypeOf((*MockPacketConner)(nil).SubscribeUnreachable), doneChan) +} + +// WriteTo mocks base method. +func (m *MockPacketConner) WriteTo(p []byte, addr net.Addr) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteTo", p, addr) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WriteTo indicates an expected call of WriteTo. +func (mr *MockPacketConnerMockRecorder) WriteTo(p, addr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockPacketConner)(nil).WriteTo), p, addr) +} diff --git a/pkg/netceptor/mock_netceptor/ping.go b/pkg/netceptor/mock_netceptor/ping.go new file mode 100644 index 000000000..76abdde29 --- /dev/null +++ b/pkg/netceptor/mock_netceptor/ping.go @@ -0,0 +1,161 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: pkg/netceptor/ping.go + +// Package mock_netceptor is a generated GoMock package. +package mock_netceptor + +import ( + context "context" + reflect "reflect" + time "time" + + netceptor "github.com/ansible/receptor/pkg/netceptor" + gomock "github.com/golang/mock/gomock" +) + +// MockNetcForPing is a mock of NetcForPing interface. +type MockNetcForPing struct { + ctrl *gomock.Controller + recorder *MockNetcForPingMockRecorder +} + +// MockNetcForPingMockRecorder is the mock recorder for MockNetcForPing. +type MockNetcForPingMockRecorder struct { + mock *MockNetcForPing +} + +// NewMockNetcForPing creates a new mock instance. +func NewMockNetcForPing(ctrl *gomock.Controller) *MockNetcForPing { + mock := &MockNetcForPing{ctrl: ctrl} + mock.recorder = &MockNetcForPingMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetcForPing) EXPECT() *MockNetcForPingMockRecorder { + return m.recorder +} + +// Context mocks base method. +func (m *MockNetcForPing) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockNetcForPingMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockNetcForPing)(nil).Context)) +} + +// ListenPacket mocks base method. +func (m *MockNetcForPing) ListenPacket(service string) (netceptor.PacketConner, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListenPacket", service) + ret0, _ := ret[0].(netceptor.PacketConner) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListenPacket indicates an expected call of ListenPacket. +func (mr *MockNetcForPingMockRecorder) ListenPacket(service interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListenPacket", reflect.TypeOf((*MockNetcForPing)(nil).ListenPacket), service) +} + +// NewAddr mocks base method. +func (m *MockNetcForPing) NewAddr(target, service string) netceptor.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewAddr", target, service) + ret0, _ := ret[0].(netceptor.Addr) + return ret0 +} + +// NewAddr indicates an expected call of NewAddr. +func (mr *MockNetcForPingMockRecorder) NewAddr(target, service interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewAddr", reflect.TypeOf((*MockNetcForPing)(nil).NewAddr), target, service) +} + +// NodeID mocks base method. +func (m *MockNetcForPing) NodeID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NodeID") + ret0, _ := ret[0].(string) + return ret0 +} + +// NodeID indicates an expected call of NodeID. +func (mr *MockNetcForPingMockRecorder) NodeID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeID", reflect.TypeOf((*MockNetcForPing)(nil).NodeID)) +} + +// MockNetcForTraceroute is a mock of NetcForTraceroute interface. +type MockNetcForTraceroute struct { + ctrl *gomock.Controller + recorder *MockNetcForTracerouteMockRecorder +} + +// MockNetcForTracerouteMockRecorder is the mock recorder for MockNetcForTraceroute. +type MockNetcForTracerouteMockRecorder struct { + mock *MockNetcForTraceroute +} + +// NewMockNetcForTraceroute creates a new mock instance. +func NewMockNetcForTraceroute(ctrl *gomock.Controller) *MockNetcForTraceroute { + mock := &MockNetcForTraceroute{ctrl: ctrl} + mock.recorder = &MockNetcForTracerouteMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetcForTraceroute) EXPECT() *MockNetcForTracerouteMockRecorder { + return m.recorder +} + +// Context mocks base method. +func (m *MockNetcForTraceroute) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockNetcForTracerouteMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockNetcForTraceroute)(nil).Context)) +} + +// MaxForwardingHops mocks base method. +func (m *MockNetcForTraceroute) MaxForwardingHops() byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MaxForwardingHops") + ret0, _ := ret[0].(byte) + return ret0 +} + +// MaxForwardingHops indicates an expected call of MaxForwardingHops. +func (mr *MockNetcForTracerouteMockRecorder) MaxForwardingHops() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaxForwardingHops", reflect.TypeOf((*MockNetcForTraceroute)(nil).MaxForwardingHops)) +} + +// Ping mocks base method. +func (m *MockNetcForTraceroute) Ping(ctx context.Context, target string, hopsToLive byte) (time.Duration, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Ping", ctx, target, hopsToLive) + ret0, _ := ret[0].(time.Duration) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Ping indicates an expected call of Ping. +func (mr *MockNetcForTracerouteMockRecorder) Ping(ctx, target, hopsToLive interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockNetcForTraceroute)(nil).Ping), ctx, target, hopsToLive) +} diff --git a/pkg/netceptor/packetconn.go b/pkg/netceptor/packetconn.go index 0ee64c4f3..70cbdf359 100644 --- a/pkg/netceptor/packetconn.go +++ b/pkg/netceptor/packetconn.go @@ -7,9 +7,25 @@ import ( "reflect" "time" + "github.com/ansible/receptor/pkg/logger" "github.com/ansible/receptor/pkg/utils" ) +type PacketConner interface { + SetHopsToLive(hopsToLive byte) + SubscribeUnreachable(doneChan chan struct{}) chan UnreachableNotification + ReadFrom(p []byte) (int, net.Addr, error) + WriteTo(p []byte, addr net.Addr) (n int, err error) + LocalAddr() net.Addr + Close() error + SetDeadline(t time.Time) error + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error + Cancel() *context.CancelFunc + GetLocalService() string + GetLogger() *logger.ReceptorLogger +} + // PacketConn implements the net.PacketConn interface via the Receptor network. type PacketConn struct { s *Netceptor @@ -25,9 +41,30 @@ type PacketConn struct { cancel context.CancelFunc } +func NewPacketConnWithConst(s *Netceptor, service string, advertise bool, adtags map[string]string, connTypeDatagram byte) *PacketConn { + npc := &PacketConn{ + s: s, + localService: service, + recvChan: make(chan *MessageData), + advertise: advertise, + adTags: adtags, + connType: connTypeDatagram, + hopsToLive: s.maxForwardingHops, + } + + npc.startUnreachable() + s.listenerRegistry[service] = npc + + return npc +} + +func NewPacketConn(s *Netceptor, service string, connTypeDatagram byte) *PacketConn { + return NewPacketConnWithConst(s, service, false, nil, connTypeDatagram) +} + // ListenPacket returns a datagram connection compatible with Go's net.PacketConn. // If service is blank, generates and uses an ephemeral service name. -func (s *Netceptor) ListenPacket(service string) (*PacketConn, error) { +func (s *Netceptor) ListenPacket(service string) (PacketConner, error) { if len(service) > 8 { return nil, fmt.Errorf("service name %s too long", service) } @@ -42,35 +79,46 @@ func (s *Netceptor) ListenPacket(service string) (*PacketConn, error) { return nil, fmt.Errorf("service %s is already listening", service) } _ = s.addNameHash(service) - pc := &PacketConn{ - s: s, - localService: service, - recvChan: make(chan *MessageData), - advertise: false, - adTags: nil, - connType: ConnTypeDatagram, - hopsToLive: s.maxForwardingHops, - } - pc.startUnreachable() - s.listenerRegistry[service] = pc + pc := NewPacketConn(s, service, ConnTypeDatagram) return pc, nil } // ListenPacketAndAdvertise returns a datagram listener, and also broadcasts service // advertisements to the Receptor network as long as the listener remains open. -func (s *Netceptor) ListenPacketAndAdvertise(service string, tags map[string]string) (*PacketConn, error) { - pc, err := s.ListenPacket(service) - if err != nil { - return nil, err +func (s *Netceptor) ListenPacketAndAdvertise(service string, tags map[string]string) (PacketConner, error) { + if len(service) > 8 { + return nil, fmt.Errorf("service name %s too long", service) + } + if service == "" { + service = s.getEphemeralService() } - pc.advertise = true - pc.adTags = tags + s.listenerLock.Lock() + defer s.listenerLock.Unlock() + _, isReserved := s.reservedServices[service] + _, isListening := s.listenerRegistry[service] + if isReserved || isListening { + return nil, fmt.Errorf("service %s is already listening and advertising", service) + } + pc := NewPacketConnWithConst(s, service, true, tags, ConnTypeDatagram) + s.addLocalServiceAdvertisement(service, ConnTypeDatagram, tags) return pc, nil } +func (pc *PacketConn) Cancel() *context.CancelFunc { + return &pc.cancel +} + +func (pc *PacketConn) GetLocalService() string { + return pc.localService +} + +func (pc *PacketConn) GetLogger() *logger.ReceptorLogger { + return pc.s.Logger +} + // startUnreachable starts monitoring the netceptor unreachable channel and forwarding relevant messages. func (pc *PacketConn) startUnreachable() { pc.context, pc.cancel = context.WithCancel(pc.s.context) diff --git a/pkg/netceptor/ping.go b/pkg/netceptor/ping.go index d1836c2db..af4acf021 100644 --- a/pkg/netceptor/ping.go +++ b/pkg/netceptor/ping.go @@ -7,8 +7,21 @@ import ( "time" ) -// Ping sends a single test packet and waits for a reply or error. +// NetcForPing should include all methods of Netceptor needed by the Ping function. +type NetcForPing interface { + ListenPacket(service string) (PacketConner, error) + NewAddr(target string, service string) Addr + NodeID() string + Context() context.Context +} + +// Ping calls SendPing to sends a single test packet and waits for a reply or error. func (s *Netceptor) Ping(ctx context.Context, target string, hopsToLive byte) (time.Duration, string, error) { + return SendPing(ctx, s, target, hopsToLive) +} + +// SendPing creates Ping by sending a single test packet and waits for a replay or error. +func SendPing(ctx context.Context, s NetcForPing, target string, hopsToLive byte) (time.Duration, string, error) { pc, err := s.ListenPacket("") if err != nil { return 0, "", err @@ -49,7 +62,7 @@ func (s *Netceptor) Ping(ctx context.Context, target string, hopsToLive byte) (t select { case replyChan <- fromNode: case <-ctxPing.Done(): - case <-s.context.Done(): + case <-s.Context().Done(): } } else { select { @@ -58,7 +71,7 @@ func (s *Netceptor) Ping(ctx context.Context, target string, hopsToLive byte) (t fromNode: fromNode, }: case <-ctx.Done(): - case <-s.context.Done(): + case <-s.Context().Done(): } } }() @@ -75,11 +88,17 @@ func (s *Netceptor) Ping(ctx context.Context, target string, hopsToLive byte) (t return time.Since(startTime), "", fmt.Errorf("timeout") case <-ctxPing.Done(): return time.Since(startTime), "", fmt.Errorf("user cancelled") - case <-s.context.Done(): + case <-s.Context().Done(): return time.Since(startTime), "", fmt.Errorf("netceptor shutdown") } } +type NetcForTraceroute interface { + MaxForwardingHops() byte + Ping(ctx context.Context, target string, hopsToLive byte) (time.Duration, string, error) + Context() context.Context +} + // TracerouteResult is the result of one hop of a traceroute. type TracerouteResult struct { From string @@ -87,8 +106,12 @@ type TracerouteResult struct { Err error } -// Traceroute returns a channel which will receive a series of hops between this node and the target. func (s *Netceptor) Traceroute(ctx context.Context, target string) <-chan *TracerouteResult { + return CreateTraceroute(ctx, s, target) +} + +// CreateTraceroute returns a channel which will receive a series of hops between this node and the target. +func CreateTraceroute(ctx context.Context, s NetcForTraceroute, target string) <-chan *TracerouteResult { results := make(chan *TracerouteResult) go func() { defer close(results) @@ -105,7 +128,7 @@ func (s *Netceptor) Traceroute(ctx context.Context, target string) <-chan *Trace case results <- res: case <-ctx.Done(): return - case <-s.context.Done(): + case <-s.Context().Done(): return } if res.Err != nil || err == nil { diff --git a/pkg/netceptor/ping_test.go b/pkg/netceptor/ping_test.go new file mode 100644 index 000000000..6ea48605e --- /dev/null +++ b/pkg/netceptor/ping_test.go @@ -0,0 +1,381 @@ +package netceptor_test + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/ansible/receptor/pkg/netceptor" + "github.com/ansible/receptor/pkg/netceptor/mock_netceptor" + "github.com/golang/mock/gomock" +) + +// setupTest sets up TestPing tests. +func setupTest(t *testing.T) (*gomock.Controller, *mock_netceptor.MockNetcForPing, *mock_netceptor.MockPacketConner) { + ctrl := gomock.NewController(t) + + // Prepare mocks + mockNetceptor := mock_netceptor.NewMockNetcForPing(ctrl) + mockPacketConn := mock_netceptor.NewMockPacketConner(ctrl) + + return ctrl, mockNetceptor, mockPacketConn +} + +// createChannel creates a channel that passes an error to errorChan inside of createPing. +func createChannel(mockPacketConn *mock_netceptor.MockPacketConner) { + mockUnreachableMessage := netceptor.UnreachableMessage{ + FromNode: "", + ToNode: "", + FromService: "", + ToService: "", + Problem: "test", + } + + mockUnreachableNotification := netceptor.UnreachableNotification{ + mockUnreachableMessage, + "test", + } + channel := make(chan netceptor.UnreachableNotification) + + mockPacketConn.EXPECT().SubscribeUnreachable(gomock.Any()).Return(channel) + go func() { + channel <- mockUnreachableNotification + }() +} + +// checkPing checks TestPing tests by comparing return values to expected values. +func checkPing(duration time.Duration, expectedDuration int, remote string, expectedRemote string, err error, expectedError error, t *testing.T) { + if expectedError == nil && err != nil { + t.Errorf("Expected no error, got: %v", err) + } else if expectedError != nil && (err == nil || err.Error() != expectedError.Error()) { + t.Errorf("Expected error: %s, got: %v", expectedError.Error(), err) + } + if expectedDuration != int(duration) && expectedDuration != 0 { + t.Errorf("Expected duration to be %v, got: %v", expectedDuration, duration) + } + if expectedRemote != remote && expectedRemote != "" { + t.Errorf("Expected remote to be %v, got: %v", expectedRemote, remote) + } +} + +func setupTestExpects(args ...interface{}) { + mockNetceptor := args[0].(*mock_netceptor.MockNetcForPing) + mockPacketConn := args[1].(*mock_netceptor.MockPacketConner) + testCase := args[2].(pingTestCaseStruct) + + testExpects := map[string]func(){ + "ListenPacketReturn": func() { + mockNetceptor.EXPECT().ListenPacket(gomock.Any()).Return(testCase.returnListenPacket.packetConn, testCase.returnListenPacket.err).Times(testCase.returnListenPacket.times) + }, + "SubscribeUnreachableReturn": func() { + mockPacketConn.EXPECT().SubscribeUnreachable(gomock.Any()).Return(make(chan netceptor.UnreachableNotification)) + }, + "WriteToReturn": func() { + mockPacketConn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Return(testCase.returnWriteTo.packetLen, testCase.returnWriteTo.err).Times(testCase.returnWriteTo.times) + }, + "ReadFromReturn": func() { + mockPacketConn.EXPECT().ReadFrom(gomock.Any()).Return(0, testCase.returnReadFrom.address, testCase.returnReadFrom.err).MaxTimes(testCase.returnReadFrom.times) + }, + "ReadFromDo": func() { + mockPacketConn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { + time.Sleep(time.Second * 11) + }).Times(testCase.returnReadFrom.times) + }, + "ReadFromDoAndReturn": func() { + mockPacketConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) { + time.Sleep(time.Second * 2) + + return 0, testCase.returnReadFrom.address, testCase.returnReadFrom.err + }).MaxTimes(testCase.returnReadFrom.times) + }, + "ContextReturn": func() { + mockNetceptor.EXPECT().Context().Return(testCase.returnContext.ctx).MaxTimes(testCase.returnContext.times) + }, + "ContextDoAndReturn": func() { + mockNetceptor.EXPECT().Context().DoAndReturn(func() context.Context { + newCtx, ctxCancel := context.WithCancel(context.Background()) + ctxCancel() + + return newCtx + }).MaxTimes(testCase.returnContext.times) + }, + "SetHopsToLiveReturn": func() { mockPacketConn.EXPECT().SetHopsToLive(gomock.Any()).Times(testCase.returnSetHopsToLiveTimes) }, + "CloseReturn": func() { mockPacketConn.EXPECT().Close().Return(nil).Times(testCase.returnCloseTimes) }, + "NewAddrReturn": func() { + mockNetceptor.EXPECT().NewAddr(gomock.Any(), gomock.Any()).Return(netceptor.Addr{}).Times(testCase.returnNewAddrTimes) + }, + "NodeID": func() { mockNetceptor.EXPECT().NodeID().Return("nodeID") }, + "CreateChannel": func() { createChannel(mockPacketConn) }, + "SleepOneSecond": func() { time.Sleep(time.Second * 1) }, + } + + for _, expect := range testCase.expects { + testExpects[expect]() + } +} + +type listenPacketReturn struct { + packetConn netceptor.PacketConner + err error + times int + returnType string +} + +type writeToReturn struct { + packetLen int + err error + times int + returnType string +} + +type contextReturn struct { + ctx context.Context + times int + returnType string +} + +type readFromReturn struct { + data int + address net.Addr + err error + times int + returnType string +} + +type pingTestCaseStruct struct { + name string + pingTarget string + pingHopsToLive byte + returnSetHopsToLiveTimes int + returnCloseTimes int + returnNewAddrTimes int + returnListenPacket listenPacketReturn + returnWriteTo writeToReturn + returnContext contextReturn + returnReadFrom readFromReturn + expects []string + setupTestExpects func(args ...interface{}) + expectedDuration int + expectedRemote string + expectedError error +} + +// TestCreatePing tests CreatePing inside ping.go. +func TestCreatePing(t *testing.T) { + ctrl, mockNetceptor, mockPacketConn := setupTest(t) + + pingTestCases := []pingTestCaseStruct{ + { + "NetceptorShutdown Error", + "target", + byte(1), + 1, + 1, + 1, + listenPacketReturn{mockPacketConn, nil, 1, "return"}, + writeToReturn{0, nil, 1, "return"}, + contextReturn{context.Background(), 2, "doAndReturn"}, + readFromReturn{0, nil, nil, 1, "return"}, + []string{"ListenPacketReturn", "SetHopsToLiveReturn", "CloseReturn", "NewAddrReturn", "SubscribeUnreachableReturn", "WriteToReturn", "ReadFromReturn", "ContextDoAndReturn", "SleepOneSecond"}, + setupTestExpects, + 0, + "", + errors.New("netceptor shutdown"), + }, + { + "SubscribeUnreachable Error", + "target", + byte(1), + 1, + 1, + 1, + listenPacketReturn{mockPacketConn, nil, 1, "return"}, + writeToReturn{0, nil, 1, "return"}, + contextReturn{context.Background(), 2, "return"}, + readFromReturn{0, nil, nil, 1, "return"}, + []string{"CreateChannel", "ListenPacketReturn", "SetHopsToLiveReturn", "CloseReturn", "NewAddrReturn", "WriteToReturn", "ReadFromDoAndReturn", "ContextReturn"}, + setupTestExpects, + 0, + "", + errors.New("test"), + }, + { + "CreatePing Success", + "target", + byte(1), + 1, + 1, + 1, + listenPacketReturn{mockPacketConn, nil, 1, "return"}, + writeToReturn{0, nil, 1, "return"}, + contextReturn{context.Background(), 2, "return"}, + readFromReturn{0, &netceptor.Addr{}, nil, 1, "return"}, + []string{"ListenPacketReturn", "SetHopsToLiveReturn", "CloseReturn", "NewAddrReturn", "SubscribeUnreachableReturn", "WriteToReturn", "ReadFromReturn", "ContextReturn"}, + setupTestExpects, + 0, + ":", + nil, + }, + { + "ListenPacket Error", + "target", + byte(1), + 1, + 1, + 1, + listenPacketReturn{nil, errors.New("Catch ListenPacket error"), 1, "return"}, + writeToReturn{0, nil, 0, "return"}, + contextReturn{context.Background(), 0, "return"}, + readFromReturn{0, &netceptor.Addr{}, nil, 0, "return"}, + []string{"ListenPacketReturn"}, + setupTestExpects, + 0, + "", + errors.New("Catch ListenPacket error"), + }, + { + "ReadFrom Error", + "target", + byte(1), + 1, + 1, + 1, + listenPacketReturn{mockPacketConn, nil, 1, "return"}, + writeToReturn{0, nil, 1, "return"}, + contextReturn{context.Background(), 2, "return"}, + readFromReturn{0, nil, errors.New("ReadFrom error"), 1, "return"}, + []string{"ListenPacketReturn", "SetHopsToLiveReturn", "CloseReturn", "NewAddrReturn", "SubscribeUnreachableReturn", "WriteToReturn", "ReadFromReturn", "ContextReturn"}, + setupTestExpects, + 0, + "", + errors.New("ReadFrom error"), + }, + { + "WriteTo Error", + "target", + byte(1), + 1, + 1, + 1, + listenPacketReturn{mockPacketConn, nil, 1, "return"}, + writeToReturn{0, errors.New("WriteTo error"), 1, "return"}, + contextReturn{context.Background(), 2, "return"}, + readFromReturn{0, nil, nil, 1, "return"}, + []string{"ListenPacketReturn", "SetHopsToLiveReturn", "CloseReturn", "NewAddrReturn", "SubscribeUnreachableReturn", "WriteToReturn", "ReadFromReturn", "ContextReturn", "NodeID"}, + setupTestExpects, + 0, + "", + errors.New("WriteTo error"), + }, + { + "Timeout Error", + "target", + byte(1), + 1, + 1, + 1, + listenPacketReturn{mockPacketConn, nil, 1, "return"}, + writeToReturn{0, nil, 1, "return"}, + contextReturn{context.Background(), 2, "return"}, + readFromReturn{0, nil, nil, 1, "do"}, + []string{"ListenPacketReturn", "SetHopsToLiveReturn", "CloseReturn", "NewAddrReturn", "SubscribeUnreachableReturn", "WriteToReturn", "ReadFromDo", "ContextReturn"}, + setupTestExpects, + 0, + "", + errors.New("timeout"), + }, + { + "User Cancel Error", + "target", + byte(1), + 1, + 1, + 1, + listenPacketReturn{mockPacketConn, nil, 1, "return"}, + writeToReturn{0, nil, 1, "return"}, + contextReturn{context.Background(), 2, "return"}, + readFromReturn{0, nil, nil, 1, "doAndReturn"}, + []string{"ListenPacketReturn", "SetHopsToLiveReturn", "CloseReturn", "NewAddrReturn", "SubscribeUnreachableReturn", "WriteToReturn", "ReadFromDoAndReturn", "ContextReturn"}, + setupTestExpects, + 0, + "", + errors.New("user cancelled"), + }, + } + + for _, testCase := range pingTestCases { + ctx := context.Background() + t.Run(testCase.name, func(t *testing.T) { + testCase.setupTestExpects(mockNetceptor, mockPacketConn, testCase) + if testCase.name == "NetceptorShutdown Error" { + time.Sleep(time.Second * 1) + } + if testCase.name == "User Cancel Error" { + newCtx, ctxCancel := context.WithCancel(ctx) + + time.AfterFunc(1*time.Second, ctxCancel) + + duration, remote, err := netceptor.SendPing(newCtx, mockNetceptor, testCase.pingTarget, testCase.pingHopsToLive) + checkPing(duration, testCase.expectedDuration, remote, testCase.expectedRemote, err, testCase.expectedError, t) + } else { + duration, remote, err := netceptor.SendPing(ctx, mockNetceptor, testCase.pingTarget, testCase.pingHopsToLive) + checkPing(duration, testCase.expectedDuration, remote, testCase.expectedRemote, err, testCase.expectedError, t) + } + + ctrl.Finish() + ctx.Done() + }) + } +} + +type pingReturn struct { + duration time.Duration + remote string + err error +} + +type expectedResult struct { + from string + time time.Duration + err error +} + +// TestCreateTraceroute tests CreateTraceroute inside ping.go. +func TestCreateTraceroute(t *testing.T) { + ctrl := gomock.NewController(t) + + mockNetceptor := mock_netceptor.NewMockNetcForTraceroute(ctrl) + ctx := context.Background() + defer ctx.Done() + + createTracerouteTestCases := []struct { + name string + createTracerouteTarget string + returnPing pingReturn + expectedResult expectedResult + }{ + {"CreateTraceroute Success", "target", pingReturn{time.Since(time.Now()), "target", nil}, expectedResult{":", time.Since(time.Now()), nil}}, + {"CreateTraceroute Error", "target", pingReturn{time.Since(time.Now()), "target", errors.New("traceroute error")}, expectedResult{":", time.Since(time.Now()), errors.New("traceroute error")}}, + } + + for _, testCase := range createTracerouteTestCases { + t.Run(testCase.name, func(t *testing.T) { + mockNetceptor.EXPECT().Context().Return(context.Background()) + mockNetceptor.EXPECT().MaxForwardingHops().Return(byte(1)) + mockNetceptor.EXPECT().Ping(ctx, testCase.createTracerouteTarget, byte(0)).Return(testCase.returnPing.duration, testCase.returnPing.remote, testCase.returnPing.err) + + result := netceptor.CreateTraceroute(ctx, mockNetceptor, testCase.createTracerouteTarget) + for res := range result { + if testCase.expectedResult.err == nil && res.Err != nil { + t.Errorf("Expected no error, got: %v", res.Err.Error()) + } else if testCase.expectedResult.err != nil && (res.Err == nil || res.Err.Error() != testCase.expectedResult.err.Error()) { + t.Errorf("Expected error: %s, got: %v", testCase.expectedResult.err.Error(), res.Err) + } + } + + ctrl.Finish() + }) + } +} diff --git a/pkg/services/ip_router.go b/pkg/services/ip_router.go index 9f9f82113..b3926a123 100644 --- a/pkg/services/ip_router.go +++ b/pkg/services/ip_router.go @@ -39,7 +39,7 @@ type IPRouterService struct { destIP net.IP tunIf *water.Interface link netlink.Link - nConn *netceptor.PacketConn + nConn netceptor.PacketConner knownRoutes []ipRoute knownRoutesLock *sync.RWMutex } diff --git a/pkg/services/udp_proxy.go b/pkg/services/udp_proxy.go index 93b1073be..399b06076 100644 --- a/pkg/services/udp_proxy.go +++ b/pkg/services/udp_proxy.go @@ -15,7 +15,7 @@ import ( // UDPProxyServiceInbound listens on a UDP port and forwards packets to a remote Receptor service. func UDPProxyServiceInbound(s *netceptor.Netceptor, host string, port int, node string, service string) error { - connMap := make(map[string]*netceptor.PacketConn) + connMap := make(map[string]netceptor.PacketConner) buffer := make([]byte, utils.NormalBufferSize) addrStr := fmt.Sprintf("%s:%d", host, port) @@ -69,7 +69,7 @@ func UDPProxyServiceInbound(s *netceptor.Netceptor, host string, port int, node return nil } -func runNetceptorToUDPInbound(pc *netceptor.PacketConn, uc *net.UDPConn, udpAddr net.Addr, expectedAddr netceptor.Addr, logger *logger.ReceptorLogger) { +func runNetceptorToUDPInbound(pc netceptor.PacketConner, uc *net.UDPConn, udpAddr net.Addr, expectedAddr netceptor.Addr, logger *logger.ReceptorLogger) { buf := make([]byte, utils.NormalBufferSize) for { n, addr, err := pc.ReadFrom(buf) @@ -150,7 +150,7 @@ func UDPProxyServiceOutbound(s *netceptor.Netceptor, service string, address str return nil } -func runUDPToNetceptorOutbound(uc *net.UDPConn, pc *netceptor.PacketConn, addr net.Addr, logger *logger.ReceptorLogger) { +func runUDPToNetceptorOutbound(uc *net.UDPConn, pc netceptor.PacketConner, addr net.Addr, logger *logger.ReceptorLogger) { buf := make([]byte, utils.NormalBufferSize) for { n, err := uc.Read(buf)