diff --git a/internal/message/base_test.go b/internal/message/base_test.go index 93b506ea..0c792dde 100644 --- a/internal/message/base_test.go +++ b/internal/message/base_test.go @@ -6,6 +6,7 @@ package message import ( + "crypto/tls" "errors" "fmt" "testing" @@ -30,12 +31,12 @@ func (c *MockClient) Post(msg string) ([]byte, error) { return response, c.Err } -func (c *MockClient) Send(data []byte) error { return nil } -func (c *MockClient) Receive() ([]byte, error) { return nil, nil } -func (c *MockClient) CloseConnection() error { return nil } -func (c *MockClient) Connect() error { return nil } -func (c *MockClient) IsAuthenticated() bool { return true } - +func (c *MockClient) Send(data []byte) error { return nil } +func (c *MockClient) Receive() ([]byte, error) { return nil, nil } +func (c *MockClient) CloseConnection() error { return nil } +func (c *MockClient) Connect() error { return nil } +func (c *MockClient) IsAuthenticated() bool { return true } +func (c *MockClient) GetServerCertificate() (*tls.Certificate, error) { return nil, nil } func TestBaseWithClient(t *testing.T) { mockWsmanMessageCreator := NewWSManMessageCreator("test-uri") mockClient := MockClient{} diff --git a/pkg/wsman/client/types.go b/pkg/wsman/client/types.go index 23287c55..a08ef33a 100644 --- a/pkg/wsman/client/types.go +++ b/pkg/wsman/client/types.go @@ -13,4 +13,5 @@ type Parameters struct { LogAMTMessages bool Transport http.RoundTripper IsRedirection bool + PinnedCert string } diff --git a/pkg/wsman/client/wsman.go b/pkg/wsman/client/wsman.go index 57a61a1c..ac3e3df8 100644 --- a/pkg/wsman/client/wsman.go +++ b/pkg/wsman/client/wsman.go @@ -7,13 +7,17 @@ package client import ( "bytes" + "crypto/sha256" "crypto/tls" + "crypto/x509" + "encoding/hex" "errors" "fmt" "io" "net" "net/http" "net/url" + "strings" "sync" "time" @@ -45,6 +49,7 @@ type WSMan interface { Receive() ([]byte, error) CloseConnection() error IsAuthenticated() bool + GetServerCertificate() (*tls.Certificate, error) } // Target is a thin wrapper around http.Target. @@ -90,12 +95,39 @@ func NewWsman(cp Parameters) *Target { res.Timeout = timeout if cp.Transport == nil { + // check if pinnedCert is not null and not empty + var config *tls.Config + if len(cp.PinnedCert) > 0 { + config = &tls.Config{ + InsecureSkipVerify: cp.SelfSignedAllowed, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + for _, rawCert := range rawCerts { + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + return err + } + + // Compare the current certificate with the pinned certificate + sha256Fingerprint := sha256.Sum256(cert.Raw) + if hex.EncodeToString(sha256Fingerprint[:]) == cp.PinnedCert { + return nil // Success: The certificate matches the pinned certificate + } + } + + return fmt.Errorf("certificate pinning failed") + }, + } + } else { + config = &tls.Config{InsecureSkipVerify: cp.SelfSignedAllowed} + } + res.Transport = &http.Transport{ MaxIdleConns: 10, IdleConnTimeout: 30 * time.Second, DisableKeepAlives: true, - TLSClientConfig: &tls.Config{InsecureSkipVerify: cp.SelfSignedAllowed}, + TLSClientConfig: config, } + } else { res.Transport = cp.Transport } @@ -111,6 +143,48 @@ func (t *Target) IsAuthenticated() bool { return t.challenge != nil && t.challenge.Realm != "" } +func (t *Target) GetServerCertificate() (*tls.Certificate, error) { + httpTransport, ok := t.Transport.(*http.Transport) + if !ok { + return nil, errors.New("transport does not support TLSClientConfig") + } + + tlsConfig := httpTransport.TLSClientConfig + if tlsConfig == nil { + return nil, errors.New("TLSClientConfig is nil") + } + + // Create a custom DialTLS to capture the server certificate + capturedCert := &tls.Certificate{} + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + if len(rawCerts) > 0 { + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return err + } + *capturedCert = tls.Certificate{ + Certificate: [][]byte{cert.Raw}, + } + } + return nil + } + + // Perform a connection to trigger the TLS handshake + nohttps := strings.Replace(t.endpoint, "https://", "", 1) + nohttps = strings.Replace(nohttps, "/wsman", "", 1) + conn, err := tls.Dial("tcp", nohttps, tlsConfig) + if err != nil { + return nil, err + } + defer conn.Close() + + if len(capturedCert.Certificate) == 0 { + return nil, errors.New("no server certificate captured") + } + + return capturedCert, nil +} + // Post overrides http.Client's Post method. func (t *Target) Post(msg string) (response []byte, err error) { msgBody := []byte(msg) diff --git a/pkg/wsman/client/wsman_test.go b/pkg/wsman/client/wsman_test.go index 7912e005..6a330a59 100644 --- a/pkg/wsman/client/wsman_test.go +++ b/pkg/wsman/client/wsman_test.go @@ -417,3 +417,34 @@ func TestClient_SimpleRountripper(t *testing.T) { t.Error("Failed to detect proper transport") } } + +func TestClient_GetServerCertificate(t *testing.T) { + // Setting up a mock server to simulate a TLS handshake and provide a certificate + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + cp := Parameters{ + Target: ts.URL, + Username: "user", + Password: "password", + UseDigest: false, + UseTLS: true, + SelfSignedAllowed: true, + LogAMTMessages: false, + } + + client := NewWsman(cp) + client.endpoint = ts.URL + + cert, err := client.GetServerCertificate() + if err != nil { + t.Errorf("Unexpected error during GetServerCertificate: %v", err) + } + + // Check that a certificate was indeed captured + if cert == nil || len(cert.Certificate) == 0 { + t.Error("Expected a server certificate, but none was captured") + } +} diff --git a/pkg/wsman/wsmantesting/clientMock.go b/pkg/wsman/wsmantesting/clientMock.go index 661c9bc1..c3f874f2 100644 --- a/pkg/wsman/wsmantesting/clientMock.go +++ b/pkg/wsman/wsmantesting/clientMock.go @@ -1,6 +1,7 @@ package wsmantesting import ( + "crypto/tls" "io" "os" "strings" @@ -41,7 +42,8 @@ func (c *MockClient) Post(msg string) ([]byte, error) { // Simulate a successful response for testing. return xmlData, nil } -func (c *MockClient) Send(data []byte) error { return nil } -func (c *MockClient) Receive() ([]byte, error) { return nil, nil } -func (c *MockClient) CloseConnection() error { return nil } -func (c *MockClient) Connect() error { return nil } +func (c *MockClient) Send(data []byte) error { return nil } +func (c *MockClient) Receive() ([]byte, error) { return nil, nil } +func (c *MockClient) CloseConnection() error { return nil } +func (c *MockClient) Connect() error { return nil } +func (c *MockClient) GetServerCertificate() (*tls.Certificate, error) { return nil, nil }