From 14d99000a4a26139a2f6ea7e4aa8622907177125 Mon Sep 17 00:00:00 2001 From: Homayoon Alimohammadi Date: Thu, 24 Oct 2024 12:31:22 +0400 Subject: [PATCH] Spawn a pod to make request to cluster-agent (#71) --- controllers/configs.go | 4 +- controllers/reconcile.go | 22 ++- go.mod | 4 +- go.sum | 4 +- pkg/clusteragent/clusteragent.go | 234 +++++++++++++++++++------- pkg/clusteragent/clusteragent_test.go | 139 +++++++++------ pkg/clusteragent/remove_node.go | 6 +- pkg/clusteragent/remove_node_test.go | 60 ++++--- pkg/control/wait.go | 50 ++++++ pkg/control/wait_test.go | 58 +++++++ pkg/httptest/httptest.go | 62 ------- pkg/images/images.go | 5 + 12 files changed, 442 insertions(+), 206 deletions(-) create mode 100644 pkg/control/wait.go create mode 100644 pkg/control/wait_test.go delete mode 100644 pkg/httptest/httptest.go create mode 100644 pkg/images/images.go diff --git a/controllers/configs.go b/controllers/configs.go index 12e0776..c233e49 100644 --- a/controllers/configs.go +++ b/controllers/configs.go @@ -28,7 +28,7 @@ import ( "k8s.io/client-go/kubernetes" "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/util/connrotation" - "k8s.io/utils/pointer" + "k8s.io/utils/ptr" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -241,7 +241,7 @@ func (r *MicroK8sControlPlaneReconciler) generateMicroK8sConfig(ctx context.Cont Kind: "MicroK8sControlPlane", Name: tcp.Name, UID: tcp.UID, - BlockOwnerDeletion: pointer.BoolPtr(true), + BlockOwnerDeletion: ptr.To(true), } bootstrapConfig := &bootstrapv1beta1.MicroK8sConfig{ diff --git a/controllers/reconcile.go b/controllers/reconcile.go index 17d6bf8..e7f8259 100644 --- a/controllers/reconcile.go +++ b/controllers/reconcile.go @@ -11,7 +11,9 @@ import ( clusterv1beta1 "github.com/canonical/cluster-api-control-plane-provider-microk8s/api/v1beta1" "github.com/canonical/cluster-api-control-plane-provider-microk8s/pkg/clusteragent" + "github.com/canonical/cluster-api-control-plane-provider-microk8s/pkg/images" "github.com/canonical/cluster-api-control-plane-provider-microk8s/pkg/token" + "github.com/go-logr/logr" "golang.org/x/mod/semver" "github.com/pkg/errors" @@ -33,9 +35,8 @@ import ( ) const ( - defaultClusterAgentPort string = "25000" - defaultDqlitePort string = "19001" - defaultClusterAgentClientTimeout time.Duration = 10 * time.Second + defaultClusterAgentPort string = "25000" + defaultDqlitePort string = "19001" ) type errServiceUnhealthy struct { @@ -601,7 +602,14 @@ func (r *MicroK8sControlPlaneReconciler) scaleDownControlPlane(ctx context.Conte if len(machines) > 2 { portRemap := tcp != nil && tcp.Spec.ControlPlaneConfig.ClusterConfiguration != nil && tcp.Spec.ControlPlaneConfig.ClusterConfiguration.PortCompatibilityRemap - if clusterAgentClient, err := getClusterAgentClient(machines, deleteMachine, portRemap); err == nil { + kubeclient, err := r.kubeconfigForCluster(ctx, cluster) + if err != nil { + return ctrl.Result{RequeueAfter: 5 * time.Second}, fmt.Errorf("failed to get kubeconfig for cluster: %w", err) + } + + defer kubeclient.Close() //nolint:errcheck + + if clusterAgentClient, err := getClusterAgentClient(kubeclient, logger, machines, deleteMachine, portRemap); err == nil { if err := r.removeNodeFromDqlite(ctx, clusterAgentClient, cluster, deleteMachine, portRemap); err != nil { logger.Error(err, "failed to remove node from dqlite: %w", "machineName", deleteMachine.Name, "nodeName", node.Name) } @@ -627,7 +635,7 @@ func (r *MicroK8sControlPlaneReconciler) scaleDownControlPlane(ctx context.Conte return ctrl.Result{Requeue: true}, nil } -func getClusterAgentClient(machines []clusterv1.Machine, delMachine clusterv1.Machine, portRemap bool) (*clusteragent.Client, error) { +func getClusterAgentClient(kubeclient *kubernetesClient, logger logr.Logger, machines []clusterv1.Machine, delMachine clusterv1.Machine, portRemap bool) (*clusteragent.Client, error) { opts := clusteragent.Options{ // NOTE(hue): We want to pick a random machine's IP to call POST /dqlite/remove on its cluster agent endpoint. // This machine should preferably not be the itself, although this is not forced by Microk8s. @@ -640,7 +648,7 @@ func getClusterAgentClient(machines []clusterv1.Machine, delMachine clusterv1.Ma port = "30000" } - clusterAgentClient, err := clusteragent.NewClient(machines, port, defaultClusterAgentClientTimeout, opts) + clusterAgentClient, err := clusteragent.NewClient(kubeclient, logger, machines, port, opts) if err != nil { return nil, fmt.Errorf("failed to initialize cluster agent client: %w", err) } @@ -696,7 +704,7 @@ func createUpgradePod(ctx context.Context, kubeclient *kubernetesClient, nodeNam Containers: []corev1.Container{ { Name: "upgrade", - Image: "curlimages/curl:7.87.0", + Image: images.CurlImage, Command: []string{ "su", "-c", diff --git a/go.mod b/go.mod index 68f5586..53fc2f0 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,7 @@ require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect - github.com/go-logr/logr v1.2.3 // indirect + github.com/go-logr/logr v1.2.3 github.com/go-logr/zapr v1.2.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect @@ -83,7 +83,7 @@ require ( k8s.io/component-base v0.25.3 // indirect k8s.io/klog/v2 v2.80.1 // indirect k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280 // indirect - k8s.io/utils v0.0.0-20221012122500-cfd413dd9e85 + k8s.io/utils v0.0.0-20240921022957-49e7df575cb6 sigs.k8s.io/cluster-api v1.2.4 sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2 // indirect sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect diff --git a/go.sum b/go.sum index 29e0a04..a802505 100644 --- a/go.sum +++ b/go.sum @@ -745,8 +745,8 @@ k8s.io/klog/v2 v2.80.1 h1:atnLQ121W371wYYFawwYx1aEY2eUfs4l3J72wtgAwV4= k8s.io/klog/v2 v2.80.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280 h1:+70TFaan3hfJzs+7VK2o+OGxg8HsuBr/5f6tVAjDu6E= k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280/go.mod h1:+Axhij7bCpeqhklhUTe3xmOn6bWxolyZEeyaFpjGtl4= -k8s.io/utils v0.0.0-20221012122500-cfd413dd9e85 h1:cTdVh7LYu82xeClmfzGtgyspNh6UxpwLWGi8R4sspNo= -k8s.io/utils v0.0.0-20221012122500-cfd413dd9e85/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +k8s.io/utils v0.0.0-20240921022957-49e7df575cb6 h1:MDF6h2H/h4tbzmtIKTuctcwZmY0tY9mD9fNT47QO6HI= +k8s.io/utils v0.0.0-20240921022957-49e7df575cb6/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/pkg/clusteragent/clusteragent.go b/pkg/clusteragent/clusteragent.go index dce3dca..7c0d5db 100644 --- a/pkg/clusteragent/clusteragent.go +++ b/pkg/clusteragent/clusteragent.go @@ -1,114 +1,230 @@ package clusteragent import ( - "bytes" "context" - "crypto/tls" "encoding/json" "errors" "fmt" "net" "net/http" - "time" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/sets" + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/utils/ptr" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" + + "github.com/canonical/cluster-api-control-plane-provider-microk8s/pkg/control" + "github.com/canonical/cluster-api-control-plane-provider-microk8s/pkg/images" +) + +const ( + CallerPodNameFormat string = "cluster-agent-caller-%s" + DefaultPodNameSpace string = "default" ) // Options should be used when initializing a new client. type Options struct { - // IgnoreMachineNames is a set of ignored machine names that we don't want to pick their IP for the cluster agent endpoint. + // IgnoreMachineNames is a set of ignored machine names that we don't want to pick their node name for the cluster agent. IgnoreMachineNames sets.String + // SkipSucceededCheck skips the waiting for succeeded phase on pod. Mostly used for testing purposes. + SkipSucceededCheck bool + // SkipPodCleanup skips the pod cleanup after the request is done. Mostly used for testing purposes. + SkipPodCleanup bool } +// KubeClient is an interface for the Kubernetes client. +type KubeClient interface { + CoreV1() corev1client.CoreV1Interface +} + +// Client is a client for the cluster agent. type Client struct { - ip, port string - client *http.Client + KubeClient + logger Logger + nodeName string + namespace string + ip string + port string + + skipSucceededCheck bool + skipPodCleanup bool +} + +// Logger is an interface for logging. +type Logger interface { + Info(msg string, keysAndValues ...interface{}) + Error(err error, msg string, keysAndValues ...interface{}) } -// NewClient picks an IP from one of the given machines and creates a new client for the cluster agent -// with that IP. -func NewClient(machines []clusterv1.Machine, port string, timeout time.Duration, opts Options) (*Client, error) { +// NewClient picks a node name and IP from one of the given machines and creates a new client for the cluster agent. +func NewClient(kubeclient KubeClient, logger Logger, machines []clusterv1.Machine, port string, opts Options) (*Client, error) { + var nodeName string var ip string for _, m := range machines { + if !m.DeletionTimestamp.IsZero() { + continue + } + if opts.IgnoreMachineNames.Has(m.Name) { continue } - for _, addr := range m.Status.Addresses { - if net.ParseIP(addr.Address) != nil { - ip = addr.Address - break + if m.Status.NodeRef != nil { + nodeName = m.Status.NodeRef.Name + for _, addr := range m.Status.Addresses { + if net.ParseIP(addr.Address) != nil { + ip = addr.Address + break + } } + break } - break + } + + if nodeName == "" { + return nil, errors.New("failed to find a node for cluster agent") } if ip == "" { - return nil, errors.New("failed to find an IP for cluster agent") + return nil, errors.New("failed to find an IP address for cluster agent") } return &Client{ - ip: ip, - port: port, - client: &http.Client{ - Timeout: timeout, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - // TODO(Hue): Workaround for now, address later on - // get the certificate fingerprint from the matching node through a resource in the cluster (TBD), - // and validate it in the TLSClientConfig - InsecureSkipVerify: true, - }, - }, - }, + KubeClient: kubeclient, + logger: logger, + nodeName: nodeName, + namespace: DefaultPodNameSpace, + ip: ip, + port: port, + skipSucceededCheck: opts.SkipSucceededCheck, + skipPodCleanup: opts.SkipPodCleanup, }, nil } -func (c *Client) Endpoint() string { - return fmt.Sprintf("https://%s:%s", c.ip, c.port) -} +// do sends a request to the cluster agent. +func (c *Client) do(ctx context.Context, method, endpoint string, header http.Header, data map[string]any) error { + pod, err := c.createPod(ctx, method, endpoint, header, data) + if err != nil { + return fmt.Errorf("failed to create pod: %w", err) + } + + if !c.skipPodCleanup { + defer func() { + if err := c.deletePod(ctx, pod.Name); err != nil { + c.logger.Error(err, "failed to delete pod") + } + }() + } + + if c.skipSucceededCheck { + return nil + } -// do makes a request to the given endpoint with the given method. It marshals the request and unmarshals -// server response body if the provided response is not nil. -// The endpoint should _not_ have a leading slash. -func (c *Client) do(ctx context.Context, method, endpoint string, request any, header map[string][]string, response any) error { - url := fmt.Sprintf("https://%s:%s/%s", c.ip, c.port, endpoint) + podName := pod.Name + if err := control.WaitUntilReady(ctx, func() (bool, error) { + pod, err := c.CoreV1().Pods(c.namespace).Get(ctx, podName, metav1.GetOptions{}) + if err != nil { + return false, fmt.Errorf("failed to get pod: %w", err) + } - requestBody, err := json.Marshal(request) - if err != nil { - return fmt.Errorf("failed to prepare worker info request: %w", err) + if pod.Status.Phase == corev1.PodSucceeded { + return true, nil + } + if pod.Status.Phase == corev1.PodFailed { + return false, fmt.Errorf("pod failed") + } + + return false, nil + }, control.WaitOptions{NumRetries: ptr.To(120)}); err != nil { + return fmt.Errorf("failed to wait for pod to succeed: %w", err) } - req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(requestBody)) + return nil +} + +// createPod creates a pod that runs a curl command. +func (c *Client) createPod(ctx context.Context, method, endpoint string, header http.Header, data map[string]any) (*corev1.Pod, error) { + curl, err := c.createCURLString(method, endpoint, header, data) if err != nil { - return fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("failed to create curl string: %w", err) } - req.Header = http.Header(header) + c.logger.Info("creating curl pod", "cmd", curl) - res, err := c.client.Do(req) + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf(CallerPodNameFormat, c.nodeName), + }, + Spec: corev1.PodSpec{ + NodeName: c.nodeName, + Containers: []corev1.Container{ + { + Name: "caller", + Image: images.CurlImage, + Command: []string{ + "su", + "-c", + curl, + }, + SecurityContext: &corev1.SecurityContext{Privileged: ptr.To(true), RunAsUser: ptr.To(int64(0))}, + }, + }, + RestartPolicy: corev1.RestartPolicyNever, + }, + } + + pod, err = c.CoreV1().Pods(c.namespace).Create(ctx, pod, metav1.CreateOptions{}) if err != nil { - return fmt.Errorf("failed to call cluster agent: %w", err) + return nil, fmt.Errorf("failed to create pod: %w", err) } - defer res.Body.Close() - - if res.StatusCode != http.StatusOK { - // NOTE(hue): Marshal and print any response that we got since it might contain valuable information - // on why the request failed. - // Ignore JSON errors to prevent unnecessarily complicated error handling. - anyResp := make(map[string]any) - _ = json.NewDecoder(res.Body).Decode(&anyResp) - b, _ := json.Marshal(anyResp) - resStr := string(b) - - return fmt.Errorf("HTTP request to cluster agent failed with status code %d, got response: %q", res.StatusCode, resStr) + + return pod, nil +} + +// createCURLString creates a curl command string. +func (c *Client) createCURLString(method, endpoint string, header http.Header, data map[string]any) (string, error) { + // Method + req := fmt.Sprintf("curl -k -X %s", method) + + // Headers + for h, vv := range header { + for _, v := range vv { + req += fmt.Sprintf(" -H \"%s: %s\"", h, v) + } + } + + // Data + if data != nil { + dataB, err := json.Marshal(data) + if err != nil { + return "", fmt.Errorf("failed to marshal data: %w", err) + } + req += fmt.Sprintf(" -d '%s'", string(dataB)) + } + + // Endpoint + req += fmt.Sprintf(" https://%s:%s/%s", c.ip, c.port, endpoint) + + return req, nil +} + +// deletePod deletes a pod. +func (c *Client) deletePod(ctx context.Context, podName string) error { + deleteOptions := metav1.DeleteOptions{ + GracePeriodSeconds: ptr.To(int64(0)), } - if response != nil { - if err := json.NewDecoder(res.Body).Decode(response); err != nil { - return fmt.Errorf("failed to decode response: %w", err) + if err := control.WaitUntilReady(ctx, func() (bool, error) { + err := c.CoreV1().Pods(c.namespace).Delete(ctx, podName, deleteOptions) + if err != nil && !apierrors.IsNotFound(err) { + return false, nil } + return true, nil + }, control.WaitOptions{NumRetries: ptr.To(120)}); err != nil { + return fmt.Errorf("failed to wait for pod deletion: %w", err) } return nil diff --git a/pkg/clusteragent/clusteragent_test.go b/pkg/clusteragent/clusteragent_test.go index 25dfa69..ab0c298 100644 --- a/pkg/clusteragent/clusteragent_test.go +++ b/pkg/clusteragent/clusteragent_test.go @@ -4,27 +4,24 @@ import ( "context" "fmt" "math/rand" - "net" - "net/http" - "strings" "testing" - "time" + "github.com/canonical/cluster-api-control-plane-provider-microk8s/pkg/images" . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/kubernetes/fake" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" - - "github.com/canonical/cluster-api-control-plane-provider-microk8s/pkg/httptest" ) func TestClient(t *testing.T) { t.Run("CanNotFindAddress", func(t *testing.T) { g := NewWithT(t) - // Machines don't have any addresses. + // Machines don't have node refs machines := []clusterv1.Machine{{}, {}} - _, err := NewClient(machines, "25000", time.Second, Options{}) + _, err := NewClient(nil, newLogger(), machines, "25000", Options{}) g.Expect(err).To(HaveOccurred()) @@ -36,15 +33,13 @@ func TestClient(t *testing.T) { Name: ignoreName, }, Status: clusterv1.MachineStatus{ - Addresses: clusterv1.MachineAddresses{ - { - Address: "1.1.1.1", - }, + NodeRef: &corev1.ObjectReference{ + Name: "node", }, }, }, } - _, err = NewClient(machines, "25000", time.Second, Options{IgnoreMachineNames: sets.NewString(ignoreName)}) + _, err = NewClient(nil, nil, machines, "25000", Options{IgnoreMachineNames: sets.NewString(ignoreName)}) g.Expect(err).To(HaveOccurred()) }) @@ -53,37 +48,45 @@ func TestClient(t *testing.T) { g := NewWithT(t) port := "30000" - firstAddr := "1.1.1.1" - secondAddr := "2.2.2.2" - thirdAddr := "3.3.3.3" + firstNodeName := "node-1" + secondNodeName := "node-2" + thirdNodeName := "node-3" + + firstAddr := "1.2.3.4" + secondAddr := "2.3.4.5" + thirdAddr := "3.4.5.6" ignoreName := "ignore" - ignoreAddr := "8.8.8.8" + ignoreAddr := "9.8.7.6" + ignoreNodeName := "node-ignore" machines := []clusterv1.Machine{ { Status: clusterv1.MachineStatus{ + NodeRef: &corev1.ObjectReference{ + Name: firstNodeName, + }, Addresses: clusterv1.MachineAddresses{ - { - Address: firstAddr, - }, + {Address: firstAddr}, }, }, }, { Status: clusterv1.MachineStatus{ + NodeRef: &corev1.ObjectReference{ + Name: secondNodeName, + }, Addresses: clusterv1.MachineAddresses{ - { - Address: secondAddr, - }, + {Address: secondAddr}, }, }, }, { Status: clusterv1.MachineStatus{ + NodeRef: &corev1.ObjectReference{ + Name: thirdNodeName, + }, Addresses: clusterv1.MachineAddresses{ - { - Address: thirdAddr, - }, + {Address: thirdAddr}, }, }, }, @@ -92,10 +95,11 @@ func TestClient(t *testing.T) { Name: ignoreName, }, Status: clusterv1.MachineStatus{ + NodeRef: &corev1.ObjectReference{ + Name: ignoreNodeName, + }, Addresses: clusterv1.MachineAddresses{ - { - Address: ignoreAddr, - }, + {Address: ignoreAddr}, }, }, }, @@ -105,54 +109,73 @@ func TestClient(t *testing.T) { IgnoreMachineNames: sets.NewString(ignoreName), } - // NOTE(Hue): Repeat the test to make sure the ignored machine's IP is not picked by chance (reduce flakiness). + // NOTE(Hue): Repeat the test to make sure the ignored machine's node name is not picked by chance (reduce flakiness). for i := 0; i < 100; i++ { machines = shuffleMachines(machines) - c, err := NewClient(machines, port, time.Second, opts) + c, err := NewClient(nil, nil, machines, port, opts) g.Expect(err).ToNot(HaveOccurred()) // Check if the endpoint is one of the expected ones and not the ignored one. - g.Expect([]string{fmt.Sprintf("https://%s:%s", firstAddr, port), fmt.Sprintf("https://%s:%s", secondAddr, port), fmt.Sprintf("https://%s:%s", thirdAddr, port)}).To(ContainElement(c.Endpoint())) - g.Expect(c.Endpoint()).ToNot(Equal(fmt.Sprintf("https://%s:%s", ignoreAddr, port))) + g.Expect([]string{firstNodeName, secondNodeName, thirdNodeName}).To(ContainElement(c.nodeName)) + g.Expect([]string{firstAddr, secondAddr, thirdAddr}).To(ContainElement(c.ip)) + g.Expect(c.nodeName).ToNot(Equal(ignoreNodeName)) } - }) } func TestDo(t *testing.T) { g := NewWithT(t) - path := "/random/path" - method := http.MethodPost - resp := map[string]string{ - "key": "value", + kubeclient := fake.NewSimpleClientset() + nodeName := "node" + nodeAddress := "5.6.7.8" + port := "1234" + method := "POST" + endpoint := "my/endpoint" + dataKey, dataValue := "dkey", "dvalue" + data := map[string]any{ + dataKey: dataValue, + } + headerKey, headerValue := "hkey", "hvalue" + header := map[string][]string{ + headerKey: {headerValue}, } - servM := httptest.NewServerMock(method, path, resp) - defer servM.Srv.Close() - ip, port, err := net.SplitHostPort(strings.TrimPrefix(servM.Srv.URL, "https://")) - g.Expect(err).ToNot(HaveOccurred()) - c, err := NewClient([]clusterv1.Machine{ + c, err := NewClient(kubeclient, newLogger(), []clusterv1.Machine{ { Status: clusterv1.MachineStatus{ + NodeRef: &corev1.ObjectReference{ + Name: nodeName, + }, Addresses: clusterv1.MachineAddresses{ { - Address: ip, + Address: nodeAddress, }, }, }, }, - }, port, time.Second, Options{}) + }, port, Options{SkipSucceededCheck: true, SkipPodCleanup: true}) g.Expect(err).ToNot(HaveOccurred()) - response := make(map[string]string) - req := map[string]string{"req": "value"} - path = strings.TrimPrefix(path, "/") - g.Expect(c.do(context.Background(), method, path, req, nil, &response)).To(Succeed()) + g.Expect(c.do(context.Background(), method, endpoint, header, data)).To(Succeed()) - g.Expect(response).To(Equal(resp)) + pod, err := kubeclient.CoreV1().Pods(DefaultPodNameSpace).Get(context.Background(), fmt.Sprintf(CallerPodNameFormat, nodeName), v1.GetOptions{}) + g.Expect(err).ToNot(HaveOccurred()) + + g.Expect(pod.Spec.NodeName).To(Equal(nodeName)) + g.Expect(pod.Spec.Containers).To(HaveLen(1)) + + container := pod.Spec.Containers[0] + g.Expect(container.Image).To(Equal(images.CurlImage)) + g.Expect(*container.SecurityContext.Privileged).To(BeTrue()) + g.Expect(*container.SecurityContext.RunAsUser).To(Equal(int64(0))) + g.Expect(container.Command).To(HaveLen(3)) + g.Expect(container.Command[2]).To(Equal(fmt.Sprintf( + "curl -k -X %s -H \"%s: %s\" -d '{\"%s\":\"%s\"}' https://%s:%s/%s", + method, headerKey, headerValue, dataKey, dataValue, nodeAddress, port, endpoint, + ))) } func shuffleMachines(src []clusterv1.Machine) []clusterv1.Machine { @@ -163,3 +186,19 @@ func shuffleMachines(src []clusterv1.Machine) []clusterv1.Machine { } return dest } + +func newLogger() Logger { + return &mockLogger{} +} + +type mockLogger struct { + entries []string +} + +func (l *mockLogger) Info(msg string, keysAndValues ...interface{}) { + l.entries = append(l.entries, msg) +} + +func (l *mockLogger) Error(err error, msg string, keysAndValues ...interface{}) { + l.entries = append(l.entries, msg) +} diff --git a/pkg/clusteragent/remove_node.go b/pkg/clusteragent/remove_node.go index 59ca691..a7f3b47 100644 --- a/pkg/clusteragent/remove_node.go +++ b/pkg/clusteragent/remove_node.go @@ -6,11 +6,11 @@ import ( ) // RemoveNodeFromDqlite calls the /v2/dqlite/remove endpoint on cluster agent to remove the given address from Dqlite. -// The endpoint should be in the format of "address:port". +// The removeEp should be in the format of "address:port". func (p *Client) RemoveNodeFromDqlite(ctx context.Context, token string, removeEp string) error { - request := map[string]string{"remove_endpoint": removeEp} + request := map[string]any{"remove_endpoint": removeEp} header := map[string][]string{ AuthTokenHeader: {token}, } - return p.do(ctx, http.MethodPost, "cluster/api/v2.0/dqlite/remove", request, header, nil) + return p.do(ctx, http.MethodPost, "cluster/api/v2.0/dqlite/remove", header, request) } diff --git a/pkg/clusteragent/remove_node_test.go b/pkg/clusteragent/remove_node_test.go index ef79633..053712e 100644 --- a/pkg/clusteragent/remove_node_test.go +++ b/pkg/clusteragent/remove_node_test.go @@ -2,44 +2,66 @@ package clusteragent_test import ( "context" - "net" - "net/http" - "strings" + "fmt" "testing" - "time" . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" "github.com/canonical/cluster-api-control-plane-provider-microk8s/pkg/clusteragent" - "github.com/canonical/cluster-api-control-plane-provider-microk8s/pkg/httptest" ) func TestRemoveFromDqlite(t *testing.T) { g := NewWithT(t) - path := "/cluster/api/v2.0/dqlite/remove" - token := "myRandomToken" - method := http.MethodPost - servM := httptest.NewServerMock(method, path, nil) - defer servM.Srv.Close() - - ip, port, err := net.SplitHostPort(strings.TrimPrefix(servM.Srv.URL, "https://")) - g.Expect(err).ToNot(HaveOccurred()) - c, err := clusteragent.NewClient([]clusterv1.Machine{ + token := "token" + port := "1234" + removeEp := "1.1.1.1:9876" + machineIP := "5.6.7.8" + nodeName := "node-1" + kubeclient := fake.NewSimpleClientset() + c, err := clusteragent.NewClient(kubeclient, newLogger(), []clusterv1.Machine{ { Status: clusterv1.MachineStatus{ + NodeRef: &corev1.ObjectReference{ + Name: nodeName, + }, Addresses: clusterv1.MachineAddresses{ { - Address: ip, + Address: machineIP, }, }, }, }, - }, port, time.Second, clusteragent.Options{}) + }, port, clusteragent.Options{SkipSucceededCheck: true, SkipPodCleanup: true}) g.Expect(err).ToNot(HaveOccurred()) - g.Expect(c.RemoveNodeFromDqlite(context.Background(), token, "1.1.1.1:1234")).To(Succeed()) - g.Expect(servM.Request).To(HaveKeyWithValue("remove_endpoint", "1.1.1.1:1234")) - g.Expect(servM.Header.Get(clusteragent.AuthTokenHeader)).To(Equal(token)) + g.Expect(c.RemoveNodeFromDqlite(context.Background(), token, removeEp)).To(Succeed()) + + pod, err := kubeclient.CoreV1().Pods(clusteragent.DefaultPodNameSpace).Get(context.Background(), fmt.Sprintf(clusteragent.CallerPodNameFormat, nodeName), v1.GetOptions{}) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(pod.Spec.Containers).To(HaveLen(1)) + + container := pod.Spec.Containers[0] + g.Expect(container.Command).To(HaveLen(3)) + g.Expect(container.Command[2]).To(Equal(fmt.Sprintf("curl -k -X POST -H \"capi-auth-token: %s\" -d '{\"remove_endpoint\":\"%s\"}' https://%s:%s/cluster/api/v2.0/dqlite/remove", token, removeEp, machineIP, port))) +} + +func newLogger() clusteragent.Logger { + return &mockLogger{} +} + +type mockLogger struct { + entries []string +} + +func (l *mockLogger) Info(msg string, keysAndValues ...interface{}) { + l.entries = append(l.entries, msg) +} + +func (l *mockLogger) Error(err error, msg string, keysAndValues ...interface{}) { + l.entries = append(l.entries, msg) } diff --git a/pkg/control/wait.go b/pkg/control/wait.go new file mode 100644 index 0000000..c37a3b1 --- /dev/null +++ b/pkg/control/wait.go @@ -0,0 +1,50 @@ +package control + +import ( + "context" + "fmt" + "time" +) + +const ( + defaultWaitInterval = 1 * time.Second + defaultNumRetries = 0 // 0 means infinite retries +) + +type WaitOptions struct { + NumRetries *int + WaitInterval *time.Duration +} + +// WaitUntilReady waits until the specified condition becomes true. +// checkFunc can return an error to return early. +func WaitUntilReady(ctx context.Context, checkFunc func() (bool, error), opts ...WaitOptions) error { + var opt WaitOptions + if len(opts) > 0 { + opt = opts[0] + } + + waitInterval := defaultWaitInterval + if opt.WaitInterval != nil { + waitInterval = *opt.WaitInterval + } + + numRetries := defaultNumRetries + if opt.NumRetries != nil { + numRetries = *opt.NumRetries + } + + for i := 0; i < numRetries || numRetries == 0; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(waitInterval): + if ok, err := checkFunc(); err != nil { + return fmt.Errorf("wait check failed: %w", err) + } else if ok { + return nil + } + } + } + return fmt.Errorf("check was not successful after %d attempts", numRetries) +} diff --git a/pkg/control/wait_test.go b/pkg/control/wait_test.go new file mode 100644 index 0000000..2b18758 --- /dev/null +++ b/pkg/control/wait_test.go @@ -0,0 +1,58 @@ +package control + +import ( + "context" + "errors" + "testing" + "time" +) + +// Mock check function that returns true after 2 iterations. +func mockCheckFunc() (bool, error) { + return true, nil +} + +var errTest = errors.New("test error") + +// Mock check function that returns an error. +func mockErrorCheckFunc() (bool, error) { + return false, errTest +} + +func TestWaitUntilReady(t *testing.T) { + // Test case 1: Successful completion + ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second*5) + defer cancel1() + + err1 := WaitUntilReady(ctx1, mockCheckFunc) + if err1 != nil { + t.Errorf("Expected no error, got: %v", err1) + } + + // Test case 2: Context cancellation + ctx2, cancel2 := context.WithCancel(context.Background()) + cancel2() // Cancel the context immediately + + err2 := WaitUntilReady(ctx2, mockCheckFunc) + if err2 == nil || !errors.Is(err2, context.Canceled) { + t.Errorf("Expected context.Canceled error, got: %v", err2) + } + + // Test case 3: Timeout + ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second*2) + defer cancel3() + + err3 := WaitUntilReady(ctx3, func() (bool, error) { return false, nil }) + if err3 == nil || !errors.Is(err3, context.DeadlineExceeded) { + t.Errorf("Expected context.DeadlineExceeded error, got: %v", err3) + } + + // Test case 4: CheckFunc returns an error + ctx4, cancel4 := context.WithTimeout(context.Background(), time.Second*5) + defer cancel4() + + err4 := WaitUntilReady(ctx4, mockErrorCheckFunc) + if err4 == nil || !errors.Is(err4, errTest) { + t.Errorf("Expected test error, got: %v", err4) + } +} diff --git a/pkg/httptest/httptest.go b/pkg/httptest/httptest.go deleted file mode 100644 index 055bfe3..0000000 --- a/pkg/httptest/httptest.go +++ /dev/null @@ -1,62 +0,0 @@ -package httptest - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" -) - -type serverMock struct { - Method string - Path string - Response any - Request map[string]any - Header http.Header - Srv *httptest.Server -} - -// NewServerMock creates a test server that responds with the given response when called with the given method and path. -// Make sure to close the server after the test is done. -// Server will try to decode the request body into a map[string]any. -func NewServerMock(method string, path string, response any) *serverMock { - req := make(map[string]any) - header := make(map[string][]string) - ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != path { - http.NotFound(w, r) - return - } - if r.Method != method { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - - for k, vv := range map[string][]string(r.Header) { - header[k] = vv - } - - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - if response != nil { - if err := json.NewEncoder(w).Encode(response); err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - } - w.WriteHeader(http.StatusOK) - })) - - fmt.Println("header:", header) - return &serverMock{ - Method: method, - Path: path, - Response: response, - Header: header, - Request: req, - Srv: ts, - } -} diff --git a/pkg/images/images.go b/pkg/images/images.go new file mode 100644 index 0000000..a050342 --- /dev/null +++ b/pkg/images/images.go @@ -0,0 +1,5 @@ +package images + +const ( + CurlImage string = "curlimages/curl:7.87.0" +)