diff --git a/aes.go b/aes.go index ecda35a9..c097ef8f 100644 --- a/aes.go +++ b/aes.go @@ -56,11 +56,17 @@ func (c *aesCipher) BlockSize() int { } func (c *aesCipher) Encrypt(dst, src []byte) { - c.encrypt(dst, src) + if err := c.encrypt(dst, src); err != nil { + // Upstream expects that the panic message starts with "crypto/aes: ". + panic("crypto/aes: " + err.Error()) + } } func (c *aesCipher) Decrypt(dst, src []byte) { - c.decrypt(dst, src) + if err := c.decrypt(dst, src); err != nil { + // Upstream expects that the panic message starts with "crypto/aes: ". + panic("crypto/aes: " + err.Error()) + } } func (c *aesCipher) NewCBCEncrypter(iv []byte) cipher.BlockMode { diff --git a/aes_test.go b/aes_test.go index 14c798cd..39cc1600 100644 --- a/aes_test.go +++ b/aes_test.go @@ -9,6 +9,31 @@ import ( "github.com/golang-fips/openssl/v2" ) +func TestAESShortBlocks(t *testing.T) { + bytes := func(n int) []byte { return make([]byte, n) } + + c, _ := openssl.NewAESCipher(bytes(16)) + + mustPanic(t, "crypto/aes: input not full block", func() { c.Encrypt(bytes(1), bytes(1)) }) + mustPanic(t, "crypto/aes: input not full block", func() { c.Decrypt(bytes(1), bytes(1)) }) + mustPanic(t, "crypto/aes: input not full block", func() { c.Encrypt(bytes(100), bytes(1)) }) + mustPanic(t, "crypto/aes: input not full block", func() { c.Decrypt(bytes(100), bytes(1)) }) + mustPanic(t, "crypto/aes: output not full block", func() { c.Encrypt(bytes(1), bytes(100)) }) + mustPanic(t, "crypto/aes: output not full block", func() { c.Decrypt(bytes(1), bytes(100)) }) +} + +func mustPanic(t *testing.T, msg string, f func()) { + defer func() { + err := recover() + if err == nil { + t.Errorf("function did not panic, wanted %q", msg) + } else if err != msg { + t.Errorf("got panic %v, wanted %q", err, msg) + } + }() + f() +} + func TestNewGCMNonce(t *testing.T) { key := []byte("D249BF6DEC97B1EBD69BC4D6B3A3C49D") ci, err := openssl.NewAESCipher(key) diff --git a/cipher.go b/cipher.go index df6c40f1..c8828690 100644 --- a/cipher.go +++ b/cipher.go @@ -166,57 +166,59 @@ func (c *evpCipher) finalize() { } } -func (c *evpCipher) encrypt(dst, src []byte) { +func (c *evpCipher) encrypt(dst, src []byte) error { if len(src) < c.blockSize { - panic("crypto/cipher: input not full block") + return errors.New("input not full block") } if len(dst) < c.blockSize { - panic("crypto/cipher: output not full block") + return errors.New("output not full block") } // Only check for overlap between the parts of src and dst that will actually be used. // This matches Go standard library behavior. if inexactOverlap(dst[:c.blockSize], src[:c.blockSize]) { - panic("crypto/cipher: invalid buffer overlap") + return errors.New("invalid buffer overlap") } if c.enc_ctx == nil { var err error c.enc_ctx, err = newCipherCtx(c.kind, cipherModeECB, cipherOpEncrypt, c.key, nil) if err != nil { - panic(err) + return err } } if C.go_openssl_EVP_EncryptUpdate_wrapper(c.enc_ctx, base(dst), base(src), C.int(c.blockSize)) != 1 { - panic("crypto/cipher: EncryptUpdate failed") + return errors.New("EncryptUpdate failed") } runtime.KeepAlive(c) + return nil } -func (c *evpCipher) decrypt(dst, src []byte) { +func (c *evpCipher) decrypt(dst, src []byte) error { if len(src) < c.blockSize { - panic("crypto/cipher: input not full block") + return errors.New("input not full block") } if len(dst) < c.blockSize { - panic("crypto/cipher: output not full block") + return errors.New("output not full block") } // Only check for overlap between the parts of src and dst that will actually be used. // This matches Go standard library behavior. if inexactOverlap(dst[:c.blockSize], src[:c.blockSize]) { - panic("crypto/cipher: invalid buffer overlap") + return errors.New("invalid buffer overlap") } if c.dec_ctx == nil { var err error c.dec_ctx, err = newCipherCtx(c.kind, cipherModeECB, cipherOpDecrypt, c.key, nil) if err != nil { - panic(err) + return err } if C.go_openssl_EVP_CIPHER_CTX_set_padding(c.dec_ctx, 0) != 1 { - panic("crypto/cipher: could not disable cipher padding") + return errors.New("could not disable cipher padding") } } C.go_openssl_EVP_DecryptUpdate_wrapper(c.dec_ctx, base(dst), base(src), C.int(c.blockSize)) runtime.KeepAlive(c) + return nil } type cipherCBC struct {