diff --git a/rib/rib.go b/rib/rib.go index 371285a3..4f094772 100644 --- a/rib/rib.go +++ b/rib/rib.go @@ -604,8 +604,12 @@ func (r *RIB) DeleteEntry(ni string, op *spb.AFTOperation) ([]*OpResult, []*OpRe return referencingRIB, nil } - // TODO(robjs): currently, the post-change hook is not called for deletes. Add - // support for calling this hook after delete. + var ( + callHook bool + aft constants.AFT + key any + ) + switch { case err != nil: fails = append(fails, &OpResult{ @@ -622,6 +626,9 @@ func (r *RIB) DeleteEntry(ni string, op *spb.AFTOperation) ([]*OpResult, []*OpRe return nil, nil, err } referencingRIB.decNHGRefCount(originalv4.GetNextHopGroup()) + callHook = true + aft = constants.IPv4 + key = originalv4.GetPrefix() case originalNHG != nil: for id := range originalNHG.NextHop { niR.decNHRefCount(id) @@ -632,6 +639,9 @@ func (r *RIB) DeleteEntry(ni string, op *spb.AFTOperation) ([]*OpResult, []*OpRe return nil, nil, err } referencingRIB.decNHGRefCount(originalMPLS.GetNextHopGroup()) + callHook = true + aft = constants.MPLS + key = originalMPLS.GetLabel() } log.V(2).Infof("operation %d deleted from RIB successfully", op.GetId()) @@ -645,6 +655,12 @@ func (r *RIB) DeleteEntry(ni string, op *spb.AFTOperation) ([]*OpResult, []*OpRe Op: op, }) } + + if callHook { + if err := r.callResolvedEntryHook(constants.Delete, ni, aft, key); err != nil { + return oks, fails, fmt.Errorf("cannot run resolvedEntryHook, %v", err) + } + } return oks, fails, nil } diff --git a/rib/rib_test.go b/rib/rib_test.go index f6800c42..636c10af 100644 --- a/rib/rib_test.go +++ b/rib/rib_test.go @@ -3189,6 +3189,30 @@ func TestResolvedEntryHook(t *testing.T) { } gotCh := make(chan interface{}) + + stringHook := func(_ map[string]*aft.RIB, op constants.OpType, netinst string, aft constants.AFT, prefix any, _ ...ResolvedDetails) { + s := fmt.Sprintf("%s %s:%v->%v", op, aft, netinst, prefix) + t.Logf("writing to channel: %s", s) + gotCh <- s + } + + stringChecker := func(s string) func() error { + return func() error { + got := <-gotCh + switch t := got.(type) { + case error: + return t + case string: + if t != s { + return fmt.Errorf("got incorrect string, got: %s, want: %s", got, s) + } + return nil + default: + return fmt.Errorf("got unexpected type, got: %T, want: string", got) + } + } + } + tests := []struct { desc string inRIB *RIB @@ -3303,6 +3327,198 @@ func TestResolvedEntryHook(t *testing.T) { } return nil }, + }, { + desc: "delete ipv4 entry", + inRIB: func() *RIB { + r := baseRIB() + _, _, err := r.AddEntry(defName, &spb.AFTOperation{ + Entry: &spb.AFTOperation_Ipv4{ + Ipv4: &aftpb.Afts_Ipv4EntryKey{ + Prefix: "10.0.0.0/8", + Ipv4Entry: &aftpb.Afts_Ipv4Entry{ + NextHopGroup: &wpb.UintValue{Value: 1}, + }, + }, + }, + }) + if err != nil { + t.Fatalf("cannot build complete RIB, error: %v", err) + } + return r + }(), + inOperation: &spb.AFTOperation{ + Id: 42, + Op: spb.AFTOperation_DELETE, + Entry: &spb.AFTOperation_Ipv4{ + Ipv4: &aftpb.Afts_Ipv4EntryKey{ + Prefix: "10.0.0.0/8", + }, + }, + }, + inHook: stringHook, + checkFn: stringChecker("Delete IPv4:DEFAULT->10.0.0.0/8"), + }, { + desc: "delete mpls entry", + inRIB: func() *RIB { + r := baseRIB() + _, _, err := r.AddEntry(defName, &spb.AFTOperation{ + Entry: &spb.AFTOperation_Mpls{ + Mpls: &aftpb.Afts_LabelEntryKey{ + Label: &aftpb.Afts_LabelEntryKey_LabelUint64{ + LabelUint64: 42, + }, + LabelEntry: &aftpb.Afts_LabelEntry{ + NextHopGroup: &wpb.UintValue{Value: 1}, + }, + }, + }}) + if err != nil { + t.Fatalf("cannot build complete RIB, error: %v", err) + } + return r + }(), + inOperation: &spb.AFTOperation{ + Id: 42, + Op: spb.AFTOperation_DELETE, + Entry: &spb.AFTOperation_Mpls{ + Mpls: &aftpb.Afts_LabelEntryKey{ + Label: &aftpb.Afts_LabelEntryKey_LabelUint64{ + LabelUint64: 42, + }, + }, + }, + }, + inHook: stringHook, + checkFn: stringChecker("Delete MPLS:DEFAULT->42"), + }, { + desc: "replace ipv4 entry", + inRIB: func() *RIB { + r := baseRIB() + + ops := []*spb.AFTOperation{{ + Entry: &spb.AFTOperation_NextHop{ + NextHop: &aftpb.Afts_NextHopKey{ + Index: 2, + NextHop: &aftpb.Afts_NextHop{ + IpAddress: &wpb.StringValue{Value: "2.2.2.2"}, + }, + }, + }, + }, { + Entry: &spb.AFTOperation_NextHopGroup{ + NextHopGroup: &aftpb.Afts_NextHopGroupKey{ + Id: 2, + NextHopGroup: &aftpb.Afts_NextHopGroup{ + NextHop: []*aftpb.Afts_NextHopGroup_NextHopKey{{ + Index: 2, + NextHop: &aftpb.Afts_NextHopGroup_NextHop{ + Weight: &wpb.UintValue{Value: 32}, + }, + }}, + }, + }, + }, + }, { + Op: spb.AFTOperation_ADD, + Entry: &spb.AFTOperation_Ipv4{ + Ipv4: &aftpb.Afts_Ipv4EntryKey{ + Prefix: "10.0.0.0/8", + Ipv4Entry: &aftpb.Afts_Ipv4Entry{ + NextHopGroup: &wpb.UintValue{Value: 1}, + }, + }, + }, + }} + for i, op := range ops { + op.Id = uint64(i) + op.Op = spb.AFTOperation_ADD + + if _, _, err := r.AddEntry(defName, op); err != nil { + t.Fatalf("cannot add entry %s, %v", prototext.Format(op), err) + } + } + return r + }(), + inOperation: &spb.AFTOperation{ + Id: 42, + Op: spb.AFTOperation_REPLACE, + Entry: &spb.AFTOperation_Ipv4{ + Ipv4: &aftpb.Afts_Ipv4EntryKey{ + Prefix: "10.0.0.0/8", + Ipv4Entry: &aftpb.Afts_Ipv4Entry{ + NextHopGroup: &wpb.UintValue{Value: 2}, + }, + }, + }, + }, + inHook: stringHook, + checkFn: stringChecker("Add IPv4:DEFAULT->10.0.0.0/8"), + }, { + desc: "replace mpls entry", + inRIB: func() *RIB { + r := baseRIB() + + ops := []*spb.AFTOperation{{ + Entry: &spb.AFTOperation_NextHop{ + NextHop: &aftpb.Afts_NextHopKey{ + Index: 2, + NextHop: &aftpb.Afts_NextHop{ + IpAddress: &wpb.StringValue{Value: "2.2.2.2"}, + }, + }, + }, + }, { + Entry: &spb.AFTOperation_NextHopGroup{ + NextHopGroup: &aftpb.Afts_NextHopGroupKey{ + Id: 2, + NextHopGroup: &aftpb.Afts_NextHopGroup{ + NextHop: []*aftpb.Afts_NextHopGroup_NextHopKey{{ + Index: 2, + NextHop: &aftpb.Afts_NextHopGroup_NextHop{ + Weight: &wpb.UintValue{Value: 32}, + }, + }}, + }, + }, + }, + }, { + Entry: &spb.AFTOperation_Mpls{ + Mpls: &aftpb.Afts_LabelEntryKey{ + Label: &aftpb.Afts_LabelEntryKey_LabelUint64{ + LabelUint64: 42, + }, + LabelEntry: &aftpb.Afts_LabelEntry{ + NextHopGroup: &wpb.UintValue{Value: 1}, + }, + }, + }, + }} + for i, op := range ops { + op.Id = uint64(i) + op.Op = spb.AFTOperation_ADD + + if _, _, err := r.AddEntry(defName, op); err != nil { + t.Fatalf("cannot add entry %s, %v", prototext.Format(op), err) + } + } + return r + }(), + inOperation: &spb.AFTOperation{ + Id: 42, + Op: spb.AFTOperation_REPLACE, + Entry: &spb.AFTOperation_Mpls{ + Mpls: &aftpb.Afts_LabelEntryKey{ + Label: &aftpb.Afts_LabelEntryKey_LabelUint64{ + LabelUint64: 42, + }, + LabelEntry: &aftpb.Afts_LabelEntry{ + NextHopGroup: &wpb.UintValue{Value: 2}, + }, + }, + }, + }, + inHook: stringHook, + checkFn: stringChecker("Add MPLS:DEFAULT->42"), }} for _, tt := range tests { @@ -3318,13 +3534,22 @@ func TestResolvedEntryHook(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - if err := tt.checkFn(); err != nil { - checkErr = err + if tt.checkFn != nil { + if err := tt.checkFn(); err != nil { + checkErr = err + } } }() - if _, fails, err := tt.inRIB.AddEntry(defName, tt.inOperation); err != nil || len(fails) != 0 { - t.Fatalf("did not successfully add entry, error: %v, fails: %v", err, fails) + switch tt.inOperation.Op { + case spb.AFTOperation_ADD, spb.AFTOperation_REPLACE: + if _, fails, err := tt.inRIB.AddEntry(defName, tt.inOperation); err != nil || len(fails) != 0 { + t.Fatalf("did not successfully add entry, error: %v, fails: %v", err, fails) + } + case spb.AFTOperation_DELETE: + if _, fails, err := tt.inRIB.DeleteEntry(defName, tt.inOperation); err != nil || len(fails) != 0 { + t.Fatalf("did not successfully delete entry, error: %v', fails: %v", err, fails) + } } wg.Wait() if checkErr != nil {