From acf0d3833aa01405ce6084f8bb5b2c099cd3dc5b Mon Sep 17 00:00:00 2001 From: Mike Johanson Date: Thu, 22 Aug 2024 11:18:38 -0700 Subject: [PATCH] feat: enable cert pinning for redirection capabilities --- pkg/wsman/client/wsman.go | 6 +++++- pkg/wsman/client/wsman_tcp.go | 35 ++++++++++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/pkg/wsman/client/wsman.go b/pkg/wsman/client/wsman.go index ac3e3df8..b2d30265 100644 --- a/pkg/wsman/client/wsman.go +++ b/pkg/wsman/client/wsman.go @@ -65,6 +65,7 @@ type Target struct { bufferPool sync.Pool UseTLS bool InsecureSkipVerify bool + PinnedCert string } const timeout = 10 * time.Second @@ -127,7 +128,6 @@ func NewWsman(cp Parameters) *Target { DisableKeepAlives: true, TLSClientConfig: config, } - } else { res.Transport = cp.Transport } @@ -162,20 +162,24 @@ func (t *Target) GetServerCertificate() (*tls.Certificate, error) { if err != nil { return err } + *capturedCert = tls.Certificate{ Certificate: [][]byte{cert.Raw}, } } + return nil } // Perform a connection to trigger the TLS handshake nohttps := strings.Replace(t.endpoint, "https://", "", 1) nohttps = strings.Replace(nohttps, "/wsman", "", 1) + conn, err := tls.Dial("tcp", nohttps, tlsConfig) if err != nil { return nil, err } + defer conn.Close() if len(capturedCert.Certificate) == 0 { diff --git a/pkg/wsman/client/wsman_tcp.go b/pkg/wsman/client/wsman_tcp.go index 021573ad..6c0b3966 100644 --- a/pkg/wsman/client/wsman_tcp.go +++ b/pkg/wsman/client/wsman_tcp.go @@ -1,7 +1,10 @@ package client import ( + "crypto/sha256" "crypto/tls" + "crypto/x509" + "encoding/hex" "fmt" "net" "sync" @@ -22,6 +25,7 @@ func NewWsmanTCP(cp Parameters) *Target { challenge: &AuthChallenge{}, UseTLS: cp.UseTLS, InsecureSkipVerify: cp.SelfSignedAllowed, + PinnedCert: cp.PinnedCert, bufferPool: sync.Pool{ New: func() interface{} { return make([]byte, 4096) // Adjust size according to your needs. @@ -33,10 +37,35 @@ func NewWsmanTCP(cp Parameters) *Target { // Connect establishes a TCP connection to the endpoint specified in the Target struct. func (t *Target) Connect() error { var err error + if t.UseTLS { - t.conn, err = tls.Dial("tcp", t.endpoint, &tls.Config{ - InsecureSkipVerify: t.InsecureSkipVerify, - }) + // check if pinnedCert is not null and not empty + var config *tls.Config + if len(t.PinnedCert) > 0 { + config = &tls.Config{ + InsecureSkipVerify: t.InsecureSkipVerify, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + for _, rawCert := range rawCerts { + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + return err + } + + // Compare the current certificate with the pinned certificate + sha256Fingerprint := sha256.Sum256(cert.Raw) + if hex.EncodeToString(sha256Fingerprint[:]) == t.PinnedCert { + return nil // Success: The certificate matches the pinned certificate + } + } + + return fmt.Errorf("certificate pinning failed") + }, + } + } else { + config = &tls.Config{InsecureSkipVerify: t.InsecureSkipVerify} + } + + t.conn, err = tls.Dial("tcp", t.endpoint, config) } else { t.conn, err = net.Dial("tcp", t.endpoint) }