From 83d99b3eaf7130a6bed00fe2e2d6dfc82e816f29 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 18 Dec 2024 15:38:03 +0100 Subject: [PATCH] add support for CSHAKE --- cng/hash.go | 147 +++++++++++++++++++++ cng/hash_test.go | 205 ++++++++++++++++++++++++++++++ internal/bcrypt/bcrypt_windows.go | 30 +++-- 3 files changed, 371 insertions(+), 11 deletions(-) diff --git a/cng/hash.go b/cng/hash.go index 87b1c95..223088a 100644 --- a/cng/hash.go +++ b/cng/hash.go @@ -11,6 +11,7 @@ import ( "crypto" "hash" "runtime" + "slices" "unsafe" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" @@ -304,3 +305,149 @@ func (h *hashX) Sum(in []byte) []byte { } return append(in, h.buf...) } + +// SumSHAKE128 applies the SHAKE128 extendable output function to data and +// returns an output of the given length in bytes. +func SumSHAKE128(data []byte, length int) []byte { + out := make([]byte, length) + if err := hashOneShot(bcrypt.CSHAKE128_ALGORITHM, data, out); err != nil { + panic("bcrypt: CSHAKE128_ALGORITHM failed") + } + return out +} + +// SumSHAKE256 applies the SHAKE256 extendable output function to data and +// returns an output of the given length in bytes. +func SumSHAKE256(data []byte, length int) []byte { + out := make([]byte, length) + if err := hashOneShot(bcrypt.CSHAKE256_ALGORITHM, data, out); err != nil { + panic("bcrypt: CSHAKE128_ALGORITHM failed") + } + return out +} + +// SHAKE is an instance of a SHAKE extendable output function. +type SHAKE struct { + alg *hashAlgorithm + ctx bcrypt.HASH_HANDLE + n, s []byte +} + +func newShake(id string, N, S []byte) *SHAKE { + alg, err := loadHash(id, bcrypt.ALG_NONE_FLAG) + if err != nil { + panic(err) + } + h := &SHAKE{alg: alg, n: slices.Clone(N), s: slices.Clone(S)} + err = bcrypt.CreateHash(h.alg.handle, &h.ctx, nil, nil, 0) + if err != nil { + panic(err) + } + if len(N) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(h.ctx), utf16PtrFromString(bcrypt.FUNCTION_NAME_STRING), N, 0); err != nil { + panic(err) + } + } + if len(S) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(h.ctx), utf16PtrFromString(bcrypt.CUSTOMIZATION_STRING), S, 0); err != nil { + panic(err) + } + } + runtime.SetFinalizer(h, (*SHAKE).finalize) + return h +} + +// NewSHAKE128 creates a new SHAKE128 XOF. +func NewSHAKE128() *SHAKE { + return newShake(bcrypt.CSHAKE128_ALGORITHM, nil, nil) +} + +// NewSHAKE256 creates a new SHAKE256 XOF. +func NewSHAKE256() *SHAKE { + return newShake(bcrypt.CSHAKE256_ALGORITHM, nil, nil) +} + +// NewCSHAKE128 creates a new cSHAKE128 XOF. +// +// N is used to define functions based on cSHAKE, it can be empty when plain +// cSHAKE is desired. S is a customization byte string used for domain +// separation. When N and S are both empty, this is equivalent to NewSHAKE128. +func NewCSHAKE128(N, S []byte) *SHAKE { + return newShake(bcrypt.CSHAKE128_ALGORITHM, N, S) +} + +// NewCSHAKE256 creates a new cSHAKE256 XOF. +// +// N is used to define functions based on cSHAKE, it can be empty when plain +// cSHAKE is desired. S is a customization byte string used for domain +// separation. When N and S are both empty, this is equivalent to NewSHAKE256. +func NewCSHAKE256(N, S []byte) *SHAKE { + return newShake(bcrypt.CSHAKE256_ALGORITHM, N, S) +} + +func (h *SHAKE) finalize() { + bcrypt.DestroyHash(h.ctx) +} + +// Write absorbs more data into the XOF's state. +// +// It panics if any output has already been read. +func (s *SHAKE) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + defer runtime.KeepAlive(s) + for n < len(p) && err == nil { + nn := len32(p[n:]) + err = bcrypt.HashData(s.ctx, p[n:n+nn], 0) + n += nn + } + if err != nil { + panic(err) + } + return len(p), nil +} + +// Read squeezes more output from the XOF. +// +// Any call to Write after a call to Read will panic. +func (s *SHAKE) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + defer runtime.KeepAlive(s) + for n < len(p) && err == nil { + nn := len32(p[n:]) + err = bcrypt.FinishHash(s.ctx, p[n:n+nn], bcrypt.HASH_DONT_RESET_FLAG) + n += nn + } + if err != nil { + panic(err) + } + return len(p), nil +} + +// Reset resets the XOF to its initial state. +func (s *SHAKE) Reset() { + defer runtime.KeepAlive(s) + bcrypt.DestroyHash(s.ctx) + err := bcrypt.CreateHash(s.alg.handle, &s.ctx, nil, nil, 0) + if err != nil { + panic(err) + } + if len(s.n) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(s.ctx), utf16PtrFromString(bcrypt.FUNCTION_NAME_STRING), s.n, 0); err != nil { + panic(err) + } + } + if len(s.s) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(s.ctx), utf16PtrFromString(bcrypt.CUSTOMIZATION_STRING), s.s, 0); err != nil { + panic(err) + } + } +} + +// BlockSize returns the rate of the XOF. +func (s *SHAKE) BlockSize() int { + return int(s.alg.blockSize) +} diff --git a/cng/hash_test.go b/cng/hash_test.go index 21a7fa8..c692ca6 100644 --- a/cng/hash_test.go +++ b/cng/hash_test.go @@ -9,8 +9,10 @@ package cng_test import ( "bytes" "crypto" + "encoding/hex" "hash" "io" + "math/rand" "testing" "github.com/microsoft/go-crypto-winnative/cng" @@ -212,3 +214,206 @@ func BenchmarkSHA256_OneShot(b *testing.B) { cng.SHA256(buf) } } + +// testShakes contains functions that return *sha3.SHAKE instances for +// with output-length equal to the KAT length. +var testShakes = map[string]struct { + constructor func(N []byte, S []byte) *cng.SHAKE + defAlgoName string + defCustomStr string +}{ + // NewCSHAKE without customization produces same result as SHAKE + "SHAKE128": {cng.NewCSHAKE128, "", ""}, + "SHAKE256": {cng.NewCSHAKE256, "", ""}, + "cSHAKE128": {cng.NewCSHAKE128, "CSHAKE128", "CustomString"}, + "cSHAKE256": {cng.NewCSHAKE256, "CSHAKE256", "CustomString"}, +} + +// TestCSHAKESqueezing checks that squeezing the full output a single time produces +// the same output as repeatedly squeezing the instance. +func TestCSHAKESqueezing(t *testing.T) { + const testString = "brekeccakkeccak koax koax" + for algo, v := range testShakes { + d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + d0.Write([]byte(testString)) + ref := make([]byte, 32) + d0.Read(ref) + + d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + d1.Write([]byte(testString)) + var multiple []byte + for range ref { + d1.Read(make([]byte, 0)) + one := make([]byte, 1) + d1.Read(one) + multiple = append(multiple, one...) + } + if !bytes.Equal(ref, multiple) { + t.Errorf("%s: squeezing %d bytes one at a time failed", algo, len(ref)) + } + } +} + +// sequentialBytes produces a buffer of size consecutive bytes 0x00, 0x01, ..., used for testing. +func sequentialBytes(size int) []byte { + alignmentOffset := rand.Intn(8) + result := make([]byte, size+alignmentOffset)[alignmentOffset:] + for i := range result { + result[i] = byte(i) + } + return result +} + +func TestCSHAKEReset(t *testing.T) { + out1 := make([]byte, 32) + out2 := make([]byte, 32) + + for _, v := range testShakes { + // Calculate hash for the first time + c := v.constructor(nil, []byte{0x99, 0x98}) + c.Write(sequentialBytes(0x100)) + c.Read(out1) + + // Calculate hash again + c.Reset() + c.Write(sequentialBytes(0x100)) + c.Read(out2) + + if !bytes.Equal(out1, out2) { + t.Error("\nExpected:\n", out1, "\ngot:\n", out2) + } + } +} + +func TestCSHAKEAccumulated(t *testing.T) { + t.Run("CSHAKE128", func(t *testing.T) { + testCSHAKEAccumulated(t, cng.NewCSHAKE128, (1600-256)/8, + "bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252") + }) + t.Run("CSHAKE256", func(t *testing.T) { + testCSHAKEAccumulated(t, cng.NewCSHAKE256, (1600-512)/8, + "0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef") + }) +} + +func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *cng.SHAKE, rate int64, exp string) { + rnd := newCSHAKE(nil, nil) + acc := newCSHAKE(nil, nil) + for n := 0; n < 200; n++ { + N := make([]byte, n) + rnd.Read(N) + for s := 0; s < 200; s++ { + S := make([]byte, s) + rnd.Read(S) + + c := newCSHAKE(N, S) + io.CopyN(c, rnd, 100 /* < rate */) + io.CopyN(acc, c, 200) + + c.Reset() + io.CopyN(c, rnd, rate) + io.CopyN(acc, c, 200) + + c.Reset() + io.CopyN(c, rnd, 200 /* > rate */) + io.CopyN(acc, c, 200) + } + } + out := make([]byte, 32) + acc.Read(out) + if got := hex.EncodeToString(out); got != exp { + t.Errorf("got %s, want %s", got, exp) + } +} + +func TestCSHAKELargeS(t *testing.T) { + const s = (1<<32)/8 + 1000 // s * 8 > 2^32 + S := make([]byte, s) + rnd := cng.NewSHAKE128() + rnd.Read(S) + c := cng.NewCSHAKE128(nil, S) + io.CopyN(c, rnd, 1000) + out := make([]byte, 32) + c.Read(out) + + exp := "2cb9f237767e98f2614b8779cf096a52da9b3a849280bbddec820771ae529cf0" + if got := hex.EncodeToString(out); got != exp { + t.Errorf("got %s, want %s", got, exp) + } +} + +func TestCSHAKESum(t *testing.T) { + const testString = "hello world" + t.Run("CSHAKE128", func(t *testing.T) { + h := cng.NewCSHAKE128(nil, nil) + h.Write([]byte(testString[:5])) + h.Write([]byte(testString[5:])) + want := make([]byte, 32) + h.Read(want) + got := cng.SumSHAKE128([]byte(testString), 32) + if !bytes.Equal(got, want) { + t.Errorf("got:%x want:%x", got, want) + } + }) + t.Run("CSHAKE256", func(t *testing.T) { + h := cng.NewCSHAKE256(nil, nil) + h.Write([]byte(testString[:5])) + h.Write([]byte(testString[5:])) + want := make([]byte, 32) + h.Read(want) + got := cng.SumSHAKE256([]byte(testString), 32) + if !bytes.Equal(got, want) { + t.Errorf("got:%x want:%x", got, want) + } + }) +} + +// benchmarkHash tests the speed to hash num buffers of buflen each. +func benchmarkHash(b *testing.B, h hash.Hash, size, num int) { + b.StopTimer() + h.Reset() + data := sequentialBytes(size) + b.SetBytes(int64(size * num)) + b.StartTimer() + + var state []byte + for i := 0; i < b.N; i++ { + for j := 0; j < num; j++ { + h.Write(data) + } + state = h.Sum(state[:0]) + } + b.StopTimer() + h.Reset() +} + +// benchmarkCSHAKE is specialized to the Shake instances, which don't +// require a copy on reading output. +func benchmarkCSHAKE(b *testing.B, h *cng.SHAKE, size, num int) { + b.StopTimer() + h.Reset() + data := sequentialBytes(size) + d := make([]byte, 32) + + b.SetBytes(int64(size * num)) + b.StartTimer() + + for i := 0; i < b.N; i++ { + h.Reset() + for j := 0; j < num; j++ { + h.Write(data) + } + h.Read(d) + } +} + +func BenchmarkSHA3_512_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_512(), 1350, 1) } +func BenchmarkSHA3_384_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_384(), 1350, 1) } +func BenchmarkSHA3_256_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_256(), 1350, 1) } + +func BenchmarkCSHAKE128_MTU(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE128(), 1350, 1) } +func BenchmarkCSHAKE256_MTU(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 1350, 1) } +func BenchmarkCSHAKE256_16x(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 16, 1024) } +func BenchmarkCSHAKE256_1MiB(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 1024, 1024) } + +func BenchmarkCSHA3_512_1MiB(b *testing.B) { benchmarkHash(b, cng.NewSHA3_512(), 1024, 1024) } diff --git a/internal/bcrypt/bcrypt_windows.go b/internal/bcrypt/bcrypt_windows.go index 090c74a..e3255e2 100644 --- a/internal/bcrypt/bcrypt_windows.go +++ b/internal/bcrypt/bcrypt_windows.go @@ -22,6 +22,8 @@ const ( SHA3_256_ALGORITHM = "SHA3-256" SHA3_384_ALGORITHM = "SHA3-384" SHA3_512_ALGORITHM = "SHA3-512" + CSHAKE128_ALGORITHM = "CSHAKE128" + CSHAKE256_ALGORITHM = "CSHAKE256" AES_ALGORITHM = "AES" RC4_ALGORITHM = "RC4" RSA_ALGORITHM = "RSA" @@ -47,17 +49,19 @@ const ( ) const ( - HASH_LENGTH = "HashDigestLength" - HASH_BLOCK_LENGTH = "HashBlockLength" - CHAINING_MODE = "ChainingMode" - CHAIN_MODE_ECB = "ChainingModeECB" - CHAIN_MODE_CBC = "ChainingModeCBC" - CHAIN_MODE_GCM = "ChainingModeGCM" - KEY_LENGTH = "KeyLength" - KEY_LENGTHS = "KeyLengths" - SIGNATURE_LENGTH = "SignatureLength" - BLOCK_LENGTH = "BlockLength" - ECC_CURVE_NAME = "ECCCurveName" + HASH_LENGTH = "HashDigestLength" + HASH_BLOCK_LENGTH = "HashBlockLength" + CHAINING_MODE = "ChainingMode" + CHAIN_MODE_ECB = "ChainingModeECB" + CHAIN_MODE_CBC = "ChainingModeCBC" + CHAIN_MODE_GCM = "ChainingModeGCM" + KEY_LENGTH = "KeyLength" + KEY_LENGTHS = "KeyLengths" + SIGNATURE_LENGTH = "SignatureLength" + BLOCK_LENGTH = "BlockLength" + ECC_CURVE_NAME = "ECCCurveName" + FUNCTION_NAME_STRING = "FunctionNameString" + CUSTOMIZATION_STRING = "CustomizationString" ) const ( @@ -113,6 +117,10 @@ const ( USE_SYSTEM_PREFERRED_RNG = 0x00000002 ) +const ( + HASH_DONT_RESET_FLAG = 0x00000001 +) + const ( KDF_RAW_SECRET = "TRUNCATE" )