diff --git a/CHANGELOG.md b/CHANGELOG.md index 67b90c58..91f67cd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Changelog for Cass Operator, new PRs should update the `main / unreleased` secti * [FEATURE] [#651](https://github.com/k8ssandra/cass-operator/issues/651) Add tsreload task for DSE deployments and ability to check if sync operation is available on the mgmt-api side * [ENHANCEMENT] [#722](https://github.com/k8ssandra/cass-operator/issues/722) Add back the ability to track cleanup task before marking scale up as done. This is controlled by an annotation cassandra.datastax.com/track-cleanup-tasks +* [ENHANCEMENT] [#729](https://github.com/k8ssandra/cass-operator/issues/729) Modify NewMgmtClient to support additional transport option for the http.Client * [BUGFIX] [#705](https://github.com/k8ssandra/cass-operator/issues/705) Ensure ConfigSecret has annotations map before trying to set a value ## v1.22.4 diff --git a/internal/controllers/control/cassandratask_controller.go b/internal/controllers/control/cassandratask_controller.go index 86a7a744..932c231b 100644 --- a/internal/controllers/control/cassandratask_controller.go +++ b/internal/controllers/control/cassandratask_controller.go @@ -355,7 +355,7 @@ JobDefinition: if err := r.replacePreProcess(taskConfig); err != nil { return ctrl.Result{}, err } - nodeMgmtClient, err := httphelper.NewMgmtClient(ctx, r.Client, dc) + nodeMgmtClient, err := httphelper.NewMgmtClient(ctx, r.Client, dc, nil) if err != nil { return ctrl.Result{}, err } @@ -634,7 +634,7 @@ func (r *CassandraTaskReconciler) reconcileEveryPodTask(ctx context.Context, dc return dcPods[i].Name < dcPods[j].Name }) - nodeMgmtClient, err := httphelper.NewMgmtClient(ctx, r.Client, dc) + nodeMgmtClient, err := httphelper.NewMgmtClient(ctx, r.Client, dc, nil) if err != nil { return ctrl.Result{}, 0, 0, "", err } diff --git a/pkg/httphelper/client.go b/pkg/httphelper/client.go index 34475f10..4457e4a6 100644 --- a/pkg/httphelper/client.go +++ b/pkg/httphelper/client.go @@ -166,10 +166,10 @@ func (f *FeatureSet) Supports(feature Feature) bool { return found } -func NewMgmtClient(ctx context.Context, client client.Client, dc *cassdcapi.CassandraDatacenter) (NodeMgmtClient, error) { +func NewMgmtClient(ctx context.Context, client client.Client, dc *cassdcapi.CassandraDatacenter, customTransport *http.Transport) (NodeMgmtClient, error) { logger := log.FromContext(ctx) - httpClient, err := BuildManagementApiHttpClient(dc, client, ctx) + httpClient, err := BuildManagementApiHttpClient(ctx, client, dc, customTransport) if err != nil { logger.Error(err, "error in BuildManagementApiHttpClient") return NodeMgmtClient{}, err diff --git a/pkg/httphelper/client_test.go b/pkg/httphelper/client_test.go index d2921ecf..1edf0ab4 100644 --- a/pkg/httphelper/client_test.go +++ b/pkg/httphelper/client_test.go @@ -5,10 +5,13 @@ package httphelper import ( "bytes" + "context" "encoding/json" "errors" "io" + "net" "net/http" + "net/http/httptest" "testing" "github.com/go-logr/logr" @@ -556,3 +559,44 @@ var badPod = &corev1.Pod{ Name: "pod1", }, } + +func TestCustomTransport(t *testing.T) { + require := require.New(t) + + called := false + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v0/ops/auth/role" { + w.WriteHeader(http.StatusOK) + called = true + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer testServer.Close() + + testServerAddr := testServer.Listener.Addr().String() + + customTransport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, testServerAddr) + }, + } + + dc := &api.CassandraDatacenter{ + Spec: api.CassandraDatacenterSpec{ + ClusterName: "test-cluster", + }, + } + + mockClient := mocks.NewClient(t) + + mgmtClient, err := NewMgmtClient(context.TODO(), mockClient, dc, customTransport) + mgmtClient.Log = logr.Discard() + require.NoError(err) + + // This should call http://1.2.3.4:8080/api/v0/ops/auth/role, but the custom transport will override the address + err = mgmtClient.CallCreateRoleEndpoint(goodPod, "role1", "password1", true) + require.NoError(err) + require.True(called) + +} diff --git a/pkg/httphelper/security.go b/pkg/httphelper/security.go index eb42ab2c..d204025c 100644 --- a/pkg/httphelper/security.go +++ b/pkg/httphelper/security.go @@ -41,12 +41,12 @@ func GetManagementApiProtocol(dc *api.CassandraDatacenter) (string, error) { return provider.GetProtocol(), nil } -func BuildManagementApiHttpClient(dc *api.CassandraDatacenter, client client.Client, ctx context.Context) (HttpClient, error) { +func BuildManagementApiHttpClient(ctx context.Context, client client.Client, dc *api.CassandraDatacenter, customTransport *http.Transport) (HttpClient, error) { provider, err := BuildManagementApiSecurityProvider(dc) if err != nil { return nil, err } - return provider.BuildHttpClient(client, ctx) + return provider.BuildHttpClient(ctx, client, customTransport) } func AddManagementApiServerSecurity(dc *api.CassandraDatacenter, pod *corev1.PodTemplateSpec) error { @@ -91,17 +91,17 @@ func ValidateManagementApiConfig(dc *api.CassandraDatacenter, client client.Clie return []error{err} } - return provider.ValidateConfig(client, ctx) + return provider.ValidateConfig(ctx, client) } // SPI for adding new mechanisms for securing the management API type ManagementApiSecurityProvider interface { - BuildHttpClient(client client.Client, ctx context.Context) (HttpClient, error) + BuildHttpClient(ctx context.Context, client client.Client, transport *http.Transport) (HttpClient, error) BuildMgmtApiGetAction(endpoint string, timeout int) *corev1.ExecAction BuildMgmtApiPostAction(endpoint string, timeout int) *corev1.ExecAction AddServerSecurity(pod *corev1.PodTemplateSpec) error GetProtocol() string - ValidateConfig(client client.Client, ctx context.Context) []error + ValidateConfig(ctx context.Context, client client.Client) []error } type InsecureManagementApiSecurityProvider struct { @@ -119,15 +119,21 @@ func (provider *InsecureManagementApiSecurityProvider) GetProtocol() string { return "http" } -func (provider *InsecureManagementApiSecurityProvider) BuildHttpClient(client client.Client, ctx context.Context) (HttpClient, error) { - return http.DefaultClient, nil +func (provider *InsecureManagementApiSecurityProvider) BuildHttpClient(ctx context.Context, client client.Client, transport *http.Transport) (HttpClient, error) { + c := http.DefaultClient + + if transport != nil { + c.Transport = transport + } + + return c, nil } func (provider *InsecureManagementApiSecurityProvider) AddServerSecurity(pod *corev1.PodTemplateSpec) error { return nil } -func (provider *InsecureManagementApiSecurityProvider) ValidateConfig(client client.Client, ctx context.Context) []error { +func (provider *InsecureManagementApiSecurityProvider) ValidateConfig(ctx context.Context, client client.Client) []error { return []error{} } @@ -634,7 +640,7 @@ func validateSecret(secret *corev1.Secret) []error { return validationErrors } -func (provider *ManualManagementApiSecurityProvider) ValidateConfig(client client.Client, ctx context.Context) []error { +func (provider *ManualManagementApiSecurityProvider) ValidateConfig(ctx context.Context, client client.Client) []error { var validationErrors []error if provider.Config.SkipSecretValidation { @@ -715,7 +721,12 @@ func (provider *ManualManagementApiSecurityProvider) ValidateConfig(client clien return validationErrors } -func (provider *ManualManagementApiSecurityProvider) BuildHttpClient(client client.Client, ctx context.Context) (HttpClient, error) { +func (provider *ManualManagementApiSecurityProvider) BuildHttpClient(ctx context.Context, client client.Client, transport *http.Transport) (HttpClient, error) { + httpClient := &http.Client{Transport: transport} + if transport != nil && transport.TLSClientConfig != nil { + return httpClient, nil + } + // Get the client Secret secretNamespacedName := types.NamespacedName{ Name: provider.Config.ClientSecretName, @@ -762,8 +773,14 @@ func (provider *ManualManagementApiSecurityProvider) BuildHttpClient(client clie InsecureSkipVerify: true, VerifyPeerCertificate: buildVerifyPeerCertificateNoHostCheck(caCertPool), } - transport := &http.Transport{TLSClientConfig: tlsConfig} - httpClient := &http.Client{Transport: transport} + + if transport != nil && transport.TLSClientConfig == nil { + transport.TLSClientConfig = tlsConfig + } else if transport == nil { + transport = &http.Transport{TLSClientConfig: tlsConfig} + } + + httpClient.Transport = transport return httpClient, nil } diff --git a/pkg/httphelper/security_test.go b/pkg/httphelper/security_test.go index fcff1741..e780b8d5 100644 --- a/pkg/httphelper/security_test.go +++ b/pkg/httphelper/security_test.go @@ -4,13 +4,22 @@ package httphelper import ( + "context" "crypto/x509" "encoding/pem" + "net/http" "os" "path/filepath" "testing" + api "github.com/k8ssandra/cass-operator/apis/cassandra/v1beta1" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/client-go/kubernetes/scheme" + + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/serializer" + "sigs.k8s.io/controller-runtime/pkg/client/fake" ) func helperLoadBytes(t *testing.T, name string) []byte { @@ -98,3 +107,53 @@ func Test_validatePrivateKey(t *testing.T) { t, 1, len(errs), "Should consider an empty key as an invalid key") } + +// Create Datacenter with managementAuth set to manual and TLS enabled, test that the client is created with the correct TLS configuration using +// BuildManagementApiHttpClient method +func TestBuildMTLSClient(t *testing.T) { + require := require.New(t) + require.NoError(api.AddToScheme(scheme.Scheme)) + decode := serializer.NewCodecFactory(scheme.Scheme).UniversalDeserializer().Decode + + loadYaml := func(path string) (runtime.Object, error) { + bytes, err := os.ReadFile(path) + if err != nil { + return nil, err + } + obj, _, err := decode(bytes, nil, nil) + return obj, err + } + + clientSecret, err := loadYaml(filepath.Join("..", "..", "tests", "testdata", "mtls-certs-client.yaml")) + require.NoError(err) + + serverSecret, err := loadYaml(filepath.Join("..", "..", "tests", "testdata", "mtls-certs-server.yaml")) + require.NoError(err) + + dc := &api.CassandraDatacenter{ + Spec: api.CassandraDatacenterSpec{ + ClusterName: "test-cluster", + ManagementApiAuth: api.ManagementApiAuthConfig{ + Manual: &api.ManagementApiAuthManualConfig{ + ClientSecretName: "mgmt-api-client-credentials", + ServerSecretName: "mgmt-api-server-credentials", + }, + }, + }, + } + + trackObjects := []runtime.Object{ + clientSecret, + serverSecret, + dc, + } + + client := fake.NewClientBuilder().WithRuntimeObjects(trackObjects...).Build() + ctx := context.TODO() + + httpClient, err := BuildManagementApiHttpClient(ctx, client, dc, nil) + require.NoError(err) + + tlsConfig := httpClient.(*http.Client).Transport.(*http.Transport).TLSClientConfig + require.NotNil(tlsConfig) +} diff --git a/pkg/reconciliation/context.go b/pkg/reconciliation/context.go index 1069103b..04017ef5 100644 --- a/pkg/reconciliation/context.go +++ b/pkg/reconciliation/context.go @@ -98,7 +98,7 @@ func CreateReconciliationContext( log.IntoContext(ctx, rc.ReqLogger) var err error - rc.NodeMgmtClient, err = httphelper.NewMgmtClient(rc.Ctx, cli, dc) + rc.NodeMgmtClient, err = httphelper.NewMgmtClient(rc.Ctx, cli, dc, nil) if err != nil { rc.ReqLogger.Error(err, "failed to build NodeMgmtClient") return nil, err