Skip to content

Commit

Permalink
note
Browse files Browse the repository at this point in the history
  • Loading branch information
lukedirtwalker committed Dec 23, 2024
1 parent cd48ed8 commit 8bbf361
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 12 deletions.
45 changes: 33 additions & 12 deletions pkg/daemon/topology.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,29 @@ type ReloadingTopology struct {

// NewReloadingTopology creates a new ReloadingTopology that reloads the
// interface information periodically. The Run method must be called for
// interface information to be populated.
// interface information to be populated. NOTE: The reloading topology does not
// clean up old interface information, so if you have a lot of interface churn,
// you may want to use a different implementation.
func NewReloadingTopology(ctx context.Context, conn Connector) (*ReloadingTopology, error) {
topo, err := LoadTopology(ctx, conn)
ia, err := conn.LocalIA(ctx)
if err != nil {
return nil, serrors.Wrap("loading local ISD-AS", err)
}
start, end, err := conn.PortRange(ctx)
if err != nil {
return nil, serrors.Wrap("loading port range", err)
}
t := &ReloadingTopology{
conn: conn,
baseTopology: snet.Topology{
LocalIA: ia,
PortRange: snet.TopologyPortRange{Start: start, End: end},
},
}
if err := t.loadInterfaces(ctx); err != nil {
return nil, err
}
return &ReloadingTopology{
conn: conn,
baseTopology: topo,
}, nil
return t, nil
}

func (t *ReloadingTopology) Topology() snet.Topology {
Expand All @@ -96,15 +109,12 @@ func (t *ReloadingTopology) Run(ctx context.Context, period time.Duration) {
defer ticker.Stop()

reload := func() {
intfs, err := t.conn.Interfaces(ctx)
if err != nil {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
if err := t.loadInterfaces(ctx); err != nil {
log.FromCtx(ctx).Error("Failed to reload interfaces", "err", err)
}
for ifID, addr := range intfs {
t.interfaces.Store(ifID, addr)
}
}

reload()
for {
select {
Expand All @@ -115,3 +125,14 @@ func (t *ReloadingTopology) Run(ctx context.Context, period time.Duration) {
}
}
}

func (t *ReloadingTopology) loadInterfaces(ctx context.Context) error {
intfs, err := t.conn.Interfaces(ctx)
if err != nil {
return err
}
for ifID, addr := range intfs {
t.interfaces.Store(ifID, addr)
}
return nil
}
107 changes: 107 additions & 0 deletions pkg/daemon/topology_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package daemon_test

import (
"context"
"net/netip"
"testing"
"time"

"github.com/golang/mock/gomock"
"github.com/scionproto/scion/pkg/addr"
"github.com/scionproto/scion/pkg/daemon"
"github.com/scionproto/scion/pkg/daemon/mock_daemon"
"github.com/scionproto/scion/pkg/snet"
"github.com/stretchr/testify/assert"
)

func TestLoadTopology(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
conn := mock_daemon.NewMockConnector(ctrl)
wantTopo := testTopology{
ia: addr.MustParseIA("1-ff00:0:110"),
start: uint16(4096),
end: uint16(8192),
interfaces: map[uint16]netip.AddrPort{
1: netip.MustParseAddrPort("10.0.0.1:5153"),
2: netip.MustParseAddrPort("10.0.0.2:6421"),
},
}
wantTopo.setupMockResponses(conn)

topo, err := daemon.LoadTopology(context.Background(), conn)
assert.NoError(t, err)
wantTopo.checkTopology(t, topo)
}

func TestReloadingTopology(t *testing.T) {
ctrl := gomock.NewController(t)
conn := mock_daemon.NewMockConnector(ctrl)

wantTopo := testTopology{
ia: addr.MustParseIA("1-ff00:0:110"),
start: uint16(4096),
end: uint16(8192),
interfaces: map[uint16]netip.AddrPort{
1: netip.MustParseAddrPort("10.0.0.1:5153"),
2: netip.MustParseAddrPort("10.0.0.2:6421"),
},
}
interfacesLater := map[uint16]netip.AddrPort{
2: netip.MustParseAddrPort("10.0.0.2:6421"),
3: netip.MustParseAddrPort("10.0.0.3:7539"),
}
calls := wantTopo.setupMockResponses(conn)
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
gomock.InOrder(
append(calls,
conn.EXPECT().Interfaces(gomock.Any()).DoAndReturn(func(context.Context) (map[uint16]netip.AddrPort, error) {
cancel()
return interfacesLater, nil
}).AnyTimes(),
)...,
)

loader, err := daemon.NewReloadingTopology(ctx, conn)
assert.NoError(t, err)
topo := loader.Topology()
wantTopo.checkTopology(t, topo)

go func() {
loader.Run(ctx, 100*time.Millisecond)
close(done)
}()
<-done
wantTopo.interfaces = interfacesLater
wantTopo.checkTopology(t, loader.Topology())
}

type testTopology struct {
ia addr.IA
start uint16
end uint16
interfaces map[uint16]netip.AddrPort
}

func (tt testTopology) setupMockResponses(c *mock_daemon.MockConnector) []*gomock.Call {
return []*gomock.Call{
c.EXPECT().LocalIA(gomock.Any()).Return(tt.ia, nil),
c.EXPECT().PortRange(gomock.Any()).Return(tt.start, tt.end, nil),
c.EXPECT().Interfaces(gomock.Any()).Return(tt.interfaces, nil),
}
}

func (tt testTopology) checkTopology(t *testing.T, topo snet.Topology) {
t.Helper()

assert.Equal(t, tt.ia, topo.LocalIA)
assert.Equal(t, tt.start, topo.PortRange.Start)
assert.Equal(t, tt.end, topo.PortRange.End)
for ifID, want := range tt.interfaces {
got, ok := topo.Interface(ifID)
assert.True(t, ok, "interface %d", ifID)
assert.Equal(t, want, got, "interface %d", ifID)
}
}

0 comments on commit 8bbf361

Please sign in to comment.