diff --git a/ipset_linux.go b/ipset_linux.go index f4c05229..7730992e 100644 --- a/ipset_linux.go +++ b/ipset_linux.go @@ -147,9 +147,11 @@ func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOption req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_TYPENAME, nl.ZeroTerminated(typename))) + cadtFlags := optionsToBitflag(options) + revision := options.Revision if revision == 0 { - revision = getIpsetDefaultWithTypeName(typename) + revision = getIpsetDefaultRevision(typename, cadtFlags) } req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_REVISION, nl.Uint8Attr(revision))) @@ -181,18 +183,6 @@ func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOption data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER, Value: *timeout}) } - var cadtFlags uint32 - - if options.Comments { - cadtFlags |= nl.IPSET_FLAG_WITH_COMMENT - } - if options.Counters { - cadtFlags |= nl.IPSET_FLAG_WITH_COUNTERS - } - if options.Skbinfo { - cadtFlags |= nl.IPSET_FLAG_WITH_SKBINFO - } - if cadtFlags != 0 { data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_CADT_FLAGS | nl.NLA_F_NET_BYTEORDER, Value: cadtFlags}) } @@ -395,14 +385,89 @@ func (h *Handle) newIpsetRequest(cmd int) *nl.NetlinkRequest { return req } -func getIpsetDefaultWithTypeName(typename string) uint8 { +// NOTE: This can't just take typename into account, it also has to take desired +// feature support into account, on a per-set-type basis, to return the correct revision, see e.g. +// https://github.com/Olipro/ipset/blob/9f145b49100104d6570fe5c31a5236816ebb4f8f/kernel/net/netfilter/ipset/ip_set_hash_ipport.c#L30 +// +// This means that whenever a new "type" of ipset is added, returning the "correct" default revision +// requires adding a new case here for that type, and consulting the ipset C code to figure out the correct +// combination of type name, feature bit flags, and revision ranges. +// +// Care should be taken as some types share the same revision ranges for the same features, and others do not. +// When in doubt, mimic the C code. +func getIpsetDefaultRevision(typename string, featureFlags uint32) uint8 { switch typename { case "hash:ip,port", - "hash:ip,port,ip", - "hash:ip,port,net", + "hash:ip,port,ip": + // Taken from + // - ipset/kernel/net/netfilter/ipset/ip_set_hash_ipport.c + // - ipset/kernel/net/netfilter/ipset/ip_set_hash_ipportip.c + if (featureFlags & nl.IPSET_FLAG_WITH_SKBINFO) != 0 { + return 5 + } + + if (featureFlags & nl.IPSET_FLAG_WITH_FORCEADD) != 0 { + return 4 + } + + if (featureFlags & nl.IPSET_FLAG_WITH_COMMENT) != 0 { + return 3 + } + + if (featureFlags & nl.IPSET_FLAG_WITH_COUNTERS) != 0 { + return 2 + } + + // the min revision this library supports for this type + return 1 + + case "hash:ip,port,net", "hash:net,port": + // Taken from + // - ipset/kernel/net/netfilter/ipset/ip_set_hash_ipportnet.c + // - ipset/kernel/net/netfilter/ipset/ip_set_hash_netport.c + if (featureFlags & nl.IPSET_FLAG_WITH_SKBINFO) != 0 { + return 7 + } + + if (featureFlags & nl.IPSET_FLAG_WITH_FORCEADD) != 0 { + return 6 + } + + if (featureFlags & nl.IPSET_FLAG_WITH_COMMENT) != 0 { + return 5 + } + + if (featureFlags & nl.IPSET_FLAG_WITH_COUNTERS) != 0 { + return 4 + } + + if (featureFlags & nl.IPSET_FLAG_NOMATCH) != 0 { + return 3 + } + // the min revision this library supports for this type + return 2 + + case "hash:ip": + // Taken from + // - ipset/kernel/net/netfilter/ipset/ip_set_hash_ip.c + if (featureFlags & nl.IPSET_FLAG_WITH_SKBINFO) != 0 { + return 4 + } + + if (featureFlags & nl.IPSET_FLAG_WITH_FORCEADD) != 0 { + return 3 + } + + if (featureFlags & nl.IPSET_FLAG_WITH_COMMENT) != 0 { + return 2 + } + + // the min revision this library supports for this type return 1 } + + // can't map the correct revision for this type. return 0 } @@ -579,3 +644,19 @@ func parseIPSetEntry(data []byte) (entry IPSetEntry) { } return } + +func optionsToBitflag(options IpsetCreateOptions) uint32 { + var cadtFlags uint32 + + if options.Comments { + cadtFlags |= nl.IPSET_FLAG_WITH_COMMENT + } + if options.Counters { + cadtFlags |= nl.IPSET_FLAG_WITH_COUNTERS + } + if options.Skbinfo { + cadtFlags |= nl.IPSET_FLAG_WITH_SKBINFO + } + + return cadtFlags +} diff --git a/ipset_linux_test.go b/ipset_linux_test.go index 298c3a3a..81a4086a 100644 --- a/ipset_linux_test.go +++ b/ipset_linux_test.go @@ -2,8 +2,8 @@ package netlink import ( "bytes" - "io/ioutil" "net" + "os" "testing" "github.com/vishvananda/netlink/nl" @@ -11,7 +11,7 @@ import ( ) func TestParseIpsetProtocolResult(t *testing.T) { - msgBytes, err := ioutil.ReadFile("testdata/ipset_protocol_result") + msgBytes, err := os.ReadFile("testdata/ipset_protocol_result") if err != nil { t.Fatalf("reading test fixture failed: %v", err) } @@ -23,7 +23,7 @@ func TestParseIpsetProtocolResult(t *testing.T) { } func TestParseIpsetListResult(t *testing.T) { - msgBytes, err := ioutil.ReadFile("testdata/ipset_list_result") + msgBytes, err := os.ReadFile("testdata/ipset_list_result") if err != nil { t.Fatalf("reading test fixture failed: %v", err) } @@ -759,3 +759,66 @@ func TestIpsetMaxElements(t *testing.T) { t.Fatalf("expected '%d' entry be created, got '%d'", maxElements, len(result.Entries)) } } + +func TestIpsetDefaultRevision(t *testing.T) { + testCases := []struct { + desc string + typename string + options IpsetCreateOptions + expectedRevision uint8 + }{ + { + desc: "Type-hash:ip,port", + typename: "hash:ip,port", + options: IpsetCreateOptions{ + Counters: true, + Comments: true, + Skbinfo: false, + }, + expectedRevision: 3, + }, + { + desc: "Type-hash:ip,port_nocomment", + typename: "hash:ip,port", + options: IpsetCreateOptions{ + Counters: true, + Comments: false, + Skbinfo: false, + }, + expectedRevision: 2, + }, + { + desc: "Type-hash:ip,port_skbinfo", + typename: "hash:ip,port", + options: IpsetCreateOptions{ + Counters: true, + Comments: false, + Skbinfo: true, + }, + expectedRevision: 5, + }, + { + desc: "Type-hash:ip,port,net", + typename: "hash:ip,port,net", + options: IpsetCreateOptions{ + Counters: true, + Comments: false, + Skbinfo: true, + }, + expectedRevision: 7, + }, + } + + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + + cadtFlags := optionsToBitflag(tC.options) + + defRev := getIpsetDefaultRevision(tC.typename, cadtFlags) + + if defRev != tC.expectedRevision { + t.Fatalf("expected default revision of '%d', got '%d'", tC.expectedRevision, defRev) + } + }) + } +}