diff --git a/network/protocol.go b/network/protocol.go index 45e337559..84df1e4ac 100644 --- a/network/protocol.go +++ b/network/protocol.go @@ -5,6 +5,8 @@ package network +import "strings" + type Protocol string const ( @@ -12,3 +14,7 @@ const ( UDP Protocol = "udp" ARP Protocol = "arp" ) + +func (p Protocol) String() string { + return strings.ToLower(string(p)) +} diff --git a/network/tc.go b/network/tc.go index 3f9025be0..a5b2794f0 100644 --- a/network/tc.go +++ b/network/tc.go @@ -204,11 +204,11 @@ func (t *tc) AddFilter(ifaces []string, parent string, handle uint32, srcIP, dst var params, filterProtocol string // match protocol if specified, default to tcp otherwise - switch protocol { - case TCP, UDP: + switch protocol.String() { + case TCP.String(), UDP.String(): filterProtocol = "ip" - params += fmt.Sprintf("ip_proto %s ", strings.ToLower(string(protocol))) - case ARP: + params += fmt.Sprintf("ip_proto %s ", protocol.String()) + case ARP.String(): filterProtocol = "arp" default: return 0, fmt.Errorf("unexpected protocol: %s", protocol) diff --git a/network/tc_test.go b/network/tc_test.go index 1021eb878..724b62a1b 100644 --- a/network/tc_test.go +++ b/network/tc_test.go @@ -78,7 +78,7 @@ var _ = Describe("Tc", func() { } srcPort = 12345 dstPort = 80 - protocol = TCP + protocol = "TCP" connState = ConnStateNew flowid = "1:2" })