Skip to content

Commit

Permalink
feat: support getting server certificate (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsdmike authored Aug 22, 2024
1 parent dbad927 commit abfed6d
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 11 deletions.
13 changes: 7 additions & 6 deletions internal/message/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package message

import (
"crypto/tls"
"errors"
"fmt"
"testing"
Expand All @@ -30,12 +31,12 @@ func (c *MockClient) Post(msg string) ([]byte, error) {

return response, c.Err
}
func (c *MockClient) Send(data []byte) error { return nil }
func (c *MockClient) Receive() ([]byte, error) { return nil, nil }
func (c *MockClient) CloseConnection() error { return nil }
func (c *MockClient) Connect() error { return nil }
func (c *MockClient) IsAuthenticated() bool { return true }

func (c *MockClient) Send(data []byte) error { return nil }
func (c *MockClient) Receive() ([]byte, error) { return nil, nil }
func (c *MockClient) CloseConnection() error { return nil }
func (c *MockClient) Connect() error { return nil }
func (c *MockClient) IsAuthenticated() bool { return true }
func (c *MockClient) GetServerCertificate() (*tls.Certificate, error) { return nil, nil }
func TestBaseWithClient(t *testing.T) {
mockWsmanMessageCreator := NewWSManMessageCreator("test-uri")
mockClient := MockClient{}
Expand Down
1 change: 1 addition & 0 deletions pkg/wsman/client/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ type Parameters struct {
LogAMTMessages bool
Transport http.RoundTripper
IsRedirection bool
PinnedCert string
}
76 changes: 75 additions & 1 deletion pkg/wsman/client/wsman.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ package client

import (
"bytes"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -45,6 +49,7 @@ type WSMan interface {
Receive() ([]byte, error)
CloseConnection() error
IsAuthenticated() bool
GetServerCertificate() (*tls.Certificate, error)
}

// Target is a thin wrapper around http.Target.
Expand Down Expand Up @@ -90,12 +95,39 @@ func NewWsman(cp Parameters) *Target {
res.Timeout = timeout

if cp.Transport == nil {
// check if pinnedCert is not null and not empty
var config *tls.Config
if len(cp.PinnedCert) > 0 {
config = &tls.Config{
InsecureSkipVerify: cp.SelfSignedAllowed,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
for _, rawCert := range rawCerts {
cert, err := x509.ParseCertificate(rawCert)
if err != nil {
return err
}

// Compare the current certificate with the pinned certificate
sha256Fingerprint := sha256.Sum256(cert.Raw)
if hex.EncodeToString(sha256Fingerprint[:]) == cp.PinnedCert {
return nil // Success: The certificate matches the pinned certificate
}
}

return fmt.Errorf("certificate pinning failed")
},
}
} else {
config = &tls.Config{InsecureSkipVerify: cp.SelfSignedAllowed}
}

res.Transport = &http.Transport{
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second,
DisableKeepAlives: true,
TLSClientConfig: &tls.Config{InsecureSkipVerify: cp.SelfSignedAllowed},
TLSClientConfig: config,
}

} else {
res.Transport = cp.Transport
}
Expand All @@ -111,6 +143,48 @@ func (t *Target) IsAuthenticated() bool {
return t.challenge != nil && t.challenge.Realm != ""
}

func (t *Target) GetServerCertificate() (*tls.Certificate, error) {
httpTransport, ok := t.Transport.(*http.Transport)
if !ok {
return nil, errors.New("transport does not support TLSClientConfig")
}

tlsConfig := httpTransport.TLSClientConfig
if tlsConfig == nil {
return nil, errors.New("TLSClientConfig is nil")
}

// Create a custom DialTLS to capture the server certificate
capturedCert := &tls.Certificate{}
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if len(rawCerts) > 0 {
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return err
}
*capturedCert = tls.Certificate{
Certificate: [][]byte{cert.Raw},
}
}
return nil
}

// Perform a connection to trigger the TLS handshake
nohttps := strings.Replace(t.endpoint, "https://", "", 1)
nohttps = strings.Replace(nohttps, "/wsman", "", 1)
conn, err := tls.Dial("tcp", nohttps, tlsConfig)
if err != nil {
return nil, err
}
defer conn.Close()

if len(capturedCert.Certificate) == 0 {
return nil, errors.New("no server certificate captured")
}

return capturedCert, nil
}

// Post overrides http.Client's Post method.
func (t *Target) Post(msg string) (response []byte, err error) {
msgBody := []byte(msg)
Expand Down
31 changes: 31 additions & 0 deletions pkg/wsman/client/wsman_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,34 @@ func TestClient_SimpleRountripper(t *testing.T) {
t.Error("Failed to detect proper transport")
}
}

func TestClient_GetServerCertificate(t *testing.T) {
// Setting up a mock server to simulate a TLS handshake and provide a certificate
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()

cp := Parameters{
Target: ts.URL,
Username: "user",
Password: "password",
UseDigest: false,
UseTLS: true,
SelfSignedAllowed: true,
LogAMTMessages: false,
}

client := NewWsman(cp)
client.endpoint = ts.URL

cert, err := client.GetServerCertificate()
if err != nil {
t.Errorf("Unexpected error during GetServerCertificate: %v", err)
}

// Check that a certificate was indeed captured
if cert == nil || len(cert.Certificate) == 0 {
t.Error("Expected a server certificate, but none was captured")
}
}
10 changes: 6 additions & 4 deletions pkg/wsman/wsmantesting/clientMock.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package wsmantesting

import (
"crypto/tls"
"io"
"os"
"strings"
Expand Down Expand Up @@ -41,7 +42,8 @@ func (c *MockClient) Post(msg string) ([]byte, error) {
// Simulate a successful response for testing.
return xmlData, nil
}
func (c *MockClient) Send(data []byte) error { return nil }
func (c *MockClient) Receive() ([]byte, error) { return nil, nil }
func (c *MockClient) CloseConnection() error { return nil }
func (c *MockClient) Connect() error { return nil }
func (c *MockClient) Send(data []byte) error { return nil }
func (c *MockClient) Receive() ([]byte, error) { return nil, nil }
func (c *MockClient) CloseConnection() error { return nil }
func (c *MockClient) Connect() error { return nil }
func (c *MockClient) GetServerCertificate() (*tls.Certificate, error) { return nil, nil }

0 comments on commit abfed6d

Please sign in to comment.