From d55a0aacd1ce94c093c3f16adb27db709e45d586 Mon Sep 17 00:00:00 2001 From: armfazh Date: Tue, 8 Oct 2024 18:06:12 -0700 Subject: [PATCH] Avoid to export unsafeSignInternal. --- sign/dilithium/gen.go | 35 ++++ sign/dilithium/templates/acvp.templ.go | 266 +++++++++++++++++++++++++ sign/dilithium/templates/pkg.templ.go | 6 +- sign/mldsa/{ => mldsa44}/acvp_test.go | 32 ++- sign/mldsa/mldsa44/dilithium.go | 2 +- sign/mldsa/mldsa65/acvp_test.go | 262 ++++++++++++++++++++++++ sign/mldsa/mldsa65/dilithium.go | 2 +- sign/mldsa/mldsa87/acvp_test.go | 262 ++++++++++++++++++++++++ sign/mldsa/mldsa87/dilithium.go | 2 +- 9 files changed, 845 insertions(+), 24 deletions(-) create mode 100644 sign/dilithium/templates/acvp.templ.go rename sign/mldsa/{ => mldsa44}/acvp_test.go (88%) create mode 100644 sign/mldsa/mldsa65/acvp_test.go create mode 100644 sign/mldsa/mldsa87/acvp_test.go diff --git a/sign/dilithium/gen.go b/sign/dilithium/gen.go index 6b136ec07..a817f4a41 100644 --- a/sign/dilithium/gen.go +++ b/sign/dilithium/gen.go @@ -145,6 +145,7 @@ var ( func main() { generateModePackageFiles() + generateACVPTest() generateParamsFiles() generateSourceFiles() } @@ -212,6 +213,40 @@ func generateModePackageFiles() { } } +// Generates modeX/dilithium.go from templates/pkg.templ.go +func generateACVPTest() { + tl, err := template.ParseFiles("templates/acvp.templ.go") + if err != nil { + panic(err) + } + + for _, mode := range Modes { + if !strings.HasPrefix(mode.Name, "ML-DSA") { + continue + } + + buf := new(bytes.Buffer) + err := tl.Execute(buf, mode) + if err != nil { + panic(err) + } + + res, err := format.Source(buf.Bytes()) + if err != nil { + panic("error formating code") + } + + offset := strings.Index(string(res), TemplateWarning) + if offset == -1 { + panic("Missing template warning in pkg.templ.go") + } + err = os.WriteFile(mode.PkgPath()+"/acvp_test.go", res[offset:], 0o644) + if err != nil { + panic(err) + } + } +} + // Copies mode3 source files to other modes func generateSourceFiles() { files := make(map[string][]byte) diff --git a/sign/dilithium/templates/acvp.templ.go b/sign/dilithium/templates/acvp.templ.go new file mode 100644 index 000000000..a86d8fd8b --- /dev/null +++ b/sign/dilithium/templates/acvp.templ.go @@ -0,0 +1,266 @@ +// +build ignore +// The previous line (and this one up to the warning below) is removed by the +// template generator. + +// Code generated from acvp.templ.go. DO NOT EDIT. + +package {{.Pkg}} + +import ( + "bytes" + "compress/gzip" + "encoding/hex" + "encoding/json" + "io" + "os" + "testing" +) + +// []byte but is encoded in hex for JSON +type HexBytes []byte + +func (b HexBytes) MarshalJSON() ([]byte, error) { + return json.Marshal(hex.EncodeToString(b)) +} + +func (b *HexBytes) UnmarshalJSON(data []byte) (err error) { + var s string + if err = json.Unmarshal(data, &s); err != nil { + return err + } + *b, err = hex.DecodeString(s) + return err +} + +func gunzip(in []byte) ([]byte, error) { + buf := bytes.NewBuffer(in) + r, err := gzip.NewReader(buf) + if err != nil { + return nil, err + } + return io.ReadAll(r) +} + +func readGzip(path string) ([]byte, error) { + buf, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return gunzip(buf) +} + +func TestACVP(t *testing.T) { + for _, sub := range []string{ + "keyGen", + "sigGen", + } { + t.Run(sub, func(t *testing.T) { + testACVP(t, sub) + }) + } +} + +// nolint:funlen,gocyclo +func testACVP(t *testing.T, sub string) { + buf, err := readGzip("../testdata/ML-DSA-" + sub + "-FIPS204/prompt.json.gz") + if err != nil { + t.Fatal(err) + } + + var prompt struct { + TestGroups []json.RawMessage `json:"testGroups"` + } + + if err = json.Unmarshal(buf, &prompt); err != nil { + t.Fatal(err) + } + + buf, err = readGzip("../testdata/ML-DSA-" + sub + "-FIPS204/expectedResults.json.gz") + if err != nil { + t.Fatal(err) + } + + var results struct { + TestGroups []json.RawMessage `json:"testGroups"` + } + + if err := json.Unmarshal(buf, &results); err != nil { + t.Fatal(err) + } + + rawResults := make(map[int]json.RawMessage) + + for _, rawGroup := range results.TestGroups { + var abstractGroup struct { + Tests []json.RawMessage `json:"tests"` + } + if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil { + t.Fatal(err) + } + for _, rawTest := range abstractGroup.Tests { + var abstractTest struct { + TcID int `json:"tcId"` + } + if err := json.Unmarshal(rawTest, &abstractTest); err != nil { + t.Fatal(err) + } + if _, exists := rawResults[abstractTest.TcID]; exists { + t.Fatalf("Duplicate test id: %d", abstractTest.TcID) + } + rawResults[abstractTest.TcID] = rawTest + } + } + + scheme := Scheme() + + for _, rawGroup := range prompt.TestGroups { + var abstractGroup struct { + TestType string `json:"testType"` + } + if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil { + t.Fatal(err) + } + switch { + case abstractGroup.TestType == "AFT" && sub == "keyGen": + var group struct { + TgID int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Tests []struct { + TcID int `json:"tcId"` + Seed HexBytes `json:"seed"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + if group.ParameterSet != scheme.Name() { + continue + } + + for _, test := range group.Tests { + var result struct { + Pk HexBytes `json:"pk"` + Sk HexBytes `json:"sk"` + } + rawResult, ok := rawResults[test.TcID] + if !ok { + t.Fatalf("Missing result: %d", test.TcID) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + pk, sk := scheme.DeriveKey(test.Seed) + + pk2, err := scheme.UnmarshalBinaryPublicKey(result.Pk) + if err != nil { + t.Fatalf("tc=%d: %v", test.TcID, err) + } + sk2, err := scheme.UnmarshalBinaryPrivateKey(result.Sk) + if err != nil { + t.Fatal(err) + } + + if !pk.Equal(pk2) { + t.Fatal("pk does not match") + } + if !sk.Equal(sk2) { + t.Fatal("sk does not match") + } + } + case abstractGroup.TestType == "AFT" && sub == "sigGen": + var group struct { + TgID int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Deterministic bool `json:"deterministic"` + Tests []struct { + TcID int `json:"tcId"` + Sk HexBytes `json:"sk"` + Message HexBytes `json:"message"` + Rnd HexBytes `json:"rnd"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + if group.ParameterSet != scheme.Name() { + continue + } + + for _, test := range group.Tests { + var result struct { + Signature HexBytes `json:"signature"` + } + rawResult, ok := rawResults[test.TcID] + if !ok { + t.Fatalf("Missing result: %d", test.TcID) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + sk, err := scheme.UnmarshalBinaryPrivateKey(test.Sk) + if err != nil { + t.Fatal(err) + } + + var rnd [32]byte + if !group.Deterministic { + copy(rnd[:], test.Rnd) + } + + sig2 := sk.(*PrivateKey).unsafeSignInternal(test.Message, rnd) + + if !bytes.Equal(sig2, result.Signature) { + t.Fatalf("signature doesn't match: %x ≠ %x", + sig2, result.Signature) + } + } + case abstractGroup.TestType == "AFT" && sub == "sigVer": + var group struct { + TgID int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Pk HexBytes `json:"pk"` + Tests []struct { + TcID int `json:"tcId"` + Message HexBytes `json:"message"` + Signature HexBytes `json:"signature"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + if group.ParameterSet != scheme.Name() { + continue + } + + pk, err := scheme.UnmarshalBinaryPublicKey(group.Pk) + if err != nil { + t.Fatal(err) + } + + for _, test := range group.Tests { + var result struct { + TestPassed bool `json:"testPassed"` + } + rawResult, ok := rawResults[test.TcID] + if !ok { + t.Fatalf("Missing result: %d", test.TcID) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + passed2 := scheme.Verify(pk, test.Message, test.Signature, nil) + if passed2 != result.TestPassed { + t.Fatalf("verification %v ≠ %v", passed2, result.TestPassed) + } + } + default: + t.Fatalf("unknown type %s for %s", abstractGroup.TestType, sub) + } + } +} diff --git a/sign/dilithium/templates/pkg.templ.go b/sign/dilithium/templates/pkg.templ.go index f07a08ccf..238de68c7 100644 --- a/sign/dilithium/templates/pkg.templ.go +++ b/sign/dilithium/templates/pkg.templ.go @@ -116,7 +116,7 @@ func SignTo(sk *PrivateKey, msg, sig []byte) { {{- if .NIST }} // Do not use. Implements ML-DSA.Sign_internal used for compatibility tests. -func (sk *PrivateKey) UnsafeSignInternal(msg []byte, rnd [32]byte) []byte { +func (sk *PrivateKey) unsafeSignInternal(msg []byte, rnd [32]byte) []byte { var ret [SignatureSize]byte internal.SignTo( (*internal.PrivateKey)(sk), @@ -386,7 +386,7 @@ func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (sign.PublicKey, error) { buf2 [PublicKeySize]byte ret PublicKey ) - + copy(buf2[:], buf) ret.Unpack(&buf2) return &ret, nil @@ -401,7 +401,7 @@ func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (sign.PrivateKey, error) { buf2 [PrivateKeySize]byte ret PrivateKey ) - + copy(buf2[:], buf) ret.Unpack(&buf2) return &ret, nil diff --git a/sign/mldsa/acvp_test.go b/sign/mldsa/mldsa44/acvp_test.go similarity index 88% rename from sign/mldsa/acvp_test.go rename to sign/mldsa/mldsa44/acvp_test.go index bc2dfa186..09ef2c1af 100644 --- a/sign/mldsa/acvp_test.go +++ b/sign/mldsa/mldsa44/acvp_test.go @@ -1,4 +1,6 @@ -package mldsa +// Code generated from acvp.templ.go. DO NOT EDIT. + +package mldsa44 import ( "bytes" @@ -8,8 +10,6 @@ import ( "io" "os" "testing" - - "github.com/cloudflare/circl/sign/schemes" ) // []byte but is encoded in hex for JSON @@ -58,7 +58,7 @@ func TestACVP(t *testing.T) { // nolint:funlen,gocyclo func testACVP(t *testing.T, sub string) { - buf, err := readGzip("testdata/ML-DSA-" + sub + "-FIPS204/prompt.json.gz") + buf, err := readGzip("../testdata/ML-DSA-" + sub + "-FIPS204/prompt.json.gz") if err != nil { t.Fatal(err) } @@ -71,7 +71,7 @@ func testACVP(t *testing.T, sub string) { t.Fatal(err) } - buf, err = readGzip("testdata/ML-DSA-" + sub + "-FIPS204/expectedResults.json.gz") + buf, err = readGzip("../testdata/ML-DSA-" + sub + "-FIPS204/expectedResults.json.gz") if err != nil { t.Fatal(err) } @@ -107,6 +107,8 @@ func testACVP(t *testing.T, sub string) { } } + scheme := Scheme() + for _, rawGroup := range prompt.TestGroups { var abstractGroup struct { TestType string `json:"testType"` @@ -128,9 +130,8 @@ func testACVP(t *testing.T, sub string) { t.Fatal(err) } - scheme := schemes.ByName(group.ParameterSet) - if scheme == nil { - t.Fatalf("No such scheme: %s", group.ParameterSet) + if group.ParameterSet != scheme.Name() { + continue } for _, test := range group.Tests { @@ -180,9 +181,8 @@ func testACVP(t *testing.T, sub string) { t.Fatal(err) } - scheme := schemes.ByName(group.ParameterSet) - if scheme == nil { - t.Fatalf("No such scheme: %s", group.ParameterSet) + if group.ParameterSet != scheme.Name() { + continue } for _, test := range group.Tests { @@ -207,10 +207,7 @@ func testACVP(t *testing.T, sub string) { copy(rnd[:], test.Rnd) } - isk := sk.(interface { - UnsafeSignInternal(msg []byte, rnd [32]byte) []byte - }) - sig2 := isk.UnsafeSignInternal(test.Message, rnd) + sig2 := sk.(*PrivateKey).unsafeSignInternal(test.Message, rnd) if !bytes.Equal(sig2, result.Signature) { t.Fatalf("signature doesn't match: %x ≠ %x", @@ -232,9 +229,8 @@ func testACVP(t *testing.T, sub string) { t.Fatal(err) } - scheme := schemes.ByName(group.ParameterSet) - if scheme == nil { - t.Fatalf("No such scheme: %s", group.ParameterSet) + if group.ParameterSet != scheme.Name() { + continue } pk, err := scheme.UnmarshalBinaryPublicKey(group.Pk) diff --git a/sign/mldsa/mldsa44/dilithium.go b/sign/mldsa/mldsa44/dilithium.go index 497e2c769..257d824ee 100644 --- a/sign/mldsa/mldsa44/dilithium.go +++ b/sign/mldsa/mldsa44/dilithium.go @@ -83,7 +83,7 @@ func SignTo(sk *PrivateKey, msg, ctx []byte, randomized bool, sig []byte) error } // Do not use. Implements ML-DSA.Sign_internal used for compatibility tests. -func (sk *PrivateKey) UnsafeSignInternal(msg []byte, rnd [32]byte) []byte { +func (sk *PrivateKey) unsafeSignInternal(msg []byte, rnd [32]byte) []byte { var ret [SignatureSize]byte internal.SignTo( (*internal.PrivateKey)(sk), diff --git a/sign/mldsa/mldsa65/acvp_test.go b/sign/mldsa/mldsa65/acvp_test.go new file mode 100644 index 000000000..24742cb0e --- /dev/null +++ b/sign/mldsa/mldsa65/acvp_test.go @@ -0,0 +1,262 @@ +// Code generated from acvp.templ.go. DO NOT EDIT. + +package mldsa65 + +import ( + "bytes" + "compress/gzip" + "encoding/hex" + "encoding/json" + "io" + "os" + "testing" +) + +// []byte but is encoded in hex for JSON +type HexBytes []byte + +func (b HexBytes) MarshalJSON() ([]byte, error) { + return json.Marshal(hex.EncodeToString(b)) +} + +func (b *HexBytes) UnmarshalJSON(data []byte) (err error) { + var s string + if err = json.Unmarshal(data, &s); err != nil { + return err + } + *b, err = hex.DecodeString(s) + return err +} + +func gunzip(in []byte) ([]byte, error) { + buf := bytes.NewBuffer(in) + r, err := gzip.NewReader(buf) + if err != nil { + return nil, err + } + return io.ReadAll(r) +} + +func readGzip(path string) ([]byte, error) { + buf, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return gunzip(buf) +} + +func TestACVP(t *testing.T) { + for _, sub := range []string{ + "keyGen", + "sigGen", + } { + t.Run(sub, func(t *testing.T) { + testACVP(t, sub) + }) + } +} + +// nolint:funlen,gocyclo +func testACVP(t *testing.T, sub string) { + buf, err := readGzip("../testdata/ML-DSA-" + sub + "-FIPS204/prompt.json.gz") + if err != nil { + t.Fatal(err) + } + + var prompt struct { + TestGroups []json.RawMessage `json:"testGroups"` + } + + if err = json.Unmarshal(buf, &prompt); err != nil { + t.Fatal(err) + } + + buf, err = readGzip("../testdata/ML-DSA-" + sub + "-FIPS204/expectedResults.json.gz") + if err != nil { + t.Fatal(err) + } + + var results struct { + TestGroups []json.RawMessage `json:"testGroups"` + } + + if err := json.Unmarshal(buf, &results); err != nil { + t.Fatal(err) + } + + rawResults := make(map[int]json.RawMessage) + + for _, rawGroup := range results.TestGroups { + var abstractGroup struct { + Tests []json.RawMessage `json:"tests"` + } + if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil { + t.Fatal(err) + } + for _, rawTest := range abstractGroup.Tests { + var abstractTest struct { + TcID int `json:"tcId"` + } + if err := json.Unmarshal(rawTest, &abstractTest); err != nil { + t.Fatal(err) + } + if _, exists := rawResults[abstractTest.TcID]; exists { + t.Fatalf("Duplicate test id: %d", abstractTest.TcID) + } + rawResults[abstractTest.TcID] = rawTest + } + } + + scheme := Scheme() + + for _, rawGroup := range prompt.TestGroups { + var abstractGroup struct { + TestType string `json:"testType"` + } + if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil { + t.Fatal(err) + } + switch { + case abstractGroup.TestType == "AFT" && sub == "keyGen": + var group struct { + TgID int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Tests []struct { + TcID int `json:"tcId"` + Seed HexBytes `json:"seed"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + if group.ParameterSet != scheme.Name() { + continue + } + + for _, test := range group.Tests { + var result struct { + Pk HexBytes `json:"pk"` + Sk HexBytes `json:"sk"` + } + rawResult, ok := rawResults[test.TcID] + if !ok { + t.Fatalf("Missing result: %d", test.TcID) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + pk, sk := scheme.DeriveKey(test.Seed) + + pk2, err := scheme.UnmarshalBinaryPublicKey(result.Pk) + if err != nil { + t.Fatalf("tc=%d: %v", test.TcID, err) + } + sk2, err := scheme.UnmarshalBinaryPrivateKey(result.Sk) + if err != nil { + t.Fatal(err) + } + + if !pk.Equal(pk2) { + t.Fatal("pk does not match") + } + if !sk.Equal(sk2) { + t.Fatal("sk does not match") + } + } + case abstractGroup.TestType == "AFT" && sub == "sigGen": + var group struct { + TgID int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Deterministic bool `json:"deterministic"` + Tests []struct { + TcID int `json:"tcId"` + Sk HexBytes `json:"sk"` + Message HexBytes `json:"message"` + Rnd HexBytes `json:"rnd"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + if group.ParameterSet != scheme.Name() { + continue + } + + for _, test := range group.Tests { + var result struct { + Signature HexBytes `json:"signature"` + } + rawResult, ok := rawResults[test.TcID] + if !ok { + t.Fatalf("Missing result: %d", test.TcID) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + sk, err := scheme.UnmarshalBinaryPrivateKey(test.Sk) + if err != nil { + t.Fatal(err) + } + + var rnd [32]byte + if !group.Deterministic { + copy(rnd[:], test.Rnd) + } + + sig2 := sk.(*PrivateKey).unsafeSignInternal(test.Message, rnd) + + if !bytes.Equal(sig2, result.Signature) { + t.Fatalf("signature doesn't match: %x ≠ %x", + sig2, result.Signature) + } + } + case abstractGroup.TestType == "AFT" && sub == "sigVer": + var group struct { + TgID int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Pk HexBytes `json:"pk"` + Tests []struct { + TcID int `json:"tcId"` + Message HexBytes `json:"message"` + Signature HexBytes `json:"signature"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + if group.ParameterSet != scheme.Name() { + continue + } + + pk, err := scheme.UnmarshalBinaryPublicKey(group.Pk) + if err != nil { + t.Fatal(err) + } + + for _, test := range group.Tests { + var result struct { + TestPassed bool `json:"testPassed"` + } + rawResult, ok := rawResults[test.TcID] + if !ok { + t.Fatalf("Missing result: %d", test.TcID) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + passed2 := scheme.Verify(pk, test.Message, test.Signature, nil) + if passed2 != result.TestPassed { + t.Fatalf("verification %v ≠ %v", passed2, result.TestPassed) + } + } + default: + t.Fatalf("unknown type %s for %s", abstractGroup.TestType, sub) + } + } +} diff --git a/sign/mldsa/mldsa65/dilithium.go b/sign/mldsa/mldsa65/dilithium.go index cbbfb0d09..e54459863 100644 --- a/sign/mldsa/mldsa65/dilithium.go +++ b/sign/mldsa/mldsa65/dilithium.go @@ -83,7 +83,7 @@ func SignTo(sk *PrivateKey, msg, ctx []byte, randomized bool, sig []byte) error } // Do not use. Implements ML-DSA.Sign_internal used for compatibility tests. -func (sk *PrivateKey) UnsafeSignInternal(msg []byte, rnd [32]byte) []byte { +func (sk *PrivateKey) unsafeSignInternal(msg []byte, rnd [32]byte) []byte { var ret [SignatureSize]byte internal.SignTo( (*internal.PrivateKey)(sk), diff --git a/sign/mldsa/mldsa87/acvp_test.go b/sign/mldsa/mldsa87/acvp_test.go new file mode 100644 index 000000000..b124d4553 --- /dev/null +++ b/sign/mldsa/mldsa87/acvp_test.go @@ -0,0 +1,262 @@ +// Code generated from acvp.templ.go. DO NOT EDIT. + +package mldsa87 + +import ( + "bytes" + "compress/gzip" + "encoding/hex" + "encoding/json" + "io" + "os" + "testing" +) + +// []byte but is encoded in hex for JSON +type HexBytes []byte + +func (b HexBytes) MarshalJSON() ([]byte, error) { + return json.Marshal(hex.EncodeToString(b)) +} + +func (b *HexBytes) UnmarshalJSON(data []byte) (err error) { + var s string + if err = json.Unmarshal(data, &s); err != nil { + return err + } + *b, err = hex.DecodeString(s) + return err +} + +func gunzip(in []byte) ([]byte, error) { + buf := bytes.NewBuffer(in) + r, err := gzip.NewReader(buf) + if err != nil { + return nil, err + } + return io.ReadAll(r) +} + +func readGzip(path string) ([]byte, error) { + buf, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return gunzip(buf) +} + +func TestACVP(t *testing.T) { + for _, sub := range []string{ + "keyGen", + "sigGen", + } { + t.Run(sub, func(t *testing.T) { + testACVP(t, sub) + }) + } +} + +// nolint:funlen,gocyclo +func testACVP(t *testing.T, sub string) { + buf, err := readGzip("../testdata/ML-DSA-" + sub + "-FIPS204/prompt.json.gz") + if err != nil { + t.Fatal(err) + } + + var prompt struct { + TestGroups []json.RawMessage `json:"testGroups"` + } + + if err = json.Unmarshal(buf, &prompt); err != nil { + t.Fatal(err) + } + + buf, err = readGzip("../testdata/ML-DSA-" + sub + "-FIPS204/expectedResults.json.gz") + if err != nil { + t.Fatal(err) + } + + var results struct { + TestGroups []json.RawMessage `json:"testGroups"` + } + + if err := json.Unmarshal(buf, &results); err != nil { + t.Fatal(err) + } + + rawResults := make(map[int]json.RawMessage) + + for _, rawGroup := range results.TestGroups { + var abstractGroup struct { + Tests []json.RawMessage `json:"tests"` + } + if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil { + t.Fatal(err) + } + for _, rawTest := range abstractGroup.Tests { + var abstractTest struct { + TcID int `json:"tcId"` + } + if err := json.Unmarshal(rawTest, &abstractTest); err != nil { + t.Fatal(err) + } + if _, exists := rawResults[abstractTest.TcID]; exists { + t.Fatalf("Duplicate test id: %d", abstractTest.TcID) + } + rawResults[abstractTest.TcID] = rawTest + } + } + + scheme := Scheme() + + for _, rawGroup := range prompt.TestGroups { + var abstractGroup struct { + TestType string `json:"testType"` + } + if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil { + t.Fatal(err) + } + switch { + case abstractGroup.TestType == "AFT" && sub == "keyGen": + var group struct { + TgID int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Tests []struct { + TcID int `json:"tcId"` + Seed HexBytes `json:"seed"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + if group.ParameterSet != scheme.Name() { + continue + } + + for _, test := range group.Tests { + var result struct { + Pk HexBytes `json:"pk"` + Sk HexBytes `json:"sk"` + } + rawResult, ok := rawResults[test.TcID] + if !ok { + t.Fatalf("Missing result: %d", test.TcID) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + pk, sk := scheme.DeriveKey(test.Seed) + + pk2, err := scheme.UnmarshalBinaryPublicKey(result.Pk) + if err != nil { + t.Fatalf("tc=%d: %v", test.TcID, err) + } + sk2, err := scheme.UnmarshalBinaryPrivateKey(result.Sk) + if err != nil { + t.Fatal(err) + } + + if !pk.Equal(pk2) { + t.Fatal("pk does not match") + } + if !sk.Equal(sk2) { + t.Fatal("sk does not match") + } + } + case abstractGroup.TestType == "AFT" && sub == "sigGen": + var group struct { + TgID int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Deterministic bool `json:"deterministic"` + Tests []struct { + TcID int `json:"tcId"` + Sk HexBytes `json:"sk"` + Message HexBytes `json:"message"` + Rnd HexBytes `json:"rnd"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + if group.ParameterSet != scheme.Name() { + continue + } + + for _, test := range group.Tests { + var result struct { + Signature HexBytes `json:"signature"` + } + rawResult, ok := rawResults[test.TcID] + if !ok { + t.Fatalf("Missing result: %d", test.TcID) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + sk, err := scheme.UnmarshalBinaryPrivateKey(test.Sk) + if err != nil { + t.Fatal(err) + } + + var rnd [32]byte + if !group.Deterministic { + copy(rnd[:], test.Rnd) + } + + sig2 := sk.(*PrivateKey).unsafeSignInternal(test.Message, rnd) + + if !bytes.Equal(sig2, result.Signature) { + t.Fatalf("signature doesn't match: %x ≠ %x", + sig2, result.Signature) + } + } + case abstractGroup.TestType == "AFT" && sub == "sigVer": + var group struct { + TgID int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Pk HexBytes `json:"pk"` + Tests []struct { + TcID int `json:"tcId"` + Message HexBytes `json:"message"` + Signature HexBytes `json:"signature"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + if group.ParameterSet != scheme.Name() { + continue + } + + pk, err := scheme.UnmarshalBinaryPublicKey(group.Pk) + if err != nil { + t.Fatal(err) + } + + for _, test := range group.Tests { + var result struct { + TestPassed bool `json:"testPassed"` + } + rawResult, ok := rawResults[test.TcID] + if !ok { + t.Fatalf("Missing result: %d", test.TcID) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + passed2 := scheme.Verify(pk, test.Message, test.Signature, nil) + if passed2 != result.TestPassed { + t.Fatalf("verification %v ≠ %v", passed2, result.TestPassed) + } + } + default: + t.Fatalf("unknown type %s for %s", abstractGroup.TestType, sub) + } + } +} diff --git a/sign/mldsa/mldsa87/dilithium.go b/sign/mldsa/mldsa87/dilithium.go index 1f17dce37..69ab919d8 100644 --- a/sign/mldsa/mldsa87/dilithium.go +++ b/sign/mldsa/mldsa87/dilithium.go @@ -83,7 +83,7 @@ func SignTo(sk *PrivateKey, msg, ctx []byte, randomized bool, sig []byte) error } // Do not use. Implements ML-DSA.Sign_internal used for compatibility tests. -func (sk *PrivateKey) UnsafeSignInternal(msg []byte, rnd [32]byte) []byte { +func (sk *PrivateKey) unsafeSignInternal(msg []byte, rnd [32]byte) []byte { var ret [SignatureSize]byte internal.SignTo( (*internal.PrivateKey)(sk),