From 1b0bdef3e33cd5711572300905ef428ea07e21db Mon Sep 17 00:00:00 2001 From: jsmoon Date: Sat, 15 Jun 2024 22:32:05 +0900 Subject: [PATCH 01/10] Add PKCS#11 Wrapper Signed-off-by: jsmoon --- wrappers/pkcs11/go.mod | 18 + wrappers/pkcs11/go.sum | 21 ++ wrappers/pkcs11/options.go | 140 +++++++ wrappers/pkcs11/pkcs11.go | 568 +++++++++++++++++++++++++++++ wrappers/pkcs11/pkcs11_acc_test.go | 49 +++ 5 files changed, 796 insertions(+) create mode 100644 wrappers/pkcs11/go.mod create mode 100644 wrappers/pkcs11/go.sum create mode 100644 wrappers/pkcs11/options.go create mode 100644 wrappers/pkcs11/pkcs11.go create mode 100644 wrappers/pkcs11/pkcs11_acc_test.go diff --git a/wrappers/pkcs11/go.mod b/wrappers/pkcs11/go.mod new file mode 100644 index 00000000..085a3d04 --- /dev/null +++ b/wrappers/pkcs11/go.mod @@ -0,0 +1,18 @@ +module github.com/openbao/go-kms-wrapping/wrappers/pkcs11/v2 + +go 1.22.1 + +replace github.com/openbao/go-kms-wrapping/v2 => ../../ + +require ( + github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b + github.com/openbao/go-kms-wrapping/v2 v2.0.0-00010101000000-000000000000 +) + +require ( + github.com/hashicorp/go-uuid v1.0.3 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + google.golang.org/protobuf v1.31.0 // indirect +) + +retract [v2.0.0, v2.0.2] diff --git a/wrappers/pkcs11/go.sum b/wrappers/pkcs11/go.sum new file mode 100644 index 00000000..c4a4124d --- /dev/null +++ b/wrappers/pkcs11/go.sum @@ -0,0 +1,21 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= +github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/wrappers/pkcs11/options.go b/wrappers/pkcs11/options.go new file mode 100644 index 00000000..09578f93 --- /dev/null +++ b/wrappers/pkcs11/options.go @@ -0,0 +1,140 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pkcs11 + +import ( + "strconv" + + wrapping "github.com/openbao/go-kms-wrapping/v2" +) + +// getOpts iterates the inbound Options and returns a struct +func getOpts(opt ...wrapping.Option) (*options, error) { + // First, separate out options into local and global + opts := getDefaultOptions() + var wrappingOptions []wrapping.Option + var localOptions []OptionFunc + for _, o := range opt { + if o == nil { + continue + } + iface := o() + switch to := iface.(type) { + case wrapping.OptionFunc: + wrappingOptions = append(wrappingOptions, o) + case OptionFunc: + localOptions = append(localOptions, to) + } + } + + // Parse the global options + var err error + opts.Options, err = wrapping.GetOpts(wrappingOptions...) + if err != nil { + return nil, err + } + + // Don't ever return blank options + if opts.Options == nil { + opts.Options = new(wrapping.Options) + } + + // Local options can be provided either via the WithConfigMap field + // (for over the plugin barrier or embedding) or via local option functions + // (for embedding). First pull from the option. + if opts.WithConfigMap != nil { + for k, v := range opts.WithConfigMap { + switch k { + case "kms_key_id": // deprecated backend-specific value, set global + opts.WithKeyId = v + case "slot": + var err error + var slot uint64 + slot, err = strconv.ParseUint(v, 10, 64) + if err != nil { + return nil, err + } + opts.withSlot = uint(slot) + case "pin": + opts.withPin = v + case "module": + opts.withModule = v + case "label": + opts.withLabel = v + case "mechanism": + opts.withMechanism = v + } + } + } + + // Now run the local options functions. This may overwrite options set by + // the options above. + for _, o := range localOptions { + if o != nil { + if err := o(&opts); err != nil { + return nil, err + } + } + } + + return &opts, nil +} + +// OptionFunc holds a function with local options +type OptionFunc func(*options) error + +// options = how options are represented +type options struct { + *wrapping.Options + + withSlot uint + withPin string + withModule string + withLabel string + withMechanism string +} + +func getDefaultOptions() options { + return options{} +} + +// WithSlot sets the slot +func WithSlot(slot uint) OptionFunc { + return func(o *options) error { + o.withSlot = slot + return nil + } +} + +// WithPin sets the pin +func WithPin(pin string) OptionFunc { + return func(o *options) error { + o.withPin = pin + return nil + } +} + +// WithModule sets the module +func WithModule(module string) OptionFunc { + return func(o *options) error { + o.withModule = module + return nil + } +} + +// WithLabel sets the label +func WithLabel(label string) OptionFunc { + return func(o *options) error { + o.withLabel = label + return nil + } +} + +// WithMechanism sets the mechanism +func WithMechanism(mechanism string) OptionFunc { + return func(o *options) error { + o.withMechanism = mechanism + return nil + } +} diff --git a/wrappers/pkcs11/pkcs11.go b/wrappers/pkcs11/pkcs11.go new file mode 100644 index 00000000..175a4357 --- /dev/null +++ b/wrappers/pkcs11/pkcs11.go @@ -0,0 +1,568 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pkcs11 + +import ( + "context" + "encoding/binary" + "encoding/hex" + "fmt" + "os" + "strconv" + "sync/atomic" + + uuid "github.com/hashicorp/go-uuid" + pkcs11 "github.com/miekg/pkcs11" + wrapping "github.com/openbao/go-kms-wrapping/v2" +) + +// These constants contain the accepted env vars; the Vault one is for backwards compat +const ( + EnvPkcs11WrapperKeyId = "PKCS11_WRAPPER_KEY_ID" + EnvVaultPkcs11SealKeyId = "VAULT_PKCS11_SEAL_KEY_ID" +) + +// Wrapper is a Wrapper that uses PKCS11 +type Wrapper struct { + client *pkcs11KMS + keyId string + currentKeyId *atomic.Value +} + +// Ensure that we are implementing Wrapper +var _ wrapping.Wrapper = (*Wrapper)(nil) + +// NewWrapper creates a new PKCS11 Wrapper +func NewWrapper() *Wrapper { + k := &Wrapper{ + currentKeyId: new(atomic.Value), + } + k.currentKeyId.Store("") + return k +} + +// SetConfig sets the fields on the Pkcs11Wrapper object based on +// values from the config parameter. +// +// Order of precedence Pkcs11 values: +// * Environment variable +// * Value from Vault configuration file +// * Instance metadata role (access key and secret key) +func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrapping.WrapperConfig, error) { + opts, err := getOpts(opt...) + if err != nil { + return nil, err + } + + // Check and set KeyId + switch { + case os.Getenv(EnvPkcs11WrapperKeyId) != "" && !opts.Options.WithDisallowEnvVars: + k.keyId = os.Getenv(EnvPkcs11WrapperKeyId) + case os.Getenv(EnvVaultPkcs11SealKeyId) != "" && !opts.Options.WithDisallowEnvVars: + k.keyId = os.Getenv(EnvVaultPkcs11SealKeyId) + case opts.WithKeyId != "": + k.keyId = opts.WithKeyId + default: + return nil, fmt.Errorf("key id not found (env or config) for pkcs11 wrapper configuration") + } + + // Set and check k.client + if k.client == nil { + k.client = &pkcs11KMS{} + + if !opts.Options.WithDisallowEnvVars && os.Getenv("PKCS11_SLOT") != "" { + var err error + var slot uint64 + slot, err = strconv.ParseUint(os.Getenv("PKCS11_SLOT"), 10, 64) + if err != nil { + return nil, err + } + opts.withSlot = uint(slot) + } + if k.client.slot == 0 { + k.client.slot = opts.withSlot + } + + if !opts.Options.WithDisallowEnvVars { + k.client.pin = os.Getenv("PKCS11_PIN") + } + if k.client.pin == "" { + k.client.pin = opts.withPin + } + + if !opts.Options.WithDisallowEnvVars { + k.client.module = os.Getenv("PKCS11_MODULE") + } + if k.client.module == "" { + k.client.module = opts.withModule + } + + if !opts.Options.WithDisallowEnvVars { + k.client.label = os.Getenv("PKCS11_LABEL") + } + if k.client.label == "" { + k.client.label = opts.withLabel + } + + if !opts.Options.WithDisallowEnvVars { + mechanismName := os.Getenv("PKCS11_MECHANISM") + if mechanismName != "" { + k.client.mechanism, err = MechanisFromString(mechanismName) + if err != nil { + return nil, err + } + } + } + if k.client.mechanism == 0 { + if opts.withMechanism != "" { + k.client.mechanism, err = MechanisFromString(opts.withMechanism) + if err != nil { + return nil, err + } + } + } + + k.client.keyId = k.keyId + + p := pkcs11.New(k.client.module) + err := p.Initialize() + if err != nil { + return nil, fmt.Errorf("failed to initialize PKCS11: %w", err) + } + defer p.Destroy() + defer p.Finalize() + + session, err := p.OpenSession(k.client.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) + if err != nil { + return nil, fmt.Errorf("failed to open session: %w", err) + } + defer p.CloseSession(session) + + err = p.Login(session, pkcs11.CKU_USER, k.client.pin) + if err != nil { + return nil, fmt.Errorf("failed to login: %w", err) + } + defer p.Logout(session) + + } + // Store the current key id. If using a key alias, this will point to the actual + // unique key that that was used for this encrypt operation. + k.currentKeyId.Store(k.keyId) + + // Map that holds non-sensitive configuration info + wrapConfig := new(wrapping.WrapperConfig) + wrapConfig.Metadata = make(map[string]string) + wrapConfig.Metadata["kms_key_id"] = k.keyId + wrapConfig.Metadata["slot"] = strconv.Itoa(int(k.client.slot)) + if k.client.label != "" { + wrapConfig.Metadata["label"] = k.client.label + } + if k.client.mechanism != 0 { + wrapConfig.Metadata["mechanism"] = MechanisString(k.client.mechanism) + } + + return wrapConfig, nil +} + +// Type returns the type for this particular wrapper implementation +func (k *Wrapper) Type(_ context.Context) (wrapping.WrapperType, error) { + return wrapping.WrapperTypePkcs11, nil +} + +// KeyId returns the last known key id +func (k *Wrapper) KeyId(_ context.Context) (string, error) { + return k.currentKeyId.Load().(string), nil +} + +// Encrypt is used to encrypt the master key using the the PKCS11. +// This returns the ciphertext, and/or any errors from this +// call. This should be called after the KMS client has been instantiated. +func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.Option) (*wrapping.BlobInfo, error) { + if plaintext == nil { + return nil, fmt.Errorf("given plaintext for encryption is nil") + } + + env, err := wrapping.EnvelopeEncrypt(plaintext, opt...) + if err != nil { + return nil, fmt.Errorf("error wrapping data: %w", err) + } + + WrappedKey, err := k.client.EncryptDEK(context.Background(), env.Key) + if err != nil { + return nil, fmt.Errorf("error encrypting data: %w", err) + } + + // Store the current key id. + k.currentKeyId.Store(k.keyId) + + ret := &wrapping.BlobInfo{ + Ciphertext: env.Ciphertext, + Iv: env.Iv, + KeyInfo: &wrapping.KeyInfo{ + KeyId: k.keyId, + WrappedKey: WrappedKey, + }, + } + + return ret, nil +} + +// Decrypt is used to decrypt the ciphertext. This should be called after Init. +func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapping.Option) ([]byte, error) { + if in == nil { + return nil, fmt.Errorf("given input for decryption is nil") + } + + keyBytes, err := k.client.DecryptDEK(context.Background(), in.KeyInfo.WrappedKey) + if err != nil { + return nil, fmt.Errorf("error decrypting data encryption key: %w", err) + } + + envInfo := &wrapping.EnvelopeInfo{ + Key: keyBytes, + Iv: in.Iv, + Ciphertext: in.Ciphertext, + } + plaintext, err := wrapping.EnvelopeDecrypt(envInfo, opt...) + if err != nil { + return nil, fmt.Errorf("error decrypting data: %w", err) + } + + return plaintext, nil +} + +func GetKeyTypeFromMech(mech uint) (uint, error) { + switch mech { + case pkcs11.CKM_RSA_PKCS: + return pkcs11.CKK_RSA, nil + case pkcs11.CKM_AES_CBC_PAD: + return pkcs11.CKK_AES, nil + default: + return 0, fmt.Errorf("unsupported mechanism: %d", mech) + } +} + +func MechanisString(mech uint) string { + switch mech { + case pkcs11.CKM_RSA_PKCS: + return "CKM_RSA_PKCS" + case pkcs11.CKM_AES_CBC_PAD: + return "CKM_AES_CBC_PAD" + default: + return "Unknown" + } +} + +func IsIvNeeded(mech uint) bool { + switch mech { + case pkcs11.CKM_AES_CBC_PAD: + return true + default: + return false + } +} + +func MechanisFromString(mech string) (uint, error) { + switch mech { + case "CKM_RSA_PKCS": + return pkcs11.CKM_RSA_PKCS, nil + case "CKM_AES_CBC_PAD": + return pkcs11.CKM_AES_CBC_PAD, nil + default: + return 0, fmt.Errorf("unsupported mechanism: %s", mech) + } +} + +type pkcs11KMS struct { + // standard PKCS11 configuration options + slot uint + pin string + module string + keyId string + label string + mechanism uint +} + +// EncryptDEK uses the PKCS11 encrypt operation to encrypt the DEK. +func (kms *pkcs11KMS) EncryptDEK(ctx context.Context, plainDEK []byte) ([]byte, error) { + p := pkcs11.New(kms.module) + err := p.Initialize() + if err != nil { + return nil, fmt.Errorf("failed to initialize PKCS11: %w", err) + } + + defer p.Destroy() + defer p.Finalize() + + session, err := p.OpenSession(kms.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) + if err != nil { + return nil, fmt.Errorf("failed to open session: %w", err) + } + defer p.CloseSession(session) + + err = p.Login(session, pkcs11.CKU_USER, kms.pin) + if err != nil { + return nil, fmt.Errorf("failed to login: %w", err) + } + defer p.Logout(session) + + keyIdBytes, err := hex.DecodeString(kms.keyId) + if err != nil { + return nil, fmt.Errorf("failed to decode key id: %w", err) + } + template := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_ID, keyIdBytes), + pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, true), + } + if kms.label != "" { + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_LABEL, kms.label)) + } + if kms.mechanism != 0 { + keyTypeString, err := GetKeyTypeFromMech(kms.mechanism) + if err != nil { + return nil, fmt.Errorf("failed to get key type from mechanism: %s", err) + } + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, keyTypeString)) + } + if err := p.FindObjectsInit(session, template); err != nil { + return nil, fmt.Errorf("failed to pkcs11 FindObjectsInit: %s", err) + } + obj, _, err := p.FindObjects(session, 2) + if err != nil { + return nil, fmt.Errorf("failed to pkcs11 FindObjects: %s", err) + } + if err := p.FindObjectsFinal(session); err != nil { + return nil, fmt.Errorf("failed to pkcs11 FindObjectsFinal: %s", err) + } + + if len(obj) != 1 { + return nil, fmt.Errorf("expected 1 object, got %d", len(obj)) + } + key := obj[0] + + template = []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil), + } + attr, err := p.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) + if err != nil { + return nil, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) + } + + attrMap := GetAttributesMap(attr) + keyType := GetValueAsInt(attrMap[pkcs11.CKA_KEY_TYPE]) + + mechanism := uint(0) + switch keyType { + case pkcs11.CKK_AES: + if kms.mechanism != 0 { + mechanism = kms.mechanism + } else { + mechanism = pkcs11.CKM_AES_CBC_PAD + } + case pkcs11.CKK_RSA: + if kms.mechanism != 0 { + mechanism = kms.mechanism + } else { + mechanism = pkcs11.CKM_RSA_PKCS + } + default: + return nil, fmt.Errorf("unsupported key type: %d", keyType) + } + + var iv []byte + if IsIvNeeded(mechanism) { + template = []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, nil), + } + attr, err := p.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) + if err != nil { + return nil, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) + } + attrMap := GetAttributesMap(attr) + + ivLength := 0 + ivLength = int(GetValueAsInt(attrMap[pkcs11.CKA_VALUE_LEN])) + + iv, err = uuid.GenerateRandomBytes(ivLength) + if err != nil { + return nil, err + } + } + + if err = p.EncryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { + return nil, fmt.Errorf("failed to pkcs11 EncryptInit: %s", err) + } + var ciphertext []byte + if ciphertext, err = p.Encrypt(session, plainDEK); err != nil { + return nil, fmt.Errorf("failed to pkcs11 Encrypt: %s", err) + } + + if iv != nil { + return append(iv, ciphertext...), nil + } else { + return ciphertext, nil + } +} + +// DecryptDEK uses the PKCS11 decrypt operation to decrypt the DEK. +func (kms *pkcs11KMS) DecryptDEK(ctx context.Context, encryptedDEK []byte) ([]byte, error) { + p := pkcs11.New(kms.module) + err := p.Initialize() + if err != nil { + return nil, fmt.Errorf("failed to initialize PKCS11: %w", err) + } + + defer p.Destroy() + defer p.Finalize() + + session, err := p.OpenSession(kms.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) + if err != nil { + return nil, fmt.Errorf("failed to open session: %w", err) + } + defer p.CloseSession(session) + + err = p.Login(session, pkcs11.CKU_USER, kms.pin) + if err != nil { + return nil, fmt.Errorf("failed to login: %w", err) + } + defer p.Logout(session) + + keyIdBytes, err := hex.DecodeString(kms.keyId) + if err != nil { + return nil, fmt.Errorf("failed to decode key id: %w", err) + } + template := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_ID, keyIdBytes), + pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), + } + if kms.label != "" { + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(kms.label))) + } + if kms.mechanism != 0 { + keyTypeString, err := GetKeyTypeFromMech(kms.mechanism) + if err != nil { + return nil, fmt.Errorf("failed to get key type from mechanism: %s", err) + } + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, keyTypeString)) + } + if err := p.FindObjectsInit(session, template); err != nil { + return nil, fmt.Errorf("failed to pkcs11 FindObjectsInit: %s", err) + } + + obj, _, err := p.FindObjects(session, 2) + if err != nil { + return nil, fmt.Errorf("failed to pkcs11 FindObjects: %s", err) + } + if err := p.FindObjectsFinal(session); err != nil { + return nil, fmt.Errorf("failed to pkcs11 FindObjectsFinal: %s", err) + } + + if len(obj) != 1 { + return nil, fmt.Errorf("expected 1 object, got %d", len(obj)) + } + key := obj[0] + + template = []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil), + } + attr, err := p.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) + if err != nil { + return nil, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) + } + + attrMap := GetAttributesMap(attr) + keyType := GetValueAsInt(attrMap[pkcs11.CKA_KEY_TYPE]) + + mechanism := uint(0) + switch keyType { + case pkcs11.CKK_AES: + if kms.mechanism != 0 { + mechanism = kms.mechanism + } else { + mechanism = pkcs11.CKM_AES_CBC_PAD + } + case pkcs11.CKK_RSA: + if kms.mechanism != 0 { + mechanism = kms.mechanism + } else { + mechanism = pkcs11.CKM_RSA_PKCS + } + default: + return nil, fmt.Errorf("unsupported key type: %d", keyType) + } + + var iv []byte + if IsIvNeeded(mechanism) { + template = []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, nil), + } + attr, err := p.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) + if err != nil { + return nil, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) + } + attrMap := GetAttributesMap(attr) + + ivLength := 0 + ivLength = int(GetValueAsInt(attrMap[pkcs11.CKA_VALUE_LEN])) + + if len(encryptedDEK) < ivLength { + return nil, fmt.Errorf("encrypted DEK is too short") + } + + iv = encryptedDEK[:ivLength] + encryptedDEK = encryptedDEK[ivLength:] + } + + if err = p.DecryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { + return nil, fmt.Errorf("failed to pkcs11 DecryptInit: %s", err) + } + + var decrypted []byte + if decrypted, err = p.Decrypt(session, encryptedDEK); err != nil { + return nil, fmt.Errorf("failed to pkcs11 Decrypt: %s", err) + } + return decrypted, nil +} + +func GetAttributesMap(attrs []*pkcs11.Attribute) map[uint][]byte { + m := make(map[uint][]byte, len(attrs)) + for _, a := range attrs { + m[a.Type] = a.Value + } + return m +} + +func GetValueAsInt(value []byte) int64 { + if value == nil { + return 0 + } + switch len(value) { + case 1: + return int64(value[0]) + case 2: + return int64(binary.NativeEndian.Uint16(value)) + case 4: + return int64(binary.NativeEndian.Uint32(value)) + case 8: + return int64(binary.NativeEndian.Uint64(value)) + } + return 0 +} + +func GetValueAsUint(value []byte) uint64 { + if value == nil { + return 0 + } + switch len(value) { + case 1: + return uint64(value[0]) + case 2: + return uint64(binary.NativeEndian.Uint16(value)) + case 4: + return uint64(binary.NativeEndian.Uint32(value)) + case 8: + return uint64(binary.NativeEndian.Uint64(value)) + } + return 0 +} diff --git a/wrappers/pkcs11/pkcs11_acc_test.go b/wrappers/pkcs11/pkcs11_acc_test.go new file mode 100644 index 00000000..c9335678 --- /dev/null +++ b/wrappers/pkcs11/pkcs11_acc_test.go @@ -0,0 +1,49 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pkcs11 + +import ( + "context" + "os" + "reflect" + "testing" +) + +// This test executes real calls. The calls themselves should be free, +// but the KMS key used is generally not free. +// +// To run this test, the following env variables need to be set: +// - VAULT_PKCS11_SEAL_KEY_ID or PKCS11_WRAPPING_KEY_ID +// - PKCS11_WRAPPER_KEY_ID +// - PKCS11_SLOT +// - PKCS11_PIN +// - PKCS11_MODULE +// - PKCS11_LABEL +// - PKCS11_MECHANISM +func TestAccPkcs11Wrapper_Lifecycle(t *testing.T) { + if os.Getenv("VAULT_ACC") == "" && os.Getenv("KMS_ACC_TESTS") == "" { + t.SkipNow() + } + + s := NewWrapper() + _, err := s.SetConfig(context.Background()) + if err != nil { + t.Fatalf("err : %s", err) + } + + input := []byte("foo") + swi, err := s.Encrypt(context.Background(), input) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + pt, err := s.Decrypt(context.Background(), swi) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + if !reflect.DeepEqual(input, pt) { + t.Fatalf("expected %s, got %s", input, pt) + } +} From 1f56ebce8f28586f2ccf85a59f85bc040a170db3 Mon Sep 17 00:00:00 2001 From: jsmoon Date: Sat, 22 Jun 2024 01:38:34 +0900 Subject: [PATCH 02/10] Remove unnecessary conditional statements Signed-off-by: jsmoon --- wrappers/pkcs11/pkcs11.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/wrappers/pkcs11/pkcs11.go b/wrappers/pkcs11/pkcs11.go index 175a4357..0edd068e 100644 --- a/wrappers/pkcs11/pkcs11.go +++ b/wrappers/pkcs11/pkcs11.go @@ -534,9 +534,6 @@ func GetAttributesMap(attrs []*pkcs11.Attribute) map[uint][]byte { } func GetValueAsInt(value []byte) int64 { - if value == nil { - return 0 - } switch len(value) { case 1: return int64(value[0]) @@ -551,9 +548,6 @@ func GetValueAsInt(value []byte) int64 { } func GetValueAsUint(value []byte) uint64 { - if value == nil { - return 0 - } switch len(value) { case 1: return uint64(value[0]) From 61ab7a1e20bcf5428113895085fc94628255a5ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20GLATIGNY?= <14180893+glatigny@users.noreply.github.com> Date: Thu, 4 Jul 2024 12:15:59 +0200 Subject: [PATCH 03/10] [fix] change IsIvNeeded to return also the IV length, if fixed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Some parts of the code are commented for future improvements. In case of algorithm which would requires more inputs to know the IV length, the IsIvNeeded can return "0". Clearly I'm not very fan of it, but I don't know if giving access to the PKCS#11 session is better. Signed-off-by: Jérôme GLATIGNY <14180893+glatigny@users.noreply.github.com> --- wrappers/pkcs11/pkcs11.go | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/wrappers/pkcs11/pkcs11.go b/wrappers/pkcs11/pkcs11.go index 0edd068e..be14a3a4 100644 --- a/wrappers/pkcs11/pkcs11.go +++ b/wrappers/pkcs11/pkcs11.go @@ -78,7 +78,7 @@ func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrappin if err != nil { return nil, err } - opts.withSlot = uint(slot) + k.client.slot = uint(slot) } if k.client.slot == 0 { k.client.slot = opts.withSlot @@ -254,12 +254,12 @@ func MechanisString(mech uint) string { } } -func IsIvNeeded(mech uint) bool { +func IsIvNeeded(mech uint) (bool, int) { switch mech { case pkcs11.CKM_AES_CBC_PAD: - return true + return true, 16 default: - return false + return false, 0 } } @@ -371,7 +371,9 @@ func (kms *pkcs11KMS) EncryptDEK(ctx context.Context, plainDEK []byte) ([]byte, } var iv []byte - if IsIvNeeded(mechanism) { + needIV, ivLength := IsIvNeeded(mechanism) + /* + if needIV && ivLength == 0 { template = []*pkcs11.Attribute{ pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, nil), } @@ -381,9 +383,11 @@ func (kms *pkcs11KMS) EncryptDEK(ctx context.Context, plainDEK []byte) ([]byte, } attrMap := GetAttributesMap(attr) - ivLength := 0 - ivLength = int(GetValueAsInt(attrMap[pkcs11.CKA_VALUE_LEN])) - + keyLength = int(GetValueAsInt(attrMap[pkcs11.CKA_VALUE_LEN])) + ivLength = GetIvSize(mechanism, keyLength) + } + */ + if needIV && ivLength > 0 { iv, err = uuid.GenerateRandomBytes(ivLength) if err != nil { return nil, err @@ -493,7 +497,9 @@ func (kms *pkcs11KMS) DecryptDEK(ctx context.Context, encryptedDEK []byte) ([]by } var iv []byte - if IsIvNeeded(mechanism) { + needIV, ivLength := IsIvNeeded(mechanism) + /* + if needIV && ivLength == 0 { template = []*pkcs11.Attribute{ pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, nil), } @@ -503,9 +509,11 @@ func (kms *pkcs11KMS) DecryptDEK(ctx context.Context, encryptedDEK []byte) ([]by } attrMap := GetAttributesMap(attr) - ivLength := 0 - ivLength = int(GetValueAsInt(attrMap[pkcs11.CKA_VALUE_LEN])) - + keyLength = int(GetValueAsInt(attrMap[pkcs11.CKA_VALUE_LEN])) + ivLength = GetIvSize(mechanism, keyLength) + } + */ + if needIV && ivLength > 0 { if len(encryptedDEK) < ivLength { return nil, fmt.Errorf("encrypted DEK is too short") } From 315b9462ceb788e6b4a3a7cc6b2098d800ef41dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20GLATIGNY?= <14180893+glatigny@users.noreply.github.com> Date: Thu, 4 Jul 2024 12:15:59 +0200 Subject: [PATCH 04/10] [feat] Update PKCS#11 options to be more compatible with Vault. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme GLATIGNY <14180893+glatigny@users.noreply.github.com> --- wrappers/pkcs11/options.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/wrappers/pkcs11/options.go b/wrappers/pkcs11/options.go index 09578f93..3abc7580 100644 --- a/wrappers/pkcs11/options.go +++ b/wrappers/pkcs11/options.go @@ -58,8 +58,10 @@ func getOpts(opt ...wrapping.Option) (*options, error) { opts.withSlot = uint(slot) case "pin": opts.withPin = v + case "lib": case "module": opts.withModule = v + case "key_label": case "label": opts.withLabel = v case "mechanism": From c4560b364a22d749f5546ce139e76b5ae606e04d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20GLATIGNY?= <14180893+glatigny@users.noreply.github.com> Date: Fri, 8 Nov 2024 14:20:58 +0100 Subject: [PATCH 05/10] [refactor] first part of refactoring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Moving PKCS11 client into the struct * Creating the client during configuration loading * Support of Finalize to destroy the client * Adding tools functions Signed-off-by: Jérôme GLATIGNY <14180893+glatigny@users.noreply.github.com> --- wrappers/pkcs11/pkcs11.go | 302 +++++++++++++++++--------------------- 1 file changed, 137 insertions(+), 165 deletions(-) diff --git a/wrappers/pkcs11/pkcs11.go b/wrappers/pkcs11/pkcs11.go index be14a3a4..b07cb9ac 100644 --- a/wrappers/pkcs11/pkcs11.go +++ b/wrappers/pkcs11/pkcs11.go @@ -42,6 +42,17 @@ func NewWrapper() *Wrapper { return k } +// Init is called during core.Initialize +func (s *Wrapper) Init(_ context.Context) error { + return nil +} + +// Finalize is called during shutdown +func (s *Wrapper) Finalize(_ context.Context) error { + s.client.DestroyClient() + return nil +} + // SetConfig sets the fields on the Pkcs11Wrapper object based on // values from the config parameter. // @@ -108,7 +119,7 @@ func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrappin if !opts.Options.WithDisallowEnvVars { mechanismName := os.Getenv("PKCS11_MECHANISM") if mechanismName != "" { - k.client.mechanism, err = MechanisFromString(mechanismName) + k.client.mechanism, err = MechanismFromString(mechanismName) if err != nil { return nil, err } @@ -116,7 +127,7 @@ func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrappin } if k.client.mechanism == 0 { if opts.withMechanism != "" { - k.client.mechanism, err = MechanisFromString(opts.withMechanism) + k.client.mechanism, err = MechanismFromString(opts.withMechanism) if err != nil { return nil, err } @@ -125,26 +136,17 @@ func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrappin k.client.keyId = k.keyId - p := pkcs11.New(k.client.module) - err := p.Initialize() + // Initialize the client + _, err = k.client.GetClient() if err != nil { - return nil, fmt.Errorf("failed to initialize PKCS11: %w", err) + return nil, err } - defer p.Destroy() - defer p.Finalize() - - session, err := p.OpenSession(k.client.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) + // Validate credentials for session establishment + session, err := k.client.GetSession() if err != nil { - return nil, fmt.Errorf("failed to open session: %w", err) - } - defer p.CloseSession(session) - - err = p.Login(session, pkcs11.CKU_USER, k.client.pin) - if err != nil { - return nil, fmt.Errorf("failed to login: %w", err) + return nil, err } - defer p.Logout(session) - + defer k.client.CloseSession(session) } // Store the current key id. If using a key alias, this will point to the actual // unique key that that was used for this encrypt operation. @@ -153,7 +155,7 @@ func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrappin // Map that holds non-sensitive configuration info wrapConfig := new(wrapping.WrapperConfig) wrapConfig.Metadata = make(map[string]string) - wrapConfig.Metadata["kms_key_id"] = k.keyId + wrapConfig.Metadata["key_id"] = k.keyId wrapConfig.Metadata["slot"] = strconv.Itoa(int(k.client.slot)) if k.client.label != "" { wrapConfig.Metadata["label"] = k.client.label @@ -175,12 +177,12 @@ func (k *Wrapper) KeyId(_ context.Context) (string, error) { return k.currentKeyId.Load().(string), nil } -// Encrypt is used to encrypt the master key using the the PKCS11. +// Encrypt is used to encrypt data using the the PKCS11 key. // This returns the ciphertext, and/or any errors from this // call. This should be called after the KMS client has been instantiated. func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.Option) (*wrapping.BlobInfo, error) { - if plaintext == nil { - return nil, fmt.Errorf("given plaintext for encryption is nil") + if len(plaintext) == 0 { + return nil, fmt.Errorf("given plaintext for encryption is empty") } env, err := wrapping.EnvelopeEncrypt(plaintext, opt...) @@ -188,7 +190,7 @@ func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.O return nil, fmt.Errorf("error wrapping data: %w", err) } - WrappedKey, err := k.client.EncryptDEK(context.Background(), env.Key) + WrappedKey, err := k.client.Encrypt(context.Background(), env.Key) if err != nil { return nil, fmt.Errorf("error encrypting data: %w", err) } @@ -214,7 +216,7 @@ func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapp return nil, fmt.Errorf("given input for decryption is nil") } - keyBytes, err := k.client.DecryptDEK(context.Background(), in.KeyInfo.WrappedKey) + keyBytes, err := k.client.Decrypt(context.Background(), in.KeyInfo.WrappedKey) if err != nil { return nil, fmt.Errorf("error decrypting data encryption key: %w", err) } @@ -234,8 +236,12 @@ func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapp func GetKeyTypeFromMech(mech uint) (uint, error) { switch mech { + case pkcs11.CKM_RSA_PKCS_OAEP: + return pkcs11.CKK_RSA, nil case pkcs11.CKM_RSA_PKCS: return pkcs11.CKK_RSA, nil + case pkcs11.CKM_AES_GCM: + return pkcs11.CKK_AES, nil case pkcs11.CKM_AES_CBC_PAD: return pkcs11.CKK_AES, nil default: @@ -245,8 +251,12 @@ func GetKeyTypeFromMech(mech uint) (uint, error) { func MechanisString(mech uint) string { switch mech { + case pkcs11.CKM_RSA_PKCS_OAEP: + return "CKM_RSA_PKCS_OAEP" case pkcs11.CKM_RSA_PKCS: return "CKM_RSA_PKCS" + case pkcs11.CKM_AES_GCM: + return "CKM_AES_GCM" case pkcs11.CKM_AES_CBC_PAD: return "CKM_AES_CBC_PAD" default: @@ -256,6 +266,8 @@ func MechanisString(mech uint) string { func IsIvNeeded(mech uint) (bool, int) { switch mech { + case pkcs11.CKM_AES_GCM: + return true, 16 case pkcs11.CKM_AES_CBC_PAD: return true, 16 default: @@ -263,10 +275,14 @@ func IsIvNeeded(mech uint) (bool, int) { } } -func MechanisFromString(mech string) (uint, error) { +func MechanismFromString(mech string) (uint, error) { switch mech { + case "CKM_RSA_PKCS_OAEP": + return pkcs11.CKM_RSA_PKCS_OAEP, nil case "CKM_RSA_PKCS": return pkcs11.CKM_RSA_PKCS, nil + case "CKM_AES_GCM": + return pkcs11.CKM_AES_GCM, nil case "CKM_AES_CBC_PAD": return pkcs11.CKM_AES_CBC_PAD, nil default: @@ -275,6 +291,7 @@ func MechanisFromString(mech string) (uint, error) { } type pkcs11KMS struct { + client *pkcs11.Ctx // standard PKCS11 configuration options slot uint pin string @@ -284,39 +301,63 @@ type pkcs11KMS struct { mechanism uint } -// EncryptDEK uses the PKCS11 encrypt operation to encrypt the DEK. -func (kms *pkcs11KMS) EncryptDEK(ctx context.Context, plainDEK []byte) ([]byte, error) { - p := pkcs11.New(kms.module) - err := p.Initialize() +// Create a PKCS11 client for the configured module. +func (kms *pkcs11KMS) GetClient() (*pkcs11.Ctx, error) { + if kms.client != nil { + return kms.client, nil + } + kms.client = pkcs11.New(kms.module) + err := kms.client.Initialize() if err != nil { + kms.client = nil return nil, fmt.Errorf("failed to initialize PKCS11: %w", err) } + return kms.client, nil +} - defer p.Destroy() - defer p.Finalize() +// Open a session and perform the authentication process. +func (kms *pkcs11KMS) GetSession() (pkcs11.SessionHandle, error) { + if kms.client == nil { + return 0, fmt.Errorf("PKCS11 not initialized") + } - session, err := p.OpenSession(kms.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) + session, err := kms.client.OpenSession(kms.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) if err != nil { - return nil, fmt.Errorf("failed to open session: %w", err) + return 0, fmt.Errorf("failed to open session: %w", err) } - defer p.CloseSession(session) - - err = p.Login(session, pkcs11.CKU_USER, kms.pin) + err = kms.client.Login(session, pkcs11.CKU_USER, kms.pin) if err != nil { - return nil, fmt.Errorf("failed to login: %w", err) + return 0, fmt.Errorf("failed to login: %w", err) } - defer p.Logout(session) + return session, nil +} - keyIdBytes, err := hex.DecodeString(kms.keyId) - if err != nil { - return nil, fmt.Errorf("failed to decode key id: %w", err) +func (kms *pkcs11KMS) CloseSession(session pkcs11.SessionHandle) { + if kms.client == nil { + return + } + kms.client.Logout(session) + kms.client.CloseSession(session) +} + +func (kms *pkcs11KMS) DestroyClient() { + if kms.client == nil { + return } + kms.client.Finalize() + kms.client.Destroy() + kms.client = nil +} + +// +func (kms *pkcs11KMS) FindKey(session pkcs11.SessionHandle, typ uint) ([]pkcs11.ObjectHandle, error) { template := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_ID, keyIdBytes), - pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, true), + pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(kms.label)), + pkcs11.NewAttribute(typ, true), } - if kms.label != "" { - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_LABEL, kms.label)) + keyIdBytes, err := hex.DecodeString(kms.keyId) + if err == nil { + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_ID, keyIdBytes)) } if kms.mechanism != 0 { keyTypeString, err := GetKeyTypeFromMech(kms.mechanism) @@ -325,28 +366,28 @@ func (kms *pkcs11KMS) EncryptDEK(ctx context.Context, plainDEK []byte) ([]byte, } template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, keyTypeString)) } - if err := p.FindObjectsInit(session, template); err != nil { + + if err := kms.client.FindObjectsInit(session, template); err != nil { return nil, fmt.Errorf("failed to pkcs11 FindObjectsInit: %s", err) } - obj, _, err := p.FindObjects(session, 2) + obj, _, err := kms.client.FindObjects(session, 2) if err != nil { return nil, fmt.Errorf("failed to pkcs11 FindObjects: %s", err) } - if err := p.FindObjectsFinal(session); err != nil { + if err := kms.client.FindObjectsFinal(session); err != nil { return nil, fmt.Errorf("failed to pkcs11 FindObjectsFinal: %s", err) } - if len(obj) != 1 { - return nil, fmt.Errorf("expected 1 object, got %d", len(obj)) - } - key := obj[0] + return obj, nil +} - template = []*pkcs11.Attribute{ +func (kms *pkcs11KMS) GetKeyMechanism(session pkcs11.SessionHandle, key pkcs11.ObjectHandle) (uint, error) { + template := []*pkcs11.Attribute{ pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil), } - attr, err := p.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) + attr, err := kms.client.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) if err != nil { - return nil, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) + return 0, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) } attrMap := GetAttributesMap(attr) @@ -358,35 +399,45 @@ func (kms *pkcs11KMS) EncryptDEK(ctx context.Context, plainDEK []byte) ([]byte, if kms.mechanism != 0 { mechanism = kms.mechanism } else { - mechanism = pkcs11.CKM_AES_CBC_PAD + mechanism = pkcs11.CKM_AES_GCM } case pkcs11.CKK_RSA: if kms.mechanism != 0 { mechanism = kms.mechanism } else { - mechanism = pkcs11.CKM_RSA_PKCS + mechanism = pkcs11.CKM_RSA_PKCS_OAEP } default: - return nil, fmt.Errorf("unsupported key type: %d", keyType) + return 0, fmt.Errorf("unsupported key type: %d", keyType) } - var iv []byte - needIV, ivLength := IsIvNeeded(mechanism) - /* - if needIV && ivLength == 0 { - template = []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, nil), - } - attr, err := p.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) - if err != nil { - return nil, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) - } - attrMap := GetAttributesMap(attr) + return mechanism, nil +} + +// +func (kms *pkcs11KMS) Encrypt(ctx context.Context, plainDEK []byte) ([]byte, error) { + session, err := kms.GetSession() + if err != nil { + return nil, err + } + defer kms.CloseSession(session) + + obj, err := kms.FindKey(session, pkcs11.CKA_ENCRYPT) + if err != nil { + return nil, err + } + if len(obj) != 1 { + return nil, fmt.Errorf("expected 1 object, got %d", len(obj)) + } + key := obj[0] - keyLength = int(GetValueAsInt(attrMap[pkcs11.CKA_VALUE_LEN])) - ivLength = GetIvSize(mechanism, keyLength) + mechanism, err := kms.GetKeyMechanism(session, key) + if err != nil { + return nil, err } - */ + + var iv []byte + needIV, ivLength := IsIvNeeded(mechanism) if needIV && ivLength > 0 { iv, err = uuid.GenerateRandomBytes(ivLength) if err != nil { @@ -394,11 +445,11 @@ func (kms *pkcs11KMS) EncryptDEK(ctx context.Context, plainDEK []byte) ([]byte, } } - if err = p.EncryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { + if err = kms.client.EncryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { return nil, fmt.Errorf("failed to pkcs11 EncryptInit: %s", err) } var ciphertext []byte - if ciphertext, err = p.Encrypt(session, plainDEK); err != nil { + if ciphertext, err = kms.client.Encrypt(session, plainDEK); err != nil { return nil, fmt.Errorf("failed to pkcs11 Encrypt: %s", err) } @@ -409,57 +460,17 @@ func (kms *pkcs11KMS) EncryptDEK(ctx context.Context, plainDEK []byte) ([]byte, } } -// DecryptDEK uses the PKCS11 decrypt operation to decrypt the DEK. -func (kms *pkcs11KMS) DecryptDEK(ctx context.Context, encryptedDEK []byte) ([]byte, error) { - p := pkcs11.New(kms.module) - err := p.Initialize() - if err != nil { - return nil, fmt.Errorf("failed to initialize PKCS11: %w", err) - } - - defer p.Destroy() - defer p.Finalize() - - session, err := p.OpenSession(kms.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) +// Decrypt uses the PKCS11 decrypt operation to decrypt the DEK. +func (kms *pkcs11KMS) Decrypt(ctx context.Context, encryptedDEK []byte) ([]byte, error) { + session, err := kms.GetSession() if err != nil { - return nil, fmt.Errorf("failed to open session: %w", err) - } - defer p.CloseSession(session) - - err = p.Login(session, pkcs11.CKU_USER, kms.pin) - if err != nil { - return nil, fmt.Errorf("failed to login: %w", err) - } - defer p.Logout(session) - - keyIdBytes, err := hex.DecodeString(kms.keyId) - if err != nil { - return nil, fmt.Errorf("failed to decode key id: %w", err) - } - template := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_ID, keyIdBytes), - pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), - } - if kms.label != "" { - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(kms.label))) - } - if kms.mechanism != 0 { - keyTypeString, err := GetKeyTypeFromMech(kms.mechanism) - if err != nil { - return nil, fmt.Errorf("failed to get key type from mechanism: %s", err) - } - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, keyTypeString)) - } - if err := p.FindObjectsInit(session, template); err != nil { - return nil, fmt.Errorf("failed to pkcs11 FindObjectsInit: %s", err) + return nil, err } + defer kms.CloseSession(session) - obj, _, err := p.FindObjects(session, 2) + obj, err := kms.FindKey(session, pkcs11.CKA_DECRYPT) if err != nil { - return nil, fmt.Errorf("failed to pkcs11 FindObjects: %s", err) - } - if err := p.FindObjectsFinal(session); err != nil { - return nil, fmt.Errorf("failed to pkcs11 FindObjectsFinal: %s", err) + return nil, err } if len(obj) != 1 { @@ -467,52 +478,13 @@ func (kms *pkcs11KMS) DecryptDEK(ctx context.Context, encryptedDEK []byte) ([]by } key := obj[0] - template = []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil), - } - attr, err := p.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) + mechanism, err := kms.GetKeyMechanism(session, key) if err != nil { - return nil, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) - } - - attrMap := GetAttributesMap(attr) - keyType := GetValueAsInt(attrMap[pkcs11.CKA_KEY_TYPE]) - - mechanism := uint(0) - switch keyType { - case pkcs11.CKK_AES: - if kms.mechanism != 0 { - mechanism = kms.mechanism - } else { - mechanism = pkcs11.CKM_AES_CBC_PAD - } - case pkcs11.CKK_RSA: - if kms.mechanism != 0 { - mechanism = kms.mechanism - } else { - mechanism = pkcs11.CKM_RSA_PKCS - } - default: - return nil, fmt.Errorf("unsupported key type: %d", keyType) + return nil, err } var iv []byte needIV, ivLength := IsIvNeeded(mechanism) - /* - if needIV && ivLength == 0 { - template = []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, nil), - } - attr, err := p.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) - if err != nil { - return nil, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) - } - attrMap := GetAttributesMap(attr) - - keyLength = int(GetValueAsInt(attrMap[pkcs11.CKA_VALUE_LEN])) - ivLength = GetIvSize(mechanism, keyLength) - } - */ if needIV && ivLength > 0 { if len(encryptedDEK) < ivLength { return nil, fmt.Errorf("encrypted DEK is too short") @@ -522,12 +494,12 @@ func (kms *pkcs11KMS) DecryptDEK(ctx context.Context, encryptedDEK []byte) ([]by encryptedDEK = encryptedDEK[ivLength:] } - if err = p.DecryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { + if err = kms.client.DecryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { return nil, fmt.Errorf("failed to pkcs11 DecryptInit: %s", err) } var decrypted []byte - if decrypted, err = p.Decrypt(session, encryptedDEK); err != nil { + if decrypted, err = kms.client.Decrypt(session, encryptedDEK); err != nil { return nil, fmt.Errorf("failed to pkcs11 Decrypt: %s", err) } return decrypted, nil From e3f1099b7302c0a7c2c20e240c4972034a143594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20GLATIGNY?= <14180893+glatigny@users.noreply.github.com> Date: Tue, 12 Nov 2024 15:14:34 +0100 Subject: [PATCH 06/10] [refactor] second part of refactoring. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Move client into a specific file * Support of setting: token label * Encrypt do not wrap anymore * Decrypt do not unwrap anymore * Encrypt store the key identifier in the KeyInfo * Decrypt and read the key identifier from the KeyInfo * Client Encrypt and Decrypt return and take a key identifier * Pkcs11Key struct to store the key identifier (label and id) Signed-off-by: Jérôme GLATIGNY <14180893+glatigny@users.noreply.github.com> --- wrappers/pkcs11/options.go | 67 ++-- wrappers/pkcs11/pkcs11.go | 475 ++------------------------ wrappers/pkcs11/pkcs11_client.go | 553 +++++++++++++++++++++++++++++++ 3 files changed, 623 insertions(+), 472 deletions(-) create mode 100644 wrappers/pkcs11/pkcs11_client.go diff --git a/wrappers/pkcs11/options.go b/wrappers/pkcs11/options.go index 3abc7580..867f8779 100644 --- a/wrappers/pkcs11/options.go +++ b/wrappers/pkcs11/options.go @@ -4,8 +4,6 @@ package pkcs11 import ( - "strconv" - wrapping "github.com/openbao/go-kms-wrapping/v2" ) @@ -46,24 +44,19 @@ func getOpts(opt ...wrapping.Option) (*options, error) { if opts.WithConfigMap != nil { for k, v := range opts.WithConfigMap { switch k { - case "kms_key_id": // deprecated backend-specific value, set global - opts.WithKeyId = v + // case "key_id", "kms_key_id": // deprecated backend-specific value, set global + case "key_id": + opts.withKeyId = v case "slot": - var err error - var slot uint64 - slot, err = strconv.ParseUint(v, 10, 64) - if err != nil { - return nil, err - } - opts.withSlot = uint(slot) + opts.withSlot = v case "pin": opts.withPin = v - case "lib": - case "module": - opts.withModule = v - case "key_label": - case "label": - opts.withLabel = v + case "lib", "module": + opts.withLib = v + case "token", "token_label": + opts.withTokenLabel = v + case "label", "key_label": + opts.withKeyLabel = v case "mechanism": opts.withMechanism = v } @@ -90,11 +83,13 @@ type OptionFunc func(*options) error type options struct { *wrapping.Options - withSlot uint - withPin string - withModule string - withLabel string - withMechanism string + withSlot string + withPin string + withLib string + withKeyId string + withKeyLabel string + withTokenLabel string + withMechanism string } func getDefaultOptions() options { @@ -102,13 +97,21 @@ func getDefaultOptions() options { } // WithSlot sets the slot -func WithSlot(slot uint) OptionFunc { +func WithSlot(slot string) OptionFunc { return func(o *options) error { o.withSlot = slot return nil } } +// WithSlot sets the slot +func WithTokenLabel(slot string) OptionFunc { + return func(o *options) error { + o.withTokenLabel = slot + return nil + } +} + // WithPin sets the pin func WithPin(pin string) OptionFunc { return func(o *options) error { @@ -117,18 +120,26 @@ func WithPin(pin string) OptionFunc { } } -// WithModule sets the module -func WithModule(module string) OptionFunc { +// WithLib sets the module +func WithLib(lib string) OptionFunc { + return func(o *options) error { + o.withLib = lib + return nil + } +} + +// WithLabel sets the label +func WithKeyId(keyId string) OptionFunc { return func(o *options) error { - o.withModule = module + o.withKeyId = keyId return nil } } // WithLabel sets the label -func WithLabel(label string) OptionFunc { +func WithKeyLabel(label string) OptionFunc { return func(o *options) error { - o.withLabel = label + o.withKeyLabel = label return nil } } diff --git a/wrappers/pkcs11/pkcs11.go b/wrappers/pkcs11/pkcs11.go index b07cb9ac..c648982e 100644 --- a/wrappers/pkcs11/pkcs11.go +++ b/wrappers/pkcs11/pkcs11.go @@ -5,15 +5,9 @@ package pkcs11 import ( "context" - "encoding/binary" - "encoding/hex" "fmt" - "os" - "strconv" "sync/atomic" - uuid "github.com/hashicorp/go-uuid" - pkcs11 "github.com/miekg/pkcs11" wrapping "github.com/openbao/go-kms-wrapping/v2" ) @@ -25,7 +19,7 @@ const ( // Wrapper is a Wrapper that uses PKCS11 type Wrapper struct { - client *pkcs11KMS + client pkcs11ClientEncryptor keyId string currentKeyId *atomic.Value } @@ -43,125 +37,34 @@ func NewWrapper() *Wrapper { } // Init is called during core.Initialize -func (s *Wrapper) Init(_ context.Context) error { +func (k *Wrapper) Init(_ context.Context) error { return nil } // Finalize is called during shutdown -func (s *Wrapper) Finalize(_ context.Context) error { - s.client.DestroyClient() +func (k *Wrapper) Finalize(_ context.Context) error { + k.client.Close() return nil } -// SetConfig sets the fields on the Pkcs11Wrapper object based on -// values from the config parameter. -// -// Order of precedence Pkcs11 values: -// * Environment variable -// * Value from Vault configuration file -// * Instance metadata role (access key and secret key) +// SetConfig processes the config info from the server config func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrapping.WrapperConfig, error) { opts, err := getOpts(opt...) if err != nil { return nil, err } - // Check and set KeyId - switch { - case os.Getenv(EnvPkcs11WrapperKeyId) != "" && !opts.Options.WithDisallowEnvVars: - k.keyId = os.Getenv(EnvPkcs11WrapperKeyId) - case os.Getenv(EnvVaultPkcs11SealKeyId) != "" && !opts.Options.WithDisallowEnvVars: - k.keyId = os.Getenv(EnvVaultPkcs11SealKeyId) - case opts.WithKeyId != "": - k.keyId = opts.WithKeyId - default: - return nil, fmt.Errorf("key id not found (env or config) for pkcs11 wrapper configuration") - } - - // Set and check k.client - if k.client == nil { - k.client = &pkcs11KMS{} - - if !opts.Options.WithDisallowEnvVars && os.Getenv("PKCS11_SLOT") != "" { - var err error - var slot uint64 - slot, err = strconv.ParseUint(os.Getenv("PKCS11_SLOT"), 10, 64) - if err != nil { - return nil, err - } - k.client.slot = uint(slot) - } - if k.client.slot == 0 { - k.client.slot = opts.withSlot - } - - if !opts.Options.WithDisallowEnvVars { - k.client.pin = os.Getenv("PKCS11_PIN") - } - if k.client.pin == "" { - k.client.pin = opts.withPin - } - - if !opts.Options.WithDisallowEnvVars { - k.client.module = os.Getenv("PKCS11_MODULE") - } - if k.client.module == "" { - k.client.module = opts.withModule - } - - if !opts.Options.WithDisallowEnvVars { - k.client.label = os.Getenv("PKCS11_LABEL") - } - if k.client.label == "" { - k.client.label = opts.withLabel - } - - if !opts.Options.WithDisallowEnvVars { - mechanismName := os.Getenv("PKCS11_MECHANISM") - if mechanismName != "" { - k.client.mechanism, err = MechanismFromString(mechanismName) - if err != nil { - return nil, err - } - } - } - if k.client.mechanism == 0 { - if opts.withMechanism != "" { - k.client.mechanism, err = MechanismFromString(opts.withMechanism) - if err != nil { - return nil, err - } - } - } - - k.client.keyId = k.keyId - - // Initialize the client - _, err = k.client.GetClient() - if err != nil { - return nil, err - } - // Validate credentials for session establishment - session, err := k.client.GetSession() - if err != nil { - return nil, err - } - defer k.client.CloseSession(session) + client, wrapConfig, err := newPkcs11Client(opts) + if err != nil { + return nil, err } - // Store the current key id. If using a key alias, this will point to the actual - // unique key that that was used for this encrypt operation. - k.currentKeyId.Store(k.keyId) + k.client = client + k.keyId = client.GetCurrentKey().String() - // Map that holds non-sensitive configuration info - wrapConfig := new(wrapping.WrapperConfig) - wrapConfig.Metadata = make(map[string]string) - wrapConfig.Metadata["key_id"] = k.keyId - wrapConfig.Metadata["slot"] = strconv.Itoa(int(k.client.slot)) - if k.client.label != "" { - wrapConfig.Metadata["label"] = k.client.label - } - if k.client.mechanism != 0 { - wrapConfig.Metadata["mechanism"] = MechanisString(k.client.mechanism) + // Send a value to test the wrapper and to set the current key id + if _, err := k.Encrypt(context.Background(), []byte("a")); err != nil { + client.Close() + return nil, err } return wrapConfig, nil @@ -181,32 +84,20 @@ func (k *Wrapper) KeyId(_ context.Context) (string, error) { // This returns the ciphertext, and/or any errors from this // call. This should be called after the KMS client has been instantiated. func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.Option) (*wrapping.BlobInfo, error) { - if len(plaintext) == 0 { - return nil, fmt.Errorf("given plaintext for encryption is empty") - } - - env, err := wrapping.EnvelopeEncrypt(plaintext, opt...) + ciphertext, key, err := k.client.Encrypt(plaintext) if err != nil { - return nil, fmt.Errorf("error wrapping data: %w", err) - } - - WrappedKey, err := k.client.Encrypt(context.Background(), env.Key) - if err != nil { - return nil, fmt.Errorf("error encrypting data: %w", err) + return nil, err } - - // Store the current key id. - k.currentKeyId.Store(k.keyId) + + keyId := key.String() + k.currentKeyId.Store(keyId) ret := &wrapping.BlobInfo{ - Ciphertext: env.Ciphertext, - Iv: env.Iv, + Ciphertext: ciphertext, KeyInfo: &wrapping.KeyInfo{ - KeyId: k.keyId, - WrappedKey: WrappedKey, + KeyId: keyId, }, } - return ret, nil } @@ -216,327 +107,23 @@ func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapp return nil, fmt.Errorf("given input for decryption is nil") } - keyBytes, err := k.client.Decrypt(context.Background(), in.KeyInfo.WrappedKey) - if err != nil { - return nil, fmt.Errorf("error decrypting data encryption key: %w", err) - } - - envInfo := &wrapping.EnvelopeInfo{ - Key: keyBytes, - Iv: in.Iv, - Ciphertext: in.Ciphertext, - } - plaintext, err := wrapping.EnvelopeDecrypt(envInfo, opt...) - if err != nil { - return nil, fmt.Errorf("error decrypting data: %w", err) - } - - return plaintext, nil -} - -func GetKeyTypeFromMech(mech uint) (uint, error) { - switch mech { - case pkcs11.CKM_RSA_PKCS_OAEP: - return pkcs11.CKK_RSA, nil - case pkcs11.CKM_RSA_PKCS: - return pkcs11.CKK_RSA, nil - case pkcs11.CKM_AES_GCM: - return pkcs11.CKK_AES, nil - case pkcs11.CKM_AES_CBC_PAD: - return pkcs11.CKK_AES, nil - default: - return 0, fmt.Errorf("unsupported mechanism: %d", mech) - } -} - -func MechanisString(mech uint) string { - switch mech { - case pkcs11.CKM_RSA_PKCS_OAEP: - return "CKM_RSA_PKCS_OAEP" - case pkcs11.CKM_RSA_PKCS: - return "CKM_RSA_PKCS" - case pkcs11.CKM_AES_GCM: - return "CKM_AES_GCM" - case pkcs11.CKM_AES_CBC_PAD: - return "CKM_AES_CBC_PAD" - default: - return "Unknown" - } -} - -func IsIvNeeded(mech uint) (bool, int) { - switch mech { - case pkcs11.CKM_AES_GCM: - return true, 16 - case pkcs11.CKM_AES_CBC_PAD: - return true, 16 - default: - return false, 0 - } -} - -func MechanismFromString(mech string) (uint, error) { - switch mech { - case "CKM_RSA_PKCS_OAEP": - return pkcs11.CKM_RSA_PKCS_OAEP, nil - case "CKM_RSA_PKCS": - return pkcs11.CKM_RSA_PKCS, nil - case "CKM_AES_GCM": - return pkcs11.CKM_AES_GCM, nil - case "CKM_AES_CBC_PAD": - return pkcs11.CKM_AES_CBC_PAD, nil - default: - return 0, fmt.Errorf("unsupported mechanism: %s", mech) - } -} - -type pkcs11KMS struct { - client *pkcs11.Ctx - // standard PKCS11 configuration options - slot uint - pin string - module string - keyId string - label string - mechanism uint -} - -// Create a PKCS11 client for the configured module. -func (kms *pkcs11KMS) GetClient() (*pkcs11.Ctx, error) { - if kms.client != nil { - return kms.client, nil - } - kms.client = pkcs11.New(kms.module) - err := kms.client.Initialize() - if err != nil { - kms.client = nil - return nil, fmt.Errorf("failed to initialize PKCS11: %w", err) - } - return kms.client, nil -} - -// Open a session and perform the authentication process. -func (kms *pkcs11KMS) GetSession() (pkcs11.SessionHandle, error) { - if kms.client == nil { - return 0, fmt.Errorf("PKCS11 not initialized") - } - - session, err := kms.client.OpenSession(kms.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) - if err != nil { - return 0, fmt.Errorf("failed to open session: %w", err) - } - err = kms.client.Login(session, pkcs11.CKU_USER, kms.pin) - if err != nil { - return 0, fmt.Errorf("failed to login: %w", err) - } - return session, nil -} - -func (kms *pkcs11KMS) CloseSession(session pkcs11.SessionHandle) { - if kms.client == nil { - return - } - kms.client.Logout(session) - kms.client.CloseSession(session) -} - -func (kms *pkcs11KMS) DestroyClient() { - if kms.client == nil { - return - } - kms.client.Finalize() - kms.client.Destroy() - kms.client = nil -} - -// -func (kms *pkcs11KMS) FindKey(session pkcs11.SessionHandle, typ uint) ([]pkcs11.ObjectHandle, error) { - template := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(kms.label)), - pkcs11.NewAttribute(typ, true), - } - keyIdBytes, err := hex.DecodeString(kms.keyId) - if err == nil { - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_ID, keyIdBytes)) - } - if kms.mechanism != 0 { - keyTypeString, err := GetKeyTypeFromMech(kms.mechanism) - if err != nil { - return nil, fmt.Errorf("failed to get key type from mechanism: %s", err) - } - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, keyTypeString)) - } - - if err := kms.client.FindObjectsInit(session, template); err != nil { - return nil, fmt.Errorf("failed to pkcs11 FindObjectsInit: %s", err) - } - obj, _, err := kms.client.FindObjects(session, 2) - if err != nil { - return nil, fmt.Errorf("failed to pkcs11 FindObjects: %s", err) - } - if err := kms.client.FindObjectsFinal(session); err != nil { - return nil, fmt.Errorf("failed to pkcs11 FindObjectsFinal: %s", err) - } - - return obj, nil -} - -func (kms *pkcs11KMS) GetKeyMechanism(session pkcs11.SessionHandle, key pkcs11.ObjectHandle) (uint, error) { - template := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil), - } - attr, err := kms.client.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) - if err != nil { - return 0, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) - } - - attrMap := GetAttributesMap(attr) - keyType := GetValueAsInt(attrMap[pkcs11.CKA_KEY_TYPE]) - - mechanism := uint(0) - switch keyType { - case pkcs11.CKK_AES: - if kms.mechanism != 0 { - mechanism = kms.mechanism - } else { - mechanism = pkcs11.CKM_AES_GCM + if in.KeyInfo == nil { + in.KeyInfo = &wrapping.KeyInfo{ + KeyId: k.keyId, } - case pkcs11.CKK_RSA: - if kms.mechanism != 0 { - mechanism = kms.mechanism - } else { - mechanism = pkcs11.CKM_RSA_PKCS_OAEP - } - default: - return 0, fmt.Errorf("unsupported key type: %d", keyType) - } - - return mechanism, nil -} - -// -func (kms *pkcs11KMS) Encrypt(ctx context.Context, plainDEK []byte) ([]byte, error) { - session, err := kms.GetSession() - if err != nil { - return nil, err } - defer kms.CloseSession(session) - - obj, err := kms.FindKey(session, pkcs11.CKA_ENCRYPT) + keyId, err := newPkcs11Key(in.KeyInfo.KeyId) if err != nil { return nil, err } - if len(obj) != 1 { - return nil, fmt.Errorf("expected 1 object, got %d", len(obj)) - } - key := obj[0] - - mechanism, err := kms.GetKeyMechanism(session, key) + plaintext, err := k.client.Decrypt(in.Ciphertext, keyId) if err != nil { return nil, err } - - var iv []byte - needIV, ivLength := IsIvNeeded(mechanism) - if needIV && ivLength > 0 { - iv, err = uuid.GenerateRandomBytes(ivLength) - if err != nil { - return nil, err - } - } - - if err = kms.client.EncryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { - return nil, fmt.Errorf("failed to pkcs11 EncryptInit: %s", err) - } - var ciphertext []byte - if ciphertext, err = kms.client.Encrypt(session, plainDEK); err != nil { - return nil, fmt.Errorf("failed to pkcs11 Encrypt: %s", err) - } - - if iv != nil { - return append(iv, ciphertext...), nil - } else { - return ciphertext, nil - } -} - -// Decrypt uses the PKCS11 decrypt operation to decrypt the DEK. -func (kms *pkcs11KMS) Decrypt(ctx context.Context, encryptedDEK []byte) ([]byte, error) { - session, err := kms.GetSession() - if err != nil { - return nil, err - } - defer kms.CloseSession(session) - - obj, err := kms.FindKey(session, pkcs11.CKA_DECRYPT) - if err != nil { - return nil, err - } - - if len(obj) != 1 { - return nil, fmt.Errorf("expected 1 object, got %d", len(obj)) - } - key := obj[0] - - mechanism, err := kms.GetKeyMechanism(session, key) - if err != nil { - return nil, err - } - - var iv []byte - needIV, ivLength := IsIvNeeded(mechanism) - if needIV && ivLength > 0 { - if len(encryptedDEK) < ivLength { - return nil, fmt.Errorf("encrypted DEK is too short") - } - - iv = encryptedDEK[:ivLength] - encryptedDEK = encryptedDEK[ivLength:] - } - - if err = kms.client.DecryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { - return nil, fmt.Errorf("failed to pkcs11 DecryptInit: %s", err) - } - - var decrypted []byte - if decrypted, err = kms.client.Decrypt(session, encryptedDEK); err != nil { - return nil, fmt.Errorf("failed to pkcs11 Decrypt: %s", err) - } - return decrypted, nil -} - -func GetAttributesMap(attrs []*pkcs11.Attribute) map[uint][]byte { - m := make(map[uint][]byte, len(attrs)) - for _, a := range attrs { - m[a.Type] = a.Value - } - return m -} - -func GetValueAsInt(value []byte) int64 { - switch len(value) { - case 1: - return int64(value[0]) - case 2: - return int64(binary.NativeEndian.Uint16(value)) - case 4: - return int64(binary.NativeEndian.Uint32(value)) - case 8: - return int64(binary.NativeEndian.Uint64(value)) - } - return 0 + return plaintext, nil } -func GetValueAsUint(value []byte) uint64 { - switch len(value) { - case 1: - return uint64(value[0]) - case 2: - return uint64(binary.NativeEndian.Uint16(value)) - case 4: - return uint64(binary.NativeEndian.Uint32(value)) - case 8: - return uint64(binary.NativeEndian.Uint64(value)) - } - return 0 -} +// GetClient returns the pkcs11 Wrapper's pkcs11ClientEncryptor +func (k *Wrapper) GetClient() pkcs11ClientEncryptor { + return k.client +} \ No newline at end of file diff --git a/wrappers/pkcs11/pkcs11_client.go b/wrappers/pkcs11/pkcs11_client.go new file mode 100644 index 00000000..69a89447 --- /dev/null +++ b/wrappers/pkcs11/pkcs11_client.go @@ -0,0 +1,553 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pkcs11 + +import ( + "fmt" + "strconv" + "strings" + "encoding/binary" + "encoding/hex" + + "github.com/openbao/openbao/api" + uuid "github.com/hashicorp/go-uuid" + pkcs11 "github.com/miekg/pkcs11" + wrapping "github.com/openbao/go-kms-wrapping/v2" +) + +type Pkcs11Key struct { + label string + id string +} +func (k Pkcs11Key) String() string { + return fmt.Sprintf("%s:%s", k.label, k.id) +} +func newPkcs11Key(v string) (*Pkcs11Key, error) { + pos := strings.LastIndex(v, ":") + if pos <= 0 { + return nil, fmt.Errorf("Invalid key format") + } + k := &Pkcs11Key{ + label: v[:pos], + id: v[pos+1:], + } + return k, nil +} +func (k Pkcs11Key) Set(v string) error { + pos := strings.LastIndex(v, ":") + if pos <= 0 { + return fmt.Errorf("Invalid key format") + } + k.label = v[:pos] + k.id = v[pos+1:] + return nil +} + +type pkcs11ClientEncryptor interface { + Close() + Encrypt(plaintext []byte) (ciphertext []byte, keyId *Pkcs11Key, err error) + Decrypt(ciphertext []byte, keyId *Pkcs11Key) (plaintext []byte, err error) +} + +type Pkcs11Client struct { + client *pkcs11.Ctx + lib string + slot uint + tokenLabel string + pin string + keyLabel string + keyId string + mechanism uint +} + +const ( + EnvHsmWrapperLib = "BAO_HSM_LIB" + EnvHsmWrapperSlot = "BAO_HSM_SLOT" + EnvHsmWrapperTokenLabel = "BAO_HSM_TOKEN_LABEL" + EnvHsmWrapperPin = "BAO_HSM_PIN" + EnvHsmWrapperKeyLabel = "BAO_HSM_KEY_LABEL" + EnvHsmWrapperKeyId = "BAO_HSM_KEY_ID" + EnvHsmWrapperMechanism = "BAO_HSM_MECHANISM" +) + +func newPkcs11Client(opts *options) (*Pkcs11Client, *wrapping.WrapperConfig, error) { + var lib, slot, keyId, tokenLabel, pin, keyLabel, mechanism string + var slotNum, mechanismNum uint64 + var err error + + switch { + case api.ReadBaoVariable(EnvHsmWrapperLib) != "" && !opts.Options.WithDisallowEnvVars: + lib = api.ReadBaoVariable(EnvHsmWrapperLib) + case opts.withLib != "": + lib = opts.withLib + default: + return nil, nil, fmt.Errorf("lib is required") + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperSlot) != "" && !opts.Options.WithDisallowEnvVars: + slot = api.ReadBaoVariable(EnvHsmWrapperSlot) + case opts.withSlot != "": + slot = opts.withSlot + default: + slot = "" + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperTokenLabel) != "" && !opts.Options.WithDisallowEnvVars: + tokenLabel = api.ReadBaoVariable(EnvHsmWrapperTokenLabel) + case opts.withTokenLabel != "": + tokenLabel = opts.withTokenLabel + default: + tokenLabel = "" + } + + if slot == "" && tokenLabel == "" { + return nil, nil, fmt.Errorf("slot or token label required") + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperKeyId) != "" && !opts.Options.WithDisallowEnvVars: + keyId = api.ReadBaoVariable(EnvHsmWrapperKeyId) + case opts.withKeyId != "": + keyId = opts.withKeyId + default: + keyId = "" + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperPin) != "" && !opts.Options.WithDisallowEnvVars: + pin = api.ReadBaoVariable(EnvHsmWrapperPin) + case opts.withPin != "": + pin = opts.withPin + default: + return nil, nil, fmt.Errorf("pin is required") + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperKeyLabel) != "" && !opts.Options.WithDisallowEnvVars: + keyLabel = api.ReadBaoVariable(EnvHsmWrapperKeyLabel) + case opts.withKeyLabel != "": + keyLabel = opts.withKeyLabel + default: + return nil, nil, fmt.Errorf("key label is required") + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperMechanism) != "" && !opts.Options.WithDisallowEnvVars: + mechanism = api.ReadBaoVariable(EnvHsmWrapperMechanism) + case opts.withMechanism != "": + mechanism = opts.withMechanism + default: + mechanism = "" + } + + if slot != "" { + if slotNum, err = numberAutoParse(slot, 32); err != nil { + return nil, nil, fmt.Errorf("Invalid slot number") + } + } else { + slotNum = 0 + } + + if mechanism != "" { + if mechanismNum, err = MechanismFromString(mechanism); err != nil { + return nil, nil, err + } + } else { + mechanismNum = 0 + } + + client := &Pkcs11Client{ + client: nil, + lib: lib, + slot: uint(slotNum), + pin: pin, + tokenLabel: tokenLabel, + keyId: keyId, + keyLabel: keyLabel, + mechanism: uint(mechanismNum), + } + + // Initialize the client + _, err = client.GetClient() + if err != nil { + return nil, nil, err + } + // Validate credentials for session establishment + session, err := client.GetSession() + if err != nil { + return nil, nil, err + } + defer client.CloseSession(session) + + wrapConfig := new(wrapping.WrapperConfig) + wrapConfig.Metadata = make(map[string]string) + wrapConfig.Metadata["lib"] = lib + wrapConfig.Metadata["key_label"] = keyLabel + wrapConfig.Metadata["key_id"] = keyId + if slotNum != 0 { + wrapConfig.Metadata["slot"] = strconv.Itoa(int(slotNum)) + } + if tokenLabel != "" { + wrapConfig.Metadata["token_label"] = tokenLabel + } + if mechanismNum != 0 { + wrapConfig.Metadata["mechanism"] = MechanismString(uint(mechanismNum)) + } + + return client, wrapConfig, nil +} + + +func (c *Pkcs11Client) Close() { + if c.client == nil { + return + } + c.client.Finalize() + c.client.Destroy() + c.client = nil +} + +func (c *Pkcs11Client) Encrypt(plaintext []byte) ([]byte, *Pkcs11Key, error) { + session, err := c.GetSession() + if err != nil { + return nil, nil, err + } + defer c.CloseSession(session) + + keyId := Pkcs11Key{ label: c.keyLabel, id: c.keyId } + obj, err := c.FindKey(session, keyId, pkcs11.CKA_ENCRYPT) + if err != nil { + return nil, nil, err + } + if len(obj) != 1 { + return nil, nil, fmt.Errorf("expected 1 object, got %d", len(obj)) + } + key := obj[0] + + mechanism, err := c.GetKeyMechanism(session, key) + if err != nil { + return nil, nil, err + } + + var iv []byte + needIV, ivLength := IsIvNeeded(mechanism) + if needIV && ivLength > 0 { + iv, err = uuid.GenerateRandomBytes(ivLength) + if err != nil { + return nil, nil, err + } + } + + if err = c.client.EncryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { + return nil, nil, fmt.Errorf("failed to pkcs11 EncryptInit: %s", err) + } + var ciphertext []byte + if ciphertext, err = c.client.Encrypt(session, plaintext); err != nil { + return nil, nil, fmt.Errorf("failed to pkcs11 Encrypt: %s", err) + } + + if iv != nil { + return append(iv, ciphertext...), &keyId, nil + } else { + return ciphertext, &keyId, nil + } +} + +func (c *Pkcs11Client) Decrypt(ciphertext []byte, keyId *Pkcs11Key) ([]byte, error) { + session, err := c.GetSession() + if err != nil { + return nil, err + } + defer c.CloseSession(session) + + if keyId == nil { + keyId = &Pkcs11Key{ label: c.keyLabel, id: c.keyId } + } + + obj, err := c.FindKey(session, *keyId, pkcs11.CKA_DECRYPT) + if err != nil { + return nil, err + } + + if len(obj) != 1 { + return nil, fmt.Errorf("expected 1 object, got %d", len(obj)) + } + key := obj[0] + + mechanism, err := c.GetKeyMechanism(session, key) + if err != nil { + return nil, err + } + + var iv []byte + needIV, ivLength := IsIvNeeded(mechanism) + if needIV && ivLength > 0 { + if len(ciphertext) < ivLength { + return nil, fmt.Errorf("encrypted DEK is too short") + } + + iv = ciphertext[:ivLength] + ciphertext = ciphertext[ivLength:] + } + + if err = c.client.DecryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { + return nil, fmt.Errorf("failed to pkcs11 DecryptInit: %s", err) + } + + var decrypted []byte + if decrypted, err = c.client.Decrypt(session, ciphertext); err != nil { + return nil, fmt.Errorf("failed to pkcs11 Decrypt: %s", err) + } + return decrypted, nil +} + +// Create a PKCS11 client for the configured module. +func (c *Pkcs11Client) GetClient() (*pkcs11.Ctx, error) { + if c.client != nil { + return c.client, nil + } + c.client = pkcs11.New(c.lib) + err := c.client.Initialize() + if err != nil { + c.client = nil + return nil, fmt.Errorf("failed to initialize PKCS11: %w", err) + } + return c.client, nil +} + +func (c *Pkcs11Client) GetSlotForLabel() (uint, error) { + if c.slot != 0 { + return c.slot, nil + } + if c.tokenLabel == "" { + return 0, fmt.Errorf("not token label configured") + } + slots, _ := c.client.GetSlotList(true) + for _, slot := range slots { + tokenInfo, err := c.client.GetTokenInfo(slot) + if err == nil && tokenInfo.Label == c.tokenLabel { + c.slot = slot + break + } + } + if c.slot == 0 { + return 0, fmt.Errorf("failed to find token with label: %s", c.tokenLabel) + } + return c.slot, nil +} + +// Open a session and perform the authentication process. +func (c *Pkcs11Client) GetSession() (pkcs11.SessionHandle, error) { + if c.client == nil { + return 0, fmt.Errorf("PKCS11 not initialized") + } + + if c.slot == 0 { + _, err := c.GetSlotForLabel() + if err != nil { + return 0, err + } + } + + session, err := c.client.OpenSession(c.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) + if err != nil { + return 0, fmt.Errorf("failed to open session: %w", err) + } + err = c.client.Login(session, pkcs11.CKU_USER, c.pin) + if err != nil { + return 0, fmt.Errorf("failed to login: %w", err) + } + return session, nil +} + +func (c *Pkcs11Client) CloseSession(session pkcs11.SessionHandle) { + if c.client == nil { + return + } + c.client.Logout(session) + c.client.CloseSession(session) +} + +// +func (c *Pkcs11Client) FindKey(session pkcs11.SessionHandle, key Pkcs11Key, typ uint) ([]pkcs11.ObjectHandle, error) { + template := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(key.label)), + pkcs11.NewAttribute(typ, true), + } + if keyIdBytes, err := hex.DecodeString(key.id); err == nil { + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_ID, keyIdBytes)) + } + if c.mechanism != 0 { + keyTypeString, err := GetKeyTypeFromMech(c.mechanism) + if err != nil { + return nil, fmt.Errorf("failed to get key type from mechanism: %s", err) + } + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, keyTypeString)) + } + + if err := c.client.FindObjectsInit(session, template); err != nil { + return nil, fmt.Errorf("failed to pkcs11 FindObjectsInit: %s", err) + } + obj, _, err := c.client.FindObjects(session, 2) + if err != nil { + return nil, fmt.Errorf("failed to pkcs11 FindObjects: %s", err) + } + if err := c.client.FindObjectsFinal(session); err != nil { + return nil, fmt.Errorf("failed to pkcs11 FindObjectsFinal: %s", err) + } + + return obj, nil +} + +func (c *Pkcs11Client) GetKeyMechanism(session pkcs11.SessionHandle, key pkcs11.ObjectHandle) (uint, error) { + template := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil), + } + attr, err := c.client.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) + if err != nil { + return 0, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) + } + + attrMap := GetAttributesMap(attr) + keyType := GetValueAsInt(attrMap[pkcs11.CKA_KEY_TYPE]) + + mechanism := uint(0) + switch keyType { + case pkcs11.CKK_AES: + if c.mechanism != 0 { + mechanism = c.mechanism + } else { + mechanism = pkcs11.CKM_AES_GCM + } + case pkcs11.CKK_RSA: + if c.mechanism != 0 { + mechanism = c.mechanism + } else { + mechanism = pkcs11.CKM_RSA_PKCS_OAEP + } + default: + return 0, fmt.Errorf("unsupported key type: %d", keyType) + } + + return mechanism, nil +} + +func (c *Pkcs11Client) GetCurrentKey() Pkcs11Key { + return Pkcs11Key{ + label: c.keyLabel, + id: c.keyId, + } +} + +func IsIvNeeded(mech uint) (bool, int) { + switch mech { + case pkcs11.CKM_AES_GCM: + return true, 16 + case pkcs11.CKM_AES_CBC_PAD: + return true, 16 + default: + return false, 0 + } +} + +func GetKeyTypeFromMech(mech uint) (uint, error) { + switch mech { + case pkcs11.CKM_RSA_PKCS_OAEP: + return pkcs11.CKK_RSA, nil + case pkcs11.CKM_RSA_PKCS: + return pkcs11.CKK_RSA, nil + case pkcs11.CKM_AES_GCM: + return pkcs11.CKK_AES, nil + case pkcs11.CKM_AES_CBC_PAD: + return pkcs11.CKK_AES, nil + default: + return 0, fmt.Errorf("unsupported mechanism: %d", mech) + } +} + +func MechanismString(mech uint) string { + switch mech { + case pkcs11.CKM_RSA_PKCS_OAEP: + return "CKM_RSA_PKCS_OAEP" + case pkcs11.CKM_RSA_PKCS: + return "CKM_RSA_PKCS" + case pkcs11.CKM_AES_GCM: + return "CKM_AES_GCM" + case pkcs11.CKM_AES_CBC_PAD: + return "CKM_AES_CBC_PAD" + default: + return "Unknown" + } +} + +func MechanismFromString(mech string) (uint64, error) { + switch mech { + case "CKM_RSA_PKCS_OAEP", "RSA_PKCS_OAEP": + return pkcs11.CKM_RSA_PKCS_OAEP, nil + case "CKM_RSA_PKCS", "RSA_PKCS": + return pkcs11.CKM_RSA_PKCS, nil + case "CKM_AES_GCM", "AES_GCM": + return pkcs11.CKM_AES_GCM, nil + case "CKM_AES_CBC_PAD", "AES_CBC_PAD": + return pkcs11.CKM_AES_CBC_PAD, nil + default: + if mechanismNum, err := numberAutoParse(mech, 32); err == nil { + if _, err = GetKeyTypeFromMech(uint(mechanismNum)); err == nil { + return mechanismNum, nil + } + } + return 0, fmt.Errorf("unsupported mechanism: %s", mech) + } +} + + +func GetAttributesMap(attrs []*pkcs11.Attribute) map[uint][]byte { + m := make(map[uint][]byte, len(attrs)) + for _, a := range attrs { + m[a.Type] = a.Value + } + return m +} + +func GetValueAsInt(value []byte) int64 { + switch len(value) { + case 1: + return int64(value[0]) + case 2: + return int64(binary.NativeEndian.Uint16(value)) + case 4: + return int64(binary.NativeEndian.Uint32(value)) + case 8: + return int64(binary.NativeEndian.Uint64(value)) + } + return 0 +} + +func GetValueAsUint(value []byte) uint64 { + switch len(value) { + case 1: + return uint64(value[0]) + case 2: + return uint64(binary.NativeEndian.Uint16(value)) + case 4: + return uint64(binary.NativeEndian.Uint32(value)) + case 8: + return uint64(binary.NativeEndian.Uint64(value)) + } + return 0 +} + +func numberAutoParse(value string, bitSize int) (uint64, error) { + var ret uint64 + var err error + value = strings.ToLower(value) + if strings.HasPrefix(value, "0x") { + ret, err = strconv.ParseUint(value[2:], 16, bitSize) + } else { + ret, err = strconv.ParseUint(value, 10, bitSize) + } + return ret, err +} \ No newline at end of file From 02488dd05e8951b32ec117223989e033e8e48df6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20GLATIGNY?= <14180893+glatigny@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:01:43 +0100 Subject: [PATCH 07/10] [refactor] validation of unit test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme GLATIGNY <14180893+glatigny@users.noreply.github.com> --- wrappers/pkcs11/go.mod | 33 +++++++-- wrappers/pkcs11/go.sum | 98 ++++++++++++++++++++++--- wrappers/pkcs11/options.go | 70 ++++++++++-------- wrappers/pkcs11/options_test.go | 112 +++++++++++++++++++++++++++++ wrappers/pkcs11/pkcs11.go | 6 -- wrappers/pkcs11/pkcs11_acc_test.go | 13 ++-- wrappers/pkcs11/pkcs11_client.go | 4 +- 7 files changed, 281 insertions(+), 55 deletions(-) create mode 100644 wrappers/pkcs11/options_test.go diff --git a/wrappers/pkcs11/go.mod b/wrappers/pkcs11/go.mod index 085a3d04..c906375c 100644 --- a/wrappers/pkcs11/go.mod +++ b/wrappers/pkcs11/go.mod @@ -5,14 +5,39 @@ go 1.22.1 replace github.com/openbao/go-kms-wrapping/v2 => ../../ require ( + github.com/hashicorp/go-uuid v1.0.3 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b - github.com/openbao/go-kms-wrapping/v2 v2.0.0-00010101000000-000000000000 + github.com/openbao/go-kms-wrapping/v2 v2.1.0 + github.com/openbao/openbao/api/v2 v2.0.1 + github.com/stretchr/testify v1.8.4 ) require ( - github.com/hashicorp/go-uuid v1.0.3 // indirect - golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect - google.golang.org/protobuf v1.31.0 // indirect + github.com/cenkalti/backoff/v3 v3.0.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-jose/go-jose/v3 v3.0.1 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/hashicorp/go-retryablehttp v0.7.7 // indirect + github.com/hashicorp/go-rootcerts v1.0.2 // indirect + github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6 // indirect + github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect + github.com/hashicorp/go-sockaddr v1.0.2 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect + github.com/ryanuber/go-glob v1.0.0 // indirect + golang.org/x/crypto v0.24.0 // indirect + golang.org/x/net v0.26.0 // indirect + golang.org/x/text v0.16.0 // indirect + golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect + google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) retract [v2.0.0, v2.0.2] diff --git a/wrappers/pkcs11/go.sum b/wrappers/pkcs11/go.sum index c4a4124d..b8f0c38d 100644 --- a/wrappers/pkcs11/go.sum +++ b/wrappers/pkcs11/go.sum @@ -1,21 +1,103 @@ +github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= +github.com/cenkalti/backoff/v3 v3.0.0 h1:ske+9nBpD9qZsTBoF41nW5L+AIuFBKMeze18XQ3eG1c= +github.com/cenkalti/backoff/v3 v3.0.0/go.mod h1:cIeZDE3IrqwwJl6VUwCN6trj1oXrTS4rc0ij+ULvLYs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= +github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= +github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= +github.com/go-test/deep v1.0.2 h1:onZX1rnHT3Wv6cqNgYyFOOlgVKJrksuCMCRvJStbMYw= +github.com/go-test/deep v1.0.2/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= +github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= +github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= +github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= +github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6 h1:om4Al8Oy7kCm/B86rLCLah4Dt5Aa0Fr5rYBG60OzwHQ= +github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6/go.mod h1:QmrqtbKuxxSWTN3ETMPuB+VtEiBJ/A9XhoYGv8E1uD8= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.1/go.mod h1:gKOamz3EwoIoJq7mlMIRBpVTAUn8qPCrEclOKKWhD3U= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4= +github.com/hashicorp/go-sockaddr v1.0.2 h1:ztczhD1jLxIRjVejw8gFomI1BQZOe2WoVOu0SyteCQc= +github.com/hashicorp/go-sockaddr v1.0.2/go.mod h1:rB4wwRAUzs07qva3c5SdrY/NEtAUjGlgmH/UkBUC97A= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/openbao/openbao/api/v2 v2.0.1 h1:oyDqLa8m+XY3YBbgQ4YnX5o+/4/ybShiDPMC/7WomtE= +github.com/openbao/openbao/api/v2 v2.0.1/go.mod h1:qIp3G8D5vaW+r7TG2YoCCEo/5HxTvidwZA0GiwA1iJ8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= +github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/time v0.0.0-20220411224347-583f2d630306 h1:+gHMid33q6pen7kv9xvT+JRinntgeXO2AeZVd0AWD3w= +golang.org/x/time v0.0.0-20220411224347-583f2d630306/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/wrappers/pkcs11/options.go b/wrappers/pkcs11/options.go index 867f8779..fba5df26 100644 --- a/wrappers/pkcs11/options.go +++ b/wrappers/pkcs11/options.go @@ -97,57 +97,71 @@ func getDefaultOptions() options { } // WithSlot sets the slot -func WithSlot(slot string) OptionFunc { - return func(o *options) error { - o.withSlot = slot - return nil +func WithSlot(slot string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withSlot = slot + return nil + }) } } // WithSlot sets the slot -func WithTokenLabel(slot string) OptionFunc { - return func(o *options) error { - o.withTokenLabel = slot - return nil +func WithTokenLabel(slot string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withTokenLabel = slot + return nil + }) } } // WithPin sets the pin -func WithPin(pin string) OptionFunc { - return func(o *options) error { - o.withPin = pin - return nil +func WithPin(pin string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withPin = pin + return nil + }) } } // WithLib sets the module -func WithLib(lib string) OptionFunc { - return func(o *options) error { - o.withLib = lib - return nil +func WithLib(lib string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withLib = lib + return nil + }) } } // WithLabel sets the label -func WithKeyId(keyId string) OptionFunc { - return func(o *options) error { - o.withKeyId = keyId - return nil +func WithKeyId(keyId string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withKeyId = keyId + return nil + }) } } // WithLabel sets the label -func WithKeyLabel(label string) OptionFunc { - return func(o *options) error { - o.withKeyLabel = label - return nil +func WithKeyLabel(label string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withKeyLabel = label + return nil + }) } } // WithMechanism sets the mechanism -func WithMechanism(mechanism string) OptionFunc { - return func(o *options) error { - o.withMechanism = mechanism - return nil +func WithMechanism(mechanism string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withMechanism = mechanism + return nil + }) } } diff --git a/wrappers/pkcs11/options_test.go b/wrappers/pkcs11/options_test.go new file mode 100644 index 00000000..fa9569cd --- /dev/null +++ b/wrappers/pkcs11/options_test.go @@ -0,0 +1,112 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pkcs11 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test_GetOpts provides unit tests for GetOpts and all the options +func Test_GetOpts(t *testing.T) { + t.Parallel() + t.Run("WithKeyId", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withKeyId = "" + assert.Equal(opts, testOpts) + + const with = "testKeyId" + opts, err = getOpts(WithKeyId(with)) + require.NoError(err) + testOpts.withKeyId = with + assert.Equal(opts, testOpts) + }) + t.Run("WithSlot", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withSlot = "" + assert.Equal(opts, testOpts) + + const with = "1024" + opts, err = getOpts(WithSlot(with)) + require.NoError(err) + testOpts.withSlot = with + assert.Equal(opts, testOpts) + }) + t.Run("WithPin", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withPin = "" + assert.Equal(opts, testOpts) + + const with = "000000" + opts, err = getOpts(WithPin(with)) + require.NoError(err) + testOpts.withPin = with + assert.Equal(opts, testOpts) + }) + t.Run("WithLib", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withLib = "" + assert.Equal(opts, testOpts) + + const with = "/usr/lib/pkcs11.so" + opts, err = getOpts(WithLib(with)) + require.NoError(err) + testOpts.withLib = with + assert.Equal(opts, testOpts) + }) + t.Run("WithTokenLabel", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withTokenLabel = "" + assert.Equal(opts, testOpts) + + const with = "labelTest" + opts, err = getOpts(WithTokenLabel(with)) + require.NoError(err) + testOpts.withTokenLabel = with + assert.Equal(opts, testOpts) + }) + t.Run("WithMechanism", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withMechanism = "" + assert.Equal(opts, testOpts) + + const with = "CKM_AES_GCM" + opts, err = getOpts(WithMechanism(with)) + require.NoError(err) + testOpts.withMechanism = with + assert.Equal(opts, testOpts) + }) +} diff --git a/wrappers/pkcs11/pkcs11.go b/wrappers/pkcs11/pkcs11.go index c648982e..3428e038 100644 --- a/wrappers/pkcs11/pkcs11.go +++ b/wrappers/pkcs11/pkcs11.go @@ -11,12 +11,6 @@ import ( wrapping "github.com/openbao/go-kms-wrapping/v2" ) -// These constants contain the accepted env vars; the Vault one is for backwards compat -const ( - EnvPkcs11WrapperKeyId = "PKCS11_WRAPPER_KEY_ID" - EnvVaultPkcs11SealKeyId = "VAULT_PKCS11_SEAL_KEY_ID" -) - // Wrapper is a Wrapper that uses PKCS11 type Wrapper struct { client pkcs11ClientEncryptor diff --git a/wrappers/pkcs11/pkcs11_acc_test.go b/wrappers/pkcs11/pkcs11_acc_test.go index c9335678..566ca661 100644 --- a/wrappers/pkcs11/pkcs11_acc_test.go +++ b/wrappers/pkcs11/pkcs11_acc_test.go @@ -14,13 +14,12 @@ import ( // but the KMS key used is generally not free. // // To run this test, the following env variables need to be set: -// - VAULT_PKCS11_SEAL_KEY_ID or PKCS11_WRAPPING_KEY_ID -// - PKCS11_WRAPPER_KEY_ID -// - PKCS11_SLOT -// - PKCS11_PIN -// - PKCS11_MODULE -// - PKCS11_LABEL -// - PKCS11_MECHANISM +// - BAO_HSM_SLOT +// - BAO_HSM_PIN +// - BAO_HSM_LIB +// - BAO_HSM_KEY_LABEL +// - BAO_HSM_KEY_ID +// - BAO_HSM_MECHANISM func TestAccPkcs11Wrapper_Lifecycle(t *testing.T) { if os.Getenv("VAULT_ACC") == "" && os.Getenv("KMS_ACC_TESTS") == "" { t.SkipNow() diff --git a/wrappers/pkcs11/pkcs11_client.go b/wrappers/pkcs11/pkcs11_client.go index 69a89447..f936c3b4 100644 --- a/wrappers/pkcs11/pkcs11_client.go +++ b/wrappers/pkcs11/pkcs11_client.go @@ -10,7 +10,7 @@ import ( "encoding/binary" "encoding/hex" - "github.com/openbao/openbao/api" + "github.com/openbao/openbao/api/v2" uuid "github.com/hashicorp/go-uuid" pkcs11 "github.com/miekg/pkcs11" wrapping "github.com/openbao/go-kms-wrapping/v2" @@ -550,4 +550,4 @@ func numberAutoParse(value string, bitSize int) (uint64, error) { ret, err = strconv.ParseUint(value, 10, bitSize) } return ret, err -} \ No newline at end of file +} From 66e5465df15c91b01fb9eba8ef320a27fba71b42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20GLATIGNY?= <14180893+glatigny@users.noreply.github.com> Date: Wed, 13 Nov 2024 13:50:58 +0100 Subject: [PATCH 08/10] [fix] fixing copyright MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme GLATIGNY <14180893+glatigny@users.noreply.github.com> --- wrappers/pkcs11/options.go | 1 + wrappers/pkcs11/options_test.go | 1 + wrappers/pkcs11/pkcs11.go | 2 +- wrappers/pkcs11/pkcs11_acc_test.go | 1 + wrappers/pkcs11/pkcs11_client.go | 2 +- 5 files changed, 5 insertions(+), 2 deletions(-) diff --git a/wrappers/pkcs11/options.go b/wrappers/pkcs11/options.go index fba5df26..2f865c54 100644 --- a/wrappers/pkcs11/options.go +++ b/wrappers/pkcs11/options.go @@ -1,3 +1,4 @@ +// Copyright (c) 2024 OpenBao a Series of LF Projects, LLC // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 diff --git a/wrappers/pkcs11/options_test.go b/wrappers/pkcs11/options_test.go index fa9569cd..5f8e827d 100644 --- a/wrappers/pkcs11/options_test.go +++ b/wrappers/pkcs11/options_test.go @@ -1,3 +1,4 @@ +// Copyright (c) 2024 OpenBao a Series of LF Projects, LLC // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 diff --git a/wrappers/pkcs11/pkcs11.go b/wrappers/pkcs11/pkcs11.go index 3428e038..3bf5f5a3 100644 --- a/wrappers/pkcs11/pkcs11.go +++ b/wrappers/pkcs11/pkcs11.go @@ -1,4 +1,4 @@ -// Copyright (c) HashiCorp, Inc. +// Copyright (c) 2024 OpenBao a Series of LF Projects, LLC // SPDX-License-Identifier: MPL-2.0 package pkcs11 diff --git a/wrappers/pkcs11/pkcs11_acc_test.go b/wrappers/pkcs11/pkcs11_acc_test.go index 566ca661..9cd91ae6 100644 --- a/wrappers/pkcs11/pkcs11_acc_test.go +++ b/wrappers/pkcs11/pkcs11_acc_test.go @@ -1,3 +1,4 @@ +// Copyright (c) 2024 OpenBao a Series of LF Projects, LLC // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 diff --git a/wrappers/pkcs11/pkcs11_client.go b/wrappers/pkcs11/pkcs11_client.go index f936c3b4..667a12d2 100644 --- a/wrappers/pkcs11/pkcs11_client.go +++ b/wrappers/pkcs11/pkcs11_client.go @@ -1,4 +1,4 @@ -// Copyright (c) HashiCorp, Inc. +// Copyright (c) 2024 OpenBao a Series of LF Projects, LLC // SPDX-License-Identifier: MPL-2.0 package pkcs11 From c2547bde1e698ac4cafa0e9a02ab00fc36a40b2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20GLATIGNY?= <14180893+glatigny@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:51:18 +0100 Subject: [PATCH 09/10] [fix] fix support of AES GCM and RSA OAEP. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add rsa_oaep_hash parameter. * New functions for the creation of PKCS11 encrypt/decrypt parameters. * FindKey now return only one key handle. Signed-off-by: Jérôme GLATIGNY <14180893+glatigny@users.noreply.github.com> --- wrappers/pkcs11/options.go | 13 ++ wrappers/pkcs11/pkcs11.go | 7 +- wrappers/pkcs11/pkcs11_client.go | 245 ++++++++++++++++++++----------- 3 files changed, 176 insertions(+), 89 deletions(-) diff --git a/wrappers/pkcs11/options.go b/wrappers/pkcs11/options.go index 2f865c54..43fa0d1e 100644 --- a/wrappers/pkcs11/options.go +++ b/wrappers/pkcs11/options.go @@ -60,6 +60,8 @@ func getOpts(opt ...wrapping.Option) (*options, error) { opts.withKeyLabel = v case "mechanism": opts.withMechanism = v + case "rsa_oaep_hash": + opts.withRsaOaepHash = v } } } @@ -91,6 +93,7 @@ type options struct { withKeyLabel string withTokenLabel string withMechanism string + withRsaOaepHash string } func getDefaultOptions() options { @@ -166,3 +169,13 @@ func WithMechanism(mechanism string) wrapping.Option { }) } } + +// WithRsaOaepHash sets the RSA OAEP Hash mechanism +func WithRsaOaepHash(hashMechanisme string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withRsaOaepHash = hashMechanisme + return nil + }) + } +} \ No newline at end of file diff --git a/wrappers/pkcs11/pkcs11.go b/wrappers/pkcs11/pkcs11.go index 3bf5f5a3..91db3cc4 100644 --- a/wrappers/pkcs11/pkcs11.go +++ b/wrappers/pkcs11/pkcs11.go @@ -43,6 +43,7 @@ func (k *Wrapper) Finalize(_ context.Context) error { // SetConfig processes the config info from the server config func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrapping.WrapperConfig, error) { + // Option validation is performed by newPkcs11Client(...). opts, err := getOpts(opt...) if err != nil { return nil, err @@ -55,12 +56,6 @@ func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrappin k.client = client k.keyId = client.GetCurrentKey().String() - // Send a value to test the wrapper and to set the current key id - if _, err := k.Encrypt(context.Background(), []byte("a")); err != nil { - client.Close() - return nil, err - } - return wrapConfig, nil } diff --git a/wrappers/pkcs11/pkcs11_client.go b/wrappers/pkcs11/pkcs11_client.go index 667a12d2..09696251 100644 --- a/wrappers/pkcs11/pkcs11_client.go +++ b/wrappers/pkcs11/pkcs11_client.go @@ -59,20 +59,22 @@ type Pkcs11Client struct { keyLabel string keyId string mechanism uint + rsaOaepHash string } const ( - EnvHsmWrapperLib = "BAO_HSM_LIB" - EnvHsmWrapperSlot = "BAO_HSM_SLOT" - EnvHsmWrapperTokenLabel = "BAO_HSM_TOKEN_LABEL" - EnvHsmWrapperPin = "BAO_HSM_PIN" - EnvHsmWrapperKeyLabel = "BAO_HSM_KEY_LABEL" - EnvHsmWrapperKeyId = "BAO_HSM_KEY_ID" - EnvHsmWrapperMechanism = "BAO_HSM_MECHANISM" + EnvHsmWrapperLib = "BAO_HSM_LIB" + EnvHsmWrapperSlot = "BAO_HSM_SLOT" + EnvHsmWrapperTokenLabel = "BAO_HSM_TOKEN_LABEL" + EnvHsmWrapperPin = "BAO_HSM_PIN" + EnvHsmWrapperKeyLabel = "BAO_HSM_KEY_LABEL" + EnvHsmWrapperKeyId = "BAO_HSM_KEY_ID" + EnvHsmWrapperMechanism = "BAO_HSM_MECHANISM" + EnvHsmWrapperRsaOaepHash = "BAO_HSM_RSA_OAEP_HASH" ) func newPkcs11Client(opts *options) (*Pkcs11Client, *wrapping.WrapperConfig, error) { - var lib, slot, keyId, tokenLabel, pin, keyLabel, mechanism string + var lib, slot, keyId, tokenLabel, pin, keyLabel, mechanism, rsaOaepHash string var slotNum, mechanismNum uint64 var err error @@ -115,6 +117,10 @@ func newPkcs11Client(opts *options) (*Pkcs11Client, *wrapping.WrapperConfig, err default: keyId = "" } + // Remove the 0x prefix. + if strings.HasPrefix(keyId, "0x") { + keyId = keyId[2:] + } switch { case api.ReadBaoVariable(EnvHsmWrapperPin) != "" && !opts.Options.WithDisallowEnvVars: @@ -143,6 +149,15 @@ func newPkcs11Client(opts *options) (*Pkcs11Client, *wrapping.WrapperConfig, err mechanism = "" } + switch { + case api.ReadBaoVariable(EnvHsmWrapperRsaOaepHash) != "" && !opts.Options.WithDisallowEnvVars: + rsaOaepHash = strings.ToLower(api.ReadBaoVariable(EnvHsmWrapperRsaOaepHash)) + case opts.withRsaOaepHash != "": + rsaOaepHash = strings.ToLower(opts.withRsaOaepHash) + default: + rsaOaepHash = "" + } + if slot != "" { if slotNum, err = numberAutoParse(slot, 32); err != nil { return nil, nil, fmt.Errorf("Invalid slot number") @@ -160,18 +175,19 @@ func newPkcs11Client(opts *options) (*Pkcs11Client, *wrapping.WrapperConfig, err } client := &Pkcs11Client{ - client: nil, - lib: lib, - slot: uint(slotNum), - pin: pin, - tokenLabel: tokenLabel, - keyId: keyId, - keyLabel: keyLabel, - mechanism: uint(mechanismNum), + client: nil, + lib: lib, + slot: uint(slotNum), + pin: pin, + tokenLabel: tokenLabel, + keyId: keyId, + keyLabel: keyLabel, + mechanism: uint(mechanismNum), + rsaOaepHash: rsaOaepHash, } // Initialize the client - _, err = client.GetClient() + err = client.InitializeClient() if err != nil { return nil, nil, err } @@ -196,6 +212,9 @@ func newPkcs11Client(opts *options) (*Pkcs11Client, *wrapping.WrapperConfig, err if mechanismNum != 0 { wrapConfig.Metadata["mechanism"] = MechanismString(uint(mechanismNum)) } + if rsaOaepHash != "" { + wrapConfig.Metadata["rsa_oaep_hash"] = rsaOaepHash + } return client, wrapConfig, nil } @@ -218,30 +237,16 @@ func (c *Pkcs11Client) Encrypt(plaintext []byte) ([]byte, *Pkcs11Key, error) { defer c.CloseSession(session) keyId := Pkcs11Key{ label: c.keyLabel, id: c.keyId } - obj, err := c.FindKey(session, keyId, pkcs11.CKA_ENCRYPT) + key, err := c.FindKey(session, keyId, pkcs11.CKA_ENCRYPT) if err != nil { return nil, nil, err } - if len(obj) != 1 { - return nil, nil, fmt.Errorf("expected 1 object, got %d", len(obj)) - } - key := obj[0] - mechanism, err := c.GetKeyMechanism(session, key) + prefix, params, err := c.CreatePkcsParamsForKey(session, key) if err != nil { return nil, nil, err } - - var iv []byte - needIV, ivLength := IsIvNeeded(mechanism) - if needIV && ivLength > 0 { - iv, err = uuid.GenerateRandomBytes(ivLength) - if err != nil { - return nil, nil, err - } - } - - if err = c.client.EncryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { + if err = c.client.EncryptInit(session, []*pkcs11.Mechanism{params}, key); err != nil { return nil, nil, fmt.Errorf("failed to pkcs11 EncryptInit: %s", err) } var ciphertext []byte @@ -249,8 +254,8 @@ func (c *Pkcs11Client) Encrypt(plaintext []byte) ([]byte, *Pkcs11Key, error) { return nil, nil, fmt.Errorf("failed to pkcs11 Encrypt: %s", err) } - if iv != nil { - return append(iv, ciphertext...), &keyId, nil + if prefix != nil { + return append(prefix, ciphertext...), &keyId, nil } else { return ciphertext, &keyId, nil } @@ -267,33 +272,30 @@ func (c *Pkcs11Client) Decrypt(ciphertext []byte, keyId *Pkcs11Key) ([]byte, err keyId = &Pkcs11Key{ label: c.keyLabel, id: c.keyId } } - obj, err := c.FindKey(session, *keyId, pkcs11.CKA_DECRYPT) + key, err := c.FindKey(session, *keyId, pkcs11.CKA_DECRYPT) if err != nil { return nil, err } - if len(obj) != 1 { - return nil, fmt.Errorf("expected 1 object, got %d", len(obj)) - } - key := obj[0] - - mechanism, err := c.GetKeyMechanism(session, key) + var prefix []byte + mechanism, prefixSize, err := c.GetMechAndPrefixSizeForKey(session, key) if err != nil { return nil, err } - - var iv []byte - needIV, ivLength := IsIvNeeded(mechanism) - if needIV && ivLength > 0 { - if len(ciphertext) < ivLength { - return nil, fmt.Errorf("encrypted DEK is too short") + if prefixSize > 0 { + if len(ciphertext) < prefixSize { + return nil, fmt.Errorf("encrypted content is too short") } - iv = ciphertext[:ivLength] - ciphertext = ciphertext[ivLength:] + prefix = ciphertext[:prefixSize] + ciphertext = ciphertext[prefixSize:] + } + params, err := c.GetPkcsParamsFromPrefix(session, key, mechanism, prefix) + if err != nil { + return nil, err } - if err = c.client.DecryptInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, iv)}, key); err != nil { + if err = c.client.DecryptInit(session, []*pkcs11.Mechanism{params}, key); err != nil { return nil, fmt.Errorf("failed to pkcs11 DecryptInit: %s", err) } @@ -305,17 +307,17 @@ func (c *Pkcs11Client) Decrypt(ciphertext []byte, keyId *Pkcs11Key) ([]byte, err } // Create a PKCS11 client for the configured module. -func (c *Pkcs11Client) GetClient() (*pkcs11.Ctx, error) { +func (c *Pkcs11Client) InitializeClient() (error) { if c.client != nil { - return c.client, nil + return nil } c.client = pkcs11.New(c.lib) err := c.client.Initialize() if err != nil { c.client = nil - return nil, fmt.Errorf("failed to initialize PKCS11: %w", err) + return fmt.Errorf("failed to initialize PKCS11: %w", err) } - return c.client, nil + return nil } func (c *Pkcs11Client) GetSlotForLabel() (uint, error) { @@ -371,8 +373,8 @@ func (c *Pkcs11Client) CloseSession(session pkcs11.SessionHandle) { c.client.CloseSession(session) } -// -func (c *Pkcs11Client) FindKey(session pkcs11.SessionHandle, key Pkcs11Key, typ uint) ([]pkcs11.ObjectHandle, error) { +// Find on key for the given Label, ID and Mechanism. +func (c *Pkcs11Client) FindKey(session pkcs11.SessionHandle, key Pkcs11Key, typ uint) (pkcs11.ObjectHandle, error) { template := []*pkcs11.Attribute{ pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(key.label)), pkcs11.NewAttribute(typ, true), @@ -383,23 +385,29 @@ func (c *Pkcs11Client) FindKey(session pkcs11.SessionHandle, key Pkcs11Key, typ if c.mechanism != 0 { keyTypeString, err := GetKeyTypeFromMech(c.mechanism) if err != nil { - return nil, fmt.Errorf("failed to get key type from mechanism: %s", err) + return 0, fmt.Errorf("failed to get key type from mechanism: %s", err) } template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, keyTypeString)) } if err := c.client.FindObjectsInit(session, template); err != nil { - return nil, fmt.Errorf("failed to pkcs11 FindObjectsInit: %s", err) + return 0, fmt.Errorf("failed to pkcs11 FindObjectsInit: %s", err) } obj, _, err := c.client.FindObjects(session, 2) if err != nil { - return nil, fmt.Errorf("failed to pkcs11 FindObjects: %s", err) + return 0, fmt.Errorf("failed to pkcs11 FindObjects: %s", err) } if err := c.client.FindObjectsFinal(session); err != nil { - return nil, fmt.Errorf("failed to pkcs11 FindObjectsFinal: %s", err) + return 0, fmt.Errorf("failed to pkcs11 FindObjectsFinal: %s", err) + } + if len(obj) == 0 { + return 0, fmt.Errorf("no key found for the label: %s", key.label) + } + if len(obj) != 1 { + return 0, fmt.Errorf("got more than 1 key for the label: %s", key.label) } - return obj, nil + return obj[0], nil } func (c *Pkcs11Client) GetKeyMechanism(session pkcs11.SessionHandle, key pkcs11.ObjectHandle) (uint, error) { @@ -435,21 +443,71 @@ func (c *Pkcs11Client) GetKeyMechanism(session pkcs11.SessionHandle, key pkcs11. return mechanism, nil } -func (c *Pkcs11Client) GetCurrentKey() Pkcs11Key { - return Pkcs11Key{ - label: c.keyLabel, - id: c.keyId, +func (c *Pkcs11Client) GetMechAndPrefixSizeForKey(session pkcs11.SessionHandle, key pkcs11.ObjectHandle) (uint, int, error) { + mechanism, err := c.GetKeyMechanism(session, key) + if err != nil { + return 0, 0, err + } + switch mechanism { + case pkcs11.CKM_AES_GCM: + return mechanism, 16, nil + // Deprecated + case pkcs11.CKM_AES_CBC_PAD: + return mechanism, 16, nil + case pkcs11.CKM_AES_CBC: + return mechanism, 16, nil + // Consider others with no prefix + default: + return mechanism, 0, nil } } -func IsIvNeeded(mech uint) (bool, int) { - switch mech { +func (c *Pkcs11Client) CreatePkcsParamsForKey(session pkcs11.SessionHandle, key pkcs11.ObjectHandle) ([]byte, *pkcs11.Mechanism, error) { + mechanismNum, prefixSize, err := c.GetMechAndPrefixSizeForKey(session, key) + if err != nil { + return nil, nil, err + } + + // Prefix is the IV or Nonce. + prefix, err := uuid.GenerateRandomBytes(prefixSize) + if err != nil { + return nil, nil, err + } + + mechanism, err := c.GetPkcsParamsFromPrefix(session, key, mechanismNum, prefix) + if err != nil { + return nil, nil, err + } + return prefix, mechanism, nil +} + +func (c *Pkcs11Client) GetPkcsParamsFromPrefix(session pkcs11.SessionHandle, key pkcs11.ObjectHandle, mechanism uint, prefix []byte) (*pkcs11.Mechanism, error) { + var params interface{} + switch mechanism { case pkcs11.CKM_AES_GCM: - return true, 16 - case pkcs11.CKM_AES_CBC_PAD: - return true, 16 + params = pkcs11.NewGCMParams(prefix, nil, 128) + case pkcs11.CKM_RSA_PKCS_OAEP: + var rsaOaepHash string + if c.rsaOaepHash != "" { + rsaOaepHash = c.rsaOaepHash + } else { + rsaOaepHash = "sha256" + } + hash, mgf_hash, err := RsaHashMechFromString(rsaOaepHash) + if err != nil { + return nil, err + } + params = pkcs11.NewOAEPParams(hash, mgf_hash, pkcs11.CKZ_DATA_SPECIFIED, nil) default: - return false, 0 + params = prefix + } + return pkcs11.NewMechanism(mechanism, params), nil +} + +func (c *Pkcs11Client) GetCurrentKey() Pkcs11Key { + return Pkcs11Key{ + label: c.keyLabel, + id: c.keyId, } } @@ -457,12 +515,12 @@ func GetKeyTypeFromMech(mech uint) (uint, error) { switch mech { case pkcs11.CKM_RSA_PKCS_OAEP: return pkcs11.CKK_RSA, nil - case pkcs11.CKM_RSA_PKCS: - return pkcs11.CKK_RSA, nil case pkcs11.CKM_AES_GCM: return pkcs11.CKK_AES, nil - case pkcs11.CKM_AES_CBC_PAD: - return pkcs11.CKK_AES, nil + // Deprecated mechanisms + case pkcs11.CKM_RSA_PKCS, pkcs11.CKM_AES_CBC, pkcs11.CKM_AES_CBC_PAD: + return 0, fmt.Errorf("deprecated mechanism: %s (%d)", MechanismString(mech), mech) + // Other are unsupported default: return 0, fmt.Errorf("unsupported mechanism: %d", mech) } @@ -472,10 +530,13 @@ func MechanismString(mech uint) string { switch mech { case pkcs11.CKM_RSA_PKCS_OAEP: return "CKM_RSA_PKCS_OAEP" - case pkcs11.CKM_RSA_PKCS: - return "CKM_RSA_PKCS" case pkcs11.CKM_AES_GCM: return "CKM_AES_GCM" + // Deprecated mechanisms + case pkcs11.CKM_RSA_PKCS: + return "CKM_RSA_PKCS" + case pkcs11.CKM_AES_CBC: + return "CKM_AES_CBC" case pkcs11.CKM_AES_CBC_PAD: return "CKM_AES_CBC_PAD" default: @@ -487,13 +548,14 @@ func MechanismFromString(mech string) (uint64, error) { switch mech { case "CKM_RSA_PKCS_OAEP", "RSA_PKCS_OAEP": return pkcs11.CKM_RSA_PKCS_OAEP, nil - case "CKM_RSA_PKCS", "RSA_PKCS": - return pkcs11.CKM_RSA_PKCS, nil case "CKM_AES_GCM", "AES_GCM": return pkcs11.CKM_AES_GCM, nil - case "CKM_AES_CBC_PAD", "AES_CBC_PAD": - return pkcs11.CKM_AES_CBC_PAD, nil + // Deprecated mechanisms + case "CKM_RSA_PKCS", "RSA_PKCS", "CKM_AES_CBC_PAD", "AES_CBC_PAD": + return 0, fmt.Errorf("deprecated mechanism: %s", mech) + // Other mechanisms default: + // Try to extract the mechanism PKCS11 raw value. if mechanismNum, err := numberAutoParse(mech, 32); err == nil { if _, err = GetKeyTypeFromMech(uint(mechanismNum)); err == nil { return mechanismNum, nil @@ -503,6 +565,23 @@ func MechanismFromString(mech string) (uint64, error) { } } +func RsaHashMechFromString(mech string) (uint, uint, error) { + mech = strings.ToLower(mech) + switch mech { + case "sha1": + return pkcs11.CKM_SHA_1, pkcs11.CKG_MGF1_SHA1, nil + case "sha224": + return pkcs11.CKM_SHA224, pkcs11.CKG_MGF1_SHA224, nil + case "sha256": + return pkcs11.CKM_SHA256, pkcs11.CKG_MGF1_SHA256, nil + case "sha384": + return pkcs11.CKM_SHA384, pkcs11.CKG_MGF1_SHA384, nil + case "sha512": + return pkcs11.CKM_SHA512, pkcs11.CKG_MGF1_SHA512, nil + default: + return 0, 0, fmt.Errorf("unsupported mechanism: %s", mech) + } +} func GetAttributesMap(attrs []*pkcs11.Attribute) map[uint][]byte { m := make(map[uint][]byte, len(attrs)) From cb60aeafb9e7d7ad02e0d5739e6866fe9a95aade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20GLATIGNY?= <14180893+glatigny@users.noreply.github.com> Date: Wed, 4 Dec 2024 10:15:32 +0100 Subject: [PATCH 10/10] [refactor] handle GCM specific behavior for parameters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Do not concat the IV with the encrypted content anymore, use the BlobInfo instead. * Split AES GCM and RSA OAEP into different functions. * Handle Specific cases for GCM (some HSM don't read the IV/nonce and generate it instead). * Use the PKCS#11 random generator and propose an interface for it. * Move constants into a dedicated section. Signed-off-by: Jérôme GLATIGNY <14180893+glatigny@users.noreply.github.com> --- wrappers/pkcs11/pkcs11.go | 5 +- wrappers/pkcs11/pkcs11_client.go | 220 ++++++++++++++++++------------- 2 files changed, 131 insertions(+), 94 deletions(-) diff --git a/wrappers/pkcs11/pkcs11.go b/wrappers/pkcs11/pkcs11.go index 91db3cc4..396158a1 100644 --- a/wrappers/pkcs11/pkcs11.go +++ b/wrappers/pkcs11/pkcs11.go @@ -73,7 +73,7 @@ func (k *Wrapper) KeyId(_ context.Context) (string, error) { // This returns the ciphertext, and/or any errors from this // call. This should be called after the KMS client has been instantiated. func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.Option) (*wrapping.BlobInfo, error) { - ciphertext, key, err := k.client.Encrypt(plaintext) + ciphertext, iv, key, err := k.client.Encrypt(plaintext) if err != nil { return nil, err } @@ -83,6 +83,7 @@ func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.O ret := &wrapping.BlobInfo{ Ciphertext: ciphertext, + Iv: iv, KeyInfo: &wrapping.KeyInfo{ KeyId: keyId, }, @@ -105,7 +106,7 @@ func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapp if err != nil { return nil, err } - plaintext, err := k.client.Decrypt(in.Ciphertext, keyId) + plaintext, err := k.client.Decrypt(in.Ciphertext, in.Iv, keyId) if err != nil { return nil, err } diff --git a/wrappers/pkcs11/pkcs11_client.go b/wrappers/pkcs11/pkcs11_client.go index 09696251..9820c148 100644 --- a/wrappers/pkcs11/pkcs11_client.go +++ b/wrappers/pkcs11/pkcs11_client.go @@ -11,7 +11,6 @@ import ( "encoding/hex" "github.com/openbao/openbao/api/v2" - uuid "github.com/hashicorp/go-uuid" pkcs11 "github.com/miekg/pkcs11" wrapping "github.com/openbao/go-kms-wrapping/v2" ) @@ -46,8 +45,9 @@ func (k Pkcs11Key) Set(v string) error { type pkcs11ClientEncryptor interface { Close() - Encrypt(plaintext []byte) (ciphertext []byte, keyId *Pkcs11Key, err error) - Decrypt(ciphertext []byte, keyId *Pkcs11Key) (plaintext []byte, err error) + GenerateRandom(length int) ([]byte, error) + Encrypt(plaintext []byte) (ciphertext []byte, nonce []byte, keyId *Pkcs11Key, err error) + Decrypt(ciphertext []byte, nonce []byte, keyId *Pkcs11Key) (plaintext []byte, err error) } type Pkcs11Client struct { @@ -72,6 +72,14 @@ const ( EnvHsmWrapperMechanism = "BAO_HSM_MECHANISM" EnvHsmWrapperRsaOaepHash = "BAO_HSM_RSA_OAEP_HASH" ) +const ( + DefaultAesMechanism = pkcs11.CKM_AES_GCM + DefaultRsaMechanism = pkcs11.CKM_RSA_PKCS_OAEP + DefaultRsaOaepHash = "sha256" + + CryptoAesGcmNonceSize = 12 + CryptoAesGcmOverhead = 16 +) func newPkcs11Client(opts *options) (*Pkcs11Client, *wrapping.WrapperConfig, error) { var lib, slot, keyId, tokenLabel, pin, keyLabel, mechanism, rsaOaepHash string @@ -185,7 +193,7 @@ func newPkcs11Client(opts *options) (*Pkcs11Client, *wrapping.WrapperConfig, err mechanism: uint(mechanismNum), rsaOaepHash: rsaOaepHash, } - + // Initialize the client err = client.InitializeClient() if err != nil { @@ -229,39 +237,101 @@ func (c *Pkcs11Client) Close() { c.client = nil } -func (c *Pkcs11Client) Encrypt(plaintext []byte) ([]byte, *Pkcs11Key, error) { +func (c *Pkcs11Client) GenerateRandom(length int) ([]byte, error) { session, err := c.GetSession() if err != nil { - return nil, nil, err + return nil, err + } + defer c.CloseSession(session) + + return c.client.GenerateRandom(session, length) +} + +func (c *Pkcs11Client) Encrypt(plaintext []byte) ([]byte, []byte, *Pkcs11Key, error) { + session, err := c.GetSession() + if err != nil { + return nil, nil, nil, err } defer c.CloseSession(session) keyId := Pkcs11Key{ label: c.keyLabel, id: c.keyId } key, err := c.FindKey(session, keyId, pkcs11.CKA_ENCRYPT) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - prefix, params, err := c.CreatePkcsParamsForKey(session, key) + mechanism, err := c.GetKeyMechanism(session, key) if err != nil { - return nil, nil, err + return nil, nil, nil, err + } + + switch mechanism { + case pkcs11.CKM_AES_GCM: + return c.EncryptAesGcm(session, key, keyId, plaintext) + case pkcs11.CKM_RSA_PKCS_OAEP: + return c.EncryptRsaOaep(session, key, keyId, plaintext) } - if err = c.client.EncryptInit(session, []*pkcs11.Mechanism{params}, key); err != nil { - return nil, nil, fmt.Errorf("failed to pkcs11 EncryptInit: %s", err) + return nil, nil, nil, fmt.Errorf("unsupported mechanism") +} + +// Encryption for AES GCM algorithm +func (c *Pkcs11Client) EncryptAesGcm(session pkcs11.SessionHandle, key pkcs11.ObjectHandle, keyId Pkcs11Key, plaintext []byte) ([]byte, []byte, *Pkcs11Key, error) { + nonce, err := c.client.GenerateRandom(session, CryptoAesGcmNonceSize) + if err != nil { + return nil, nil, nil, err + } + + // Some HSM will ignore the given nonce and generate their own. + // That's why we need to free manually the GCM parameters. + params := pkcs11.NewGCMParams(nonce, nil, CryptoAesGcmOverhead*8) + defer params.Free() + + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_AES_GCM, params)} + + if err = c.client.EncryptInit(session, mech, key); err != nil { + return nil, nil, nil, fmt.Errorf("failed to pkcs11 EncryptInit: %s", err) } var ciphertext []byte if ciphertext, err = c.client.Encrypt(session, plaintext); err != nil { - return nil, nil, fmt.Errorf("failed to pkcs11 Encrypt: %s", err) + return nil, nil, nil, fmt.Errorf("failed to pkcs11 Encrypt: %s", err) } - if prefix != nil { - return append(prefix, ciphertext...), &keyId, nil + // Some HSM (CloudHSM) does not read the nonce/IV and generate its own. + // Since it's append, we need to extract it. + if len(ciphertext) == CryptoAesGcmNonceSize + len(plaintext) + CryptoAesGcmOverhead { + nonce = ciphertext[len(ciphertext)-CryptoAesGcmNonceSize:] + ciphertext = ciphertext[:len(ciphertext)-CryptoAesGcmNonceSize] + } + + return ciphertext, nonce, &keyId, nil +} + +func (c *Pkcs11Client) EncryptRsaOaep(session pkcs11.SessionHandle, key pkcs11.ObjectHandle, keyId Pkcs11Key, plaintext []byte) ([]byte, []byte, *Pkcs11Key, error) { + var rsaOaepHash string + if c.rsaOaepHash != "" { + rsaOaepHash = c.rsaOaepHash } else { - return ciphertext, &keyId, nil + rsaOaepHash = DefaultRsaOaepHash + } + hash, mgf_hash, err := RsaHashMechFromString(rsaOaepHash) + if err != nil { + return nil, nil, nil, err + } + params := pkcs11.NewOAEPParams(hash, mgf_hash, pkcs11.CKZ_DATA_SPECIFIED, nil) + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_OAEP, params)} + + if err = c.client.EncryptInit(session, mech, key); err != nil { + return nil, nil, nil, fmt.Errorf("failed to pkcs11 EncryptInit: %s", err) + } + var ciphertext []byte + if ciphertext, err = c.client.Encrypt(session, plaintext); err != nil { + return nil, nil, nil, fmt.Errorf("failed to pkcs11 Encrypt: %s", err) } + + return ciphertext, nil, &keyId, nil } -func (c *Pkcs11Client) Decrypt(ciphertext []byte, keyId *Pkcs11Key) ([]byte, error) { +func (c *Pkcs11Client) Decrypt(ciphertext []byte, nonce []byte, keyId *Pkcs11Key) ([]byte, error) { session, err := c.GetSession() if err != nil { return nil, err @@ -277,28 +347,55 @@ func (c *Pkcs11Client) Decrypt(ciphertext []byte, keyId *Pkcs11Key) ([]byte, err return nil, err } - var prefix []byte - mechanism, prefixSize, err := c.GetMechAndPrefixSizeForKey(session, key) + mechanism, err := c.GetKeyMechanism(session, key) if err != nil { return nil, err } - if prefixSize > 0 { - if len(ciphertext) < prefixSize { - return nil, fmt.Errorf("encrypted content is too short") - } - prefix = ciphertext[:prefixSize] - ciphertext = ciphertext[prefixSize:] + switch mechanism { + case pkcs11.CKM_AES_GCM: + return c.DecryptAesGcm(session, key, nonce, ciphertext) + case pkcs11.CKM_RSA_PKCS_OAEP: + return c.DecryptRsaOaep(session, key, nonce, ciphertext) + } + return nil, fmt.Errorf("unsupported mechanism") +} + +func (c *Pkcs11Client) DecryptAesGcm(session pkcs11.SessionHandle, key pkcs11.ObjectHandle, nonce []byte, ciphertext []byte) ([]byte, error) { + params := pkcs11.NewGCMParams(nonce, nil, CryptoAesGcmOverhead*8) + defer params.Free() + + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_AES_GCM, params)} + + var err error + if err = c.client.DecryptInit(session, mech, key); err != nil { + return nil, fmt.Errorf("failed to pkcs11 DecryptInit: %s", err) + } + var decrypted []byte + if decrypted, err = c.client.Decrypt(session, ciphertext); err != nil { + return nil, fmt.Errorf("failed to pkcs11 Decrypt: %s", err) } - params, err := c.GetPkcsParamsFromPrefix(session, key, mechanism, prefix) + return decrypted, nil +} + +func (c *Pkcs11Client) DecryptRsaOaep(session pkcs11.SessionHandle, key pkcs11.ObjectHandle, _ []byte, ciphertext []byte) ([]byte, error) { + var rsaOaepHash string + if c.rsaOaepHash != "" { + rsaOaepHash = c.rsaOaepHash + } else { + rsaOaepHash = DefaultRsaOaepHash + } + hash, mgf_hash, err := RsaHashMechFromString(rsaOaepHash) if err != nil { return nil, err } + params := pkcs11.NewOAEPParams(hash, mgf_hash, pkcs11.CKZ_DATA_SPECIFIED, nil) + + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_OAEP, params)} - if err = c.client.DecryptInit(session, []*pkcs11.Mechanism{params}, key); err != nil { + if err = c.client.DecryptInit(session, mech, key); err != nil { return nil, fmt.Errorf("failed to pkcs11 DecryptInit: %s", err) } - var decrypted []byte if decrypted, err = c.client.Decrypt(session, ciphertext); err != nil { return nil, fmt.Errorf("failed to pkcs11 Decrypt: %s", err) @@ -383,11 +480,11 @@ func (c *Pkcs11Client) FindKey(session pkcs11.SessionHandle, key Pkcs11Key, typ template = append(template, pkcs11.NewAttribute(pkcs11.CKA_ID, keyIdBytes)) } if c.mechanism != 0 { - keyTypeString, err := GetKeyTypeFromMech(c.mechanism) + keyType, err := GetKeyTypeFromMech(c.mechanism) if err != nil { return 0, fmt.Errorf("failed to get key type from mechanism: %s", err) } - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, keyTypeString)) + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, keyType)) } if err := c.client.FindObjectsInit(session, template); err != nil { @@ -428,13 +525,13 @@ func (c *Pkcs11Client) GetKeyMechanism(session pkcs11.SessionHandle, key pkcs11. if c.mechanism != 0 { mechanism = c.mechanism } else { - mechanism = pkcs11.CKM_AES_GCM + mechanism = DefaultAesMechanism } case pkcs11.CKK_RSA: if c.mechanism != 0 { mechanism = c.mechanism } else { - mechanism = pkcs11.CKM_RSA_PKCS_OAEP + mechanism = DefaultRsaMechanism } default: return 0, fmt.Errorf("unsupported key type: %d", keyType) @@ -443,67 +540,6 @@ func (c *Pkcs11Client) GetKeyMechanism(session pkcs11.SessionHandle, key pkcs11. return mechanism, nil } -func (c *Pkcs11Client) GetMechAndPrefixSizeForKey(session pkcs11.SessionHandle, key pkcs11.ObjectHandle) (uint, int, error) { - mechanism, err := c.GetKeyMechanism(session, key) - if err != nil { - return 0, 0, err - } - switch mechanism { - case pkcs11.CKM_AES_GCM: - return mechanism, 16, nil - // Deprecated - case pkcs11.CKM_AES_CBC_PAD: - return mechanism, 16, nil - case pkcs11.CKM_AES_CBC: - return mechanism, 16, nil - // Consider others with no prefix - default: - return mechanism, 0, nil - } -} - -func (c *Pkcs11Client) CreatePkcsParamsForKey(session pkcs11.SessionHandle, key pkcs11.ObjectHandle) ([]byte, *pkcs11.Mechanism, error) { - mechanismNum, prefixSize, err := c.GetMechAndPrefixSizeForKey(session, key) - if err != nil { - return nil, nil, err - } - - // Prefix is the IV or Nonce. - prefix, err := uuid.GenerateRandomBytes(prefixSize) - if err != nil { - return nil, nil, err - } - - mechanism, err := c.GetPkcsParamsFromPrefix(session, key, mechanismNum, prefix) - if err != nil { - return nil, nil, err - } - return prefix, mechanism, nil -} - -func (c *Pkcs11Client) GetPkcsParamsFromPrefix(session pkcs11.SessionHandle, key pkcs11.ObjectHandle, mechanism uint, prefix []byte) (*pkcs11.Mechanism, error) { - var params interface{} - switch mechanism { - case pkcs11.CKM_AES_GCM: - params = pkcs11.NewGCMParams(prefix, nil, 128) - case pkcs11.CKM_RSA_PKCS_OAEP: - var rsaOaepHash string - if c.rsaOaepHash != "" { - rsaOaepHash = c.rsaOaepHash - } else { - rsaOaepHash = "sha256" - } - hash, mgf_hash, err := RsaHashMechFromString(rsaOaepHash) - if err != nil { - return nil, err - } - params = pkcs11.NewOAEPParams(hash, mgf_hash, pkcs11.CKZ_DATA_SPECIFIED, nil) - default: - params = prefix - } - return pkcs11.NewMechanism(mechanism, params), nil -} - func (c *Pkcs11Client) GetCurrentKey() Pkcs11Key { return Pkcs11Key{ label: c.keyLabel,