diff --git a/wrappers/pkcs11/go.mod b/wrappers/pkcs11/go.mod new file mode 100644 index 00000000..c906375c --- /dev/null +++ b/wrappers/pkcs11/go.mod @@ -0,0 +1,43 @@ +module github.com/openbao/go-kms-wrapping/wrappers/pkcs11/v2 + +go 1.22.1 + +replace github.com/openbao/go-kms-wrapping/v2 => ../../ + +require ( + github.com/hashicorp/go-uuid v1.0.3 + github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b + github.com/openbao/go-kms-wrapping/v2 v2.1.0 + github.com/openbao/openbao/api/v2 v2.0.1 + github.com/stretchr/testify v1.8.4 +) + +require ( + github.com/cenkalti/backoff/v3 v3.0.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-jose/go-jose/v3 v3.0.1 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/hashicorp/go-retryablehttp v0.7.7 // indirect + github.com/hashicorp/go-rootcerts v1.0.2 // indirect + github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6 // indirect + github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect + github.com/hashicorp/go-sockaddr v1.0.2 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect + github.com/ryanuber/go-glob v1.0.0 // indirect + golang.org/x/crypto v0.24.0 // indirect + golang.org/x/net v0.26.0 // indirect + golang.org/x/text v0.16.0 // indirect + golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect + google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +retract [v2.0.0, v2.0.2] diff --git a/wrappers/pkcs11/go.sum b/wrappers/pkcs11/go.sum new file mode 100644 index 00000000..b8f0c38d --- /dev/null +++ b/wrappers/pkcs11/go.sum @@ -0,0 +1,103 @@ +github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= +github.com/cenkalti/backoff/v3 v3.0.0 h1:ske+9nBpD9qZsTBoF41nW5L+AIuFBKMeze18XQ3eG1c= +github.com/cenkalti/backoff/v3 v3.0.0/go.mod h1:cIeZDE3IrqwwJl6VUwCN6trj1oXrTS4rc0ij+ULvLYs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= +github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= +github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= +github.com/go-test/deep v1.0.2 h1:onZX1rnHT3Wv6cqNgYyFOOlgVKJrksuCMCRvJStbMYw= +github.com/go-test/deep v1.0.2/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= +github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= +github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= +github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= +github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6 h1:om4Al8Oy7kCm/B86rLCLah4Dt5Aa0Fr5rYBG60OzwHQ= +github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6/go.mod h1:QmrqtbKuxxSWTN3ETMPuB+VtEiBJ/A9XhoYGv8E1uD8= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.1/go.mod h1:gKOamz3EwoIoJq7mlMIRBpVTAUn8qPCrEclOKKWhD3U= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4= +github.com/hashicorp/go-sockaddr v1.0.2 h1:ztczhD1jLxIRjVejw8gFomI1BQZOe2WoVOu0SyteCQc= +github.com/hashicorp/go-sockaddr v1.0.2/go.mod h1:rB4wwRAUzs07qva3c5SdrY/NEtAUjGlgmH/UkBUC97A= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= +github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/openbao/openbao/api/v2 v2.0.1 h1:oyDqLa8m+XY3YBbgQ4YnX5o+/4/ybShiDPMC/7WomtE= +github.com/openbao/openbao/api/v2 v2.0.1/go.mod h1:qIp3G8D5vaW+r7TG2YoCCEo/5HxTvidwZA0GiwA1iJ8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= +github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/time v0.0.0-20220411224347-583f2d630306 h1:+gHMid33q6pen7kv9xvT+JRinntgeXO2AeZVd0AWD3w= +golang.org/x/time v0.0.0-20220411224347-583f2d630306/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/wrappers/pkcs11/options.go b/wrappers/pkcs11/options.go new file mode 100644 index 00000000..43fa0d1e --- /dev/null +++ b/wrappers/pkcs11/options.go @@ -0,0 +1,181 @@ +// Copyright (c) 2024 OpenBao a Series of LF Projects, LLC +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pkcs11 + +import ( + wrapping "github.com/openbao/go-kms-wrapping/v2" +) + +// getOpts iterates the inbound Options and returns a struct +func getOpts(opt ...wrapping.Option) (*options, error) { + // First, separate out options into local and global + opts := getDefaultOptions() + var wrappingOptions []wrapping.Option + var localOptions []OptionFunc + for _, o := range opt { + if o == nil { + continue + } + iface := o() + switch to := iface.(type) { + case wrapping.OptionFunc: + wrappingOptions = append(wrappingOptions, o) + case OptionFunc: + localOptions = append(localOptions, to) + } + } + + // Parse the global options + var err error + opts.Options, err = wrapping.GetOpts(wrappingOptions...) + if err != nil { + return nil, err + } + + // Don't ever return blank options + if opts.Options == nil { + opts.Options = new(wrapping.Options) + } + + // Local options can be provided either via the WithConfigMap field + // (for over the plugin barrier or embedding) or via local option functions + // (for embedding). First pull from the option. + if opts.WithConfigMap != nil { + for k, v := range opts.WithConfigMap { + switch k { + // case "key_id", "kms_key_id": // deprecated backend-specific value, set global + case "key_id": + opts.withKeyId = v + case "slot": + opts.withSlot = v + case "pin": + opts.withPin = v + case "lib", "module": + opts.withLib = v + case "token", "token_label": + opts.withTokenLabel = v + case "label", "key_label": + opts.withKeyLabel = v + case "mechanism": + opts.withMechanism = v + case "rsa_oaep_hash": + opts.withRsaOaepHash = v + } + } + } + + // Now run the local options functions. This may overwrite options set by + // the options above. + for _, o := range localOptions { + if o != nil { + if err := o(&opts); err != nil { + return nil, err + } + } + } + + return &opts, nil +} + +// OptionFunc holds a function with local options +type OptionFunc func(*options) error + +// options = how options are represented +type options struct { + *wrapping.Options + + withSlot string + withPin string + withLib string + withKeyId string + withKeyLabel string + withTokenLabel string + withMechanism string + withRsaOaepHash string +} + +func getDefaultOptions() options { + return options{} +} + +// WithSlot sets the slot +func WithSlot(slot string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withSlot = slot + return nil + }) + } +} + +// WithSlot sets the slot +func WithTokenLabel(slot string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withTokenLabel = slot + return nil + }) + } +} + +// WithPin sets the pin +func WithPin(pin string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withPin = pin + return nil + }) + } +} + +// WithLib sets the module +func WithLib(lib string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withLib = lib + return nil + }) + } +} + +// WithLabel sets the label +func WithKeyId(keyId string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withKeyId = keyId + return nil + }) + } +} + +// WithLabel sets the label +func WithKeyLabel(label string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withKeyLabel = label + return nil + }) + } +} + +// WithMechanism sets the mechanism +func WithMechanism(mechanism string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withMechanism = mechanism + return nil + }) + } +} + +// WithRsaOaepHash sets the RSA OAEP Hash mechanism +func WithRsaOaepHash(hashMechanisme string) wrapping.Option { + return func() interface{} { + return OptionFunc(func(o *options) error { + o.withRsaOaepHash = hashMechanisme + return nil + }) + } +} \ No newline at end of file diff --git a/wrappers/pkcs11/options_test.go b/wrappers/pkcs11/options_test.go new file mode 100644 index 00000000..5f8e827d --- /dev/null +++ b/wrappers/pkcs11/options_test.go @@ -0,0 +1,113 @@ +// Copyright (c) 2024 OpenBao a Series of LF Projects, LLC +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pkcs11 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test_GetOpts provides unit tests for GetOpts and all the options +func Test_GetOpts(t *testing.T) { + t.Parallel() + t.Run("WithKeyId", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withKeyId = "" + assert.Equal(opts, testOpts) + + const with = "testKeyId" + opts, err = getOpts(WithKeyId(with)) + require.NoError(err) + testOpts.withKeyId = with + assert.Equal(opts, testOpts) + }) + t.Run("WithSlot", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withSlot = "" + assert.Equal(opts, testOpts) + + const with = "1024" + opts, err = getOpts(WithSlot(with)) + require.NoError(err) + testOpts.withSlot = with + assert.Equal(opts, testOpts) + }) + t.Run("WithPin", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withPin = "" + assert.Equal(opts, testOpts) + + const with = "000000" + opts, err = getOpts(WithPin(with)) + require.NoError(err) + testOpts.withPin = with + assert.Equal(opts, testOpts) + }) + t.Run("WithLib", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withLib = "" + assert.Equal(opts, testOpts) + + const with = "/usr/lib/pkcs11.so" + opts, err = getOpts(WithLib(with)) + require.NoError(err) + testOpts.withLib = with + assert.Equal(opts, testOpts) + }) + t.Run("WithTokenLabel", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withTokenLabel = "" + assert.Equal(opts, testOpts) + + const with = "labelTest" + opts, err = getOpts(WithTokenLabel(with)) + require.NoError(err) + testOpts.withTokenLabel = with + assert.Equal(opts, testOpts) + }) + t.Run("WithMechanism", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // test default of 0 + opts, err := getOpts() + require.NoError(err) + testOpts, err := getOpts() + require.NoError(err) + testOpts.withMechanism = "" + assert.Equal(opts, testOpts) + + const with = "CKM_AES_GCM" + opts, err = getOpts(WithMechanism(with)) + require.NoError(err) + testOpts.withMechanism = with + assert.Equal(opts, testOpts) + }) +} diff --git a/wrappers/pkcs11/pkcs11.go b/wrappers/pkcs11/pkcs11.go new file mode 100644 index 00000000..396158a1 --- /dev/null +++ b/wrappers/pkcs11/pkcs11.go @@ -0,0 +1,119 @@ +// Copyright (c) 2024 OpenBao a Series of LF Projects, LLC +// SPDX-License-Identifier: MPL-2.0 + +package pkcs11 + +import ( + "context" + "fmt" + "sync/atomic" + + wrapping "github.com/openbao/go-kms-wrapping/v2" +) + +// Wrapper is a Wrapper that uses PKCS11 +type Wrapper struct { + client pkcs11ClientEncryptor + keyId string + currentKeyId *atomic.Value +} + +// Ensure that we are implementing Wrapper +var _ wrapping.Wrapper = (*Wrapper)(nil) + +// NewWrapper creates a new PKCS11 Wrapper +func NewWrapper() *Wrapper { + k := &Wrapper{ + currentKeyId: new(atomic.Value), + } + k.currentKeyId.Store("") + return k +} + +// Init is called during core.Initialize +func (k *Wrapper) Init(_ context.Context) error { + return nil +} + +// Finalize is called during shutdown +func (k *Wrapper) Finalize(_ context.Context) error { + k.client.Close() + return nil +} + +// SetConfig processes the config info from the server config +func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrapping.WrapperConfig, error) { + // Option validation is performed by newPkcs11Client(...). + opts, err := getOpts(opt...) + if err != nil { + return nil, err + } + + client, wrapConfig, err := newPkcs11Client(opts) + if err != nil { + return nil, err + } + k.client = client + k.keyId = client.GetCurrentKey().String() + + return wrapConfig, nil +} + +// Type returns the type for this particular wrapper implementation +func (k *Wrapper) Type(_ context.Context) (wrapping.WrapperType, error) { + return wrapping.WrapperTypePkcs11, nil +} + +// KeyId returns the last known key id +func (k *Wrapper) KeyId(_ context.Context) (string, error) { + return k.currentKeyId.Load().(string), nil +} + +// Encrypt is used to encrypt data using the the PKCS11 key. +// This returns the ciphertext, and/or any errors from this +// call. This should be called after the KMS client has been instantiated. +func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.Option) (*wrapping.BlobInfo, error) { + ciphertext, iv, key, err := k.client.Encrypt(plaintext) + if err != nil { + return nil, err + } + + keyId := key.String() + k.currentKeyId.Store(keyId) + + ret := &wrapping.BlobInfo{ + Ciphertext: ciphertext, + Iv: iv, + KeyInfo: &wrapping.KeyInfo{ + KeyId: keyId, + }, + } + return ret, nil +} + +// Decrypt is used to decrypt the ciphertext. This should be called after Init. +func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapping.Option) ([]byte, error) { + if in == nil { + return nil, fmt.Errorf("given input for decryption is nil") + } + + if in.KeyInfo == nil { + in.KeyInfo = &wrapping.KeyInfo{ + KeyId: k.keyId, + } + } + keyId, err := newPkcs11Key(in.KeyInfo.KeyId) + if err != nil { + return nil, err + } + plaintext, err := k.client.Decrypt(in.Ciphertext, in.Iv, keyId) + if err != nil { + return nil, err + } + return plaintext, nil +} + +// GetClient returns the pkcs11 Wrapper's pkcs11ClientEncryptor +func (k *Wrapper) GetClient() pkcs11ClientEncryptor { + return k.client +} \ No newline at end of file diff --git a/wrappers/pkcs11/pkcs11_acc_test.go b/wrappers/pkcs11/pkcs11_acc_test.go new file mode 100644 index 00000000..9cd91ae6 --- /dev/null +++ b/wrappers/pkcs11/pkcs11_acc_test.go @@ -0,0 +1,49 @@ +// Copyright (c) 2024 OpenBao a Series of LF Projects, LLC +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pkcs11 + +import ( + "context" + "os" + "reflect" + "testing" +) + +// This test executes real calls. The calls themselves should be free, +// but the KMS key used is generally not free. +// +// To run this test, the following env variables need to be set: +// - BAO_HSM_SLOT +// - BAO_HSM_PIN +// - BAO_HSM_LIB +// - BAO_HSM_KEY_LABEL +// - BAO_HSM_KEY_ID +// - BAO_HSM_MECHANISM +func TestAccPkcs11Wrapper_Lifecycle(t *testing.T) { + if os.Getenv("VAULT_ACC") == "" && os.Getenv("KMS_ACC_TESTS") == "" { + t.SkipNow() + } + + s := NewWrapper() + _, err := s.SetConfig(context.Background()) + if err != nil { + t.Fatalf("err : %s", err) + } + + input := []byte("foo") + swi, err := s.Encrypt(context.Background(), input) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + pt, err := s.Decrypt(context.Background(), swi) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + if !reflect.DeepEqual(input, pt) { + t.Fatalf("expected %s, got %s", input, pt) + } +} diff --git a/wrappers/pkcs11/pkcs11_client.go b/wrappers/pkcs11/pkcs11_client.go new file mode 100644 index 00000000..9820c148 --- /dev/null +++ b/wrappers/pkcs11/pkcs11_client.go @@ -0,0 +1,668 @@ +// Copyright (c) 2024 OpenBao a Series of LF Projects, LLC +// SPDX-License-Identifier: MPL-2.0 + +package pkcs11 + +import ( + "fmt" + "strconv" + "strings" + "encoding/binary" + "encoding/hex" + + "github.com/openbao/openbao/api/v2" + pkcs11 "github.com/miekg/pkcs11" + wrapping "github.com/openbao/go-kms-wrapping/v2" +) + +type Pkcs11Key struct { + label string + id string +} +func (k Pkcs11Key) String() string { + return fmt.Sprintf("%s:%s", k.label, k.id) +} +func newPkcs11Key(v string) (*Pkcs11Key, error) { + pos := strings.LastIndex(v, ":") + if pos <= 0 { + return nil, fmt.Errorf("Invalid key format") + } + k := &Pkcs11Key{ + label: v[:pos], + id: v[pos+1:], + } + return k, nil +} +func (k Pkcs11Key) Set(v string) error { + pos := strings.LastIndex(v, ":") + if pos <= 0 { + return fmt.Errorf("Invalid key format") + } + k.label = v[:pos] + k.id = v[pos+1:] + return nil +} + +type pkcs11ClientEncryptor interface { + Close() + GenerateRandom(length int) ([]byte, error) + Encrypt(plaintext []byte) (ciphertext []byte, nonce []byte, keyId *Pkcs11Key, err error) + Decrypt(ciphertext []byte, nonce []byte, keyId *Pkcs11Key) (plaintext []byte, err error) +} + +type Pkcs11Client struct { + client *pkcs11.Ctx + lib string + slot uint + tokenLabel string + pin string + keyLabel string + keyId string + mechanism uint + rsaOaepHash string +} + +const ( + EnvHsmWrapperLib = "BAO_HSM_LIB" + EnvHsmWrapperSlot = "BAO_HSM_SLOT" + EnvHsmWrapperTokenLabel = "BAO_HSM_TOKEN_LABEL" + EnvHsmWrapperPin = "BAO_HSM_PIN" + EnvHsmWrapperKeyLabel = "BAO_HSM_KEY_LABEL" + EnvHsmWrapperKeyId = "BAO_HSM_KEY_ID" + EnvHsmWrapperMechanism = "BAO_HSM_MECHANISM" + EnvHsmWrapperRsaOaepHash = "BAO_HSM_RSA_OAEP_HASH" +) +const ( + DefaultAesMechanism = pkcs11.CKM_AES_GCM + DefaultRsaMechanism = pkcs11.CKM_RSA_PKCS_OAEP + DefaultRsaOaepHash = "sha256" + + CryptoAesGcmNonceSize = 12 + CryptoAesGcmOverhead = 16 +) + +func newPkcs11Client(opts *options) (*Pkcs11Client, *wrapping.WrapperConfig, error) { + var lib, slot, keyId, tokenLabel, pin, keyLabel, mechanism, rsaOaepHash string + var slotNum, mechanismNum uint64 + var err error + + switch { + case api.ReadBaoVariable(EnvHsmWrapperLib) != "" && !opts.Options.WithDisallowEnvVars: + lib = api.ReadBaoVariable(EnvHsmWrapperLib) + case opts.withLib != "": + lib = opts.withLib + default: + return nil, nil, fmt.Errorf("lib is required") + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperSlot) != "" && !opts.Options.WithDisallowEnvVars: + slot = api.ReadBaoVariable(EnvHsmWrapperSlot) + case opts.withSlot != "": + slot = opts.withSlot + default: + slot = "" + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperTokenLabel) != "" && !opts.Options.WithDisallowEnvVars: + tokenLabel = api.ReadBaoVariable(EnvHsmWrapperTokenLabel) + case opts.withTokenLabel != "": + tokenLabel = opts.withTokenLabel + default: + tokenLabel = "" + } + + if slot == "" && tokenLabel == "" { + return nil, nil, fmt.Errorf("slot or token label required") + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperKeyId) != "" && !opts.Options.WithDisallowEnvVars: + keyId = api.ReadBaoVariable(EnvHsmWrapperKeyId) + case opts.withKeyId != "": + keyId = opts.withKeyId + default: + keyId = "" + } + // Remove the 0x prefix. + if strings.HasPrefix(keyId, "0x") { + keyId = keyId[2:] + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperPin) != "" && !opts.Options.WithDisallowEnvVars: + pin = api.ReadBaoVariable(EnvHsmWrapperPin) + case opts.withPin != "": + pin = opts.withPin + default: + return nil, nil, fmt.Errorf("pin is required") + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperKeyLabel) != "" && !opts.Options.WithDisallowEnvVars: + keyLabel = api.ReadBaoVariable(EnvHsmWrapperKeyLabel) + case opts.withKeyLabel != "": + keyLabel = opts.withKeyLabel + default: + return nil, nil, fmt.Errorf("key label is required") + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperMechanism) != "" && !opts.Options.WithDisallowEnvVars: + mechanism = api.ReadBaoVariable(EnvHsmWrapperMechanism) + case opts.withMechanism != "": + mechanism = opts.withMechanism + default: + mechanism = "" + } + + switch { + case api.ReadBaoVariable(EnvHsmWrapperRsaOaepHash) != "" && !opts.Options.WithDisallowEnvVars: + rsaOaepHash = strings.ToLower(api.ReadBaoVariable(EnvHsmWrapperRsaOaepHash)) + case opts.withRsaOaepHash != "": + rsaOaepHash = strings.ToLower(opts.withRsaOaepHash) + default: + rsaOaepHash = "" + } + + if slot != "" { + if slotNum, err = numberAutoParse(slot, 32); err != nil { + return nil, nil, fmt.Errorf("Invalid slot number") + } + } else { + slotNum = 0 + } + + if mechanism != "" { + if mechanismNum, err = MechanismFromString(mechanism); err != nil { + return nil, nil, err + } + } else { + mechanismNum = 0 + } + + client := &Pkcs11Client{ + client: nil, + lib: lib, + slot: uint(slotNum), + pin: pin, + tokenLabel: tokenLabel, + keyId: keyId, + keyLabel: keyLabel, + mechanism: uint(mechanismNum), + rsaOaepHash: rsaOaepHash, + } + + // Initialize the client + err = client.InitializeClient() + if err != nil { + return nil, nil, err + } + // Validate credentials for session establishment + session, err := client.GetSession() + if err != nil { + return nil, nil, err + } + defer client.CloseSession(session) + + wrapConfig := new(wrapping.WrapperConfig) + wrapConfig.Metadata = make(map[string]string) + wrapConfig.Metadata["lib"] = lib + wrapConfig.Metadata["key_label"] = keyLabel + wrapConfig.Metadata["key_id"] = keyId + if slotNum != 0 { + wrapConfig.Metadata["slot"] = strconv.Itoa(int(slotNum)) + } + if tokenLabel != "" { + wrapConfig.Metadata["token_label"] = tokenLabel + } + if mechanismNum != 0 { + wrapConfig.Metadata["mechanism"] = MechanismString(uint(mechanismNum)) + } + if rsaOaepHash != "" { + wrapConfig.Metadata["rsa_oaep_hash"] = rsaOaepHash + } + + return client, wrapConfig, nil +} + + +func (c *Pkcs11Client) Close() { + if c.client == nil { + return + } + c.client.Finalize() + c.client.Destroy() + c.client = nil +} + +func (c *Pkcs11Client) GenerateRandom(length int) ([]byte, error) { + session, err := c.GetSession() + if err != nil { + return nil, err + } + defer c.CloseSession(session) + + return c.client.GenerateRandom(session, length) +} + +func (c *Pkcs11Client) Encrypt(plaintext []byte) ([]byte, []byte, *Pkcs11Key, error) { + session, err := c.GetSession() + if err != nil { + return nil, nil, nil, err + } + defer c.CloseSession(session) + + keyId := Pkcs11Key{ label: c.keyLabel, id: c.keyId } + key, err := c.FindKey(session, keyId, pkcs11.CKA_ENCRYPT) + if err != nil { + return nil, nil, nil, err + } + + mechanism, err := c.GetKeyMechanism(session, key) + if err != nil { + return nil, nil, nil, err + } + + switch mechanism { + case pkcs11.CKM_AES_GCM: + return c.EncryptAesGcm(session, key, keyId, plaintext) + case pkcs11.CKM_RSA_PKCS_OAEP: + return c.EncryptRsaOaep(session, key, keyId, plaintext) + } + return nil, nil, nil, fmt.Errorf("unsupported mechanism") +} + +// Encryption for AES GCM algorithm +func (c *Pkcs11Client) EncryptAesGcm(session pkcs11.SessionHandle, key pkcs11.ObjectHandle, keyId Pkcs11Key, plaintext []byte) ([]byte, []byte, *Pkcs11Key, error) { + nonce, err := c.client.GenerateRandom(session, CryptoAesGcmNonceSize) + if err != nil { + return nil, nil, nil, err + } + + // Some HSM will ignore the given nonce and generate their own. + // That's why we need to free manually the GCM parameters. + params := pkcs11.NewGCMParams(nonce, nil, CryptoAesGcmOverhead*8) + defer params.Free() + + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_AES_GCM, params)} + + if err = c.client.EncryptInit(session, mech, key); err != nil { + return nil, nil, nil, fmt.Errorf("failed to pkcs11 EncryptInit: %s", err) + } + var ciphertext []byte + if ciphertext, err = c.client.Encrypt(session, plaintext); err != nil { + return nil, nil, nil, fmt.Errorf("failed to pkcs11 Encrypt: %s", err) + } + + // Some HSM (CloudHSM) does not read the nonce/IV and generate its own. + // Since it's append, we need to extract it. + if len(ciphertext) == CryptoAesGcmNonceSize + len(plaintext) + CryptoAesGcmOverhead { + nonce = ciphertext[len(ciphertext)-CryptoAesGcmNonceSize:] + ciphertext = ciphertext[:len(ciphertext)-CryptoAesGcmNonceSize] + } + + return ciphertext, nonce, &keyId, nil +} + +func (c *Pkcs11Client) EncryptRsaOaep(session pkcs11.SessionHandle, key pkcs11.ObjectHandle, keyId Pkcs11Key, plaintext []byte) ([]byte, []byte, *Pkcs11Key, error) { + var rsaOaepHash string + if c.rsaOaepHash != "" { + rsaOaepHash = c.rsaOaepHash + } else { + rsaOaepHash = DefaultRsaOaepHash + } + hash, mgf_hash, err := RsaHashMechFromString(rsaOaepHash) + if err != nil { + return nil, nil, nil, err + } + params := pkcs11.NewOAEPParams(hash, mgf_hash, pkcs11.CKZ_DATA_SPECIFIED, nil) + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_OAEP, params)} + + if err = c.client.EncryptInit(session, mech, key); err != nil { + return nil, nil, nil, fmt.Errorf("failed to pkcs11 EncryptInit: %s", err) + } + var ciphertext []byte + if ciphertext, err = c.client.Encrypt(session, plaintext); err != nil { + return nil, nil, nil, fmt.Errorf("failed to pkcs11 Encrypt: %s", err) + } + + return ciphertext, nil, &keyId, nil +} + +func (c *Pkcs11Client) Decrypt(ciphertext []byte, nonce []byte, keyId *Pkcs11Key) ([]byte, error) { + session, err := c.GetSession() + if err != nil { + return nil, err + } + defer c.CloseSession(session) + + if keyId == nil { + keyId = &Pkcs11Key{ label: c.keyLabel, id: c.keyId } + } + + key, err := c.FindKey(session, *keyId, pkcs11.CKA_DECRYPT) + if err != nil { + return nil, err + } + + mechanism, err := c.GetKeyMechanism(session, key) + if err != nil { + return nil, err + } + + switch mechanism { + case pkcs11.CKM_AES_GCM: + return c.DecryptAesGcm(session, key, nonce, ciphertext) + case pkcs11.CKM_RSA_PKCS_OAEP: + return c.DecryptRsaOaep(session, key, nonce, ciphertext) + } + return nil, fmt.Errorf("unsupported mechanism") +} + +func (c *Pkcs11Client) DecryptAesGcm(session pkcs11.SessionHandle, key pkcs11.ObjectHandle, nonce []byte, ciphertext []byte) ([]byte, error) { + params := pkcs11.NewGCMParams(nonce, nil, CryptoAesGcmOverhead*8) + defer params.Free() + + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_AES_GCM, params)} + + var err error + if err = c.client.DecryptInit(session, mech, key); err != nil { + return nil, fmt.Errorf("failed to pkcs11 DecryptInit: %s", err) + } + var decrypted []byte + if decrypted, err = c.client.Decrypt(session, ciphertext); err != nil { + return nil, fmt.Errorf("failed to pkcs11 Decrypt: %s", err) + } + return decrypted, nil +} + +func (c *Pkcs11Client) DecryptRsaOaep(session pkcs11.SessionHandle, key pkcs11.ObjectHandle, _ []byte, ciphertext []byte) ([]byte, error) { + var rsaOaepHash string + if c.rsaOaepHash != "" { + rsaOaepHash = c.rsaOaepHash + } else { + rsaOaepHash = DefaultRsaOaepHash + } + hash, mgf_hash, err := RsaHashMechFromString(rsaOaepHash) + if err != nil { + return nil, err + } + params := pkcs11.NewOAEPParams(hash, mgf_hash, pkcs11.CKZ_DATA_SPECIFIED, nil) + + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_OAEP, params)} + + if err = c.client.DecryptInit(session, mech, key); err != nil { + return nil, fmt.Errorf("failed to pkcs11 DecryptInit: %s", err) + } + var decrypted []byte + if decrypted, err = c.client.Decrypt(session, ciphertext); err != nil { + return nil, fmt.Errorf("failed to pkcs11 Decrypt: %s", err) + } + return decrypted, nil +} + +// Create a PKCS11 client for the configured module. +func (c *Pkcs11Client) InitializeClient() (error) { + if c.client != nil { + return nil + } + c.client = pkcs11.New(c.lib) + err := c.client.Initialize() + if err != nil { + c.client = nil + return fmt.Errorf("failed to initialize PKCS11: %w", err) + } + return nil +} + +func (c *Pkcs11Client) GetSlotForLabel() (uint, error) { + if c.slot != 0 { + return c.slot, nil + } + if c.tokenLabel == "" { + return 0, fmt.Errorf("not token label configured") + } + slots, _ := c.client.GetSlotList(true) + for _, slot := range slots { + tokenInfo, err := c.client.GetTokenInfo(slot) + if err == nil && tokenInfo.Label == c.tokenLabel { + c.slot = slot + break + } + } + if c.slot == 0 { + return 0, fmt.Errorf("failed to find token with label: %s", c.tokenLabel) + } + return c.slot, nil +} + +// Open a session and perform the authentication process. +func (c *Pkcs11Client) GetSession() (pkcs11.SessionHandle, error) { + if c.client == nil { + return 0, fmt.Errorf("PKCS11 not initialized") + } + + if c.slot == 0 { + _, err := c.GetSlotForLabel() + if err != nil { + return 0, err + } + } + + session, err := c.client.OpenSession(c.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) + if err != nil { + return 0, fmt.Errorf("failed to open session: %w", err) + } + err = c.client.Login(session, pkcs11.CKU_USER, c.pin) + if err != nil { + return 0, fmt.Errorf("failed to login: %w", err) + } + return session, nil +} + +func (c *Pkcs11Client) CloseSession(session pkcs11.SessionHandle) { + if c.client == nil { + return + } + c.client.Logout(session) + c.client.CloseSession(session) +} + +// Find on key for the given Label, ID and Mechanism. +func (c *Pkcs11Client) FindKey(session pkcs11.SessionHandle, key Pkcs11Key, typ uint) (pkcs11.ObjectHandle, error) { + template := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(key.label)), + pkcs11.NewAttribute(typ, true), + } + if keyIdBytes, err := hex.DecodeString(key.id); err == nil { + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_ID, keyIdBytes)) + } + if c.mechanism != 0 { + keyType, err := GetKeyTypeFromMech(c.mechanism) + if err != nil { + return 0, fmt.Errorf("failed to get key type from mechanism: %s", err) + } + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, keyType)) + } + + if err := c.client.FindObjectsInit(session, template); err != nil { + return 0, fmt.Errorf("failed to pkcs11 FindObjectsInit: %s", err) + } + obj, _, err := c.client.FindObjects(session, 2) + if err != nil { + return 0, fmt.Errorf("failed to pkcs11 FindObjects: %s", err) + } + if err := c.client.FindObjectsFinal(session); err != nil { + return 0, fmt.Errorf("failed to pkcs11 FindObjectsFinal: %s", err) + } + if len(obj) == 0 { + return 0, fmt.Errorf("no key found for the label: %s", key.label) + } + if len(obj) != 1 { + return 0, fmt.Errorf("got more than 1 key for the label: %s", key.label) + } + + return obj[0], nil +} + +func (c *Pkcs11Client) GetKeyMechanism(session pkcs11.SessionHandle, key pkcs11.ObjectHandle) (uint, error) { + template := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil), + } + attr, err := c.client.GetAttributeValue(session, pkcs11.ObjectHandle(key), template) + if err != nil { + return 0, fmt.Errorf("failed to pkcs11 GetAttributeValue: %s", err) + } + + attrMap := GetAttributesMap(attr) + keyType := GetValueAsInt(attrMap[pkcs11.CKA_KEY_TYPE]) + + mechanism := uint(0) + switch keyType { + case pkcs11.CKK_AES: + if c.mechanism != 0 { + mechanism = c.mechanism + } else { + mechanism = DefaultAesMechanism + } + case pkcs11.CKK_RSA: + if c.mechanism != 0 { + mechanism = c.mechanism + } else { + mechanism = DefaultRsaMechanism + } + default: + return 0, fmt.Errorf("unsupported key type: %d", keyType) + } + + return mechanism, nil +} + +func (c *Pkcs11Client) GetCurrentKey() Pkcs11Key { + return Pkcs11Key{ + label: c.keyLabel, + id: c.keyId, + } +} + +func GetKeyTypeFromMech(mech uint) (uint, error) { + switch mech { + case pkcs11.CKM_RSA_PKCS_OAEP: + return pkcs11.CKK_RSA, nil + case pkcs11.CKM_AES_GCM: + return pkcs11.CKK_AES, nil + // Deprecated mechanisms + case pkcs11.CKM_RSA_PKCS, pkcs11.CKM_AES_CBC, pkcs11.CKM_AES_CBC_PAD: + return 0, fmt.Errorf("deprecated mechanism: %s (%d)", MechanismString(mech), mech) + // Other are unsupported + default: + return 0, fmt.Errorf("unsupported mechanism: %d", mech) + } +} + +func MechanismString(mech uint) string { + switch mech { + case pkcs11.CKM_RSA_PKCS_OAEP: + return "CKM_RSA_PKCS_OAEP" + case pkcs11.CKM_AES_GCM: + return "CKM_AES_GCM" + // Deprecated mechanisms + case pkcs11.CKM_RSA_PKCS: + return "CKM_RSA_PKCS" + case pkcs11.CKM_AES_CBC: + return "CKM_AES_CBC" + case pkcs11.CKM_AES_CBC_PAD: + return "CKM_AES_CBC_PAD" + default: + return "Unknown" + } +} + +func MechanismFromString(mech string) (uint64, error) { + switch mech { + case "CKM_RSA_PKCS_OAEP", "RSA_PKCS_OAEP": + return pkcs11.CKM_RSA_PKCS_OAEP, nil + case "CKM_AES_GCM", "AES_GCM": + return pkcs11.CKM_AES_GCM, nil + // Deprecated mechanisms + case "CKM_RSA_PKCS", "RSA_PKCS", "CKM_AES_CBC_PAD", "AES_CBC_PAD": + return 0, fmt.Errorf("deprecated mechanism: %s", mech) + // Other mechanisms + default: + // Try to extract the mechanism PKCS11 raw value. + if mechanismNum, err := numberAutoParse(mech, 32); err == nil { + if _, err = GetKeyTypeFromMech(uint(mechanismNum)); err == nil { + return mechanismNum, nil + } + } + return 0, fmt.Errorf("unsupported mechanism: %s", mech) + } +} + +func RsaHashMechFromString(mech string) (uint, uint, error) { + mech = strings.ToLower(mech) + switch mech { + case "sha1": + return pkcs11.CKM_SHA_1, pkcs11.CKG_MGF1_SHA1, nil + case "sha224": + return pkcs11.CKM_SHA224, pkcs11.CKG_MGF1_SHA224, nil + case "sha256": + return pkcs11.CKM_SHA256, pkcs11.CKG_MGF1_SHA256, nil + case "sha384": + return pkcs11.CKM_SHA384, pkcs11.CKG_MGF1_SHA384, nil + case "sha512": + return pkcs11.CKM_SHA512, pkcs11.CKG_MGF1_SHA512, nil + default: + return 0, 0, fmt.Errorf("unsupported mechanism: %s", mech) + } +} + +func GetAttributesMap(attrs []*pkcs11.Attribute) map[uint][]byte { + m := make(map[uint][]byte, len(attrs)) + for _, a := range attrs { + m[a.Type] = a.Value + } + return m +} + +func GetValueAsInt(value []byte) int64 { + switch len(value) { + case 1: + return int64(value[0]) + case 2: + return int64(binary.NativeEndian.Uint16(value)) + case 4: + return int64(binary.NativeEndian.Uint32(value)) + case 8: + return int64(binary.NativeEndian.Uint64(value)) + } + return 0 +} + +func GetValueAsUint(value []byte) uint64 { + switch len(value) { + case 1: + return uint64(value[0]) + case 2: + return uint64(binary.NativeEndian.Uint16(value)) + case 4: + return uint64(binary.NativeEndian.Uint32(value)) + case 8: + return uint64(binary.NativeEndian.Uint64(value)) + } + return 0 +} + +func numberAutoParse(value string, bitSize int) (uint64, error) { + var ret uint64 + var err error + value = strings.ToLower(value) + if strings.HasPrefix(value, "0x") { + ret, err = strconv.ParseUint(value[2:], 16, bitSize) + } else { + ret, err = strconv.ParseUint(value, 10, bitSize) + } + return ret, err +}