diff --git a/advertiser.go b/advertiser.go index 049bc66..61cf7fc 100644 --- a/advertiser.go +++ b/advertiser.go @@ -10,11 +10,11 @@ import ( "log/slog" "net/netip" "reflect" + "slices" "sync" "time" "github.com/mdlayher/ndp" - "github.com/sethvargo/go-retry" "golang.org/x/sys/unix" ) @@ -28,10 +28,10 @@ type advertiser struct { ifaceStatus *InterfaceStatus ifaceStatusLock sync.RWMutex - reloadCh chan *InterfaceConfig - stopCh chan any - sock socket - socketCtor socketCtor + reloadCh chan *InterfaceConfig + stopCh chan any + socketCtor socketCtor + deviceWatcher deviceWatcher } // An internal structure to represent RS @@ -40,7 +40,7 @@ type rsMsg struct { from netip.Addr } -func newAdvertiser(initialConfig *InterfaceConfig, ctor socketCtor, logger *slog.Logger) *advertiser { +func newAdvertiser(initialConfig *InterfaceConfig, ctor socketCtor, devWatcher deviceWatcher, logger *slog.Logger) *advertiser { return &advertiser{ logger: logger.With(slog.String("interface", initialConfig.Name)), initialConfig: initialConfig, @@ -48,10 +48,11 @@ func newAdvertiser(initialConfig *InterfaceConfig, ctor socketCtor, logger *slog reloadCh: make(chan *InterfaceConfig), stopCh: make(chan any), socketCtor: ctor, + deviceWatcher: devWatcher, } } -func (s *advertiser) createRAMsg(config *InterfaceConfig) *ndp.RouterAdvertisement { +func (s *advertiser) createRAMsg(config *InterfaceConfig, deviceState *deviceState) *ndp.RouterAdvertisement { return &ndp.RouterAdvertisement{ CurrentHopLimit: uint8(config.CurrentHopLimit), ManagedConfiguration: config.Managed, @@ -60,15 +61,15 @@ func (s *advertiser) createRAMsg(config *InterfaceConfig) *ndp.RouterAdvertiseme RouterLifetime: time.Duration(config.RouterLifetimeSeconds) * time.Second, ReachableTime: time.Duration(config.ReachableTimeMilliseconds) * time.Millisecond, RetransmitTimer: time.Duration(config.RetransmitTimeMilliseconds) * time.Millisecond, - Options: s.createOptions(config), + Options: s.createOptions(config, deviceState), } } -func (s *advertiser) createOptions(config *InterfaceConfig) []ndp.Option { +func (s *advertiser) createOptions(config *InterfaceConfig, deviceState *deviceState) []ndp.Option { options := []ndp.Option{ &ndp.LinkLayerAddress{ Direction: ndp.Source, - Addr: s.sock.hardwareAddr(), + Addr: deviceState.addr, }, } @@ -197,38 +198,60 @@ func (s *advertiser) run(ctx context.Context) { // The current desired configuration config := s.initialConfig + // The current device state + devState := deviceState{} + // Set a timestamp for the first "update" s.setLastUpdate() - // Create the socket - err := retry.Constant(ctx, time.Second, func(ctx context.Context) error { - var err error - - s.sock, err = s.socketCtor(config.Name) - if err != nil { - // These are the unrecoverable errors we're aware of now. - if errors.Is(err, unix.EPERM) || errors.Is(err, unix.EINVAL) { - return fmt.Errorf("cannot create socket: %w", err) - } - - s.reportFailing(err) + // Watch the device state + devCh, err := s.deviceWatcher.watch(ctx, config.Name) + if err != nil { + s.reportStopped(err) + return + } - return retry.RetryableError(err) +waitDevice: + // Wait for the device to be present and up + for { + select { + case <-ctx.Done(): + s.reportStopped(ctx.Err()) + return + case dev := <-devCh: + // Update the device state + devState = dev + + // If the device is up, we can proceed with the socket creation + if dev.isUp { + break waitDevice + } } + } - return nil - }) + // Create the socket + sock, err := s.socketCtor(config.Name) if err != nil { - s.reportStopped(err) - return + // These are the unrecoverable errors we're aware of now. + if errors.Is(err, unix.EPERM) || errors.Is(err, unix.EINVAL) { + s.reportStopped(fmt.Errorf("cannot create socket: %w", err)) + return + } + // Otherwise, we'll retry + s.reportFailing(err) + goto waitDevice } // Launch the RS receiver rsCh := make(chan *rsMsg) + receiverCtx, cancelReceiver := context.WithCancel(ctx) go func() { for { - rs, addr, err := s.sock.recvRS(ctx) + rs, addr, err := sock.recvRS(receiverCtx) if err != nil { + if receiverCtx.Err() != nil { + return + } s.reportFailing(err) continue } @@ -241,7 +264,7 @@ func (s *advertiser) run(ctx context.Context) { reload: for { // RA message - msg := s.createRAMsg(config) + msg := s.createRAMsg(config, &devState) // For unsolicited RA ticker := time.NewTicker(time.Duration(config.RAIntervalMilliseconds) * time.Millisecond) @@ -252,7 +275,7 @@ reload: // Reply to RS // // TODO: Rate limit this to mitigate RS flooding attack - err := s.sock.sendRA(ctx, rs.from, msg) + err := sock.sendRA(ctx, rs.from, msg) if err != nil { s.reportFailing(err) continue @@ -261,7 +284,7 @@ reload: s.reportRunning() case <-ticker.C: // Send unsolicited RA - err := s.sock.sendRA(ctx, netip.IPv6LinkLocalAllNodes(), msg) + err := sock.sendRA(ctx, netip.IPv6LinkLocalAllNodes(), msg) if err != nil { s.reportFailing(err) continue @@ -277,6 +300,28 @@ reload: s.reportReloading() s.setLastUpdate() continue reload + case dev := <-devCh: + // Save the old address for comparison + oldAddr := devState.addr + + // Update the device state + devState = dev + + // Device is stopped. Stop the advertisement + // and wait for the device to be up again. + if !devState.isUp { + cancelReceiver() + s.reportFailing(fmt.Errorf("device is down")) + goto waitDevice + } + + // Device address has changed. We need to + // change the Link Layer Address option in the + // RA message. Reload internally. + if !slices.Equal(oldAddr, dev.addr) { + s.reportReloading() + continue reload + } case <-ctx.Done(): s.reportStopped(ctx.Err()) break reload @@ -288,7 +333,8 @@ reload: } - s.sock.close() + cancelReceiver() + sock.close() } func (s *advertiser) status() *InterfaceStatus { diff --git a/daemon.go b/daemon.go index ef24dce..141d56c 100644 --- a/daemon.go +++ b/daemon.go @@ -17,6 +17,7 @@ type Daemon struct { reloadCh chan *Config logger *slog.Logger socketConstructor socketCtor + deviceWatcher deviceWatcher advertisers map[string]*advertiser advertisersLock sync.RWMutex @@ -39,6 +40,7 @@ func NewDaemon(config *Config, opts ...DaemonOption) (*Daemon, error) { reloadCh: make(chan *Config), logger: slog.Default(), socketConstructor: newSocket, + deviceWatcher: newDeviceWatcher(), advertisers: map[string]*advertiser{}, } @@ -89,9 +91,9 @@ reload: // Add new per-interface jobs for _, c := range toAdd { d.logger.Info("Adding new RA sender", slog.String("interface", c.Name)) - sender := newAdvertiser(c, d.socketConstructor, d.logger) - go sender.run(ctx) - d.advertisers[c.Name] = sender + advertiser := newAdvertiser(c, d.socketConstructor, d.deviceWatcher, d.logger) + go advertiser.run(ctx) + d.advertisers[c.Name] = advertiser } // Update (reload) existing workers @@ -188,3 +190,11 @@ func withSocketConstructor(c socketCtor) DaemonOption { d.socketConstructor = c } } + +// withDeviceWatcher overrides the default device watcher with the provided +// one. For testing purposes only. +func withDeviceWatcher(w deviceWatcher) DaemonOption { + return func(d *Daemon) { + d.deviceWatcher = w + } +} diff --git a/daemon_test.go b/daemon_test.go index 10bfb8c..3eb9106 100644 --- a/daemon_test.go +++ b/daemon_test.go @@ -5,7 +5,9 @@ package ra import ( "context" + "net" "net/netip" + "slices" "testing" "time" @@ -20,7 +22,7 @@ func eventully(t *testing.T, f func() bool) { require.Eventually(t, f, time.Second*1, time.Millisecond*10) } -func assertRAInterval(t *testing.T, sock *fakeSock, interval time.Duration) bool { +func assertRAInterval(ct *assert.CollectT, sock *fakeSock, interval time.Duration) bool { // wait until we get 3 RAs timeout, cancel := context.WithTimeout(context.Background(), time.Second*1) @@ -30,7 +32,7 @@ outer: select { case <-timeout.Done(): cancel() - return assert.Fail(t, "couldn't get 3 RAs in time") + return assert.Fail(ct, "couldn't get 3 RAs in time") case ra := <-sock.txMulticastCh(): ras = append(ras, ra) if len(ras) == 3 { @@ -45,7 +47,7 @@ outer: diff0 := ras[1].tstamp.Sub(ras[0].tstamp) diff1 := ras[2].tstamp.Sub(ras[1].tstamp) - return assert.InDelta(t, interval, diff0, mergin) && assert.InDelta(t, interval, diff1, mergin) + return assert.InDelta(ct, interval, diff0, mergin) && assert.InDelta(ct, interval, diff1, mergin) } func TestDaemonHappyPath(t *testing.T) { @@ -105,7 +107,16 @@ func TestDaemonHappyPath(t *testing.T) { reg := newFakeSockRegistry() - d, err := NewDaemon(config, withSocketConstructor(reg.newSock)) + // Create a fake device watcher and inject an initial device state + devWatcher := newFakeDeviceWatcher("net0", "net1") + devWatcher.update("net0", deviceState{isUp: true, addr: net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77}}) + devWatcher.update("net1", deviceState{isUp: true, addr: net.HardwareAddr{0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}}) + + d, err := NewDaemon( + config, + withSocketConstructor(reg.newSock), + withDeviceWatcher(devWatcher), + ) require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) @@ -129,11 +140,15 @@ func TestDaemonHappyPath(t *testing.T) { sock, err = reg.getSock("net0") require.NoError(t, err) - require.True(t, assertRAInterval(t, sock, time.Millisecond*100)) + require.EventuallyWithT(t, func(ct *assert.CollectT) { + assertRAInterval(ct, sock, time.Millisecond*100) + }, time.Second*1, time.Millisecond*100) sock, err = reg.getSock("net1") require.NoError(t, err) - require.True(t, assertRAInterval(t, sock, time.Millisecond*100)) + require.EventuallyWithT(t, func(ct *assert.CollectT) { + assertRAInterval(ct, sock, time.Millisecond*100) + }, time.Second*1, time.Millisecond*100) }) t.Run("Ensure the RA parameter is reflected to the packet", func(t *testing.T) { @@ -165,6 +180,17 @@ func TestDaemonHappyPath(t *testing.T) { require.NotNil(t, mtuOption, "MTU option is not advertised") require.Equal(t, uint32(1500), mtuOption.MTU, "Invalid MTU") + // Find and check Source Link-Layer Address option + var slaOption *ndp.LinkLayerAddress + for _, option := range ra.msg.Options { + if opt, ok := option.(*ndp.LinkLayerAddress); ok { + slaOption = opt + break + } + } + require.NotNil(t, slaOption, "Source Link-Layer Address option is not advertised") + require.Equal(t, net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77}, slaOption.Addr) + // Find and check Prefix Information options prefixOptions := map[netip.Addr]*ndp.PrefixInformation{} for _, option := range ra.msg.Options { @@ -236,6 +262,32 @@ func TestDaemonHappyPath(t *testing.T) { assert.Equal(t, Running, status.Interfaces[1].State) }) + t.Run("Ensure Source Link Layer Address option is updated after device MAC address change", func(t *testing.T) { + // Update the MAC address of net0 + devWatcher.update("net0", deviceState{isUp: true, addr: net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x78}}) + + sock, err := reg.getSock("net0") + require.NoError(t, err) + + eventully(t, func() bool { + // Sampling one RA + ra := <-sock.txMulticastCh() + + // Find and check Source Link-Layer Address option + var slaOption *ndp.LinkLayerAddress + for _, option := range ra.msg.Options { + if opt, ok := option.(*ndp.LinkLayerAddress); ok { + slaOption = opt + break + } + } + + require.NotNil(t, slaOption, "Source Link-Layer Address option is not advertised") + + return slices.Equal(net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x78}, slaOption.Addr) + }) + }) + t.Run("Ensure unsolicited RA interval is updated after reload", func(t *testing.T) { // Update the interval of net1. net0 should remain the same. config.Interfaces[1].RAIntervalMilliseconds = 200 @@ -246,18 +298,18 @@ func TestDaemonHappyPath(t *testing.T) { require.NoError(t, err) cancelTimeout() - eventully(t, func() bool { + require.EventuallyWithT(t, func(ct *assert.CollectT) { sock0, err := reg.getSock("net0") if !assert.NoError(t, err) { - return false + return } sock1, err := reg.getSock("net1") if !assert.NoError(t, err) { - return false + return } - return assertRAInterval(t, sock0, time.Millisecond*100) && - assertRAInterval(t, sock1, time.Millisecond*200) - }) + assertRAInterval(ct, sock0, time.Millisecond*100) + assertRAInterval(ct, sock1, time.Millisecond*200) + }, time.Second*1, time.Millisecond*100) }) t.Run("Ensure RS is replied with unicast RA", func(t *testing.T) { @@ -292,17 +344,18 @@ func TestDaemonHappyPath(t *testing.T) { require.NoError(t, err) cancelTimeout() - eventully(t, func() bool { + require.EventuallyWithT(t, func(ct *assert.CollectT) { sock0, err := reg.getSock("net0") if !assert.NoError(t, err) { - return false + return } sock1, err := reg.getSock("net1") if !assert.NoError(t, err) { - return false + return } - return assertRAInterval(t, sock0, time.Millisecond*100) && assert.True(t, sock1.isClosed()) - }) + assertRAInterval(ct, sock0, time.Millisecond*100) + assert.True(ct, sock1.isClosed()) + }, time.Second*1, time.Millisecond*100) }) t.Run("Ensure unsolicited RA is stopped after stopping the daemon", func(t *testing.T) { diff --git a/device.go b/device.go new file mode 100644 index 0000000..a824dd6 --- /dev/null +++ b/device.go @@ -0,0 +1,63 @@ +package ra + +import ( + "context" + "net" + + "github.com/vishvananda/netlink" +) + +type deviceState struct { + isUp bool + addr net.HardwareAddr +} + +type deviceWatcher interface { + watch(ctx context.Context, name string) (<-chan deviceState, error) +} + +type netlinkDeviceWatcher struct{} + +var _ deviceWatcher = &netlinkDeviceWatcher{} + +func newDeviceWatcher() deviceWatcher { + return &netlinkDeviceWatcher{} +} + +func (w *netlinkDeviceWatcher) watch(ctx context.Context, name string) (<-chan deviceState, error) { + linkCh := make(chan netlink.LinkUpdate) + + if err := netlink.LinkSubscribeWithOptions( + linkCh, + ctx.Done(), + netlink.LinkSubscribeOptions{ + ErrorCallback: func(err error) {}, + ListExisting: true, + }, + ); err != nil { + return nil, err + } + + devCh := make(chan deviceState) + + go func() { + defer close(linkCh) + defer close(devCh) + for { + select { + case <-ctx.Done(): + return + case link := <-linkCh: + if link.Attrs().Name != name { + continue + } + devCh <- deviceState{ + isUp: link.Flags&uint32(net.FlagUp) != 0, + addr: link.Attrs().HardwareAddr, + } + } + } + }() + + return devCh, nil +} diff --git a/fake_device.go b/fake_device.go new file mode 100644 index 0000000..0484cad --- /dev/null +++ b/fake_device.go @@ -0,0 +1,41 @@ +package ra + +import "context" + +type fakeDeviceWatcher struct { + watchers map[string]chan deviceState +} + +var _ deviceWatcher = &fakeDeviceWatcher{} + +func newFakeDeviceWatcher(devs ...string) *fakeDeviceWatcher { + fdw := &fakeDeviceWatcher{ + watchers: make(map[string]chan deviceState), + } + for _, dev := range devs { + fdw.watchers[dev] = make(chan deviceState, 1) + } + return fdw +} + +func (w *fakeDeviceWatcher) watch(ctx context.Context, name string) (<-chan deviceState, error) { + devCh := make(chan deviceState) + + go func() { + defer close(devCh) + for { + select { + case <-ctx.Done(): + return + case dev := <-w.watchers[name]: + devCh <- dev + } + } + }() + + return devCh, nil +} + +func (w *fakeDeviceWatcher) update(name string, dev deviceState) { + w.watchers[name] <- dev +}