Skip to content

Commit

Permalink
Merge pull request #64 from microsoft/dev/qmuntal/aesfix
Browse files Browse the repository at this point in the history
aes: encrypt and decrypt one block at a time
  • Loading branch information
qmuntal authored Sep 29, 2024
2 parents f9843f6 + 145cfd3 commit 3e2be6d
Show file tree
Hide file tree
Showing 7 changed files with 759 additions and 77 deletions.
30 changes: 17 additions & 13 deletions cng/aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,28 @@ func (c *aesCipher) finalize() {

func (c *aesCipher) BlockSize() int { return aesBlockSize }

func (c *aesCipher) Encrypt(dst, src []byte) {
if subtle.InexactOverlap(dst, src) {
panic("crypto/cipher: invalid buffer overlap")
}
// validateAndClipInputs checks that dst and src meet the [cipher.Block]
// interface requirements and clips them to a single block.
func (c *aesCipher) validateAndClipInputs(dst, src []byte) (d, s []byte) {
if len(src) < aesBlockSize {
panic("crypto/aes: input not full block")
}
if len(dst) < aesBlockSize {
panic("crypto/aes: output not full block")
}
// cypher.Block methods are documented to operate on
// one block at a time, so we truncate the input and output
// to the block size.
d, s = dst[:aesBlockSize], src[:aesBlockSize]
if subtle.InexactOverlap(d, s) {
panic("crypto/aes: invalid buffer overlap")
}
return d, s
}

func (c *aesCipher) Encrypt(dst, src []byte) {
dst, src = c.validateAndClipInputs(dst, src)

var ret uint32
err := bcrypt.Encrypt(c.kh, src, nil, nil, dst, &ret, 0)
if err != nil {
Expand All @@ -62,15 +74,7 @@ func (c *aesCipher) Encrypt(dst, src []byte) {
}

func (c *aesCipher) Decrypt(dst, src []byte) {
if subtle.InexactOverlap(dst, src) {
panic("crypto/cipher: invalid buffer overlap")
}
if len(src) < aesBlockSize {
panic("crypto/aes: input not full block")
}
if len(dst) < aesBlockSize {
panic("crypto/aes: output not full block")
}
dst, src = c.validateAndClipInputs(dst, src)

var ret uint32
err := bcrypt.Decrypt(c.kh, src, nil, nil, dst, &ret, 0)
Expand Down
148 changes: 101 additions & 47 deletions cng/aes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,34 @@
//go:build windows
// +build windows

package cng
package cng_test

import (
"bytes"
"crypto/cipher"
"fmt"
"strings"
"testing"

"github.com/microsoft/go-crypto-winnative/cng"
"github.com/microsoft/go-crypto-winnative/internal/cryptotest"
)

var key = []byte("D249BF6DEC97B1EBD69BC4D6B3A3C49D")

const (
gcmTagSize = 16
gcmStandardNonceSize = 12
)

func TestNewGCMNonce(t *testing.T) {
ci, err := NewAESCipher(key)
ci, err := cng.NewAESCipher(key)
if err != nil {
t.Fatal(err)
}
c := ci.(*aesCipher)
c := ci.(interface {
NewGCM(nonceSize, tagSize int) (cipher.AEAD, error)
})
_, err = c.NewGCM(gcmStandardNonceSize-1, gcmTagSize-1)
if err == nil {
t.Error("expected error for non-standard tag and nonce size at the same time, got none")
Expand All @@ -39,12 +51,11 @@ func TestNewGCMNonce(t *testing.T) {
}

func TestSealAndOpen(t *testing.T) {
ci, err := NewAESCipher(key)
ci, err := cng.NewAESCipher(key)
if err != nil {
t.Fatal(err)
}
c := ci.(*aesCipher)
gcm, err := c.NewGCM(gcmStandardNonceSize, gcmTagSize)
gcm, err := cipher.NewGCMWithTagSize(ci, gcmTagSize)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -89,16 +100,16 @@ func TestSealAndOpenTLS(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ci, err := NewAESCipher(key)
ci, err := cng.NewAESCipher(key)
if err != nil {
t.Fatal(err)
}
var gcm cipher.AEAD
switch tt.tls {
case "1.2":
gcm, err = NewGCMTLS(ci)
gcm, err = cng.NewGCMTLS(ci)
case "1.3":
gcm, err = NewGCMTLS13(ci)
gcm, err = cng.NewGCMTLS13(ci)
}
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -157,12 +168,11 @@ func TestSealAndOpenTLS(t *testing.T) {
}

func TestSealAndOpenAuthenticationError(t *testing.T) {
ci, err := NewAESCipher(key)
ci, err := cng.NewAESCipher(key)
if err != nil {
t.Fatal(err)
}
c := ci.(*aesCipher)
gcm, err := c.NewGCM(gcmStandardNonceSize, gcmTagSize)
gcm, err := cipher.NewGCMWithTagSize(ci, gcmTagSize)
if err != nil {
t.Fatal(err)
}
Expand All @@ -171,7 +181,7 @@ func TestSealAndOpenAuthenticationError(t *testing.T) {
additionalData := []byte{0x05, 0x05, 0x07}
sealed := gcm.Seal(nil, nonce, plainText, additionalData)
_, err = gcm.Open(nil, nonce, sealed, nil)
if err != errOpen {
if !strings.Contains(err.Error(), "cipher: message authentication failed") {
t.Errorf("expected authentication error, got: %#v", err)
}
}
Expand All @@ -187,12 +197,11 @@ func assertPanic(t *testing.T, f func()) {
}

func TestSealPanic(t *testing.T) {
ci, err := NewAESCipher(key)
ci, err := cng.NewAESCipher(key)
if err != nil {
t.Fatal(err)
}
c := ci.(*aesCipher)
gcm, err := c.NewGCM(gcmStandardNonceSize, gcmTagSize)
gcm, err := cipher.NewGCMWithTagSize(ci, gcmTagSize)
if err != nil {
t.Fatal(err)
}
Expand All @@ -209,14 +218,14 @@ func TestSealPanic(t *testing.T) {
}

func TestAESInvalidKeySize(t *testing.T) {
_, err := NewAESCipher([]byte{1})
_, err := cng.NewAESCipher([]byte{1})
if err == nil {
t.Error("error expected")
}
}

func TestEncryptAndDecrypt(t *testing.T) {
ci, err := NewAESCipher(key)
ci, err := cng.NewAESCipher(key)
if err != nil {
t.Fatal(err)
}
Expand All @@ -236,7 +245,7 @@ func TestCBCBlobEncryptBasicBlockEncryption(t *testing.T) {
key := []byte{0x24, 0xcd, 0x8b, 0x13, 0x37, 0xc5, 0xc1, 0xb1, 0x0, 0xbb, 0x27, 0x40, 0x4f, 0xab, 0x5f, 0x7b, 0x2d, 0x0, 0x20, 0xf5, 0x1, 0x84, 0x4, 0xbf, 0xe3, 0xbd, 0xa1, 0xc4, 0xbf, 0x61, 0x2f, 0xc5}
iv := []byte{0x91, 0xc7, 0xa7, 0x54, 0x52, 0xef, 0x10, 0xdb, 0x91, 0xa8, 0x6c, 0xf9, 0x79, 0xd5, 0xac, 0x74}

block, err := NewAESCipher(key)
block, err := cng.NewAESCipher(key)
if err != nil {
t.Errorf("expected no error for aes.NewCipher, got: %s", err)
}
Expand All @@ -245,19 +254,14 @@ func TestCBCBlobEncryptBasicBlockEncryption(t *testing.T) {
if blockSize != 16 {
t.Errorf("unexpected block size, expected 16 got: %d", blockSize)
}
var encryptor cipher.BlockMode
if c, ok := block.(*aesCipher); ok {
encryptor = c.NewCBCEncrypter(iv)
if encryptor == nil {
t.Error("unable to create new CBC encrypter")
}
}

encrypter := cipher.NewCBCEncrypter(block, iv)

encrypted := make([]byte, 32)

// First block. 16 bytes.
srcBlock1 := bytes.Repeat([]byte{0x01}, 16)
encryptor.CryptBlocks(encrypted, srcBlock1)
encrypter.CryptBlocks(encrypted, srcBlock1)
if !bytes.Equal([]byte{
0x14, 0xb7, 0x3e, 0x2f, 0xd9, 0xe7, 0x69, 0x7e, 0xb7, 0xd2, 0xc3, 0x5b, 0x31, 0x9c, 0xf0, 0x59,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
Expand All @@ -267,21 +271,15 @@ func TestCBCBlobEncryptBasicBlockEncryption(t *testing.T) {

// Second block. 16 bytes.
srcBlock2 := bytes.Repeat([]byte{0x02}, 16)
encryptor.CryptBlocks(encrypted[16:], srcBlock2)
encrypter.CryptBlocks(encrypted[16:], srcBlock2)
if !bytes.Equal([]byte{
0x14, 0xb7, 0x3e, 0x2f, 0xd9, 0xe7, 0x69, 0x7e, 0xb7, 0xd2, 0xc3, 0x5b, 0x31, 0x9c, 0xf0, 0x59,
0xbb, 0xd4, 0x95, 0x25, 0x21, 0x56, 0x87, 0x3b, 0xe6, 0x22, 0xe8, 0xd0, 0x19, 0xa8, 0xed, 0xcd,
}, encrypted) {
t.Error("unexpected CryptBlocks result for second block")
}

var decrypter cipher.BlockMode
if c, ok := block.(*aesCipher); ok {
decrypter = c.NewCBCDecrypter(iv)
if decrypter == nil {
t.Error("unable to create new CBC decrypter")
}
}
decrypter := cipher.NewCBCDecrypter(block, iv)
plainText := append(srcBlock1, srcBlock2...)
decrypted := make([]byte, len(plainText))
decrypter.CryptBlocks(decrypted, encrypted[:16])
Expand All @@ -299,7 +297,7 @@ func TestCBCDecryptSimple(t *testing.T) {
0xe3, 0xbd, 0xa1, 0xc4, 0xbf, 0x61, 0x2f, 0xc5,
}

block, err := NewAESCipher(key)
block, err := cng.NewAESCipher(key)
if err != nil {
t.Fatal(err)
}
Expand All @@ -308,17 +306,9 @@ func TestCBCDecryptSimple(t *testing.T) {
0x91, 0xc7, 0xa7, 0x54, 0x52, 0xef, 0x10, 0xdb,
0x91, 0xa8, 0x6c, 0xf9, 0x79, 0xd5, 0xac, 0x74,
}
var encrypter, decrypter cipher.BlockMode
if c, ok := block.(*aesCipher); ok {
encrypter = c.NewCBCEncrypter(iv)
if encrypter == nil {
t.Error("unable to create new CBC encrypter")
}
decrypter = c.NewCBCDecrypter(iv)
if decrypter == nil {
t.Error("unable to create new CBC decrypter")
}
}

encrypter := cipher.NewCBCEncrypter(block, iv)
decrypter := cipher.NewCBCDecrypter(block, iv)

plainText := []byte{
0x54, 0x68, 0x65, 0x72, 0x65, 0x20, 0x69, 0x73,
Expand Down Expand Up @@ -375,3 +365,67 @@ func TestCBCDecryptSimple(t *testing.T) {
t.Errorf("decryption incorrect\nexp %v, got %v\n", plainText, decrypted)
}
}

// Test AES against the general cipher.Block interface tester.
func TestAESBlock(t *testing.T) {
for _, keylen := range []int{128, 192, 256} {
t.Run(fmt.Sprintf("AES-%d", keylen), func(t *testing.T) {
cryptotest.TestBlock(t, keylen/8, cng.NewAESCipher)
})
}
}

func TestAESBlockMode(t *testing.T) {
for _, keylen := range []int{128, 192, 256} {
t.Run(fmt.Sprintf("AES-%d", keylen), func(t *testing.T) {
rng := newRandReader(t)

key := make([]byte, keylen/8)
rng.Read(key)

block, err := cng.NewAESCipher(key)
if err != nil {
t.Fatal(err)
}

cryptotest.TestBlockMode(t, block, cipher.NewCBCEncrypter, cipher.NewCBCDecrypter)
})
}
}

// Test GCM against the general cipher.AEAD interface tester.
func TestAESGCMAEAD(t *testing.T) {
minTagSize := 12

for _, keySize := range []int{128, 192, 256} {
// Use AES as underlying block cipher at different key sizes for GCM.
t.Run(fmt.Sprintf("AES-%d", keySize), func(t *testing.T) {
rng := newRandReader(t)

key := make([]byte, keySize/8)
rng.Read(key)

block, err := cng.NewAESCipher(key)
if err != nil {
panic(err)
}

// Test GCM with the current AES block with the standard nonce and tag
// sizes.
cryptotest.TestAEAD(t, func() (cipher.AEAD, error) { return cipher.NewGCM(block) })

// Test non-standard tag sizes.
t.Run("MinTagSize", func(t *testing.T) {
cryptotest.TestAEAD(t, func() (cipher.AEAD, error) { return cipher.NewGCMWithTagSize(block, minTagSize) })
})

// Test non-standard nonce sizes.
for _, nonceSize := range []int{1, 16, 100} {
t.Run(fmt.Sprintf("NonceSize-%d", nonceSize), func(t *testing.T) {

cryptotest.TestAEAD(t, func() (cipher.AEAD, error) { return cipher.NewGCMWithNonceSize(block, nonceSize) })
})
}
})
}
}
9 changes: 9 additions & 0 deletions cng/cng_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ package cng_test

import (
"fmt"
"io"
"math/rand"
"os"
"strconv"
"testing"
"time"

"github.com/microsoft/go-crypto-winnative/cng"
)
Expand Down Expand Up @@ -39,3 +42,9 @@ func TestFIPS(t *testing.T) {
}
}
}

func newRandReader(t *testing.T) io.Reader {
seed := time.Now().UnixNano()
t.Logf("Deterministic RNG seed: 0x%x", seed)
return rand.New(rand.NewSource(seed))
}
Loading

0 comments on commit 3e2be6d

Please sign in to comment.