diff --git a/cng/aes.go b/cng/aes.go index 654cb71..caac632 100644 --- a/cng/aes.go +++ b/cng/aes.go @@ -40,20 +40,27 @@ func (c *aesCipher) finalize() { func (c *aesCipher) BlockSize() int { return aesBlockSize } -func (c *aesCipher) Encrypt(dst, src []byte) { +// validateAndClipInputs checks that dst and src meet the [cipher.Block] +// interface requirements and clips them to a single block. +func (c *aesCipher) validateAndClipInputs(dst, src []byte) (d, s []byte) { if len(src) < aesBlockSize { panic("crypto/aes: input not full block") } if len(dst) < aesBlockSize { panic("crypto/aes: output not full block") } - - // cypher.Block.Encrypt() is documented to encrypt one full block - // at a time, so we truncate the input and output to the block size. - dst, src = dst[:aesBlockSize], src[:aesBlockSize] - if subtle.InexactOverlap(dst, src) { - panic("crypto/cipher: invalid buffer overlap") + // cypher.Block methods are documented to operate on + // one block at a time, so we truncate the input and output + // to the block size. + d, s = dst[:aesBlockSize], src[:aesBlockSize] + if subtle.InexactOverlap(d, s) { + panic("crypto/aes: invalid buffer overlap") } + return d, s +} + +func (c *aesCipher) Encrypt(dst, src []byte) { + dst, src = c.validateAndClipInputs(dst, src) var ret uint32 err := bcrypt.Encrypt(c.kh, src, nil, nil, dst, &ret, 0) @@ -67,19 +74,7 @@ func (c *aesCipher) Encrypt(dst, src []byte) { } func (c *aesCipher) Decrypt(dst, src []byte) { - if len(src) < aesBlockSize { - panic("crypto/aes: input not full block") - } - if len(dst) < aesBlockSize { - panic("crypto/aes: output not full block") - } - - // cypher.Block.Decrypt() is documented to decrypt one full block - // at a time, so we truncate the input and output to the block size. - dst, src = dst[:aesBlockSize], src[:aesBlockSize] - if subtle.InexactOverlap(dst, src) { - panic("crypto/cipher: invalid buffer overlap") - } + dst, src = c.validateAndClipInputs(dst, src) var ret uint32 err := bcrypt.Decrypt(c.kh, src, nil, nil, dst, &ret, 0)