From d6c123d678204ca3b3013c499527dfde869f11dd Mon Sep 17 00:00:00 2001 From: Daiki Ueno Date: Wed, 20 Sep 2023 21:02:52 +0900 Subject: [PATCH] Read signed message through io.Reader in HashSign/HashVerify When signed message is large, passing it through a byte array would cost memory; this switches to using io.Reader to avoid that. Using `io.Reader` in the function signatures could also help diversify those HashSign/HashVerify from Sign/Verify, through a separate interface, alongside the `crypto.Signer` interface. Signed-off-by: Daiki Ueno --- ecdsa.go | 5 +++-- ecdsa_test.go | 7 ++++--- evp.go | 35 +++++++++++++++++++++++++++++------ rsa.go | 5 +++-- rsa_test.go | 4 ++-- shims.h | 1 + 6 files changed, 42 insertions(+), 15 deletions(-) diff --git a/ecdsa.go b/ecdsa.go index 46b16abf..d35277bd 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -7,6 +7,7 @@ import "C" import ( "crypto" "errors" + "io" "runtime" ) @@ -109,7 +110,7 @@ func SignMarshalECDSA(priv *PrivateKeyECDSA, hash []byte) ([]byte, error) { return evpSign(priv.withKey, 0, 0, 0, hash) } -func HashSignECDSA(priv *PrivateKeyECDSA, h crypto.Hash, msg []byte) ([]byte, error) { +func HashSignECDSA(priv *PrivateKeyECDSA, h crypto.Hash, msg io.Reader) ([]byte, error) { return evpHashSign(priv.withKey, h, msg) } @@ -117,7 +118,7 @@ func VerifyECDSA(pub *PublicKeyECDSA, hash []byte, sig []byte) bool { return evpVerify(pub.withKey, 0, 0, 0, sig, hash) == nil } -func HashVerifyECDSA(pub *PublicKeyECDSA, h crypto.Hash, msg, sig []byte) bool { +func HashVerifyECDSA(pub *PublicKeyECDSA, h crypto.Hash, msg io.Reader, sig []byte) bool { return evpHashVerify(pub.withKey, h, msg, sig) == nil } diff --git a/ecdsa_test.go b/ecdsa_test.go index d67b8d01..ed9f308b 100644 --- a/ecdsa_test.go +++ b/ecdsa_test.go @@ -1,6 +1,7 @@ package openssl_test import ( + "bytes" "crypto" "crypto/ecdsa" "crypto/elliptic" @@ -74,15 +75,15 @@ func testECDSASignAndVerify(t *testing.T, c elliptic.Curve) { if openssl.VerifyECDSA(pub, hashed[:], signed) { t.Errorf("Verify succeeded despite intentionally invalid hash!") } - signed, err = openssl.HashSignECDSA(priv, crypto.SHA256, msg) + signed, err = openssl.HashSignECDSA(priv, crypto.SHA256, bytes.NewReader(msg)) if err != nil { t.Fatal(err) } - if !openssl.HashVerifyECDSA(pub, crypto.SHA256, msg, signed) { + if !openssl.HashVerifyECDSA(pub, crypto.SHA256, bytes.NewReader(msg), signed) { t.Errorf("Verify failed") } signed[0] ^= 0xff - if openssl.HashVerifyECDSA(pub, crypto.SHA256, msg, signed) { + if openssl.HashVerifyECDSA(pub, crypto.SHA256, bytes.NewReader(msg), signed) { t.Errorf("Verify failed") } } diff --git a/evp.go b/evp.go index c7f53e3e..95c7f169 100644 --- a/evp.go +++ b/evp.go @@ -8,6 +8,7 @@ import ( "crypto" "errors" "hash" + "io" "strconv" "sync" "unsafe" @@ -371,7 +372,7 @@ func evpVerify(withKey withKeyFunc, padding C.int, saltLen C.int, h crypto.Hash, return verifyEVP(withKey, padding, nil, nil, saltLen, h, verifyInit, verify, sig, hashed) } -func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error) { +func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg io.Reader) ([]byte, error) { md := cryptoHashToMD(h) if md == nil { return nil, errors.New("unsupported hash function: " + strconv.Itoa(int(h))) @@ -388,8 +389,19 @@ func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error) }) != 1 { return nil, newOpenSSLError("EVP_DigestSignInit failed") } - if C.go_openssl_EVP_DigestUpdate(ctx, unsafe.Pointer(base(msg)), C.size_t(len(msg))) != 1 { - return nil, newOpenSSLError("EVP_DigestUpdate failed") + var blockLen = C.go_openssl_EVP_MD_get_block_size(md) + var block = make([]byte, blockLen) + for { + n, err := msg.Read(block) + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + if C.go_openssl_EVP_DigestUpdate(ctx, unsafe.Pointer(base(block)), C.size_t(n)) != 1 { + return nil, newOpenSSLError("EVP_DigestUpdate failed") + } } // Obtain the signature length if C.go_openssl_EVP_DigestSignFinal(ctx, nil, &outLen) != 1 { @@ -403,7 +415,7 @@ func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error) return out[:outLen], nil } -func evpHashVerify(withKey withKeyFunc, h crypto.Hash, msg, sig []byte) error { +func evpHashVerify(withKey withKeyFunc, h crypto.Hash, msg io.Reader, sig []byte) error { md := cryptoHashToMD(h) if md == nil { return errors.New("unsupported hash function: " + strconv.Itoa(int(h))) @@ -418,8 +430,19 @@ func evpHashVerify(withKey withKeyFunc, h crypto.Hash, msg, sig []byte) error { }) != 1 { return newOpenSSLError("EVP_DigestVerifyInit failed") } - if C.go_openssl_EVP_DigestUpdate(ctx, unsafe.Pointer(base(msg)), C.size_t(len(msg))) != 1 { - return newOpenSSLError("EVP_DigestUpdate failed") + var blockLen = C.go_openssl_EVP_MD_get_block_size(md) + var block = make([]byte, blockLen) + for { + n, err := msg.Read(block) + if err == io.EOF { + break + } + if err != nil { + return err + } + if C.go_openssl_EVP_DigestUpdate(ctx, unsafe.Pointer(base(block)), C.size_t(n)) != 1 { + return newOpenSSLError("EVP_DigestUpdate failed") + } } if C.go_openssl_EVP_DigestVerifyFinal(ctx, base(sig), C.size_t(len(sig))) != 1 { return newOpenSSLError("EVP_DigestVerifyFinal failed") diff --git a/rsa.go b/rsa.go index 5aef65b8..7a0415b4 100644 --- a/rsa.go +++ b/rsa.go @@ -9,6 +9,7 @@ import ( "crypto/subtle" "errors" "hash" + "io" "runtime" "unsafe" ) @@ -289,7 +290,7 @@ func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte, return evpSign(priv.withKey, C.GO_RSA_PKCS1_PADDING, 0, h, hashed) } -func HashSignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, msg []byte) ([]byte, error) { +func HashSignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, msg io.Reader) ([]byte, error) { return evpHashSign(priv.withKey, h, msg) } @@ -306,7 +307,7 @@ func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) err return evpVerify(pub.withKey, C.GO_RSA_PKCS1_PADDING, 0, h, sig, hashed) } -func HashVerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, msg, sig []byte) error { +func HashVerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, msg io.Reader, sig []byte) error { return evpHashVerify(pub.withKey, h, msg, sig) } diff --git a/rsa_test.go b/rsa_test.go index c926e9cf..b5b45603 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -130,7 +130,7 @@ func TestSignVerifyPKCS1v15(t *testing.T) { if err != nil { t.Fatal(err) } - signed2, err := openssl.HashSignRSAPKCS1v15(priv, crypto.SHA256, msg) + signed2, err := openssl.HashSignRSAPKCS1v15(priv, crypto.SHA256, bytes.NewReader(msg)) if err != nil { t.Fatal(err) } @@ -141,7 +141,7 @@ func TestSignVerifyPKCS1v15(t *testing.T) { if err != nil { t.Fatal(err) } - err = openssl.HashVerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed2) + err = openssl.HashVerifyRSAPKCS1v15(pub, crypto.SHA256, bytes.NewReader(msg), signed2) if err != nil { t.Fatal(err) } diff --git a/shims.h b/shims.h index 858c47e7..7e26d16d 100644 --- a/shims.h +++ b/shims.h @@ -355,4 +355,5 @@ DEFINEFUNC(int, PKCS5_PBKDF2_HMAC, (const char *pass, int passlen, const unsigne DEFINEFUNC_3_0(int, EVP_PKEY_CTX_set_tls1_prf_md, (GO_EVP_PKEY_CTX_PTR arg0, const GO_EVP_MD_PTR arg1), (arg0, arg1)) \ DEFINEFUNC_3_0(int, EVP_PKEY_CTX_set1_tls1_prf_secret, (GO_EVP_PKEY_CTX_PTR arg0, const unsigned char *arg1, int arg2), (arg0, arg1, arg2)) \ DEFINEFUNC_3_0(int, EVP_PKEY_CTX_add1_tls1_prf_seed, (GO_EVP_PKEY_CTX_PTR arg0, const unsigned char *arg1, int arg2), (arg0, arg1, arg2)) \ +DEFINEFUNC_RENAMED_3_0(int, EVP_MD_get_block_size, EVP_MD_block_size, (const GO_EVP_MD_PTR md), (md)) \