Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support override of the http.Transport in the http.Client - CASS-78 #730

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Changelog for Cass Operator, new PRs should update the `main / unreleased` secti

* [FEATURE] [#651](https://github.com/k8ssandra/cass-operator/issues/651) Add tsreload task for DSE deployments and ability to check if sync operation is available on the mgmt-api side
* [ENHANCEMENT] [#722](https://github.com/k8ssandra/cass-operator/issues/722) Add back the ability to track cleanup task before marking scale up as done. This is controlled by an annotation cassandra.datastax.com/track-cleanup-tasks
* [ENHANCEMENT] [#729](https://github.com/k8ssandra/cass-operator/issues/729) Modify NewMgmtClient to support additional transport option for the http.Client
* [BUGFIX] [#705](https://github.com/k8ssandra/cass-operator/issues/705) Ensure ConfigSecret has annotations map before trying to set a value

## v1.22.4
Expand Down
4 changes: 2 additions & 2 deletions internal/controllers/control/cassandratask_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ JobDefinition:
if err := r.replacePreProcess(taskConfig); err != nil {
return ctrl.Result{}, err
}
nodeMgmtClient, err := httphelper.NewMgmtClient(ctx, r.Client, dc)
nodeMgmtClient, err := httphelper.NewMgmtClient(ctx, r.Client, dc, nil)
if err != nil {
return ctrl.Result{}, err
}
Expand Down Expand Up @@ -634,7 +634,7 @@ func (r *CassandraTaskReconciler) reconcileEveryPodTask(ctx context.Context, dc
return dcPods[i].Name < dcPods[j].Name
})

nodeMgmtClient, err := httphelper.NewMgmtClient(ctx, r.Client, dc)
nodeMgmtClient, err := httphelper.NewMgmtClient(ctx, r.Client, dc, nil)
if err != nil {
return ctrl.Result{}, 0, 0, "", err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/httphelper/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ func (f *FeatureSet) Supports(feature Feature) bool {
return found
}

func NewMgmtClient(ctx context.Context, client client.Client, dc *cassdcapi.CassandraDatacenter) (NodeMgmtClient, error) {
func NewMgmtClient(ctx context.Context, client client.Client, dc *cassdcapi.CassandraDatacenter, customTransport *http.Transport) (NodeMgmtClient, error) {
logger := log.FromContext(ctx)

httpClient, err := BuildManagementApiHttpClient(dc, client, ctx)
httpClient, err := BuildManagementApiHttpClient(ctx, client, dc, customTransport)
if err != nil {
logger.Error(err, "error in BuildManagementApiHttpClient")
return NodeMgmtClient{}, err
Expand Down
44 changes: 44 additions & 0 deletions pkg/httphelper/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ package httphelper

import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"

"github.com/go-logr/logr"
Expand Down Expand Up @@ -556,3 +559,44 @@ var badPod = &corev1.Pod{
Name: "pod1",
},
}

func TestCustomTransport(t *testing.T) {
require := require.New(t)

called := false
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v0/ops/auth/role" {
w.WriteHeader(http.StatusOK)
called = true
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
defer testServer.Close()

testServerAddr := testServer.Listener.Addr().String()

customTransport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial(network, testServerAddr)
},
}

dc := &api.CassandraDatacenter{
Spec: api.CassandraDatacenterSpec{
ClusterName: "test-cluster",
},
}

mockClient := mocks.NewClient(t)

mgmtClient, err := NewMgmtClient(context.TODO(), mockClient, dc, customTransport)
mgmtClient.Log = logr.Discard()
require.NoError(err)

// This should call http://1.2.3.4:8080/api/v0/ops/auth/role, but the custom transport will override the address
err = mgmtClient.CallCreateRoleEndpoint(goodPod, "role1", "password1", true)
require.NoError(err)
require.True(called)

}
41 changes: 29 additions & 12 deletions pkg/httphelper/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ func GetManagementApiProtocol(dc *api.CassandraDatacenter) (string, error) {
return provider.GetProtocol(), nil
}

func BuildManagementApiHttpClient(dc *api.CassandraDatacenter, client client.Client, ctx context.Context) (HttpClient, error) {
func BuildManagementApiHttpClient(ctx context.Context, client client.Client, dc *api.CassandraDatacenter, customTransport *http.Transport) (HttpClient, error) {
provider, err := BuildManagementApiSecurityProvider(dc)
if err != nil {
return nil, err
}
return provider.BuildHttpClient(client, ctx)
return provider.BuildHttpClient(ctx, client, customTransport)
}

func AddManagementApiServerSecurity(dc *api.CassandraDatacenter, pod *corev1.PodTemplateSpec) error {
Expand Down Expand Up @@ -91,17 +91,17 @@ func ValidateManagementApiConfig(dc *api.CassandraDatacenter, client client.Clie
return []error{err}
}

return provider.ValidateConfig(client, ctx)
return provider.ValidateConfig(ctx, client)
}

// SPI for adding new mechanisms for securing the management API
type ManagementApiSecurityProvider interface {
BuildHttpClient(client client.Client, ctx context.Context) (HttpClient, error)
BuildHttpClient(ctx context.Context, client client.Client, transport *http.Transport) (HttpClient, error)
BuildMgmtApiGetAction(endpoint string, timeout int) *corev1.ExecAction
BuildMgmtApiPostAction(endpoint string, timeout int) *corev1.ExecAction
AddServerSecurity(pod *corev1.PodTemplateSpec) error
GetProtocol() string
ValidateConfig(client client.Client, ctx context.Context) []error
ValidateConfig(ctx context.Context, client client.Client) []error
}

type InsecureManagementApiSecurityProvider struct {
Expand All @@ -119,15 +119,21 @@ func (provider *InsecureManagementApiSecurityProvider) GetProtocol() string {
return "http"
}

func (provider *InsecureManagementApiSecurityProvider) BuildHttpClient(client client.Client, ctx context.Context) (HttpClient, error) {
return http.DefaultClient, nil
func (provider *InsecureManagementApiSecurityProvider) BuildHttpClient(ctx context.Context, client client.Client, transport *http.Transport) (HttpClient, error) {
c := http.DefaultClient

if transport != nil {
c.Transport = transport
}

return c, nil
}

func (provider *InsecureManagementApiSecurityProvider) AddServerSecurity(pod *corev1.PodTemplateSpec) error {
return nil
}

func (provider *InsecureManagementApiSecurityProvider) ValidateConfig(client client.Client, ctx context.Context) []error {
func (provider *InsecureManagementApiSecurityProvider) ValidateConfig(ctx context.Context, client client.Client) []error {
return []error{}
}

Expand Down Expand Up @@ -634,7 +640,7 @@ func validateSecret(secret *corev1.Secret) []error {
return validationErrors
}

func (provider *ManualManagementApiSecurityProvider) ValidateConfig(client client.Client, ctx context.Context) []error {
func (provider *ManualManagementApiSecurityProvider) ValidateConfig(ctx context.Context, client client.Client) []error {
var validationErrors []error

if provider.Config.SkipSecretValidation {
Expand Down Expand Up @@ -715,7 +721,12 @@ func (provider *ManualManagementApiSecurityProvider) ValidateConfig(client clien
return validationErrors
}

func (provider *ManualManagementApiSecurityProvider) BuildHttpClient(client client.Client, ctx context.Context) (HttpClient, error) {
func (provider *ManualManagementApiSecurityProvider) BuildHttpClient(ctx context.Context, client client.Client, transport *http.Transport) (HttpClient, error) {
httpClient := &http.Client{Transport: transport}
if transport != nil && transport.TLSClientConfig != nil {
return httpClient, nil
}

// Get the client Secret
secretNamespacedName := types.NamespacedName{
Name: provider.Config.ClientSecretName,
Expand Down Expand Up @@ -762,8 +773,14 @@ func (provider *ManualManagementApiSecurityProvider) BuildHttpClient(client clie
InsecureSkipVerify: true,
VerifyPeerCertificate: buildVerifyPeerCertificateNoHostCheck(caCertPool),
}
transport := &http.Transport{TLSClientConfig: tlsConfig}
httpClient := &http.Client{Transport: transport}

if transport != nil && transport.TLSClientConfig == nil {
transport.TLSClientConfig = tlsConfig
} else if transport == nil {
transport = &http.Transport{TLSClientConfig: tlsConfig}
}

httpClient.Transport = transport

return httpClient, nil
}
Expand Down
59 changes: 59 additions & 0 deletions pkg/httphelper/security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@
package httphelper

import (
"context"
"crypto/x509"
"encoding/pem"
"net/http"
"os"
"path/filepath"
"testing"

api "github.com/k8ssandra/cass-operator/apis/cassandra/v1beta1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/client-go/kubernetes/scheme"

"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/serializer"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

func helperLoadBytes(t *testing.T, name string) []byte {
Expand Down Expand Up @@ -98,3 +107,53 @@ func Test_validatePrivateKey(t *testing.T) {
t, 1, len(errs),
"Should consider an empty key as an invalid key")
}

// Create Datacenter with managementAuth set to manual and TLS enabled, test that the client is created with the correct TLS configuration using
// BuildManagementApiHttpClient method
func TestBuildMTLSClient(t *testing.T) {
require := require.New(t)
require.NoError(api.AddToScheme(scheme.Scheme))
decode := serializer.NewCodecFactory(scheme.Scheme).UniversalDeserializer().Decode

loadYaml := func(path string) (runtime.Object, error) {
bytes, err := os.ReadFile(path)
if err != nil {
return nil, err
}
obj, _, err := decode(bytes, nil, nil)
return obj, err
}

clientSecret, err := loadYaml(filepath.Join("..", "..", "tests", "testdata", "mtls-certs-client.yaml"))
require.NoError(err)

serverSecret, err := loadYaml(filepath.Join("..", "..", "tests", "testdata", "mtls-certs-server.yaml"))
require.NoError(err)

dc := &api.CassandraDatacenter{
Spec: api.CassandraDatacenterSpec{
ClusterName: "test-cluster",
ManagementApiAuth: api.ManagementApiAuthConfig{
Manual: &api.ManagementApiAuthManualConfig{
ClientSecretName: "mgmt-api-client-credentials",
ServerSecretName: "mgmt-api-server-credentials",
},
},
},
}

trackObjects := []runtime.Object{
clientSecret,
serverSecret,
dc,
}

client := fake.NewClientBuilder().WithRuntimeObjects(trackObjects...).Build()
ctx := context.TODO()

httpClient, err := BuildManagementApiHttpClient(ctx, client, dc, nil)
require.NoError(err)

tlsConfig := httpClient.(*http.Client).Transport.(*http.Transport).TLSClientConfig
require.NotNil(tlsConfig)
}
2 changes: 1 addition & 1 deletion pkg/reconciliation/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func CreateReconciliationContext(
log.IntoContext(ctx, rc.ReqLogger)

var err error
rc.NodeMgmtClient, err = httphelper.NewMgmtClient(rc.Ctx, cli, dc)
rc.NodeMgmtClient, err = httphelper.NewMgmtClient(rc.Ctx, cli, dc, nil)
if err != nil {
rc.ReqLogger.Error(err, "failed to build NodeMgmtClient")
return nil, err
Expand Down
Loading