From c09c1047f75a1f4f36b302516eede26933373b39 Mon Sep 17 00:00:00 2001 From: Yutaro Hayakawa Date: Sun, 26 May 2024 01:38:57 +0900 Subject: [PATCH] Implement device detection mechanism This commit tries to accomplish two features. 1. React to the device MAC address change and change Link Layer Address option appropreately. 2. React to the device down, stop RA, and restart RA when the device comes back. The implementation is netlink-based device detection mechanism. It subscribes to the rtnetlink event and detects device state update and device down/deletion. Signed-off-by: Yutaro Hayakawa --- advertiser.go | 110 +++++++++++++++++++++++++++++++++++-------------- daemon.go | 16 +++++-- daemon_test.go | 87 ++++++++++++++++++++++++++++++-------- device.go | 63 ++++++++++++++++++++++++++++ fake_device.go | 41 ++++++++++++++++++ 5 files changed, 265 insertions(+), 52 deletions(-) create mode 100644 device.go create mode 100644 fake_device.go diff --git a/advertiser.go b/advertiser.go index 049bc66..61cf7fc 100644 --- a/advertiser.go +++ b/advertiser.go @@ -10,11 +10,11 @@ import ( "log/slog" "net/netip" "reflect" + "slices" "sync" "time" "github.com/mdlayher/ndp" - "github.com/sethvargo/go-retry" "golang.org/x/sys/unix" ) @@ -28,10 +28,10 @@ type advertiser struct { ifaceStatus *InterfaceStatus ifaceStatusLock sync.RWMutex - reloadCh chan *InterfaceConfig - stopCh chan any - sock socket - socketCtor socketCtor + reloadCh chan *InterfaceConfig + stopCh chan any + socketCtor socketCtor + deviceWatcher deviceWatcher } // An internal structure to represent RS @@ -40,7 +40,7 @@ type rsMsg struct { from netip.Addr } -func newAdvertiser(initialConfig *InterfaceConfig, ctor socketCtor, logger *slog.Logger) *advertiser { +func newAdvertiser(initialConfig *InterfaceConfig, ctor socketCtor, devWatcher deviceWatcher, logger *slog.Logger) *advertiser { return &advertiser{ logger: logger.With(slog.String("interface", initialConfig.Name)), initialConfig: initialConfig, @@ -48,10 +48,11 @@ func newAdvertiser(initialConfig *InterfaceConfig, ctor socketCtor, logger *slog reloadCh: make(chan *InterfaceConfig), stopCh: make(chan any), socketCtor: ctor, + deviceWatcher: devWatcher, } } -func (s *advertiser) createRAMsg(config *InterfaceConfig) *ndp.RouterAdvertisement { +func (s *advertiser) createRAMsg(config *InterfaceConfig, deviceState *deviceState) *ndp.RouterAdvertisement { return &ndp.RouterAdvertisement{ CurrentHopLimit: uint8(config.CurrentHopLimit), ManagedConfiguration: config.Managed, @@ -60,15 +61,15 @@ func (s *advertiser) createRAMsg(config *InterfaceConfig) *ndp.RouterAdvertiseme RouterLifetime: time.Duration(config.RouterLifetimeSeconds) * time.Second, ReachableTime: time.Duration(config.ReachableTimeMilliseconds) * time.Millisecond, RetransmitTimer: time.Duration(config.RetransmitTimeMilliseconds) * time.Millisecond, - Options: s.createOptions(config), + Options: s.createOptions(config, deviceState), } } -func (s *advertiser) createOptions(config *InterfaceConfig) []ndp.Option { +func (s *advertiser) createOptions(config *InterfaceConfig, deviceState *deviceState) []ndp.Option { options := []ndp.Option{ &ndp.LinkLayerAddress{ Direction: ndp.Source, - Addr: s.sock.hardwareAddr(), + Addr: deviceState.addr, }, } @@ -197,38 +198,60 @@ func (s *advertiser) run(ctx context.Context) { // The current desired configuration config := s.initialConfig + // The current device state + devState := deviceState{} + // Set a timestamp for the first "update" s.setLastUpdate() - // Create the socket - err := retry.Constant(ctx, time.Second, func(ctx context.Context) error { - var err error - - s.sock, err = s.socketCtor(config.Name) - if err != nil { - // These are the unrecoverable errors we're aware of now. - if errors.Is(err, unix.EPERM) || errors.Is(err, unix.EINVAL) { - return fmt.Errorf("cannot create socket: %w", err) - } - - s.reportFailing(err) + // Watch the device state + devCh, err := s.deviceWatcher.watch(ctx, config.Name) + if err != nil { + s.reportStopped(err) + return + } - return retry.RetryableError(err) +waitDevice: + // Wait for the device to be present and up + for { + select { + case <-ctx.Done(): + s.reportStopped(ctx.Err()) + return + case dev := <-devCh: + // Update the device state + devState = dev + + // If the device is up, we can proceed with the socket creation + if dev.isUp { + break waitDevice + } } + } - return nil - }) + // Create the socket + sock, err := s.socketCtor(config.Name) if err != nil { - s.reportStopped(err) - return + // These are the unrecoverable errors we're aware of now. + if errors.Is(err, unix.EPERM) || errors.Is(err, unix.EINVAL) { + s.reportStopped(fmt.Errorf("cannot create socket: %w", err)) + return + } + // Otherwise, we'll retry + s.reportFailing(err) + goto waitDevice } // Launch the RS receiver rsCh := make(chan *rsMsg) + receiverCtx, cancelReceiver := context.WithCancel(ctx) go func() { for { - rs, addr, err := s.sock.recvRS(ctx) + rs, addr, err := sock.recvRS(receiverCtx) if err != nil { + if receiverCtx.Err() != nil { + return + } s.reportFailing(err) continue } @@ -241,7 +264,7 @@ func (s *advertiser) run(ctx context.Context) { reload: for { // RA message - msg := s.createRAMsg(config) + msg := s.createRAMsg(config, &devState) // For unsolicited RA ticker := time.NewTicker(time.Duration(config.RAIntervalMilliseconds) * time.Millisecond) @@ -252,7 +275,7 @@ reload: // Reply to RS // // TODO: Rate limit this to mitigate RS flooding attack - err := s.sock.sendRA(ctx, rs.from, msg) + err := sock.sendRA(ctx, rs.from, msg) if err != nil { s.reportFailing(err) continue @@ -261,7 +284,7 @@ reload: s.reportRunning() case <-ticker.C: // Send unsolicited RA - err := s.sock.sendRA(ctx, netip.IPv6LinkLocalAllNodes(), msg) + err := sock.sendRA(ctx, netip.IPv6LinkLocalAllNodes(), msg) if err != nil { s.reportFailing(err) continue @@ -277,6 +300,28 @@ reload: s.reportReloading() s.setLastUpdate() continue reload + case dev := <-devCh: + // Save the old address for comparison + oldAddr := devState.addr + + // Update the device state + devState = dev + + // Device is stopped. Stop the advertisement + // and wait for the device to be up again. + if !devState.isUp { + cancelReceiver() + s.reportFailing(fmt.Errorf("device is down")) + goto waitDevice + } + + // Device address has changed. We need to + // change the Link Layer Address option in the + // RA message. Reload internally. + if !slices.Equal(oldAddr, dev.addr) { + s.reportReloading() + continue reload + } case <-ctx.Done(): s.reportStopped(ctx.Err()) break reload @@ -288,7 +333,8 @@ reload: } - s.sock.close() + cancelReceiver() + sock.close() } func (s *advertiser) status() *InterfaceStatus { diff --git a/daemon.go b/daemon.go index ef24dce..141d56c 100644 --- a/daemon.go +++ b/daemon.go @@ -17,6 +17,7 @@ type Daemon struct { reloadCh chan *Config logger *slog.Logger socketConstructor socketCtor + deviceWatcher deviceWatcher advertisers map[string]*advertiser advertisersLock sync.RWMutex @@ -39,6 +40,7 @@ func NewDaemon(config *Config, opts ...DaemonOption) (*Daemon, error) { reloadCh: make(chan *Config), logger: slog.Default(), socketConstructor: newSocket, + deviceWatcher: newDeviceWatcher(), advertisers: map[string]*advertiser{}, } @@ -89,9 +91,9 @@ reload: // Add new per-interface jobs for _, c := range toAdd { d.logger.Info("Adding new RA sender", slog.String("interface", c.Name)) - sender := newAdvertiser(c, d.socketConstructor, d.logger) - go sender.run(ctx) - d.advertisers[c.Name] = sender + advertiser := newAdvertiser(c, d.socketConstructor, d.deviceWatcher, d.logger) + go advertiser.run(ctx) + d.advertisers[c.Name] = advertiser } // Update (reload) existing workers @@ -188,3 +190,11 @@ func withSocketConstructor(c socketCtor) DaemonOption { d.socketConstructor = c } } + +// withDeviceWatcher overrides the default device watcher with the provided +// one. For testing purposes only. +func withDeviceWatcher(w deviceWatcher) DaemonOption { + return func(d *Daemon) { + d.deviceWatcher = w + } +} diff --git a/daemon_test.go b/daemon_test.go index 10bfb8c..3eb9106 100644 --- a/daemon_test.go +++ b/daemon_test.go @@ -5,7 +5,9 @@ package ra import ( "context" + "net" "net/netip" + "slices" "testing" "time" @@ -20,7 +22,7 @@ func eventully(t *testing.T, f func() bool) { require.Eventually(t, f, time.Second*1, time.Millisecond*10) } -func assertRAInterval(t *testing.T, sock *fakeSock, interval time.Duration) bool { +func assertRAInterval(ct *assert.CollectT, sock *fakeSock, interval time.Duration) bool { // wait until we get 3 RAs timeout, cancel := context.WithTimeout(context.Background(), time.Second*1) @@ -30,7 +32,7 @@ outer: select { case <-timeout.Done(): cancel() - return assert.Fail(t, "couldn't get 3 RAs in time") + return assert.Fail(ct, "couldn't get 3 RAs in time") case ra := <-sock.txMulticastCh(): ras = append(ras, ra) if len(ras) == 3 { @@ -45,7 +47,7 @@ outer: diff0 := ras[1].tstamp.Sub(ras[0].tstamp) diff1 := ras[2].tstamp.Sub(ras[1].tstamp) - return assert.InDelta(t, interval, diff0, mergin) && assert.InDelta(t, interval, diff1, mergin) + return assert.InDelta(ct, interval, diff0, mergin) && assert.InDelta(ct, interval, diff1, mergin) } func TestDaemonHappyPath(t *testing.T) { @@ -105,7 +107,16 @@ func TestDaemonHappyPath(t *testing.T) { reg := newFakeSockRegistry() - d, err := NewDaemon(config, withSocketConstructor(reg.newSock)) + // Create a fake device watcher and inject an initial device state + devWatcher := newFakeDeviceWatcher("net0", "net1") + devWatcher.update("net0", deviceState{isUp: true, addr: net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77}}) + devWatcher.update("net1", deviceState{isUp: true, addr: net.HardwareAddr{0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}}) + + d, err := NewDaemon( + config, + withSocketConstructor(reg.newSock), + withDeviceWatcher(devWatcher), + ) require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) @@ -129,11 +140,15 @@ func TestDaemonHappyPath(t *testing.T) { sock, err = reg.getSock("net0") require.NoError(t, err) - require.True(t, assertRAInterval(t, sock, time.Millisecond*100)) + require.EventuallyWithT(t, func(ct *assert.CollectT) { + assertRAInterval(ct, sock, time.Millisecond*100) + }, time.Second*1, time.Millisecond*100) sock, err = reg.getSock("net1") require.NoError(t, err) - require.True(t, assertRAInterval(t, sock, time.Millisecond*100)) + require.EventuallyWithT(t, func(ct *assert.CollectT) { + assertRAInterval(ct, sock, time.Millisecond*100) + }, time.Second*1, time.Millisecond*100) }) t.Run("Ensure the RA parameter is reflected to the packet", func(t *testing.T) { @@ -165,6 +180,17 @@ func TestDaemonHappyPath(t *testing.T) { require.NotNil(t, mtuOption, "MTU option is not advertised") require.Equal(t, uint32(1500), mtuOption.MTU, "Invalid MTU") + // Find and check Source Link-Layer Address option + var slaOption *ndp.LinkLayerAddress + for _, option := range ra.msg.Options { + if opt, ok := option.(*ndp.LinkLayerAddress); ok { + slaOption = opt + break + } + } + require.NotNil(t, slaOption, "Source Link-Layer Address option is not advertised") + require.Equal(t, net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77}, slaOption.Addr) + // Find and check Prefix Information options prefixOptions := map[netip.Addr]*ndp.PrefixInformation{} for _, option := range ra.msg.Options { @@ -236,6 +262,32 @@ func TestDaemonHappyPath(t *testing.T) { assert.Equal(t, Running, status.Interfaces[1].State) }) + t.Run("Ensure Source Link Layer Address option is updated after device MAC address change", func(t *testing.T) { + // Update the MAC address of net0 + devWatcher.update("net0", deviceState{isUp: true, addr: net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x78}}) + + sock, err := reg.getSock("net0") + require.NoError(t, err) + + eventully(t, func() bool { + // Sampling one RA + ra := <-sock.txMulticastCh() + + // Find and check Source Link-Layer Address option + var slaOption *ndp.LinkLayerAddress + for _, option := range ra.msg.Options { + if opt, ok := option.(*ndp.LinkLayerAddress); ok { + slaOption = opt + break + } + } + + require.NotNil(t, slaOption, "Source Link-Layer Address option is not advertised") + + return slices.Equal(net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x78}, slaOption.Addr) + }) + }) + t.Run("Ensure unsolicited RA interval is updated after reload", func(t *testing.T) { // Update the interval of net1. net0 should remain the same. config.Interfaces[1].RAIntervalMilliseconds = 200 @@ -246,18 +298,18 @@ func TestDaemonHappyPath(t *testing.T) { require.NoError(t, err) cancelTimeout() - eventully(t, func() bool { + require.EventuallyWithT(t, func(ct *assert.CollectT) { sock0, err := reg.getSock("net0") if !assert.NoError(t, err) { - return false + return } sock1, err := reg.getSock("net1") if !assert.NoError(t, err) { - return false + return } - return assertRAInterval(t, sock0, time.Millisecond*100) && - assertRAInterval(t, sock1, time.Millisecond*200) - }) + assertRAInterval(ct, sock0, time.Millisecond*100) + assertRAInterval(ct, sock1, time.Millisecond*200) + }, time.Second*1, time.Millisecond*100) }) t.Run("Ensure RS is replied with unicast RA", func(t *testing.T) { @@ -292,17 +344,18 @@ func TestDaemonHappyPath(t *testing.T) { require.NoError(t, err) cancelTimeout() - eventully(t, func() bool { + require.EventuallyWithT(t, func(ct *assert.CollectT) { sock0, err := reg.getSock("net0") if !assert.NoError(t, err) { - return false + return } sock1, err := reg.getSock("net1") if !assert.NoError(t, err) { - return false + return } - return assertRAInterval(t, sock0, time.Millisecond*100) && assert.True(t, sock1.isClosed()) - }) + assertRAInterval(ct, sock0, time.Millisecond*100) + assert.True(ct, sock1.isClosed()) + }, time.Second*1, time.Millisecond*100) }) t.Run("Ensure unsolicited RA is stopped after stopping the daemon", func(t *testing.T) { diff --git a/device.go b/device.go new file mode 100644 index 0000000..a824dd6 --- /dev/null +++ b/device.go @@ -0,0 +1,63 @@ +package ra + +import ( + "context" + "net" + + "github.com/vishvananda/netlink" +) + +type deviceState struct { + isUp bool + addr net.HardwareAddr +} + +type deviceWatcher interface { + watch(ctx context.Context, name string) (<-chan deviceState, error) +} + +type netlinkDeviceWatcher struct{} + +var _ deviceWatcher = &netlinkDeviceWatcher{} + +func newDeviceWatcher() deviceWatcher { + return &netlinkDeviceWatcher{} +} + +func (w *netlinkDeviceWatcher) watch(ctx context.Context, name string) (<-chan deviceState, error) { + linkCh := make(chan netlink.LinkUpdate) + + if err := netlink.LinkSubscribeWithOptions( + linkCh, + ctx.Done(), + netlink.LinkSubscribeOptions{ + ErrorCallback: func(err error) {}, + ListExisting: true, + }, + ); err != nil { + return nil, err + } + + devCh := make(chan deviceState) + + go func() { + defer close(linkCh) + defer close(devCh) + for { + select { + case <-ctx.Done(): + return + case link := <-linkCh: + if link.Attrs().Name != name { + continue + } + devCh <- deviceState{ + isUp: link.Flags&uint32(net.FlagUp) != 0, + addr: link.Attrs().HardwareAddr, + } + } + } + }() + + return devCh, nil +} diff --git a/fake_device.go b/fake_device.go new file mode 100644 index 0000000..0484cad --- /dev/null +++ b/fake_device.go @@ -0,0 +1,41 @@ +package ra + +import "context" + +type fakeDeviceWatcher struct { + watchers map[string]chan deviceState +} + +var _ deviceWatcher = &fakeDeviceWatcher{} + +func newFakeDeviceWatcher(devs ...string) *fakeDeviceWatcher { + fdw := &fakeDeviceWatcher{ + watchers: make(map[string]chan deviceState), + } + for _, dev := range devs { + fdw.watchers[dev] = make(chan deviceState, 1) + } + return fdw +} + +func (w *fakeDeviceWatcher) watch(ctx context.Context, name string) (<-chan deviceState, error) { + devCh := make(chan deviceState) + + go func() { + defer close(devCh) + for { + select { + case <-ctx.Done(): + return + case dev := <-w.watchers[name]: + devCh <- dev + } + } + }() + + return devCh, nil +} + +func (w *fakeDeviceWatcher) update(name string, dev deviceState) { + w.watchers[name] <- dev +}