diff --git a/rsa_test.go b/rsa_test.go index 5b92025e..a1df9970 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -257,33 +257,39 @@ func TestRSASignVerifyRSAPSS(t *testing.T) { // Test cases taken from // https://github.com/golang/go/blob/54182ff54a687272dd7632c3a963e036ce03cb7c/src/crypto/rsa/pss_test.go#L200. const keyBits = 2048 + sha256 := openssl.NewSHA256() var saltLengthCombinations = []struct { signSaltLength, verifySaltLength int - good bool + good, fipsGood bool }{ - {rsa.PSSSaltLengthAuto, rsa.PSSSaltLengthAuto, true}, - {rsa.PSSSaltLengthEqualsHash, rsa.PSSSaltLengthAuto, true}, - {rsa.PSSSaltLengthEqualsHash, rsa.PSSSaltLengthEqualsHash, true}, - {rsa.PSSSaltLengthEqualsHash, 8, false}, - {rsa.PSSSaltLengthAuto, rsa.PSSSaltLengthEqualsHash, false}, - {8, 8, true}, - {rsa.PSSSaltLengthAuto, keyBits/8 - 2 - 32, true}, // simulate Go PSSSaltLengthAuto algorithm (32 = sha256 size) - {rsa.PSSSaltLengthAuto, 20, false}, - {rsa.PSSSaltLengthAuto, -2, false}, + {rsa.PSSSaltLengthAuto, rsa.PSSSaltLengthAuto, true, true}, + {rsa.PSSSaltLengthEqualsHash, rsa.PSSSaltLengthAuto, true, true}, + {rsa.PSSSaltLengthEqualsHash, rsa.PSSSaltLengthEqualsHash, true, true}, + {rsa.PSSSaltLengthEqualsHash, 8, false, false}, + {rsa.PSSSaltLengthAuto, rsa.PSSSaltLengthEqualsHash, false, false}, + {8, 8, true, true}, + // In FIPS mode, PSSSaltLengthAuto is capped at PSSSaltLengthEqualsHash. + {rsa.PSSSaltLengthAuto, rsa.PSSSaltLengthEqualsHash, false, true}, + {rsa.PSSSaltLengthAuto, keyBits/8 - 2 - sha256.Size(), true, false}, // simulate Go PSSSaltLengthAuto algorithm + {rsa.PSSSaltLengthAuto, sha256.Size(), false, true}, + {rsa.PSSSaltLengthAuto, -2, false, false}, } - sha256 := openssl.NewSHA256() priv, pub := newRSAKey(t, keyBits) sha256.Write([]byte("testing")) hashed := sha256.Sum(nil) for i, test := range saltLengthCombinations { signed, err := openssl.SignRSAPSS(priv, crypto.SHA256, hashed, test.signSaltLength) if err != nil { - t.Errorf("#%d: error while signing: %s", i, err) + t.Errorf("#%d: error while signing: %v", i, err) continue } err = openssl.VerifyRSAPSS(pub, crypto.SHA256, hashed, signed, test.verifySaltLength) - if (err == nil) != test.good { - t.Errorf("#%d: bad result, wanted: %t, got: %s", i, test.good, err) + good := test.good + if openssl.FIPS() { + good = test.fipsGood + } + if (err == nil) != good { + t.Errorf("#%d: bad result, wanted: %t, got: %v", i, good, err) } } }