Skip to content

Commit

Permalink
Fix key_store tests
Browse files Browse the repository at this point in the history
Add key_store support to signer and verifier
  • Loading branch information
Captain-ALM committed Jun 9, 2024
1 parent 32cfa7a commit 6fbc9e3
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 42 deletions.
9 changes: 7 additions & 2 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,25 @@ type Signer interface {
Verifier
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
SignJwt(claims jwt.Claims) (string, error)
GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error)
SignJwtWithKID(claims jwt.Claims, kID string) (string, error)
Issuer() string
PrivateKey() *rsa.PrivateKey
PrivateKeyOf(kID string) *rsa.PrivateKey
}

// Verifier is used to verify the validity MJWT tokens and extract the claim values.
type Verifier interface {
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
PublicKey() *rsa.PublicKey
PublicKeyOf(kID string) *rsa.PublicKey
GetKeyStore() KeyStore
}

// KeyStore is used for the kid header support in Signer and Verifier.
type KeyStore interface {
SetKey(kID string, prvKey *rsa.PrivateKey) bool
SetKeyPublic(kID string, pubKey *rsa.PublicKey) bool
SetKey(kID string, prvKey *rsa.PrivateKey)
SetKeyPublic(kID string, pubKey *rsa.PublicKey)
RemoveKey(kID string)
ListKeys() []string
GetKey(kID string) *rsa.PrivateKey
Expand Down
16 changes: 8 additions & 8 deletions key_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func NewMJwtKeyStore() KeyStore {

// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private
// rsa keys; the kID is the filename of the key up to the first .
func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt string) (KeyStore, error) {
func NewMJwtKeyStoreFromDirectory(directory, keyPrvExt, keyPubExt string) (KeyStore, error) {
// Create empty KeyStore
ks := NewMJwtKeyStore().(*defaultMJwtKeyStore)
// List directory contents
Expand Down Expand Up @@ -78,7 +78,7 @@ func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt

// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified
// extensions for public and private keys
func ExportKeyStore(ks KeyStore, directory string, keyPrvExt string, keyPubExt string) error {
func ExportKeyStore(ks KeyStore, directory, keyPrvExt, keyPubExt string) error {
if ks == nil {
return errors.New("ks is nil")
}
Expand Down Expand Up @@ -110,21 +110,21 @@ func ExportKeyStore(ks KeyStore, directory string, keyPrvExt string, keyPubExt s
}

// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) bool {
func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) {
if d == nil || prvKey == nil {
return false
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
d.store[kID] = prvKey
d.storePub[kID] = &prvKey.PublicKey
return true
return
}

// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) bool {
func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) {
if d == nil || pubKey == nil {
return false
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
Expand All @@ -133,7 +133,7 @@ func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) bo
d.store[kID] = nil
}
d.storePub[kID] = pubKey
return true
return
}

// RemoveKey removes a specified kID from the KeyStore.
Expand Down
20 changes: 9 additions & 11 deletions key_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
const prvExt = "prv"
const pubExt = "pub"

func setupTestDir(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
assert.NoError(t, err)

Expand Down Expand Up @@ -43,7 +43,7 @@ func setupTestDir(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
}
}

func commonSubTests(t *testing.T, kStore KeyStore) {
func commonSubTestsKeyStore(t *testing.T, kStore KeyStore) {
key4, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

Expand All @@ -54,14 +54,12 @@ func commonSubTests(t *testing.T, kStore KeyStore) {
const extraKID2 = "key5"

t.Run("TestSetKey", func(t *testing.T) {
b := kStore.SetKey(extraKID1, key4)
assert.True(t, b)
kStore.SetKey(extraKID1, key4)
assert.Contains(t, kStore.ListKeys(), extraKID1)
})

t.Run("TestSetKeyPublic", func(t *testing.T) {
b := kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
assert.True(t, b)
kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
assert.Contains(t, kStore.ListKeys(), extraKID2)
})

Expand Down Expand Up @@ -109,7 +107,7 @@ func commonSubTests(t *testing.T, kStore KeyStore) {
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
t.Parallel()

tempDir, cleaner := setupTestDir(t, true)
tempDir, cleaner := setupTestDirKeyStore(t, true)
defer cleaner(t)

kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt)
Expand All @@ -121,15 +119,15 @@ func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
assert.Contains(t, kStore.ListKeys(), k)
}

commonSubTests(t, kStore)
commonSubTestsKeyStore(t, kStore)
}

func TestExportKeyStore(t *testing.T) {
t.Parallel()

tempDir, cleaner := setupTestDir(t, true)
tempDir, cleaner := setupTestDirKeyStore(t, true)
defer cleaner(t)
tempDir2, cleaner2 := setupTestDir(t, false)
tempDir2, cleaner2 := setupTestDirKeyStore(t, false)
defer cleaner2(t)

kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt)
Expand All @@ -150,5 +148,5 @@ func TestExportKeyStore(t *testing.T) {
assert.Contains(t, kStore2.ListKeys(), k)
}

commonSubTests(t, kStore2)
commonSubTestsKeyStore(t, kStore2)
}
122 changes: 111 additions & 11 deletions signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mjwt
import (
"bytes"
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/golang-jwt/jwt/v4"
"io"
Expand All @@ -23,10 +24,16 @@ var _ Verifier = &defaultMJwtSigner{}

// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
return NewMJwtSignerWithKeyStore(issuer, key, NewMJwtKeyStore())
}

// NewMJwtSignerWithKeyStore creates a new defaultMJwtSigner using the issuer name, a rsa.PrivateKey
// for no kID and a KeyStore for kID based keys
func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeyStore) Signer {
return &defaultMJwtSigner{
issuer: issuer,
key: key,
verify: newMJwtVerifier(&key.PublicKey),
verify: NewMjwtVerifierWithKeyStore(&key.PublicKey, kStore).(*defaultMJwtVerifier),
}
}

Expand All @@ -44,38 +51,131 @@ func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits i
// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a
// rsa.PrivateKey file.
func NewMJwtSignerFromFile(issuer, file string) (Signer, error) {
return NewMJwtSignerFromFileAndDirectory(issuer, file, "", "", "")
}

// NewMJwtSignerFromDirectory creates a new defaultMJwtSigner using the path of a directory to
// load the keys into a KeyStore; there is no default rsa.PrivateKey
func NewMJwtSignerFromDirectory(issuer, directory, prvExt, pubExt string) (Signer, error) {
return NewMJwtSignerFromFileAndDirectory(issuer, "", directory, prvExt, pubExt)
}

// NewMJwtSignerFromFileAndDirectory creates a new defaultMJwtSigner using the path of a rsa.PrivateKey
// file as the non kID key and the path of a directory to load the keys into a KeyStore
func NewMJwtSignerFromFileAndDirectory(issuer, file, directory, prvExt, pubExt string) (Signer, error) {
var err error

// read key
key, err := rsaprivate.Read(file)
if err != nil {
return nil, err
var prv *rsa.PrivateKey = nil
if file != "" {
prv, err = rsaprivate.Read(file)
if err != nil {
return nil, err
}
}

// read KeyStore
var kStore KeyStore = nil
if directory != "" {
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
if err != nil {
return nil, err
}
}

// create signer using rsa.PrivateKey
return NewMJwtSigner(issuer, key), nil
return NewMJwtSignerWithKeyStore(issuer, prv, kStore), nil
}

// Issuer returns the name of the issuer
func (d *defaultMJwtSigner) Issuer() string { return d.issuer }
func (d *defaultMJwtSigner) Issuer() string {
if d == nil {
return ""
}
return d.issuer
}

// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
}

// SignJwt signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs
// GenerateJwt but is available for signing custom structs; uses the default key
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
return token.SignedString(d.key)
}

// GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID
func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID)
}

// SignJwtWithKID signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs; this gets signed with the specified kID
func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
pKey := d.verify.GetKeyStore().GetKey(kID)
if pKey == nil {
return "", errors.New("no private key found")
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
token.Header["kid"] = kID
return token.SignedString(pKey)
}

// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
if d == nil {
return nil, errors.New("signer nil")
}
return d.verify.VerifyJwt(token, claims)
}

func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey { return d.key }
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey { return d.verify.pub }
func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey {
if d == nil {
return nil
}
return d.key
}
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey {
if d == nil {
return nil
}
return d.verify.pub
}

func (d *defaultMJwtSigner) PublicKeyOf(kID string) *rsa.PublicKey {
if d == nil {
return nil
}
return d.verify.kStore.GetKeyPublic(kID)
}

func (d *defaultMJwtSigner) GetKeyStore() KeyStore {
if d == nil {
return nil
}
return d.verify.GetKeyStore()
}

func (d *defaultMJwtSigner) PrivateKeyOf(kID string) *rsa.PrivateKey {
if d == nil {
return nil
}
return d.verify.kStore.GetKey(kID)
}

// readOrCreatePrivateKey returns the private key it the file already exists,
// generates a new private key and saves it to the file, or returns an error if
Expand Down
Loading

0 comments on commit 6fbc9e3

Please sign in to comment.