diff --git a/pkg/wsman/client/wsman.go b/pkg/wsman/client/wsman.go index 8b9ade9b..ff16e6f6 100644 --- a/pkg/wsman/client/wsman.go +++ b/pkg/wsman/client/wsman.go @@ -65,6 +65,7 @@ type Target struct { bufferPool sync.Pool UseTLS bool InsecureSkipVerify bool + PinnedCert string } const timeout = 10 * time.Second @@ -174,7 +175,6 @@ func (t *Target) GetServerCertificate() (*tls.Certificate, error) { nohttps := strings.Replace(t.endpoint, "https://", "", 1) nohttps = strings.Replace(nohttps, "/wsman", "", 1) - // Perform a connection to trigger the TLS handshake conn, err := tls.Dial("tcp", nohttps, tlsConfig) if err != nil { return nil, err diff --git a/pkg/wsman/client/wsman_tcp.go b/pkg/wsman/client/wsman_tcp.go index 021573ad..6c0b3966 100644 --- a/pkg/wsman/client/wsman_tcp.go +++ b/pkg/wsman/client/wsman_tcp.go @@ -1,7 +1,10 @@ package client import ( + "crypto/sha256" "crypto/tls" + "crypto/x509" + "encoding/hex" "fmt" "net" "sync" @@ -22,6 +25,7 @@ func NewWsmanTCP(cp Parameters) *Target { challenge: &AuthChallenge{}, UseTLS: cp.UseTLS, InsecureSkipVerify: cp.SelfSignedAllowed, + PinnedCert: cp.PinnedCert, bufferPool: sync.Pool{ New: func() interface{} { return make([]byte, 4096) // Adjust size according to your needs. @@ -33,10 +37,35 @@ func NewWsmanTCP(cp Parameters) *Target { // Connect establishes a TCP connection to the endpoint specified in the Target struct. func (t *Target) Connect() error { var err error + if t.UseTLS { - t.conn, err = tls.Dial("tcp", t.endpoint, &tls.Config{ - InsecureSkipVerify: t.InsecureSkipVerify, - }) + // check if pinnedCert is not null and not empty + var config *tls.Config + if len(t.PinnedCert) > 0 { + config = &tls.Config{ + InsecureSkipVerify: t.InsecureSkipVerify, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + for _, rawCert := range rawCerts { + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + return err + } + + // Compare the current certificate with the pinned certificate + sha256Fingerprint := sha256.Sum256(cert.Raw) + if hex.EncodeToString(sha256Fingerprint[:]) == t.PinnedCert { + return nil // Success: The certificate matches the pinned certificate + } + } + + return fmt.Errorf("certificate pinning failed") + }, + } + } else { + config = &tls.Config{InsecureSkipVerify: t.InsecureSkipVerify} + } + + t.conn, err = tls.Dial("tcp", t.endpoint, config) } else { t.conn, err = net.Dial("tcp", t.endpoint) }