diff --git a/controllers/service_controller.go b/controllers/service_controller.go index 41411db..1a2cb9c 100644 --- a/controllers/service_controller.go +++ b/controllers/service_controller.go @@ -98,7 +98,7 @@ func (r *ServiceReconciler) Reconcile(req ctrl.Request) (ctrl.Result, error) { } logger.WithValues("IPv4", configIPv4Enabled).WithValues("IPv6", configIPv6Enabled).Info("config network stack") - // Get Nodes's pod CIDR + // Get nodes's pod CIDR node := &corev1.Node{} if err := r.Client.Get(ctx, types.NamespacedName{Name: configNodeName}, node); err != nil { logger.Error(err, "failed to get the pod's node info from API server") @@ -108,7 +108,23 @@ func (r *ServiceReconciler) Reconcile(req ctrl.Request) (ctrl.Result, error) { logger.WithValues("pod CIDR IPV4", podCIDRIPv4).WithValues("pod CIDR IPv6", podCIDRIPv6).Info("pod CIDR") // Init iptables - initIptables(logger) + if err := initIptables(logger); err != nil { + logger.Error(err, "failed to init iptables") + os.Exit(1) + } + + // Get all services + svcs := &corev1.ServiceList{} + if err := r.Client.List(ctx, svcs, client.InNamespace("")); err != nil { + logger.Error(err, "failed to get all services from API server") + os.Exit(1) + } + + // Cleanup iptables for deleted services + if err := cleanupIptables(logger, svcs, podCIDRIPv4, podCIDRIPv6); err != nil { + logger.Error(err, "failed to cleanup iptables") + os.Exit(1) + } } // Get service info @@ -132,7 +148,7 @@ func (r *ServiceReconciler) Reconcile(req ctrl.Request) (ctrl.Result, error) { // Delete iptables rules logger.WithValues("externalIP", oldExternalIP).Info("delete iptables rules") - if err := deleteIptablesRules(logger, req, oldClusterIP, oldExternalIP, podCIDRIPv4, podCIDRIPv6); err != nil { + if err := deleteIptablesRules(logger, &req, oldClusterIP, oldExternalIP, podCIDRIPv4, podCIDRIPv6); err != nil { return ctrl.Result{}, err } } @@ -148,7 +164,7 @@ func (r *ServiceReconciler) Reconcile(req ctrl.Request) (ctrl.Result, error) { return ctrl.Result{}, nil } - // Create or Delete iptables rules + // Create iptables rules for _, ingress := range svc.Status.LoadBalancer.Ingress { clusterIP := svc.Spec.ClusterIP externalIP := ingress.IP @@ -158,7 +174,7 @@ func (r *ServiceReconciler) Reconcile(req ctrl.Request) (ctrl.Result, error) { // Create iptables rules logger.WithValues("externalIP", externalIP).Info("create iptables rules") - if err := createIptablesRules(logger, req, clusterIP, externalIP, podCIDRIPv4, podCIDRIPv6); err != nil { + if err := createIptablesRules(logger, &req, clusterIP, externalIP, podCIDRIPv4, podCIDRIPv6); err != nil { return ctrl.Result{}, err } } @@ -174,6 +190,8 @@ func (r *ServiceReconciler) SetupWithManager(mgr ctrl.Manager) error { } func initIptables(logger logr.Logger) error { + logger.Info("create iptables chains") + // IPv4 if configIPv4Enabled { // Create chain in nat table @@ -238,10 +256,177 @@ func initIptables(logger logr.Logger) error { return nil } -func createIptablesRules(logger logr.Logger, req ctrl.Request, clusterIP, externalIP, podCIDRIPv4, podCIDRIPv6 string) error { +func cleanupIptables(logger logr.Logger, svcs *corev1.ServiceList, podCIDRIPv4, podCIDRIPv6 string) error { + logger.Info("cleanup iptables for deleted services") + + // IPv4 + if configIPv4Enabled { + // Make up service map + svcMap := make(map[string]*corev1.Service) + for _, svc := range svcs.Items { + if ip.IsIPv4Addr(svc.Spec.ClusterIP) && svc.Spec.Type == corev1.ServiceTypeLoadBalancer { + svcMap[svc.Namespace+"/"+svc.Name] = svc.DeepCopy() + } + } + + // Cleanup prerouting chain + preRules, err := iptables.GetRulesIPv4(iptables.TableNAT, ChainNATIPVSPrerouting) + if err != nil { + return err + } + for _, rule := range preRules { + // Get service info from rule and k8s, and delete iptables rules + nsName, src, dest, jump, dnatDest := getSvcInfoFromRule(rule) + svc, ok := svcMap[nsName] + if !ok { + logger.WithValues("rule", rule).Info("there is no service info in k8s. cleanup prerouting chain IPv4 rule") + out, err := iptables.DeleteRuleRawIPv4(iptables.TableNAT, iptables.ChangeRuleToDelete(rule)...) + if err != nil { + logger.Error(err, out) + return err + } + continue + } + + // Compare service info and delete iptables rules + for _, ingress := range svc.Status.LoadBalancer.Ingress { + externalIP := ingress.IP + "/32" + if (jump == ChainNATKubeMasquerade && (src == podCIDRIPv4 && dest == externalIP)) || + (jump == "DNAT" && (src == podCIDRIPv4 && dest == externalIP && dnatDest == svc.Spec.ClusterIP)) { + continue + } + logger.WithValues("rule", rule).Info("service info is diff. cleanup prerouting chain IPv4 rule") + out, err := iptables.DeleteRuleRawIPv4(iptables.TableNAT, iptables.ChangeRuleToDelete(rule)...) + if err != nil { + logger.Error(err, out) + return err + } + } + } + + // Cleanup output chain + outRules, err := iptables.GetRulesIPv4(iptables.TableNAT, ChainNATIPVSOutput) + if err != nil { + return err + } + for _, rule := range outRules { + // Get service info from rule and k8s, and delete iptables rules + nsName, _, dest, jump, dnatDest := getSvcInfoFromRule(rule) + svc, ok := svcMap[nsName] + if !ok { + logger.WithValues("rule", rule).Info("there is no service info in k8s. cleanup output chain IPv4 rule") + out, err := iptables.DeleteRuleRawIPv4(iptables.TableNAT, iptables.ChangeRuleToDelete(rule)...) + if err != nil { + logger.Error(err, out) + return err + } + continue + } + + // Compare service info and delete diff iptables rules + for _, ingress := range svc.Status.LoadBalancer.Ingress { + externalIP := ingress.IP + "/32" + if (jump == ChainNATKubeMasquerade && dest == externalIP) || + (jump == "DNAT" && (dest == externalIP && dnatDest == svc.Spec.ClusterIP)) { + continue + } + logger.WithValues("rule", rule).Info("service info is diff. cleanup output chain IPv4 rule") + out, err := iptables.DeleteRuleRawIPv4(iptables.TableNAT, iptables.ChangeRuleToDelete(rule)...) + if err != nil { + logger.Error(err, out) + return err + } + } + } + } + // IPv6 + if configIPv6Enabled { + // Make up service map + svcMap := make(map[string]*corev1.Service) + for _, svc := range svcs.Items { + if ip.IsIPv6Addr(svc.Spec.ClusterIP) && svc.Spec.Type == corev1.ServiceTypeLoadBalancer { + svcMap[svc.Namespace+"/"+svc.Name] = svc.DeepCopy() + } + } + + // Cleanup prerouting chain + preRules, err := iptables.GetRulesIPv6(iptables.TableNAT, ChainNATIPVSPrerouting) + if err != nil { + return err + } + for _, rule := range preRules { + // Get service info from rule and k8s, and delete iptables rules + nsName, src, dest, jump, dnatDest := getSvcInfoFromRule(rule) + svc, ok := svcMap[nsName] + if !ok { + logger.WithValues("rule", rule).Info("there is no service info in k8s. cleanup prerouting chain IPv6 rule") + out, err := iptables.DeleteRuleRawIPv6(iptables.TableNAT, iptables.ChangeRuleToDelete(rule)...) + if err != nil { + logger.Error(err, out) + return err + } + continue + } + + // Compare service info and delete iptables rules + for _, ingress := range svc.Status.LoadBalancer.Ingress { + externalIP := ingress.IP + "/128" + if (jump == ChainNATKubeMasquerade && (src == podCIDRIPv6 && dest == externalIP)) || + (jump == "DNAT" && (src == podCIDRIPv6 && dest == externalIP && dnatDest == svc.Spec.ClusterIP)) { + continue + } + logger.WithValues("rule", rule).Info("service info is diff. cleanup prerouting chain IPv6 rule") + out, err := iptables.DeleteRuleRawIPv6(iptables.TableNAT, iptables.ChangeRuleToDelete(rule)...) + if err != nil { + logger.Error(err, out) + return err + } + } + } + + // Cleanup output + outRules, err := iptables.GetRulesIPv6(iptables.TableNAT, ChainNATIPVSOutput) + if err != nil { + return err + } + for _, rule := range outRules { + // Get service info from rule and k8s, and delete iptables rules + nsName, _, dest, jump, dnatDest := getSvcInfoFromRule(rule) + svc, ok := svcMap[nsName] + if !ok { + logger.WithValues("rule", rule).Info("there is no service info in k8s. cleanup output chain IPv6 rule") + out, err := iptables.DeleteRuleRawIPv6(iptables.TableNAT, iptables.ChangeRuleToDelete(rule)...) + if err != nil { + logger.Error(err, out) + return err + } + continue + } + + // Compare service info and delete diff iptables rules + for _, ingress := range svc.Status.LoadBalancer.Ingress { + externalIP := ingress.IP + "/128" + if (jump == ChainNATKubeMasquerade && dest == externalIP) || + (jump == "DNAT" && (dest == externalIP && dnatDest == svc.Spec.ClusterIP)) { + continue + } + logger.WithValues("rule", rule).Info("service info is diff. cleanup output chain IPv6 rule") + out, err := iptables.DeleteRuleRawIPv6(iptables.TableNAT, iptables.ChangeRuleToDelete(rule)...) + if err != nil { + logger.Error(err, out) + return err + } + } + } + } + + return nil +} + +func createIptablesRules(logger logr.Logger, req *ctrl.Request, clusterIP, externalIP, podCIDRIPv4, podCIDRIPv6 string) error { // Don't use spec.ipFamily to distingush between IPv4 and IPv6 Address // for kubernetes version that dosen't support IPv6 dualstack - if configIPv4Enabled && ip.IsIPv4Addr(externalIP) { + if configIPv4Enabled && ip.IsIPv4Addr(clusterIP) { // IPv4 // Set prerouting rulePreMasq := []string{"-s", podCIDRIPv4, "-d", externalIP, "-j", ChainNATKubeMasquerade} @@ -270,7 +455,7 @@ func createIptablesRules(logger logr.Logger, req ctrl.Request, clusterIP, extern logger.Error(err, out) return err } - } else if configIPv6Enabled && ip.IsIPv6Addr(externalIP) { + } else if configIPv6Enabled && ip.IsIPv6Addr(clusterIP) { // IPv6 // Set prerouting rulePreMasq := []string{"-s", podCIDRIPv6, "-d", externalIP, "-j", ChainNATKubeMasquerade} @@ -300,18 +485,18 @@ func createIptablesRules(logger logr.Logger, req ctrl.Request, clusterIP, extern return err } } else { - if ip.IsVaildIP(externalIP) { - logger.WithValues("externalIP", externalIP).Error(errors.New("invalid IP"), "invaild IP") + if ip.IsVaildIP(clusterIP) { + logger.WithValues("clusterIP", clusterIP).Error(errors.New("invalid IP"), "invaild IP") } } return nil } -func deleteIptablesRules(logger logr.Logger, req ctrl.Request, clusterIP, externalIP, podCIDRIPv4, podCIDRIPv6 string) error { +func deleteIptablesRules(logger logr.Logger, req *ctrl.Request, clusterIP, externalIP, podCIDRIPv4, podCIDRIPv6 string) error { // Don't use spec.ipFamily to distingush between IPv4 and IPv6 Address // for kubernetes version that dosen't support IPv6 dualstack - if configIPv4Enabled && ip.IsIPv4Addr(externalIP) { + if configIPv4Enabled && ip.IsIPv4Addr(clusterIP) { // IPv4 // Unset prerouting rulePreMasq := []string{"-s", podCIDRIPv4, "-d", externalIP, "-j", ChainNATKubeMasquerade} @@ -340,7 +525,7 @@ func deleteIptablesRules(logger logr.Logger, req ctrl.Request, clusterIP, extern logger.Error(err, out) return err } - } else if configIPv6Enabled && ip.IsIPv6Addr(externalIP) { + } else if configIPv6Enabled && ip.IsIPv6Addr(clusterIP) { // IPv6 // Unset prerouting rulePreMasq := []string{"-s", podCIDRIPv6, "-d", externalIP, "-j", ChainNATKubeMasquerade} @@ -370,8 +555,8 @@ func deleteIptablesRules(logger logr.Logger, req ctrl.Request, clusterIP, extern return err } } else { - if ip.IsVaildIP(externalIP) { - logger.WithValues("externalIP", externalIP).Error(errors.New("invalid IP"), "invaild IP") + if ip.IsVaildIP(clusterIP) { + logger.WithValues("clusterIP", clusterIP).Error(errors.New("invalid IP"), "invaild IP") } } return nil @@ -388,3 +573,12 @@ func getPodCIDR(cidrs []string) (ipv4CIDR string, ipv6CIDR string) { } return } + +func getSvcInfoFromRule(rule string) (nsName, src, dest, jump, dnatDest string) { + nsName = iptables.GetRuleComment(rule) + src = iptables.GetRuleSrc(rule) + dest = iptables.GetRuleDest(rule) + jump = iptables.GetRuleJump(rule) + dnatDest = iptables.GetRuleDNATDest(rule) + return +} diff --git a/pkg/configs/configs_test.go b/pkg/configs/configs_test.go index 238f620..cc836ef 100644 --- a/pkg/configs/configs_test.go +++ b/pkg/configs/configs_test.go @@ -6,7 +6,7 @@ import ( ) func TestGetConfigNodeName(t *testing.T) { - os.Setenv(EnvNodeName, "node") + _ = os.Setenv(EnvNodeName, "node") nodeName, _ := GetConfigNodeName() if nodeName != "node" { t.Errorf("wrong result - %s", "node") @@ -14,49 +14,49 @@ func TestGetConfigNodeName(t *testing.T) { } func TestGetConfigNetStack(t *testing.T) { - os.Setenv(EnvNetStack, "ipv4") + _ = os.Setenv(EnvNetStack, "ipv4") ipv4, ipv6, _ := GetConfigNetStack() if ipv4 != true || ipv6 != false { t.Errorf("wrong result - %s", "ipv4") } - os.Setenv(EnvNetStack, "IPV4") + _ = os.Setenv(EnvNetStack, "IPV4") ipv4, ipv6, _ = GetConfigNetStack() if ipv4 != true || ipv6 != false { t.Errorf("wrong result - %s", "IPV4") } - os.Setenv(EnvNetStack, "ipv6") + _ = os.Setenv(EnvNetStack, "ipv6") ipv4, ipv6, _ = GetConfigNetStack() if ipv4 != false || ipv6 != true { t.Errorf("wrong result - %s", "ipv6") } - os.Setenv(EnvNetStack, "IPV6") + _ = os.Setenv(EnvNetStack, "IPV6") ipv4, ipv6, _ = GetConfigNetStack() if ipv4 != false || ipv6 != true { t.Errorf("wrong result - %s", "IPV6") } - os.Setenv(EnvNetStack, "ipv5") + _ = os.Setenv(EnvNetStack, "ipv5") _, _, err := GetConfigNetStack() if err == nil { t.Errorf("wrong result - %s", "ipv5") } - os.Setenv(EnvNetStack, "ipv4,ipv6") + _ = os.Setenv(EnvNetStack, "ipv4,ipv6") ipv4, ipv6, _ = GetConfigNetStack() if ipv4 != true || ipv6 != true { t.Errorf("wrong result - %s", "ipv4,ipv6") } - os.Setenv(EnvNetStack, "ipv4, ipv6") + _ = os.Setenv(EnvNetStack, "ipv4, ipv6") ipv4, ipv6, _ = GetConfigNetStack() if ipv4 != true || ipv6 != true { t.Errorf("wrong result - %s", "ipv4, ipv6") } - os.Setenv(EnvNetStack, "ipv6, ipv4") + _ = os.Setenv(EnvNetStack, "ipv6, ipv4") ipv4, ipv6, _ = GetConfigNetStack() if ipv4 != true || ipv6 != true { t.Errorf("wrong result - %s", "ipv6,ipv4") diff --git a/pkg/iptables/iptables.go b/pkg/iptables/iptables.go index 7393e7b..deadc50 100644 --- a/pkg/iptables/iptables.go +++ b/pkg/iptables/iptables.go @@ -12,9 +12,11 @@ type Table string // Const const ( - iptablesCmdIPv4 = "iptables" - iptablesCmdIPv6 = "ip6tables" - iptablesErrNoRule = "No chain/target/match by that name" + iptablesCmdIPv4 = "iptables" + iptablesCmdIPv6 = "ip6tables" + iptablesSaveCmdIPv4 = "iptables-save" + iptablesSaveCmdIPv6 = "ip6tables-save" + iptablesErrNoRule = "No chain/target/match by that name" TableNAT Table = "nat" TableFilter Table = "filter" @@ -150,6 +152,40 @@ func isExistRule(iptablesCmd string, table Table, chain string, comment string, return true } +// GetRules +func GetRulesIPv4(table Table, chain string) ([]string, error) { + return getRules(iptablesSaveCmdIPv4, table, chain) +} + +func GetRulesIPv6(table Table, chain string) ([]string, error) { + return getRules(iptablesSaveCmdIPv6, table, chain) +} + +func getRules(iptablesSaveCmd string, table Table, chain string) ([]string, error) { + // Lock + lock.Lock() + defer lock.Unlock() + + // Set Common args + args := []string{"-t", string(table)} + + // Check rule + cmd := exec.Command(iptablesSaveCmd, args...) + out, err := cmd.CombinedOutput() + if err != nil { + return nil, err + } + + var result []string + for _, rule := range strings.Split(string(out), "\n") { + if strings.HasPrefix(rule, "-A "+chain) { + result = append(result, rule) + } + } + + return result, nil +} + // CreateRuleFirst func CreateRuleFirstIPv4(table Table, chain string, comment string, rule ...string) (string, error) { return createRuleFirst(iptablesCmdIPv4, table, chain, comment, rule...) @@ -265,3 +301,41 @@ func deleteRule(iptablesCmd string, table Table, chain string, comment string, r return string(out), nil } + +// DeleteRuleRaw +func DeleteRuleRawIPv4(table Table, rule ...string) (string, error) { + return deleteRuleRaw(iptablesCmdIPv4, table, rule...) +} + +func DeleteRuleRawIPv6(table Table, rule ...string) (string, error) { + return deleteRuleRaw(iptablesCmdIPv6, table, rule...) +} + +func deleteRuleRaw(iptablesCmd string, table Table, rule ...string) (string, error) { + // Lock + lock.Lock() + defer lock.Unlock() + + // Set tables args + args := []string{"-t", string(table)} + + // Check rule + cmd := exec.Command(iptablesCmd, append(append(args, "-C"), rule...)...) + out, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(out), iptablesErrNoRule) { + // If rule isn't exist, return success + return string(out), nil + } + return string(out), err + } + + // Delete rule + cmd = exec.Command(iptablesCmd, append(append(args, "-D"), rule...)...) + out, err = cmd.CombinedOutput() + if err != nil { + return string(out), err + } + + return string(out), nil +} diff --git a/pkg/iptables/iptables_test.go b/pkg/iptables/iptables_test.go index 7671e30..2da08b7 100644 --- a/pkg/iptables/iptables_test.go +++ b/pkg/iptables/iptables_test.go @@ -5,7 +5,8 @@ import ( ) const ( - chainTest = "TestChain" + chainTest = "TestChain" + commentTest = "TestComment" ) var ( @@ -31,38 +32,64 @@ func TestCreateChainIPv6(t *testing.T) { } } -func TestCreateDeleteRuleIPv4(t *testing.T) { +func TestCreateGetDeleteRuleIPv4(t *testing.T) { // Create - if out, err := CreateRuleFirstIPv4(TableNAT, chainTest, "create first rule", ruleDNATIPv4...); err != nil { + if out, err := CreateRuleFirstIPv4(TableNAT, chainTest, commentTest, ruleDNATIPv4...); err != nil { t.Errorf("create first rule IPv4 - out:%s", out) } - if !IsExistRuleIPv4(TableNAT, chainTest, "create first rule", ruleDNATIPv4...) { + if !IsExistRuleIPv4(TableNAT, chainTest, commentTest, ruleDNATIPv4...) { t.Errorf("check created rule IPv4") } + // Get + rules, err := GetRulesIPv4(TableNAT, chainTest) + if err != nil { + t.Errorf("get rules IPv4") + } + if len(rules) == 0 { + t.Errorf("no rules IPv4") + } + expectedRule := "-A " + chainTest + " -m comment --comment " + commentTest + " -j DNAT --to-destination 192.168.0.1" + if rules[0] != expectedRule { + t.Errorf("rule is different. expected:%s / actual:%s", expectedRule, rules[0]) + } + // Delete - if out, err := DeleteRuleIPv4(TableNAT, chainTest, "create first rule", ruleDNATIPv4...); err != nil { + if out, err := DeleteRuleIPv4(TableNAT, chainTest, commentTest, ruleDNATIPv4...); err != nil { t.Errorf("delete rule IPv4 - out:%s", out) } - if IsExistRuleIPv4(TableNAT, chainTest, "create first rule", ruleDNATIPv4...) { + if IsExistRuleIPv4(TableNAT, chainTest, commentTest, ruleDNATIPv4...) { t.Errorf("check deleted rule IPv4") } } -func TestCreateDeleteRuleIPv6(t *testing.T) { +func TestCreateGetDeleteRuleIPv6(t *testing.T) { // Create - if out, err := CreateRuleFirstIPv6(TableNAT, chainTest, "create first rule", ruleDNATIPv6...); err != nil { + if out, err := CreateRuleFirstIPv6(TableNAT, chainTest, commentTest, ruleDNATIPv6...); err != nil { t.Errorf("create first rule IPv6 - out:%s", out) } - if !IsExistRuleIPv6(TableNAT, chainTest, "create first rule", ruleDNATIPv6...) { + if !IsExistRuleIPv6(TableNAT, chainTest, commentTest, ruleDNATIPv6...) { t.Errorf("check created rule IPv6") } + // Get + rules, err := GetRulesIPv6(TableNAT, chainTest) + if err != nil { + t.Errorf("get rules IPv6") + } + if len(rules) == 0 { + t.Errorf("no rules IPv6") + } + expectedRule := "-A " + chainTest + " -m comment --comment " + commentTest + " -j DNAT --to-destination fdaa::1" + if rules[0] != expectedRule { + t.Errorf("rule is different. expected:%s / actual:%s", expectedRule, rules[0]) + } + // Delete - if out, err := DeleteRuleIPv6(TableNAT, chainTest, "create first rule", ruleDNATIPv6...); err != nil { + if out, err := DeleteRuleIPv6(TableNAT, chainTest, commentTest, ruleDNATIPv6...); err != nil { t.Errorf("delete rule IPv6 - out:%s", out) } - if IsExistRuleIPv6(TableNAT, chainTest, "create first rule", ruleDNATIPv6...) { + if IsExistRuleIPv6(TableNAT, chainTest, commentTest, ruleDNATIPv6...) { t.Errorf("check deleted rule IPv6") } } diff --git a/pkg/iptables/rules.go b/pkg/iptables/rules.go new file mode 100644 index 0000000..3d91d12 --- /dev/null +++ b/pkg/iptables/rules.go @@ -0,0 +1,43 @@ +package iptables + +import ( + "strings" +) + +func getValue(rule, opt string) string { + tokens := strings.Split(rule, " ") + for i, token := range tokens { + if token == opt { + if len(tokens)-1 >= i+1 { + return tokens[i+1] + } else { + return "" + } + } + } + return "" +} + +func GetRuleComment(rule string) string { + return strings.Trim(getValue(rule, "--comment"), "\"") +} + +func GetRuleSrc(rule string) string { + return getValue(rule, "-s") +} + +func GetRuleDest(rule string) string { + return getValue(rule, "-d") +} + +func GetRuleJump(rule string) string { + return getValue(rule, "-j") +} + +func GetRuleDNATDest(rule string) string { + return getValue(rule, "--to-destination") +} + +func ChangeRuleToDelete(rule string) []string { + return strings.Split(strings.ReplaceAll(rule[3:], "\"", ""), " ") +} diff --git a/pkg/iptables/rules_test.go b/pkg/iptables/rules_test.go new file mode 100644 index 0000000..00f40cb --- /dev/null +++ b/pkg/iptables/rules_test.go @@ -0,0 +1,49 @@ +package iptables + +import ( + "reflect" + "testing" +) + +const ( + ruleTest = "-A testChain -s 192.168.0.1 -d 192.168.0.2 -m comment --comment \"testComment\" --to-destination 192.168.0.3" +) + +var ( + ruleDeleteTest = []string{"testChain", "-s", "192.168.0.1", "-d", "192.168.0.2", "-m", "comment", "--comment", "testComment", "--to-destination", "192.168.0.3"} +) + +func TestGetValue(t *testing.T) { + value := getValue(ruleTest, "-A") + if value != "testChain" { + t.Errorf("get value") + } +} + +func TestGetRuleComment(t *testing.T) { + value := GetRuleComment(ruleTest) + if value != "testComment" { + t.Errorf("get rule comment") + } +} + +func TestGetRuleSrc(t *testing.T) { + value := GetRuleSrc(ruleTest) + if value != "192.168.0.1" { + t.Errorf("get rule src") + } +} + +func TestGetRuleDNATDest(t *testing.T) { + value := GetRuleDNATDest(ruleTest) + if value != "192.168.0.3" { + t.Errorf("get rule DNAT dest") + } +} + +func TestChangeRuleToDelete(t *testing.T) { + value := ChangeRuleToDelete(ruleTest) + if !reflect.DeepEqual(value, ruleDeleteTest) { + t.Errorf("change rule to delete. expected:%+v / actual:%+v", ruleDeleteTest, value) + } +}