diff --git a/ecdh.go b/ecdh.go index 62e23333..a1e627ef 100644 --- a/ecdh.go +++ b/ecdh.go @@ -269,12 +269,12 @@ func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) { if C.go_openssl_EVP_PKEY_derive_set_peer(ctx, pub._pkey) != 1 { return nil, newOpenSSLError("EVP_PKEY_derive_set_peer") } - var outLen C.size_t - if C.go_openssl_EVP_PKEY_derive(ctx, nil, &outLen) != 1 { + r := C.go_openssl_EVP_PKEY_derive_wrapper(ctx, nil, 0) + if r.result != 1 { return nil, newOpenSSLError("EVP_PKEY_derive_init") } - out := make([]byte, outLen) - if C.go_openssl_EVP_PKEY_derive(ctx, base(out), &outLen) != 1 { + out := make([]byte, r.keylen) + if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(out), r.keylen).result != 1 { return nil, newOpenSSLError("EVP_PKEY_derive_init") } return out, nil diff --git a/ed25519.go b/ed25519.go index f66a2a1d..f74bd8f8 100644 --- a/ed25519.go +++ b/ed25519.go @@ -145,12 +145,12 @@ func NewPrivateKeyEd25519FromSeed(seed []byte) (*PrivateKeyEd25519, error) { } func extractPKEYPubEd25519(pkey C.GO_EVP_PKEY_PTR, pub []byte) error { - pubSize := C.size_t(publicKeySizeEd25519) - if C.go_openssl_EVP_PKEY_get_raw_public_key(pkey, base(pub), &pubSize) != 1 { + r := C.go_openssl_EVP_PKEY_get_raw_public_key_wrapper(pkey, base(pub), C.size_t(publicKeySizeEd25519)) + if r.result != 1 { return newOpenSSLError("EVP_PKEY_get_raw_public_key") } - if pubSize != publicKeySizeEd25519 { - return errors.New("ed25519: bad public key length: " + strconv.Itoa(int(pubSize))) + if r.len != publicKeySizeEd25519 { + return errors.New("ed25519: bad public key length: " + strconv.Itoa(int(r.len))) } return nil } @@ -159,12 +159,12 @@ func extractPKEYPrivEd25519(pkey C.GO_EVP_PKEY_PTR, priv []byte) error { if err := extractPKEYPubEd25519(pkey, priv[seedSizeEd25519:]); err != nil { return err } - privSize := C.size_t(seedSizeEd25519) - if C.go_openssl_EVP_PKEY_get_raw_private_key(pkey, base(priv), &privSize) != 1 { + r := C.go_openssl_EVP_PKEY_get_raw_private_key_wrapper(pkey, base(priv), C.size_t(seedSizeEd25519)) + if r.result != 1 { return newOpenSSLError("EVP_PKEY_get_raw_private_key") } - if privSize != seedSizeEd25519 { - return errors.New("ed25519: bad private key length: " + strconv.Itoa(int(privSize))) + if r.len != seedSizeEd25519 { + return errors.New("ed25519: bad private key length: " + strconv.Itoa(int(r.len))) } return nil } @@ -190,12 +190,12 @@ func signEd25519(priv *PrivateKeyEd25519, sig, message []byte) error { if C.go_openssl_EVP_DigestSignInit(ctx, nil, nil, nil, priv._pkey) != 1 { return newOpenSSLError("EVP_DigestSignInit") } - siglen := C.size_t(signatureSizeEd25519) - if C.go_openssl_EVP_DigestSign(ctx, base(sig), &siglen, base(message), C.size_t(len(message))) != 1 { + r := C.go_openssl_EVP_DigestSign_wrapper(ctx, base(sig), C.size_t(signatureSizeEd25519), base(message), C.size_t(len(message))) + if r.result != 1 { return newOpenSSLError("EVP_DigestSign") } - if siglen != signatureSizeEd25519 { - return errors.New("ed25519: bad signature length: " + strconv.Itoa(int(siglen))) + if r.siglen != signatureSizeEd25519 { + return errors.New("ed25519: bad signature length: " + strconv.Itoa(int(r.siglen))) } return nil } diff --git a/goopenssl.h b/goopenssl.h index 5b658ec9..e488bf20 100644 --- a/goopenssl.h +++ b/goopenssl.h @@ -107,6 +107,58 @@ go_openssl_EVP_CipherUpdate_wrapper(GO_EVP_CIPHER_CTX_PTR ctx, unsigned char *ou return go_openssl_EVP_CipherUpdate(ctx, out, &len, in, in_len); } +// These wrappers also allocate length variables on the C stack to avoid escape to the heap, but do return the result. +// A struct is returned that contains multiple return values instead of OpenSSL's approach of using pointers. + +typedef struct +{ + int result; + size_t keylen; +} go_openssl_EVP_PKEY_derive_wrapper_out; + +static inline go_openssl_EVP_PKEY_derive_wrapper_out +go_openssl_EVP_PKEY_derive_wrapper(GO_EVP_PKEY_CTX_PTR ctx, unsigned char *key, size_t keylen) +{ + go_openssl_EVP_PKEY_derive_wrapper_out r = {0, keylen}; + r.result = go_openssl_EVP_PKEY_derive(ctx, key, &r.keylen); + return r; +} + +typedef struct +{ + int result; + size_t len; +} go_openssl_EVP_PKEY_get_raw_key_out; + +static inline go_openssl_EVP_PKEY_get_raw_key_out +go_openssl_EVP_PKEY_get_raw_public_key_wrapper(const GO_EVP_PKEY_PTR pkey, unsigned char *pub, size_t len) +{ + go_openssl_EVP_PKEY_get_raw_key_out r = {0, len}; + r.result = go_openssl_EVP_PKEY_get_raw_public_key(pkey, pub, &r.len); + return r; +} + +static inline go_openssl_EVP_PKEY_get_raw_key_out +go_openssl_EVP_PKEY_get_raw_private_key_wrapper(const GO_EVP_PKEY_PTR pkey, unsigned char *priv, size_t len) +{ + go_openssl_EVP_PKEY_get_raw_key_out r = {0, len}; + r.result = go_openssl_EVP_PKEY_get_raw_private_key(pkey, priv, &r.len); + return r; +} + +typedef struct +{ + int result; + size_t siglen; +} go_openssl_EVP_DigestSign_wrapper_out; + +static inline go_openssl_EVP_DigestSign_wrapper_out +go_openssl_EVP_DigestSign_wrapper(GO_EVP_MD_CTX_PTR ctx, unsigned char *sigret, size_t siglen, const unsigned char *tbs, size_t tbslen) +{ + go_openssl_EVP_DigestSign_wrapper_out r = {0, siglen}; + r.result = go_openssl_EVP_DigestSign(ctx, sigret, &r.siglen, tbs, tbslen); + return r; +} // These wrappers allocate out_len on the C stack, and check that it matches the expected // value, to avoid having to pass a pointer from Go, which would escape to the heap. diff --git a/hkdf.go b/hkdf.go index ac3fbba0..61cf483f 100644 --- a/hkdf.go +++ b/hkdf.go @@ -98,7 +98,7 @@ func (c *hkdf) Read(p []byte) (int, error) { } c.buf = append(c.buf, make([]byte, needLen)...) outLen := C.size_t(prevLen + needLen) - if C.go_openssl_EVP_PKEY_derive(c.ctx, base(c.buf), &outLen) != 1 { + if C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, base(c.buf), outLen).result != 1 { return 0, newOpenSSLError("EVP_PKEY_derive") } n := copy(p, c.buf[prevLen:outLen]) @@ -132,15 +132,15 @@ func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) { return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt") } } - var outLen C.size_t - if C.go_openssl_EVP_PKEY_derive(c.ctx, nil, &outLen) != 1 { + r := C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, nil, 0) + if r.result != 1 { return nil, newOpenSSLError("EVP_PKEY_derive_init") } - out := make([]byte, outLen) - if C.go_openssl_EVP_PKEY_derive(c.ctx, base(out), &outLen) != 1 { + out := make([]byte, r.keylen) + if C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, base(out), r.keylen).result != 1 { return nil, newOpenSSLError("EVP_PKEY_derive") } - return out[:outLen], nil + return out[:r.keylen], nil } func ExpandHKDF(h func() hash.Hash, pseudorandomKey, info []byte) (io.Reader, error) { diff --git a/tls1prf.go b/tls1prf.go index 3153fc81..5de62f95 100644 --- a/tls1prf.go +++ b/tls1prf.go @@ -90,7 +90,7 @@ func TLS1PRF(result, secret, label, seed []byte, h func() hash.Hash) error { } } outLen := C.size_t(len(result)) - if C.go_openssl_EVP_PKEY_derive(ctx, base(result), &outLen) != 1 { + if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(result), outLen).result != 1 { return newOpenSSLError("EVP_PKEY_derive") } // The Go standard library expects TLS1PRF to return the requested number of bytes,