diff --git a/pkg/daemon/BUILD.bazel b/pkg/daemon/BUILD.bazel index cd87902ef5..7aaf7e012a 100644 --- a/pkg/daemon/BUILD.bazel +++ b/pkg/daemon/BUILD.bazel @@ -1,4 +1,4 @@ -load("//tools/lint:go.bzl", "go_library") +load("//tools/lint:go.bzl", "go_library", "go_test") go_library( name = "go_default_library", @@ -33,3 +33,16 @@ go_library( "@org_golang_google_protobuf//types/known/timestamppb:go_default_library", ], ) + +go_test( + name = "go_default_test", + srcs = ["topology_test.go"], + deps = [ + ":go_default_library", + "//pkg/addr:go_default_library", + "//pkg/daemon/mock_daemon:go_default_library", + "//pkg/snet:go_default_library", + "@com_github_golang_mock//gomock:go_default_library", + "@com_github_stretchr_testify//assert:go_default_library", + ], +) diff --git a/pkg/daemon/topology.go b/pkg/daemon/topology.go index 44def41eef..b1a76b6e96 100644 --- a/pkg/daemon/topology.go +++ b/pkg/daemon/topology.go @@ -17,7 +17,7 @@ package daemon import ( "context" "net/netip" - "sync" + "sync/atomic" "time" "github.com/scionproto/scion/pkg/log" @@ -59,14 +59,12 @@ func LoadTopology(ctx context.Context, conn Connector) (snet.Topology, error) { type ReloadingTopology struct { conn Connector baseTopology snet.Topology - interfaces sync.Map + interfaces atomic.Pointer[map[uint16]netip.AddrPort] } // NewReloadingTopology creates a new ReloadingTopology that reloads the // interface information periodically. The Run method must be called for -// interface information to be populated. NOTE: The reloading topology does not -// clean up old interface information, so if you have a lot of interface churn, -// you may want to use a different implementation. +// interface information to be populated. func NewReloadingTopology(ctx context.Context, conn Connector) (*ReloadingTopology, error) { ia, err := conn.LocalIA(ctx) if err != nil { @@ -95,11 +93,12 @@ func (t *ReloadingTopology) Topology() snet.Topology { LocalIA: base.LocalIA, PortRange: base.PortRange, Interface: func(ifID uint16) (netip.AddrPort, bool) { - a, ok := t.interfaces.Load(ifID) - if !ok { + m := t.interfaces.Load() + if m == nil { return netip.AddrPort{}, false } - return a.(netip.AddrPort), true + a, ok := (*m)[ifID] + return a, ok }, } } @@ -131,8 +130,6 @@ func (t *ReloadingTopology) loadInterfaces(ctx context.Context) error { if err != nil { return err } - for ifID, addr := range intfs { - t.interfaces.Store(ifID, addr) - } + t.interfaces.Store(&intfs) return nil } diff --git a/pkg/daemon/topology_test.go b/pkg/daemon/topology_test.go index fa46d6662b..05ce5988c5 100644 --- a/pkg/daemon/topology_test.go +++ b/pkg/daemon/topology_test.go @@ -76,6 +76,8 @@ func TestReloadingTopology(t *testing.T) { <-done wantTopo.interfaces = interfacesLater wantTopo.checkTopology(t, loader.Topology()) + _, ok := loader.Topology().Interface(1) + assert.False(t, ok) } type testTopology struct {