Skip to content

Commit

Permalink
fix: rwlock around cacheDerived
Browse files Browse the repository at this point in the history
  • Loading branch information
gak committed Aug 21, 2024
1 parent 153f674 commit 8a641c0
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions internal/encryption/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"fmt"
"strings"
"sync"

"github.com/alecthomas/types/optional"
awsv1kms "github.com/aws/aws-sdk-go/service/kms"
Expand Down Expand Up @@ -99,6 +100,7 @@ type KMSEncryptor struct {
root keyset.Handle
kekAEAD tink.AEAD
encryptedKeyset []byte
cachedDerivedMu sync.RWMutex
cachedDerived map[SubKey]tink.AEAD
}

Expand Down Expand Up @@ -206,7 +208,10 @@ func deriveKeyset(root keyset.Handle, salt []byte) (*keyset.Handle, error) {
}

func (k *KMSEncryptor) getDerivedPrimitive(subKey SubKey) (tink.AEAD, error) {
if primitive, ok := k.cachedDerived[subKey]; ok {
k.cachedDerivedMu.RLock()
primitive, ok := k.cachedDerived[subKey]
k.cachedDerivedMu.RUnlock()
if ok {
return primitive, nil
}

Expand All @@ -215,12 +220,15 @@ func (k *KMSEncryptor) getDerivedPrimitive(subKey SubKey) (tink.AEAD, error) {
return nil, fmt.Errorf("failed to derive keyset: %w", err)
}

primitive, err := aead.New(derived)
primitive, err = aead.New(derived)
if err != nil {
return nil, fmt.Errorf("failed to create primitive: %w", err)
}

k.cachedDerivedMu.Lock()
k.cachedDerived[subKey] = primitive
k.cachedDerivedMu.Unlock()

return primitive, nil
}

Expand Down

0 comments on commit 8a641c0

Please sign in to comment.