-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
388 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package interceptor | ||
|
||
import ( | ||
"context" | ||
"log" | ||
"path" | ||
"time" | ||
|
||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/credentials" | ||
"google.golang.org/grpc/peer" | ||
"google.golang.org/grpc/status" | ||
) | ||
|
||
const ( | ||
// keyStartTimestamp is the key name for the starting time of a request (in seconds). | ||
keyStartTimestamp = "sts" | ||
|
||
// keyMethod is the key name for the grpc method. | ||
keyMethod = "mtd" | ||
|
||
// keyStatusCode is the key name for the response status code. | ||
keyStatusCode = "st" | ||
|
||
// keyDurationMs is the key name for the duration of the request (in milliseconds). | ||
keyDurationMs = "dur" | ||
|
||
// keyPrincipal is the key name for the common name of the peer certificate extracted from the context. | ||
keyPrincipal = "prin" | ||
|
||
// accessLogMsg is the special log message that will be used in access log so that it can | ||
// be used to distinguish from other server logs. | ||
accessLogMsg = "grpcAccessLog" | ||
) | ||
|
||
type accessLogInterceptor struct { | ||
timeNow func() time.Time | ||
} | ||
|
||
func (i *accessLogInterceptor) Func( | ||
ctx context.Context, | ||
req interface{}, | ||
info *grpc.UnaryServerInfo, | ||
handler grpc.UnaryHandler, | ||
) (interface{}, error) { | ||
startTime := i.timeNow() | ||
resp, err := handler(ctx, req) | ||
elapsedTime := i.timeNow().Sub(startTime) | ||
|
||
log.Printf(`m=%s,%s=%s,%s=%f,%s=%s,%s=%d,%s=%d`, | ||
accessLogMsg, | ||
keyPrincipal, getPrincipalFromContext(ctx), | ||
keyStartTimestamp, float64(startTime.UnixNano())/float64(time.Second), | ||
keyMethod, path.Base(info.FullMethod), | ||
keyStatusCode, getStatus(err), | ||
keyDurationMs, elapsedTime.Milliseconds()) | ||
return resp, err | ||
} | ||
|
||
func getPrincipalFromContext(ctx context.Context) string { | ||
p, ok := peer.FromContext(ctx) | ||
if !ok || p == nil { | ||
return "unknownPeer" | ||
} | ||
tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo) | ||
if !ok { | ||
return "unknownTLSInfo" | ||
} | ||
certs := tlsInfo.State.PeerCertificates | ||
if len(certs) == 0 || certs[0] == nil { | ||
return "peerCertificateNotFound" | ||
} | ||
return certs[0].Subject.CommonName | ||
} | ||
|
||
func getStatus(err error) uint32 { | ||
statusErr := status.Convert(err) | ||
return uint32(statusErr.Code()) | ||
} | ||
|
||
func AccessLogInterceptor() grpc.UnaryServerInterceptor { | ||
interceptor := &accessLogInterceptor{ | ||
timeNow: time.Now, | ||
} | ||
return interceptor.Func | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,301 @@ | ||
package interceptor | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/tls" | ||
"crypto/x509" | ||
"crypto/x509/pkix" | ||
"encoding/pem" | ||
"log" | ||
"math/big" | ||
"net" | ||
"os" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing" | ||
pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" | ||
"github.com/stretchr/testify/require" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/credentials" | ||
"google.golang.org/grpc/test/bufconn" | ||
) | ||
|
||
const ( | ||
caCName = "ca-example" | ||
serverCName = "server-example" | ||
clientCName = "client-example" | ||
) | ||
|
||
type fakeTimer struct { | ||
t time.Time | ||
} | ||
|
||
func newFakeTimer() *fakeTimer { | ||
return &fakeTimer{time.Unix(1234567890, 987654321)} | ||
} | ||
|
||
func (f *fakeTimer) now() time.Time { | ||
now := f.t | ||
f.t = f.t.Add(time.Millisecond * 1234) | ||
return now | ||
} | ||
|
||
func signX509Cert(unsignedCert, caCert *x509.Certificate, pubKey *rsa.PublicKey, | ||
caPrivKey *rsa.PrivateKey) (*x509.Certificate, []byte, error) { | ||
certBytes, err := x509.CreateCertificate(rand.Reader, unsignedCert, caCert, pubKey, caPrivKey) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
cert, err := x509.ParseCertificate(certBytes) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
b := pem.Block{Type: "CERTIFICATE", Bytes: certBytes} | ||
pem := pem.EncodeToMemory(&b) | ||
|
||
return cert, pem, nil | ||
} | ||
|
||
func genSelfSignedCAX509Cert() (*x509.Certificate, []byte, *rsa.PrivateKey, error) { | ||
var unsignedCert = &x509.Certificate{ | ||
SerialNumber: big.NewInt(1), | ||
Subject: pkix.Name{ | ||
Country: []string{"US"}, | ||
CommonName: caCName, | ||
}, | ||
DNSNames: []string{caCName}, | ||
NotBefore: time.Now().AddDate(-1, 0, 0), | ||
NotAfter: time.Now().AddDate(10, 0, 0), | ||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, | ||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, | ||
BasicConstraintsValid: true, | ||
IsCA: true, | ||
} | ||
priv, err := rsa.GenerateKey(rand.Reader, 2048) | ||
if err != nil { | ||
return nil, nil, nil, err | ||
} | ||
cert, pem, err := signX509Cert(unsignedCert, unsignedCert, &priv.PublicKey, priv) | ||
if err != nil { | ||
return nil, nil, nil, err | ||
} | ||
return cert, pem, priv, nil | ||
} | ||
|
||
func genAndSignX509Cert(cname string, caCert *x509.Certificate, caKey *rsa.PrivateKey) ([]byte, []byte, error) { | ||
priv, err := rsa.GenerateKey(rand.Reader, 2048) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
privPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) | ||
|
||
var unsignedCert = &x509.Certificate{ | ||
SerialNumber: big.NewInt(1), | ||
Subject: pkix.Name{ | ||
Country: []string{"US"}, | ||
CommonName: cname, | ||
}, | ||
DNSNames: []string{cname}, | ||
NotBefore: time.Now().Add(-10 * time.Second), | ||
NotAfter: time.Now().AddDate(10, 0, 0), | ||
KeyUsage: x509.KeyUsageCRLSign, | ||
IsCA: false, | ||
} | ||
|
||
_, certPem, err := signX509Cert(unsignedCert, caCert, &priv.PublicKey, caKey) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
return certPem, privPem, nil | ||
} | ||
|
||
func TestAccessLogInterceptor(t *testing.T) { | ||
// Create ping request for testing. | ||
ping := &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999} | ||
|
||
// Create CA credentials for mTLS. | ||
ca, _, caPriv, err := genSelfSignedCAX509Cert() | ||
if err != nil { | ||
t.Fatalf("failed to gerenate self signed ca cert, err: %v", err) | ||
} | ||
caCertPool := x509.NewCertPool() | ||
caCertPool.AddCert(ca) | ||
|
||
tests := []struct { | ||
name string | ||
init func() | ||
setupServer func(ctx context.Context, t *testing.T, listener *bufconn.Listener) (*grpc.Server, func()) | ||
setupClient func(ctx context.Context, svr *grpc.Server, listener *bufconn.Listener) pb_testproto.TestServiceClient | ||
wantLog string | ||
}{ | ||
{ | ||
name: "happy path", | ||
setupServer: func(ctx context.Context, t *testing.T, listener *bufconn.Listener) (*grpc.Server, func()) { | ||
svrCertPem, svrPrivPem, err := genAndSignX509Cert(serverCName, ca, caPriv) | ||
if err != nil { | ||
t.Fatalf("failed to gerenate server cert, err: %v", err) | ||
} | ||
|
||
svrCertificate, err := tls.X509KeyPair(svrCertPem, svrPrivPem) | ||
if err != nil { | ||
t.Fatalf("failed to load x509 key pair, err: %v", err) | ||
} | ||
|
||
svrTLSConfig := &tls.Config{ | ||
ClientAuth: tls.RequireAndVerifyClientCert, | ||
Certificates: []tls.Certificate{svrCertificate}, | ||
ClientCAs: caCertPool, | ||
} | ||
|
||
timer := newFakeTimer() | ||
interceptor := &accessLogInterceptor{ | ||
timeNow: timer.now, | ||
} | ||
|
||
grpcServer := grpc.NewServer([]grpc.ServerOption{ | ||
grpc.Creds(credentials.NewTLS(svrTLSConfig)), | ||
grpc.UnaryInterceptor(interceptor.Func), | ||
}...) | ||
|
||
testService := &grpc_testing.TestPingService{T: t} | ||
pb_testproto.RegisterTestServiceServer(grpcServer, testService) | ||
|
||
go func() { | ||
if err := grpcServer.Serve(listener); err != nil { | ||
panic(err) | ||
} | ||
}() | ||
|
||
closer := func() { | ||
listener.Close() | ||
grpcServer.Stop() | ||
} | ||
|
||
return grpcServer, closer | ||
}, | ||
setupClient: func(ctx context.Context, server *grpc.Server, listener *bufconn.Listener) pb_testproto.TestServiceClient { | ||
clientCertPem, clientPrivPem, err := genAndSignX509Cert(clientCName, ca, caPriv) | ||
if err != nil { | ||
t.Fatalf("failed to gerenate server cert, err: %v", err) | ||
} | ||
|
||
clientCertificate, err := tls.X509KeyPair(clientCertPem, clientPrivPem) | ||
if err != nil { | ||
t.Fatalf("failed to load x509 key pair, err: %v", err) | ||
} | ||
|
||
clientTLConfig := &tls.Config{ | ||
Certificates: []tls.Certificate{clientCertificate}, | ||
RootCAs: caCertPool, | ||
ServerName: serverCName, | ||
} | ||
|
||
clientConn, _ := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { | ||
return listener.Dial() | ||
}), grpc.WithTransportCredentials(credentials.NewTLS(clientTLConfig))) | ||
return pb_testproto.NewTestServiceClient(clientConn) | ||
}, | ||
wantLog: "m=grpcAccessLog,prin=client-example,sts=1234567890.987654,mtd=Ping,st=0,dur=1234", | ||
}, | ||
{ | ||
name: "unknown tls info, without time check", | ||
setupServer: func(ctx context.Context, t *testing.T, listener *bufconn.Listener) (*grpc.Server, func()) { | ||
grpcServer := grpc.NewServer([]grpc.ServerOption{ | ||
grpc.UnaryInterceptor(AccessLogInterceptor()), | ||
}...) | ||
testService := &grpc_testing.TestPingService{T: t} | ||
pb_testproto.RegisterTestServiceServer(grpcServer, testService) | ||
go func() { | ||
if err := grpcServer.Serve(listener); err != nil { | ||
panic(err) | ||
} | ||
}() | ||
closer := func() { | ||
listener.Close() | ||
grpcServer.Stop() | ||
} | ||
return grpcServer, closer | ||
}, | ||
setupClient: func(ctx context.Context, server *grpc.Server, listener *bufconn.Listener) pb_testproto.TestServiceClient { | ||
clientConn, _ := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { | ||
return listener.Dial() | ||
}), grpc.WithInsecure()) | ||
return pb_testproto.NewTestServiceClient(clientConn) | ||
}, | ||
wantLog: "m=grpcAccessLog,prin=unknownTLSInfo", | ||
}, | ||
{ | ||
name: "unknown tls info, without time check", | ||
setupServer: func(ctx context.Context, t *testing.T, listener *bufconn.Listener) (*grpc.Server, func()) { | ||
svrCertPem, svrPrivPem, err := genAndSignX509Cert(serverCName, ca, caPriv) | ||
if err != nil { | ||
t.Fatalf("failed to gerenate server cert, err: %v", err) | ||
} | ||
|
||
svrCertificate, err := tls.X509KeyPair(svrCertPem, svrPrivPem) | ||
if err != nil { | ||
t.Fatalf("failed to load x509 key pair, err: %v", err) | ||
} | ||
|
||
svrTLSConfig := &tls.Config{ | ||
Certificates: []tls.Certificate{svrCertificate}, | ||
ClientCAs: caCertPool, | ||
} | ||
|
||
grpcServer := grpc.NewServer([]grpc.ServerOption{ | ||
grpc.Creds(credentials.NewTLS(svrTLSConfig)), | ||
grpc.UnaryInterceptor(AccessLogInterceptor()), | ||
}...) | ||
testService := &grpc_testing.TestPingService{T: t} | ||
pb_testproto.RegisterTestServiceServer(grpcServer, testService) | ||
go func() { | ||
if err := grpcServer.Serve(listener); err != nil { | ||
panic(err) | ||
} | ||
}() | ||
closer := func() { | ||
listener.Close() | ||
grpcServer.Stop() | ||
} | ||
return grpcServer, closer | ||
}, | ||
setupClient: func(ctx context.Context, server *grpc.Server, listener *bufconn.Listener) pb_testproto.TestServiceClient { | ||
clientTLConfig := &tls.Config{ServerName: serverCName, RootCAs: caCertPool} | ||
clientConn, _ := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { | ||
return listener.Dial() | ||
}), grpc.WithTransportCredentials(credentials.NewTLS(clientTLConfig))) | ||
return pb_testproto.NewTestServiceClient(clientConn) | ||
}, | ||
wantLog: "m=grpcAccessLog,prin=peerCertificateNotFound", | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
buffer := new(bytes.Buffer) | ||
log.SetOutput(buffer) | ||
defer log.SetOutput(os.Stderr) | ||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) | ||
defer cancel() | ||
listener := bufconn.Listen(1024 * 1024) | ||
grpcServer, cleanup := tt.setupServer(ctx, t, listener) | ||
defer cleanup() | ||
client := tt.setupClient(ctx, grpcServer, listener) | ||
_, err := client.Ping(ctx, ping) | ||
if err != nil { | ||
require.NoError(t, err, "no error should occur") | ||
} | ||
actualLog := string(buffer.Bytes()) | ||
|
||
if !strings.Contains(actualLog, tt.wantLog) { | ||
t.Fatalf("got: %v but want: %v", actualLog, tt.wantLog) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters