Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap multi-return APIs using structs: avoid heap escape #131

Merged
merged 2 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ecdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,12 @@ func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) {
if C.go_openssl_EVP_PKEY_derive_set_peer(ctx, pub._pkey) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_set_peer")
}
var outLen C.size_t
if C.go_openssl_EVP_PKEY_derive(ctx, nil, &outLen) != 1 {
r := C.go_openssl_EVP_PKEY_derive_wrapper(ctx, nil, 0)
if r.result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
out := make([]byte, outLen)
if C.go_openssl_EVP_PKEY_derive(ctx, base(out), &outLen) != 1 {
out := make([]byte, r.keylen)
if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(out), r.keylen).result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
return out, nil
Expand Down
12 changes: 6 additions & 6 deletions ed25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func NewPrivateKeyEd25519FromSeed(seed []byte) (*PrivateKeyEd25519, error) {

func extractPKEYPubEd25519(pkey C.GO_EVP_PKEY_PTR, pub []byte) error {
pubSize := C.size_t(publicKeySizeEd25519)
if C.go_openssl_EVP_PKEY_get_raw_public_key(pkey, base(pub), &pubSize) != 1 {
if C.go_openssl_EVP_PKEY_get_raw_public_key_wrapper(pkey, base(pub), pubSize) != 1 {
return newOpenSSLError("EVP_PKEY_get_raw_public_key")
}
if pubSize != publicKeySizeEd25519 {
Expand All @@ -160,7 +160,7 @@ func extractPKEYPrivEd25519(pkey C.GO_EVP_PKEY_PTR, priv []byte) error {
return err
}
privSize := C.size_t(seedSizeEd25519)
if C.go_openssl_EVP_PKEY_get_raw_private_key(pkey, base(priv), &privSize) != 1 {
if C.go_openssl_EVP_PKEY_get_raw_private_key_wrapper(pkey, base(priv), privSize) != 1 {
return newOpenSSLError("EVP_PKEY_get_raw_private_key")
}
if privSize != seedSizeEd25519 {
Expand Down Expand Up @@ -190,12 +190,12 @@ func signEd25519(priv *PrivateKeyEd25519, sig, message []byte) error {
if C.go_openssl_EVP_DigestSignInit(ctx, nil, nil, nil, priv._pkey) != 1 {
return newOpenSSLError("EVP_DigestSignInit")
}
siglen := C.size_t(signatureSizeEd25519)
if C.go_openssl_EVP_DigestSign(ctx, base(sig), &siglen, base(message), C.size_t(len(message))) != 1 {
r := C.go_openssl_EVP_DigestSign_wrapper(ctx, base(sig), C.size_t(signatureSizeEd25519), base(message), C.size_t(len(message)))
if r.result != 1 {
return newOpenSSLError("EVP_DigestSign")
}
if siglen != signatureSizeEd25519 {
return errors.New("ed25519: bad signature length: " + strconv.Itoa(int(siglen)))
if r.siglen != signatureSizeEd25519 {
return errors.New("ed25519: bad signature length: " + strconv.Itoa(int(r.siglen)))
}
return nil
}
Expand Down
45 changes: 44 additions & 1 deletion goopenssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ go_hash_sum(GO_EVP_MD_CTX_PTR ctx, GO_EVP_MD_CTX_PTR ctx2, unsigned char *out)
return go_openssl_EVP_DigestFinal(ctx2, out, NULL);
}

// These wrappers allocate out_len on the C stack to avoid having to pass a pointer from Go, which would escape to the heap.
// These wrappers allocate length variables on the C stack to avoid having to pass a pointer from Go, which would escape to the heap.
// Use them only in situations where the output length can be safely discarded.

static inline int
go_openssl_EVP_EncryptUpdate_wrapper(GO_EVP_CIPHER_CTX_PTR ctx, unsigned char *out, const unsigned char *in, int in_len)
{
Expand All @@ -105,6 +106,48 @@ go_openssl_EVP_CipherUpdate_wrapper(GO_EVP_CIPHER_CTX_PTR ctx, unsigned char *ou
return go_openssl_EVP_CipherUpdate(ctx, out, &len, in, in_len);
}

static inline int
go_openssl_EVP_PKEY_get_raw_public_key_wrapper(const GO_EVP_PKEY_PTR pkey, unsigned char *pub, size_t len)
dagood marked this conversation as resolved.
Show resolved Hide resolved
{
return go_openssl_EVP_PKEY_get_raw_public_key(pkey, pub, &len);
}

static inline int
go_openssl_EVP_PKEY_get_raw_private_key_wrapper(const GO_EVP_PKEY_PTR pkey, unsigned char *priv, size_t len)
{
return go_openssl_EVP_PKEY_get_raw_private_key(pkey, priv, &len);
}

// These wrappers also allocate length variables on the C stack to avoid escape to the heap, but do return the result.
// A struct is returned that contains multiple return values instead of OpenSSL's approach of using pointers.

typedef struct
{
int result;
size_t keylen;
} go_openssl_EVP_PKEY_derive_wrapper_out;

static inline go_openssl_EVP_PKEY_derive_wrapper_out
go_openssl_EVP_PKEY_derive_wrapper(GO_EVP_PKEY_CTX_PTR ctx, unsigned char *key, size_t keylen)
{
go_openssl_EVP_PKEY_derive_wrapper_out r = {0, keylen};
r.result = go_openssl_EVP_PKEY_derive(ctx, key, &r.keylen);
return r;
}

typedef struct
{
int result;
size_t siglen;
} go_openssl_EVP_DigestSign_wrapper_out;

static inline go_openssl_EVP_DigestSign_wrapper_out
go_openssl_EVP_DigestSign_wrapper(GO_EVP_MD_CTX_PTR ctx, unsigned char *sigret, size_t siglen, const unsigned char *tbs, size_t tbslen)
{
go_openssl_EVP_DigestSign_wrapper_out r = {0, siglen};
r.result = go_openssl_EVP_DigestSign(ctx, sigret, &r.siglen, tbs, tbslen);
return r;
}

// These wrappers allocate out_len on the C stack, and check that it matches the expected
// value, to avoid having to pass a pointer from Go, which would escape to the heap.
Expand Down
12 changes: 6 additions & 6 deletions hkdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (c *hkdf) Read(p []byte) (int, error) {
}
c.buf = append(c.buf, make([]byte, needLen)...)
outLen := C.size_t(prevLen + needLen)
if C.go_openssl_EVP_PKEY_derive(c.ctx, base(c.buf), &outLen) != 1 {
if C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, base(c.buf), outLen).result != 1 {
return 0, newOpenSSLError("EVP_PKEY_derive")
}
n := copy(p, c.buf[prevLen:outLen])
Expand Down Expand Up @@ -132,15 +132,15 @@ func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt")
}
}
var outLen C.size_t
if C.go_openssl_EVP_PKEY_derive(c.ctx, nil, &outLen) != 1 {
r := C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, nil, 0)
if r.result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
out := make([]byte, outLen)
if C.go_openssl_EVP_PKEY_derive(c.ctx, base(out), &outLen) != 1 {
out := make([]byte, r.keylen)
if C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, base(out), r.keylen).result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
}
return out[:outLen], nil
return out[:r.keylen], nil
}

func ExpandHKDF(h func() hash.Hash, pseudorandomKey, info []byte) (io.Reader, error) {
Expand Down
2 changes: 1 addition & 1 deletion tls1prf.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TLS1PRF(result, secret, label, seed []byte, h func() hash.Hash) error {
}
}
outLen := C.size_t(len(result))
if C.go_openssl_EVP_PKEY_derive(ctx, base(result), &outLen) != 1 {
if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(result), outLen).result != 1 {
return newOpenSSLError("EVP_PKEY_derive")
}
// The Go standard library expects TLS1PRF to return the requested number of bytes,
Expand Down