Skip to content

Commit

Permalink
factor out EVP_PKEY_CTX functions
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Sep 4, 2023
1 parent 13f20f3 commit 4a7159f
Show file tree
Hide file tree
Showing 8 changed files with 461 additions and 339 deletions.
277 changes: 277 additions & 0 deletions bindings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
//go:build !cmd_go_bootstrap

package openssl

// #include "goopenssl.h"
import "C"
import "unsafe"

type evpPkeyCtx struct {
ptr C.GO_EVP_PKEY_CTX_PTR
}

func newEvpPkeyCtx(pkey C.GO_EVP_PKEY_PTR) (evpPkeyCtx, error) {
ctx := C.go_openssl_EVP_PKEY_CTX_new(pkey, nil)
if ctx == nil {
return evpPkeyCtx{}, newOpenSSLError("EVP_PKEY_CTX_new")
}
return evpPkeyCtx{ctx}, nil
}

func newEvpPkeyCtxFromID(id int) (evpPkeyCtx, error) {
ctx := C.go_openssl_EVP_PKEY_CTX_new_id(C.int(id), nil)
if ctx == nil {
return evpPkeyCtx{}, newOpenSSLError("EVP_PKEY_CTX_new_id")
}
return evpPkeyCtx{ctx}, nil
}

func (ctx evpPkeyCtx) free() {
if ctx.ptr != nil {
C.go_openssl_EVP_PKEY_CTX_free(ctx.ptr)
}
}

func (ctx evpPkeyCtx) ctrl(keytype int, optype int, cmd int, p1 int, p2 unsafe.Pointer) error {
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx.ptr, C.int(keytype), C.int(optype), C.int(cmd), C.int(p1), p2) != 1 {
return newOpenSSLError("EVP_PKEY_CTX_ctrl")
}
return nil
}

func (ctx evpPkeyCtx) keygenInit() error {
if C.go_openssl_EVP_PKEY_keygen_init(ctx.ptr) != 1 {
return newOpenSSLError("EVP_PKEY_keygen_init")
}
return nil
}

func (ctx evpPkeyCtx) keygen() (pkey C.GO_EVP_PKEY_PTR, err error) {
if C.go_openssl_EVP_PKEY_keygen(ctx.ptr, &pkey) != 1 {
return nil, newOpenSSLError("EVP_PKEY_keygen")
}
return pkey, nil
}

func (ctx evpPkeyCtx) encryptInit() error {
if C.go_openssl_EVP_PKEY_encrypt_init(ctx.ptr) != 1 {
return newOpenSSLError("EVP_PKEY_encrypt_init")
}
return nil
}

func (ctx evpPkeyCtx) encrypt(out []byte, in []byte) ([]byte, error) {
outLen := C.size_t(len(out))
if ret := C.go_openssl_EVP_PKEY_encrypt(ctx.ptr, base(out), &outLen, base(in), C.size_t(len(in))); ret != 1 {
return nil, newOpenSSLError("EVP_PKEY_encrypt")
}
return out[:outLen], nil
}

func (ctx evpPkeyCtx) decryptInit() error {
if C.go_openssl_EVP_PKEY_decrypt_init(ctx.ptr) != 1 {
return newOpenSSLError("EVP_PKEY_decrypt_init")
}
return nil
}

func (ctx evpPkeyCtx) decrypt(out []byte, in []byte) ([]byte, error) {
outLen := C.size_t(len(out))
if ret := C.go_openssl_EVP_PKEY_decrypt(ctx.ptr, base(out), &outLen, base(in), C.size_t(len(in))); ret != 1 {
return nil, newOpenSSLError("EVP_PKEY_decrypt")
}
return out[:outLen], nil
}

func (ctx evpPkeyCtx) signInit() error {
if C.go_openssl_EVP_PKEY_sign_init(ctx.ptr) != 1 {
return newOpenSSLError("EVP_PKEY_sign_init")
}
return nil
}

func (ctx evpPkeyCtx) sign(out []byte, in []byte) ([]byte, error) {
outLen := C.size_t(len(out))
if ret := C.go_openssl_EVP_PKEY_sign(ctx.ptr, base(out), &outLen, base(in), C.size_t(len(in))); ret != 1 {
return nil, newOpenSSLError("EVP_PKEY_sign")
}
return out[:outLen], nil
}

func (ctx evpPkeyCtx) verifyInit() error {
if C.go_openssl_EVP_PKEY_verify_init(ctx.ptr) != 1 {
return newOpenSSLError("EVP_PKEY_verify_init")
}
return nil
}

func (ctx evpPkeyCtx) verify(sig []byte, in []byte) error {
if ret := C.go_openssl_EVP_PKEY_verify(ctx.ptr, base(sig), C.size_t(len(sig)), base(in), C.size_t(len(in))); ret != 1 {
return newOpenSSLError("EVP_PKEY_verify")
}
return nil
}

func (ctx evpPkeyCtx) fromdataInit() error {
if C.go_openssl_EVP_PKEY_fromdata_init(ctx.ptr) != 1 {
return newOpenSSLError("EVP_PKEY_fromdata_init")
}
return nil
}

func (ctx evpPkeyCtx) fromdata(selection int, params C.GO_OSSL_PARAM_PTR) (pkey C.GO_EVP_PKEY_PTR, err error) {
if C.go_openssl_EVP_PKEY_fromdata(ctx.ptr, &pkey, C.int(selection), params) != 1 {
return nil, newOpenSSLError("EVP_PKEY_fromdata")
}
return pkey, nil
}

func (ctx evpPkeyCtx) deriveInit() error {
if C.go_openssl_EVP_PKEY_derive_init(ctx.ptr) != 1 {
return newOpenSSLError("EVP_PKEY_derive_init")
}
return nil
}

func (ctx evpPkeyCtx) deriveSetPeer(peer C.GO_EVP_PKEY_PTR) error {
if C.go_openssl_EVP_PKEY_derive_set_peer(ctx.ptr, peer) != 1 {
return newOpenSSLError("EVP_PKEY_derive_set_peer")
}
return nil
}

func (ctx evpPkeyCtx) derive(out []byte) ([]byte, error) {
outLen := C.size_t(len(out))
if out == nil {
if C.go_openssl_EVP_PKEY_derive(ctx.ptr, nil, &outLen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
}
out = make([]byte, outLen)
}
if C.go_openssl_EVP_PKEY_derive(ctx.ptr, base(out), &outLen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
}
return out[:outLen], nil
}

func (ctx evpPkeyCtx) setHKDFProps(mode int, md C.GO_EVP_MD_PTR, key []byte, salt []byte, info []byte) error {
switch vMajor {
case 3:
if mode != 0 {
if C.go_openssl_EVP_PKEY_CTX_set_hkdf_mode(ctx.ptr, C.int(mode)) != 1 {
return newOpenSSLError("EVP_PKEY_CTX_set_hkdf_mode")
}
}
if md != nil {
if C.go_openssl_EVP_PKEY_CTX_set_hkdf_md(ctx.ptr, md) != 1 {
return newOpenSSLError("EVP_PKEY_CTX_set_hkdf_md")
}
}
if key != nil {
if C.go_openssl_EVP_PKEY_CTX_set1_hkdf_key(ctx.ptr, base(key), C.int(len(key))) != 1 {
return newOpenSSLError("gEVP_PKEY_CTX_set1_hkdf_key")
}
}
if salt != nil {
if C.go_openssl_EVP_PKEY_CTX_set1_hkdf_salt(ctx.ptr, base(salt), C.int(len(salt))) != 1 {
return newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt")
}
}
if info != nil {
if C.go_openssl_EVP_PKEY_CTX_add1_hkdf_info(ctx.ptr, base(info), C.int(len(info))) != 1 {
return newOpenSSLError("EVP_PKEY_CTX_add1_hkdf_info")
}
}
return nil
case 1:
if mode != 0 {
if err := ctx.ctrl(-1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_HKDF_MODE, mode, nil); err != nil {
return err
}
}
if md != nil {
if err := ctx.ctrl(-1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_HKDF_MD, 0, unsafe.Pointer(md)); err != nil {
return err
}
}
if key != nil {
if err := ctx.ctrl(-1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_HKDF_KEY, len(key), unsafe.Pointer(base(key))); err != nil {
return err
}
}
if salt != nil {
if err := ctx.ctrl(-1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_HKDF_SALT, len(salt), unsafe.Pointer(base(salt))); err != nil {
return err
}
}
if info != nil {
if err := ctx.ctrl(-1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_HKDF_INFO, len(info), unsafe.Pointer(base(info))); err != nil {
return err
}
}
default:
panic(errUnsupportedVersion())
}
return nil
}

func (ctx evpPkeyCtx) setTLS1PRFProps(md C.GO_EVP_MD_PTR, secret []byte, seeds ...[]byte) error {

switch vMajor {
case 3:
if C.go_openssl_EVP_PKEY_CTX_set_tls1_prf_md(ctx.ptr, md) != 1 {
return newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md")
}
if C.go_openssl_EVP_PKEY_CTX_set1_tls1_prf_secret(ctx.ptr,
base(secret), C.int(len(secret))) != 1 {
return newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret")
}
for _, s := range seeds {
if C.go_openssl_EVP_PKEY_CTX_add1_tls1_prf_seed(ctx.ptr,
base(s), C.int(len(s))) != 1 {
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
}
}
case 1:
if err := ctx.ctrl(-1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_TLS_MD, 0, unsafe.Pointer(md)); err != nil {
return err
}
if err := ctx.ctrl(-1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_TLS_SECRET, len(secret), unsafe.Pointer(base(secret))); err != nil {
return err
}
for _, s := range seeds {
if err := ctx.ctrl(-1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_TLS_SEED, len(s), unsafe.Pointer(base(s))); err != nil {
return err
}
}
default:
panic(errUnsupportedVersion())
}
return nil
}

func (ctx evpPkeyCtx) setRSAOAEPLabel(label []byte) error {
// ctx takes ownership of label, so malloc a copy for OpenSSL to free.
// OpenSSL does not take ownership of the label if the length is zero,
// so better avoid the allocation.
var clabel *C.uchar
if len(label) > 0 {
clabel = (*C.uchar)(cryptoMalloc(len(label)))
copy((*[1 << 30]byte)(unsafe.Pointer(clabel))[:len(label)], label)
}
switch vMajor {
case 3:
if C.go_openssl_EVP_PKEY_CTX_set0_rsa_oaep_label(ctx.ptr, unsafe.Pointer(clabel), C.int(len(label))) != 1 {
cryptoFree(unsafe.Pointer(clabel))
return newOpenSSLError("EVP_PKEY_CTX_set0_rsa_oaep_label")
}
case 1:
if err := ctx.ctrl(C.GO_EVP_PKEY_RSA, -1, C.GO_EVP_PKEY_CTRL_RSA_OAEP_LABEL, len(label), unsafe.Pointer(clabel)); err != nil {
cryptoFree(unsafe.Pointer(clabel))
return err
}
default:
panic(errUnsupportedVersion())
}
return nil
}
28 changes: 10 additions & 18 deletions ecdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e
}
defer C.go_openssl_OSSL_PARAM_BLD_free(bld)
C.go_openssl_OSSL_PARAM_BLD_push_utf8_string(bld, paramGroup, C.go_openssl_OBJ_nid2sn(nid), 0)
var selection C.int
var selection int
if isPrivate {
priv := C.go_openssl_BN_bin2bn(base(bytes), C.int(len(bytes)), nil)
if priv == nil {
Expand Down Expand Up @@ -258,26 +258,18 @@ func deriveEcdhPublicKey(pkey C.GO_EVP_PKEY_PTR, curve string) error {
func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) {
defer runtime.KeepAlive(priv)
defer runtime.KeepAlive(pub)
ctx := C.go_openssl_EVP_PKEY_CTX_new(priv._pkey, nil)
if ctx == nil {
return nil, newOpenSSLError("EVP_PKEY_CTX_new")
}
defer C.go_openssl_EVP_PKEY_CTX_free(ctx)
if C.go_openssl_EVP_PKEY_derive_init(ctx) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
if C.go_openssl_EVP_PKEY_derive_set_peer(ctx, pub._pkey) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_set_peer")
ctx, err := newEvpPkeyCtx(priv._pkey)
if err != nil {
return nil, err
}
var outLen C.size_t
if C.go_openssl_EVP_PKEY_derive(ctx, nil, &outLen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
defer ctx.free()
if err := ctx.deriveInit(); err != nil {
return nil, err
}
out := make([]byte, outLen)
if C.go_openssl_EVP_PKEY_derive(ctx, base(out), &outLen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
if err := ctx.deriveSetPeer(pub._pkey); err != nil {
return nil, err
}
return out, nil
return ctx.derive(nil)
}

func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
Expand Down
6 changes: 3 additions & 3 deletions ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (k *PrivateKeyECDSA) finalize() {
C.go_openssl_EVP_PKEY_free(k._pkey)
}

func (k *PrivateKeyECDSA) withKey(f func(C.GO_EVP_PKEY_PTR) C.int) C.int {
func (k *PrivateKeyECDSA) withKey(f func(C.GO_EVP_PKEY_PTR) error) error {
defer runtime.KeepAlive(k)
return f(k._pkey)
}
Expand All @@ -33,7 +33,7 @@ func (k *PublicKeyECDSA) finalize() {
C.go_openssl_EVP_PKEY_free(k._pkey)
}

func (k *PublicKeyECDSA) withKey(f func(C.GO_EVP_PKEY_PTR) C.int) C.int {
func (k *PublicKeyECDSA) withKey(f func(C.GO_EVP_PKEY_PTR) error) error {
defer runtime.KeepAlive(k)
return f(k._pkey)
}
Expand Down Expand Up @@ -199,7 +199,7 @@ func newECDSAKey3(nid C.int, bx, by, bd C.GO_BIGNUM_PTR) (C.GO_EVP_PKEY_PTR, err
cbytes := C.CBytes(pubBytes)
defer C.free(cbytes)
C.go_openssl_OSSL_PARAM_BLD_push_octet_string(bld, paramPubKey, cbytes, C.size_t(len(pubBytes)))
var selection C.int
var selection int
if bd != nil {
if C.go_openssl_OSSL_PARAM_BLD_push_BN(bld, paramPrivKey, bd) != 1 {
return nil, newOpenSSLError("OSSL_PARAM_BLD_push_BN")
Expand Down
Loading

0 comments on commit 4a7159f

Please sign in to comment.