Skip to content

Commit

Permalink
Access log interceptor (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
py4chen authored Aug 30, 2021
1 parent 35a4f19 commit 43f1689
Show file tree
Hide file tree
Showing 3 changed files with 388 additions and 0 deletions.
86 changes: 86 additions & 0 deletions server/interceptor/access_log_interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package interceptor

import (
"context"
"log"
"path"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)

const (
// keyStartTimestamp is the key name for the starting time of a request (in seconds).
keyStartTimestamp = "sts"

// keyMethod is the key name for the grpc method.
keyMethod = "mtd"

// keyStatusCode is the key name for the response status code.
keyStatusCode = "st"

// keyDurationMs is the key name for the duration of the request (in milliseconds).
keyDurationMs = "dur"

// keyPrincipal is the key name for the common name of the peer certificate extracted from the context.
keyPrincipal = "prin"

// accessLogMsg is the special log message that will be used in access log so that it can
// be used to distinguish from other server logs.
accessLogMsg = "grpcAccessLog"
)

type accessLogInterceptor struct {
timeNow func() time.Time
}

func (i *accessLogInterceptor) Func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
startTime := i.timeNow()
resp, err := handler(ctx, req)
elapsedTime := i.timeNow().Sub(startTime)

log.Printf(`m=%s,%s=%s,%s=%f,%s=%s,%s=%d,%s=%d`,
accessLogMsg,
keyPrincipal, getPrincipalFromContext(ctx),
keyStartTimestamp, float64(startTime.UnixNano())/float64(time.Second),
keyMethod, path.Base(info.FullMethod),
keyStatusCode, getStatus(err),
keyDurationMs, elapsedTime.Milliseconds())
return resp, err
}

func getPrincipalFromContext(ctx context.Context) string {
p, ok := peer.FromContext(ctx)
if !ok || p == nil {
return "unknownPeer"
}
tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo)
if !ok {
return "unknownTLSInfo"
}
certs := tlsInfo.State.PeerCertificates
if len(certs) == 0 || certs[0] == nil {
return "peerCertificateNotFound"
}
return certs[0].Subject.CommonName
}

func getStatus(err error) uint32 {
statusErr := status.Convert(err)
return uint32(statusErr.Code())
}

func AccessLogInterceptor() grpc.UnaryServerInterceptor {
interceptor := &accessLogInterceptor{
timeNow: time.Now,
}
return interceptor.Func
}
301 changes: 301 additions & 0 deletions server/interceptor/access_log_interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
package interceptor

import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"log"
"math/big"
"net"
"os"
"strings"
"testing"
"time"

grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing"
pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/test/bufconn"
)

const (
caCName = "ca-example"
serverCName = "server-example"
clientCName = "client-example"
)

type fakeTimer struct {
t time.Time
}

func newFakeTimer() *fakeTimer {
return &fakeTimer{time.Unix(1234567890, 987654321)}
}

func (f *fakeTimer) now() time.Time {
now := f.t
f.t = f.t.Add(time.Millisecond * 1234)
return now
}

func signX509Cert(unsignedCert, caCert *x509.Certificate, pubKey *rsa.PublicKey,
caPrivKey *rsa.PrivateKey) (*x509.Certificate, []byte, error) {
certBytes, err := x509.CreateCertificate(rand.Reader, unsignedCert, caCert, pubKey, caPrivKey)
if err != nil {
return nil, nil, err
}

cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, nil, err
}

b := pem.Block{Type: "CERTIFICATE", Bytes: certBytes}
pem := pem.EncodeToMemory(&b)

return cert, pem, nil
}

func genSelfSignedCAX509Cert() (*x509.Certificate, []byte, *rsa.PrivateKey, error) {
var unsignedCert = &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Country: []string{"US"},
CommonName: caCName,
},
DNSNames: []string{caCName},
NotBefore: time.Now().AddDate(-1, 0, 0),
NotAfter: time.Now().AddDate(10, 0, 0),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, nil, err
}
cert, pem, err := signX509Cert(unsignedCert, unsignedCert, &priv.PublicKey, priv)
if err != nil {
return nil, nil, nil, err
}
return cert, pem, priv, nil
}

func genAndSignX509Cert(cname string, caCert *x509.Certificate, caKey *rsa.PrivateKey) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
privPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})

var unsignedCert = &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Country: []string{"US"},
CommonName: cname,
},
DNSNames: []string{cname},
NotBefore: time.Now().Add(-10 * time.Second),
NotAfter: time.Now().AddDate(10, 0, 0),
KeyUsage: x509.KeyUsageCRLSign,
IsCA: false,
}

_, certPem, err := signX509Cert(unsignedCert, caCert, &priv.PublicKey, caKey)
if err != nil {
return nil, nil, err
}
return certPem, privPem, nil
}

func TestAccessLogInterceptor(t *testing.T) {
// Create ping request for testing.
ping := &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}

// Create CA credentials for mTLS.
ca, _, caPriv, err := genSelfSignedCAX509Cert()
if err != nil {
t.Fatalf("failed to gerenate self signed ca cert, err: %v", err)
}
caCertPool := x509.NewCertPool()
caCertPool.AddCert(ca)

tests := []struct {
name string
init func()
setupServer func(ctx context.Context, t *testing.T, listener *bufconn.Listener) (*grpc.Server, func())
setupClient func(ctx context.Context, svr *grpc.Server, listener *bufconn.Listener) pb_testproto.TestServiceClient
wantLog string
}{
{
name: "happy path",
setupServer: func(ctx context.Context, t *testing.T, listener *bufconn.Listener) (*grpc.Server, func()) {
svrCertPem, svrPrivPem, err := genAndSignX509Cert(serverCName, ca, caPriv)
if err != nil {
t.Fatalf("failed to gerenate server cert, err: %v", err)
}

svrCertificate, err := tls.X509KeyPair(svrCertPem, svrPrivPem)
if err != nil {
t.Fatalf("failed to load x509 key pair, err: %v", err)
}

svrTLSConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{svrCertificate},
ClientCAs: caCertPool,
}

timer := newFakeTimer()
interceptor := &accessLogInterceptor{
timeNow: timer.now,
}

grpcServer := grpc.NewServer([]grpc.ServerOption{
grpc.Creds(credentials.NewTLS(svrTLSConfig)),
grpc.UnaryInterceptor(interceptor.Func),
}...)

testService := &grpc_testing.TestPingService{T: t}
pb_testproto.RegisterTestServiceServer(grpcServer, testService)

go func() {
if err := grpcServer.Serve(listener); err != nil {
panic(err)
}
}()

closer := func() {
listener.Close()
grpcServer.Stop()
}

return grpcServer, closer
},
setupClient: func(ctx context.Context, server *grpc.Server, listener *bufconn.Listener) pb_testproto.TestServiceClient {
clientCertPem, clientPrivPem, err := genAndSignX509Cert(clientCName, ca, caPriv)
if err != nil {
t.Fatalf("failed to gerenate server cert, err: %v", err)
}

clientCertificate, err := tls.X509KeyPair(clientCertPem, clientPrivPem)
if err != nil {
t.Fatalf("failed to load x509 key pair, err: %v", err)
}

clientTLConfig := &tls.Config{
Certificates: []tls.Certificate{clientCertificate},
RootCAs: caCertPool,
ServerName: serverCName,
}

clientConn, _ := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return listener.Dial()
}), grpc.WithTransportCredentials(credentials.NewTLS(clientTLConfig)))
return pb_testproto.NewTestServiceClient(clientConn)
},
wantLog: "m=grpcAccessLog,prin=client-example,sts=1234567890.987654,mtd=Ping,st=0,dur=1234",
},
{
name: "unknown tls info, without time check",
setupServer: func(ctx context.Context, t *testing.T, listener *bufconn.Listener) (*grpc.Server, func()) {
grpcServer := grpc.NewServer([]grpc.ServerOption{
grpc.UnaryInterceptor(AccessLogInterceptor()),
}...)
testService := &grpc_testing.TestPingService{T: t}
pb_testproto.RegisterTestServiceServer(grpcServer, testService)
go func() {
if err := grpcServer.Serve(listener); err != nil {
panic(err)
}
}()
closer := func() {
listener.Close()
grpcServer.Stop()
}
return grpcServer, closer
},
setupClient: func(ctx context.Context, server *grpc.Server, listener *bufconn.Listener) pb_testproto.TestServiceClient {
clientConn, _ := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return listener.Dial()
}), grpc.WithInsecure())
return pb_testproto.NewTestServiceClient(clientConn)
},
wantLog: "m=grpcAccessLog,prin=unknownTLSInfo",
},
{
name: "unknown tls info, without time check",
setupServer: func(ctx context.Context, t *testing.T, listener *bufconn.Listener) (*grpc.Server, func()) {
svrCertPem, svrPrivPem, err := genAndSignX509Cert(serverCName, ca, caPriv)
if err != nil {
t.Fatalf("failed to gerenate server cert, err: %v", err)
}

svrCertificate, err := tls.X509KeyPair(svrCertPem, svrPrivPem)
if err != nil {
t.Fatalf("failed to load x509 key pair, err: %v", err)
}

svrTLSConfig := &tls.Config{
Certificates: []tls.Certificate{svrCertificate},
ClientCAs: caCertPool,
}

grpcServer := grpc.NewServer([]grpc.ServerOption{
grpc.Creds(credentials.NewTLS(svrTLSConfig)),
grpc.UnaryInterceptor(AccessLogInterceptor()),
}...)
testService := &grpc_testing.TestPingService{T: t}
pb_testproto.RegisterTestServiceServer(grpcServer, testService)
go func() {
if err := grpcServer.Serve(listener); err != nil {
panic(err)
}
}()
closer := func() {
listener.Close()
grpcServer.Stop()
}
return grpcServer, closer
},
setupClient: func(ctx context.Context, server *grpc.Server, listener *bufconn.Listener) pb_testproto.TestServiceClient {
clientTLConfig := &tls.Config{ServerName: serverCName, RootCAs: caCertPool}
clientConn, _ := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return listener.Dial()
}), grpc.WithTransportCredentials(credentials.NewTLS(clientTLConfig)))
return pb_testproto.NewTestServiceClient(clientConn)
},
wantLog: "m=grpcAccessLog,prin=peerCertificateNotFound",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buffer := new(bytes.Buffer)
log.SetOutput(buffer)
defer log.SetOutput(os.Stderr)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
listener := bufconn.Listen(1024 * 1024)
grpcServer, cleanup := tt.setupServer(ctx, t, listener)
defer cleanup()
client := tt.setupClient(ctx, grpcServer, listener)
_, err := client.Ping(ctx, ping)
if err != nil {
require.NoError(t, err, "no error should occur")
}
actualLog := string(buffer.Bytes())

if !strings.Contains(actualLog, tt.wantLog) {
t.Fatalf("got: %v but want: %v", actualLog, tt.wantLog)
}
})
}
}
1 change: 1 addition & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ func Main(keyP crypki.KeyIDProcessor) {
var grpcServer *grpc.Server
interceptors := []grpc.UnaryServerInterceptor{
recovery.UnaryServerInterceptor(recovery.WithRecoveryHandler(recoveryHandler)),
interceptor.AccessLogInterceptor(),
}
if cfg.ShutdownOnInternalFailure {
criteria := cfg.ShutdownOnInternalFailureCriteria
Expand Down

0 comments on commit 43f1689

Please sign in to comment.