Skip to content

Commit

Permalink
Merge pull request #111 from corhere/fix-aes-gcm-102fips
Browse files Browse the repository at this point in the history
Fix AES-GCM decryption on OpenSSL 1.0.2-fips
  • Loading branch information
qmuntal authored Oct 19, 2023
2 parents 08f07a7 + 9b5d63e commit 61234a9
Show file tree
Hide file tree
Showing 14 changed files with 935 additions and 50 deletions.
65 changes: 46 additions & 19 deletions aes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,52 @@ func TestNewGCMNonce(t *testing.T) {
}

func TestSealAndOpen(t *testing.T) {
key := []byte("D249BF6DEC97B1EBD69BC4D6B3A3C49D")
ci, err := openssl.NewAESCipher(key)
if err != nil {
t.Fatal(err)
}
gcm, err := cipher.NewGCM(ci)
if err != nil {
t.Fatal(err)
}
nonce := []byte{0x91, 0xc7, 0xa7, 0x54, 0x52, 0xef, 0x10, 0xdb, 0x91, 0xa8, 0x6c, 0xf9}
plainText := []byte{0x01, 0x02, 0x03}
additionalData := []byte{0x05, 0x05, 0x07}
sealed := gcm.Seal(nil, nonce, plainText, additionalData)
decrypted, err := gcm.Open(nil, nonce, sealed, additionalData)
if err != nil {
t.Error(err)
}
if !bytes.Equal(decrypted, plainText) {
t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, plainText)
for _, tt := range aesGCMTests {
t.Run(tt.description, func(t *testing.T) {
ci, err := openssl.NewAESCipher(tt.key)
if err != nil {
t.Fatalf("NewAESCipher() err = %v", err)
}
gcm, err := cipher.NewGCM(ci)
if err != nil {
t.Fatalf("cipher.NewGCM() err = %v", err)
}

sealed := gcm.Seal(nil, tt.nonce, tt.plaintext, tt.aad)
if !bytes.Equal(sealed, tt.ciphertext) {
t.Errorf("unexpected sealed result\ngot: %#v\nexp: %#v", sealed, tt.ciphertext)
}

decrypted, err := gcm.Open(nil, tt.nonce, tt.ciphertext, tt.aad)
if err != nil {
t.Errorf("gcm.Open() err = %v", err)
}
if !bytes.Equal(decrypted, tt.plaintext) {
t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, tt.plaintext)
}

// Test that open fails if the ciphertext is modified.
tt.ciphertext[0] ^= 0x80
_, err = gcm.Open(nil, tt.nonce, tt.ciphertext, tt.aad)
if err != openssl.ErrOpen {
t.Errorf("expected authentication error for tampered message\ngot: %#v", err)
}
tt.ciphertext[0] ^= 0x80

// Test that the ciphertext can be opened using a fresh context
// which was not previously used to seal the same message.
gcm, err = cipher.NewGCM(ci)
if err != nil {
t.Fatalf("cipher.NewGCM() err = %v", err)
}
decrypted, err = gcm.Open(nil, tt.nonce, tt.ciphertext, tt.aad)
if err != nil {
t.Errorf("fresh GCM instance: gcm.Open() err = %v", err)
}
if !bytes.Equal(decrypted, tt.plaintext) {
t.Errorf("fresh GCM instance: unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, tt.plaintext)
}
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ func newCipherCtx(kind cipherKind, mode cipherMode, encrypt cipherOp, key, iv []
cipher = nil
}
if C.go_openssl_EVP_CipherInit_ex(ctx, cipher, nil, base(key), base(iv), C.int(encrypt)) != 1 {
return nil, fail("unable to initialize EVP cipher ctx")
return nil, newOpenSSLError("unable to initialize EVP cipher ctx")
}
return ctx, nil
}
Expand Down
116 changes: 116 additions & 0 deletions cmd/gentestvectors/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// gentestvectors emits cryptographic test vectors using the Go standard library
// cryptographic routines to test the OpenSSL bindings.
package main

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"flag"
"fmt"
"go/format"
"io"
"log"
"math/rand"
"os"
"path/filepath"
)

var outputPath = flag.String("out", "", "output path (default stdout)")

func init() {
log.SetFlags(log.Llongfile)
log.SetOutput(os.Stderr)
}

func main() {
flag.Parse()

var b bytes.Buffer
fmt.Fprint(&b, "// Code generated by cmd/gentestvectors. DO NOT EDIT.\n\n")
if *outputPath != "" {
fmt.Fprintf(&b, "//go"+":generate go run github.com/golang-fips/openssl/v2/cmd/gentestvectors -out %s\n\n", filepath.Base(*outputPath))
}

pkg := "openssl_test"
if gopackage := os.Getenv("GOPACKAGE"); gopackage != "" {
pkg = gopackage + "_test"
}
fmt.Fprintf(&b, "package %s\n\n", pkg)

aesGCM(&b)

generated, err := format.Source(b.Bytes())
if err != nil {
log.Fatalf("failed to format generated code: %v", err)
}

if *outputPath != "" {
err := os.WriteFile(*outputPath, generated, 0o644)
if err != nil {
log.Fatalf("failed to write output file: %v\n", err)
}
} else {
_, _ = os.Stdout.Write(generated)
}
}

func aesGCM(w io.Writer) {
r := rand.New(rand.NewSource(0))

fmt.Fprintln(w, `var aesGCMTests = []struct {
description string
key, nonce, plaintext, aad, ciphertext []byte
}{`)

for _, keyLen := range []int{16, 24, 32} {
for _, aadLen := range []int{0, 1, 3, 13, 30} {
for _, plaintextLen := range []int{0, 1, 3, 13, 16, 51} {
if aadLen == 0 && plaintextLen == 0 {
continue
}

key := randbytes(r, keyLen)
nonce := randbytes(r, 12)
plaintext := randbytes(r, plaintextLen)
aad := randbytes(r, aadLen)

c, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(c)
if err != nil {
panic(err)
}
ciphertext := aead.Seal(nil, nonce, plaintext, aad)

fmt.Fprint(w, "\t{\n")
fmt.Fprintf(w, "\t\tdescription: \"AES-%d/AAD=%d/Plaintext=%d\",\n", keyLen*8, aadLen, plaintextLen)
printBytesField(w, "key", key)
printBytesField(w, "nonce", nonce)
printBytesField(w, "plaintext", plaintext)
printBytesField(w, "aad", aad)
printBytesField(w, "ciphertext", ciphertext)
fmt.Fprint(w, "\t},\n")
}
}
}
fmt.Fprintln(w, "}")
}

func randbytes(r *rand.Rand, n int) []byte {
if n == 0 {
return nil
}
b := make([]byte, n)
r.Read(b)
return b
}

func printBytesField(w io.Writer, name string, b []byte) {
if len(b) == 0 {
return
}
fmt.Fprintf(w, "\t\t%s: %#v,\n", name, b)
}
4 changes: 2 additions & 2 deletions des.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import (
// If CBC is also supported, then the returned cipher.Block
// will also implement NewCBCEncrypter and NewCBCDecrypter.
func SupportsDESCipher() bool {
// True for stock OpenSSL 1.
// True for stock OpenSSL 1 w/o FIPS.
// False for stock OpenSSL 3 unless the legacy provider is available.
return loadCipher(cipherDES, cipherModeECB) != nil
return (versionAtOrAbove(1, 1, 0) || !FIPS()) && loadCipher(cipherDES, cipherModeECB) != nil
}

// SupportsTripleDESCipher returns true if NewTripleDESCipher is supported,
Expand Down
2 changes: 1 addition & 1 deletion ed25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func SupportsEd25519() bool {
onceSupportsEd25519.Do(func() {
switch vMajor {
case 1:
supportsEd25519 = version1_1_1_or_above()
supportsEd25519 = versionAtOrAbove(1, 1, 1)
case 3:
name := C.CString("ED25519")
defer C.free(unsafe.Pointer(name))
Expand Down
16 changes: 10 additions & 6 deletions evp.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,13 @@ func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) {
}
switch ch {
case crypto.MD4:
return C.go_openssl_EVP_md4()
if versionAtOrAbove(1, 1, 0) || !FIPS() {
return C.go_openssl_EVP_md4()
}
case crypto.MD5:
return C.go_openssl_EVP_md5()
if versionAtOrAbove(1, 1, 0) || !FIPS() {
return C.go_openssl_EVP_md5()
}
case crypto.SHA1:
return C.go_openssl_EVP_sha1()
case crypto.SHA224:
Expand All @@ -86,19 +90,19 @@ func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) {
case crypto.SHA512:
return C.go_openssl_EVP_sha512()
case crypto.SHA3_224:
if version1_1_1_or_above() {
if versionAtOrAbove(1, 1, 1) {
return C.go_openssl_EVP_sha3_224()
}
case crypto.SHA3_256:
if version1_1_1_or_above() {
if versionAtOrAbove(1, 1, 1) {
return C.go_openssl_EVP_sha3_256()
}
case crypto.SHA3_384:
if version1_1_1_or_above() {
if versionAtOrAbove(1, 1, 1) {
return C.go_openssl_EVP_sha3_384()
}
case crypto.SHA3_512:
if version1_1_1_or_above() {
if versionAtOrAbove(1, 1, 1) {
return C.go_openssl_EVP_sha3_512()
}
}
Expand Down
4 changes: 2 additions & 2 deletions goopenssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ go_openssl_fips_enabled(void* handle)
// and assign them to their corresponding function pointer
// defined in goopenssl.h.
void
go_openssl_load_functions(void* handle, int major, int minor, int patch)
go_openssl_load_functions(void* handle, unsigned int major, unsigned int minor, unsigned int patch)
{
#define DEFINEFUNC_INTERNAL(name, func) \
_g_##name = dlsym(handle, func); \
if (_g_##name == NULL) { \
fprintf(stderr, "Cannot get required symbol " #func " from libcrypto version %d.%d\n", major, minor); \
fprintf(stderr, "Cannot get required symbol " #func " from libcrypto version %u.%u\n", major, minor); \
abort(); \
}
#define DEFINEFUNC(ret, func, args, argscall) \
Expand Down
28 changes: 23 additions & 5 deletions goopenssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int go_openssl_version_major(void* handle);
int go_openssl_version_minor(void* handle);
int go_openssl_version_patch(void* handle);
int go_openssl_thread_setup(void);
void go_openssl_load_functions(void* handle, int major, int minor, int patch);
void go_openssl_load_functions(void* handle, unsigned int major, unsigned int minor, unsigned int patch);
const GO_EVP_MD_PTR go_openssl_EVP_md5_sha1_backport(void);

// Define pointers to all the used OpenSSL functions.
Expand Down Expand Up @@ -144,22 +144,40 @@ go_openssl_EVP_CIPHER_CTX_open_wrapper(const GO_EVP_CIPHER_CTX_PTR ctx,
const unsigned char *aad, int aad_len,
const unsigned char *tag)
{
if (in_len == 0) in = (const unsigned char *)"";
if (in_len == 0) {
in = (const unsigned char *)"";
// OpenSSL 1.0.2 in FIPS mode contains a bug: it will fail to verify
// unless EVP_DecryptUpdate is called at least once with a non-NULL
// output buffer. OpenSSL will not dereference the output buffer when
// the input length is zero, so set it to an arbitrary non-NULL pointer
// to satisfy OpenSSL when the caller only has authenticated additional
// data (AAD) to verify. While a stack-allocated buffer could be used,
// that would risk a stack-corrupting buffer overflow if OpenSSL
// unexpectedly dereferenced it. Instead pass a value which would
// segfault if dereferenced on any modern platform where a NULL-pointer
// dereference would also segfault.
if (out == NULL) out = (unsigned char *)1;
}
if (aad_len == 0) aad = (const unsigned char *)"";

if (go_openssl_EVP_DecryptInit_ex(ctx, NULL, NULL, NULL, nonce) != 1)
return 0;

// OpenSSL 1.0.x FIPS Object Module 2.0 versions below 2.0.5 require that
// the tag be set before the ciphertext, otherwise EVP_DecryptUpdate returns
// an error. At least one extant commercially-supported, FIPS validated
// build of OpenSSL 1.0.2 uses FIPS module version 2.0.1. Set the tag first
// to maximize compatibility with all OpenSSL version combinations.
if (go_openssl_EVP_CIPHER_CTX_ctrl(ctx, GO_EVP_CTRL_GCM_SET_TAG, 16, (unsigned char *)(tag)) != 1)
return 0;

int discard_len, out_len;
if (go_openssl_EVP_DecryptUpdate(ctx, NULL, &discard_len, aad, aad_len) != 1
|| go_openssl_EVP_DecryptUpdate(ctx, out, &out_len, in, in_len) != 1)
{
return 0;
}

if (go_openssl_EVP_CIPHER_CTX_ctrl(ctx, GO_EVP_CTRL_GCM_SET_TAG, 16, (unsigned char *)(tag)) != 1)
return 0;

if (go_openssl_EVP_DecryptFinal_ex(ctx, out + out_len, &discard_len) != 1)
return 0;

Expand Down
2 changes: 1 addition & 1 deletion hkdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func SupportsHKDF() bool {
return version1_1_1_or_above()
return versionAtOrAbove(1, 1, 1)
}

func newHKDF(h func() hash.Hash, mode C.int) (*hkdf, error) {
Expand Down
13 changes: 7 additions & 6 deletions init.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
// as reported by the OpenSSL API.
//
// See Init() for details about file.
func opensslInit(file string) (major, minor, patch int, err error) {
func opensslInit(file string) (major, minor, patch uint, err error) {
// Load the OpenSSL shared library using dlopen.
handle, err := dlopen(file)
if err != nil {
Expand All @@ -24,12 +24,13 @@ func opensslInit(file string) (major, minor, patch int, err error) {
// Notice that major and minor could not match with the version parameter
// in case the name of the shared library file differs from the OpenSSL
// version it contains.
major = int(C.go_openssl_version_major(handle))
minor = int(C.go_openssl_version_minor(handle))
patch = int(C.go_openssl_version_patch(handle))
if major == -1 || minor == -1 || patch == -1 {
imajor := int(C.go_openssl_version_major(handle))
iminor := int(C.go_openssl_version_minor(handle))
ipatch := int(C.go_openssl_version_patch(handle))
if imajor < 0 || iminor < 0 || ipatch < 0 {
return 0, 0, 0, errors.New("openssl: can't retrieve OpenSSL version")
}
major, minor, patch = uint(imajor), uint(iminor), uint(ipatch)
var supported bool
if major == 1 {
supported = minor == 0 || minor == 1
Expand All @@ -43,7 +44,7 @@ func opensslInit(file string) (major, minor, patch int, err error) {

// Load the OpenSSL functions.
// See shims.go for the complete list of supported functions.
C.go_openssl_load_functions(handle, C.int(major), C.int(minor), C.int(patch))
C.go_openssl_load_functions(handle, C.uint(major), C.uint(minor), C.uint(patch))

// Initialize OpenSSL.
C.go_openssl_OPENSSL_init()
Expand Down
Loading

0 comments on commit 61234a9

Please sign in to comment.