diff --git a/group/edwards25519/point.go b/group/edwards25519/point.go index 27009191d..e04f0d3a1 100644 --- a/group/edwards25519/point.go +++ b/group/edwards25519/point.go @@ -263,3 +263,30 @@ func (P *point) HasSmallOrder() bool { return (k>>8)&1 > 0 } + +// IsCanonical determines whether the group element is canonical +// +// Checks whether group element s is less than p, according to RFC8032§5.1.3.1 +// https://tools.ietf.org/html/rfc8032#section-5.1.3 +// +// Taken from +// https://github.com/jedisct1/libsodium/blob/4744636721d2e420f8bbe2d563f31b1f5e682229/src/libsodium/crypto_core/ed25519/ref10/ed25519_ref10.c#L1113 +// +// The meethod accepts a buffer instead of calling `MarshalBianry` on the receiver +// because that always returns a value modulo `prime`. +func (P *point) IsCanonical(s []byte) bool { + if len(s) != 32 { + return false + } + + c := (s[31] & 0x7f) ^ 0x7f + for i := 30; i > 0; i-- { + c |= s[i] ^ 0xff + } + + // subtraction might underflow + c = byte((uint16(c) - 1) >> 8) + d := byte((0xed - 1 - uint16(s[0])) >> 8) + + return 1-(c&d&1) == 1 +} diff --git a/group/edwards25519/point_test.go b/group/edwards25519/point_test.go index 9d2a26803..4a5b958d5 100644 --- a/group/edwards25519/point_test.go +++ b/group/edwards25519/point_test.go @@ -23,3 +23,42 @@ func TestPoint_HasSmallOrder(t *testing.T) { require.True(t, p.HasSmallOrder(), fmt.Sprintf("%s should be considered to have a small order", hex.EncodeToString(key))) } } + +// Test_PointIsCanonical ensures that elements >= p are considered +// non canonical +func Test_PointIsCanonical(t *testing.T) { + + // buffer stores the candidate points (in little endian) that we'll test + // against, starting with `prime` + buffer := prime.Bytes() + for i, j := 0, len(buffer)-1; i < j; i, j = i+1, j-1 { + buffer[i], buffer[j] = buffer[j], buffer[i] + } + + // Iterate over the 19*2 finite field elements + point := point{} + actualNonCanonicalCount := 0 + expectedNonCanonicalCount := 24 + for i := 0; i < 19; i++ { + buffer[0] = byte(237 + i) + buffer[31] = byte(127) + + // Check if it's a valid point on the curve that's + // not canonical + err := point.UnmarshalBinary(buffer) + if err == nil && !point.IsCanonical(buffer) { + actualNonCanonicalCount++ + } + + // flip bit + buffer[31] |= 128 + + // Check if it's a valid point on the curve that's + // not canonical + err = point.UnmarshalBinary(buffer) + if err == nil && !point.IsCanonical(buffer) { + actualNonCanonicalCount++ + } + } + require.Equal(t, expectedNonCanonicalCount, actualNonCanonicalCount, "Incorrect number of non canonical points detected") +} diff --git a/group/edwards25519/scalar.go b/group/edwards25519/scalar.go index 30caa0253..ac26a68c5 100644 --- a/group/edwards25519/scalar.go +++ b/group/edwards25519/scalar.go @@ -2230,3 +2230,37 @@ func scReduce(out *[32]byte, s *[64]byte) { out[30] = byte(s11 >> 9) out[31] = byte(s11 >> 17) } + +// IsCanonical whether scalar s is in the range 0<=s= 0; i-- { + // subtraction might lead to an underflow which needs + // to be accounted for in the right shift + c |= byte((uint16(sb[i])-uint16(L[i]))>>8) & n + n &= byte((uint16(sb[i]) ^ uint16(L[i]) - 1) >> 8) + } + + return c != 0 +} diff --git a/group/edwards25519/scalar_test.go b/group/edwards25519/scalar_test.go index 0e8459c9a..5c50864a6 100644 --- a/group/edwards25519/scalar_test.go +++ b/group/edwards25519/scalar_test.go @@ -2,6 +2,7 @@ package edwards25519 import ( "fmt" + "math/big" "testing" "github.com/stretchr/testify/require" @@ -457,3 +458,24 @@ func scSubFact(s, a, c *[32]byte) { scReduceLimbs(limbs) } + +// Test_ScalarIsCanonical ensures that scalars >= primeOrder are +// considered non canonical. +func Test_ScalarIsCanonical(t *testing.T) { + candidate := big.NewInt(-2) + candidate.Add(candidate, primeOrder) + + candidateBuf := candidate.Bytes() + for i, j := 0, len(candidateBuf)-1; i < j; i, j = i+1, j-1 { + candidateBuf[i], candidateBuf[j] = candidateBuf[j], candidateBuf[i] + } + + expected := []bool{true, true, false, false} + scalar := scalar{} + + // We check in range [L-2, L+4) + for i := 0; i < 4; i++ { + require.Equal(t, expected[i], scalar.IsCanonical(candidateBuf), fmt.Sprintf("`lMinus2 + %d` does not pass canonicality test", i)) + candidateBuf[0]++ + } +} diff --git a/sign/eddsa/eddsa.go b/sign/eddsa/eddsa.go index e7bfbf135..0e8c9633f 100644 --- a/sign/eddsa/eddsa.go +++ b/sign/eddsa/eddsa.go @@ -7,7 +7,6 @@ import ( "crypto/sha512" "errors" "fmt" - "math/big" "go.dedis.ch/kyber/v3" "go.dedis.ch/kyber/v3/group/edwards25519" @@ -15,11 +14,6 @@ import ( var group = new(edwards25519.Curve) -// TODO: maybe export prime and primeOrder from edwards25519/const or allow it to be -// retrieved from the curve? -var prime, _ = new(big.Int).SetString("57896044618658097711785492504343953926634992332820282019728792003956564819949", 10) -var primeOrder, _ = new(big.Int).SetString("7237005577332262213973186563042994240857116359379907606001950938285454250989", 10) - // EdDSA is a structure holding the data necessary to make a series of // EdDSA signatures. type EdDSA struct { @@ -32,15 +26,6 @@ type EdDSA struct { prefix []byte } -// edDSAPoint is used to verify signatures -// with checks around canonicality and group order -type edDSAPoint interface { - kyber.Point - // HasSmallOrder checks if the given buffer (in little endian) - // represents a point with a small order - HasSmallOrder() bool -} - // NewEdDSA will return a freshly generated key pair to use for generating // EdDSA signatures. func NewEdDSA(stream cipher.Stream) *EdDSA { @@ -143,21 +128,28 @@ func VerifyWithChecks(pub, msg, sig []byte) error { if len(sig) != 64 { return fmt.Errorf("signature length invalid, expect 64 but got %v", len(sig)) } - if !scalarIsCanonical(sig[32:]) { + + type scalarCanCheckCanonical interface { + IsCanonical(b []byte) bool + } + + if !group.Scalar().(scalarCanCheckCanonical).IsCanonical(sig[32:]) { return fmt.Errorf("signature is not canonical") } - if !pointIsCanonical(pub) { - return fmt.Errorf("public key is not canonical") + + type pointCanCheckCanonicalAndSmallOrder interface { + HasSmallOrder() bool + IsCanonical(b []byte) bool } - if !pointIsCanonical(sig[:32]) { + R := group.Point() + if !R.(pointCanCheckCanonicalAndSmallOrder).IsCanonical(sig[:32]) { return fmt.Errorf("R is not canonical") } - R := group.Point() if err := R.UnmarshalBinary(sig[:32]); err != nil { return fmt.Errorf("got R invalid point: %s", err) } - if R.(edDSAPoint).HasSmallOrder() { + if R.(pointCanCheckCanonicalAndSmallOrder).HasSmallOrder() { return fmt.Errorf("R has small order") } @@ -167,10 +159,13 @@ func VerifyWithChecks(pub, msg, sig []byte) error { } public := group.Point() + if !public.(pointCanCheckCanonicalAndSmallOrder).IsCanonical(pub) { + return fmt.Errorf("public key is not canonical") + } if err := public.UnmarshalBinary(pub); err != nil { return fmt.Errorf("invalid public key: %s", err) } - if public.(edDSAPoint).HasSmallOrder() { + if public.(pointCanCheckCanonicalAndSmallOrder).HasSmallOrder() { return fmt.Errorf("public key has small order") } @@ -201,59 +196,3 @@ func Verify(public kyber.Point, msg, sig []byte) error { } return VerifyWithChecks(PBuf, msg, sig) } - -// scalarIsCanonical whether scalar s is in the range 0<=s= 0; i-- { - // subtraction might lead to an underflow which needs - // to be accounted for in the right shift - c |= byte((uint16(sb[i])-uint16(L[i]))>>8) & n - n &= byte((uint16(sb[i]) ^ uint16(L[i]) - 1) >> 8) - } - - return c != 0 -} - -// pointIsCanonical determines whether the group element is canonical -// -// Checks whether group element s is less than p, according to RFC8032§5.1.3.1 -// https://tools.ietf.org/html/rfc8032#section-5.1.3 -// -// Taken from -// https://github.com/jedisct1/libsodium/blob/4744636721d2e420f8bbe2d563f31b1f5e682229/src/libsodium/crypto_core/ed25519/ref10/ed25519_ref10.c#L1113 -func pointIsCanonical(s []byte) bool { - if len(s) != 32 { - return false - } - - c := (s[31] & 0x7f) ^ 0x7f - for i := 30; i > 0; i-- { - c |= s[i] ^ 0xff - } - - // subtraction might underflow - c = byte((uint16(c) - 1) >> 8) - d := byte((0xed - 1 - uint16(s[0])) >> 8) - - return 1-(c&d&1) == 1 -} diff --git a/sign/eddsa/eddsa_test.go b/sign/eddsa/eddsa_test.go index 126139e76..2f7b88334 100644 --- a/sign/eddsa/eddsa_test.go +++ b/sign/eddsa/eddsa_test.go @@ -6,8 +6,6 @@ import ( "compress/gzip" "crypto/cipher" "encoding/hex" - "fmt" - "math/big" "math/rand" "os" "strings" @@ -346,60 +344,3 @@ func TestGolden(t *testing.T) { t.Fatalf("error reading test data: %s", err) } } - -// Test_pointIsCanonical ensures that elements >= p are considered -// non canonical -func Test_pointIsCanonical(t *testing.T) { - - // buffer stores the candidate points (in little endian) that we'll test - // against, starting with `prime` - buffer := prime.Bytes() - for i, j := 0, len(buffer)-1; i < j; i, j = i+1, j-1 { - buffer[i], buffer[j] = buffer[j], buffer[i] - } - - // Iterate over the 19*2 finite field elements - point := group.Point() - actualNonCanonicalCount := 0 - expectedNonCanonicalCount := 24 - for i := 0; i < 19; i++ { - buffer[0] = byte(237 + i) - buffer[31] = byte(127) - - // Check if it's a valid point on the curve that's - // not canonical - err := point.UnmarshalBinary(buffer) - if err == nil && !pointIsCanonical(buffer) { - actualNonCanonicalCount++ - } - - // flip bit - buffer[31] |= 128 - - // Check if it's a valid point on the curve that's - // not canonical - err = point.UnmarshalBinary(buffer) - if err == nil && !pointIsCanonical(buffer) { - actualNonCanonicalCount++ - } - } - require.Equal(t, expectedNonCanonicalCount, actualNonCanonicalCount, "Incorrect number of non canonical points detected") -} - -func Test_scalarIsCanonical(t *testing.T) { - candidate := big.NewInt(-2) - candidate.Add(candidate, primeOrder) - - candidateBuf := candidate.Bytes() - for i, j := 0, len(candidateBuf)-1; i < j; i, j = i+1, j-1 { - candidateBuf[i], candidateBuf[j] = candidateBuf[j], candidateBuf[i] - } - - expected := []bool{true, true, false, false} - - // We check in range [L-2, L+4) - for i := 0; i < 4; i++ { - require.Equal(t, expected[i], scalarIsCanonical(candidateBuf), fmt.Sprintf("`lMinus2 + %d` does not pass canonicality test", i)) - candidateBuf[0]++ - } -} diff --git a/sign/schnorr/schnorr.go b/sign/schnorr/schnorr.go index 9d0dae6b8..969cb9b01 100644 --- a/sign/schnorr/schnorr.go +++ b/sign/schnorr/schnorr.go @@ -18,8 +18,6 @@ import ( "fmt" "go.dedis.ch/kyber/v3" - "go.dedis.ch/kyber/v3/group/edwards25519" - "go.dedis.ch/kyber/v3/sign/eddsa" ) // Suite represents the set of functionalities needed by the package schnorr. @@ -65,8 +63,13 @@ func Sign(s Suite, private kyber.Scalar, msg []byte) ([]byte, error) { // additional checks around the canonicality and ensures the public key // does not have a small order when using `edwards25519` group. func VerifyWithChecks(g kyber.Group, pub, msg, sig []byte) error { - if _, ok := g.(*edwards25519.SuiteEd25519); ok { - return eddsa.VerifyWithChecks(pub, msg, sig) + type scalarCanCheckCanonical interface { + IsCanonical(b []byte) bool + } + + type pointCanCheckCanonicalAndSmallOrder interface { + HasSmallOrder() bool + IsCanonical(b []byte) bool } R := g.Point() @@ -80,6 +83,17 @@ func VerifyWithChecks(g kyber.Group, pub, msg, sig []byte) error { if err := R.UnmarshalBinary(sig[:pointSize]); err != nil { return err } + if p, ok := R.(pointCanCheckCanonicalAndSmallOrder); ok { + if !p.IsCanonical(sig[:pointSize]) { + return fmt.Errorf("R is not canonical") + } + if p.HasSmallOrder() { + return fmt.Errorf("R has small order") + } + } + if s, ok := g.Scalar().(scalarCanCheckCanonical); ok && !s.IsCanonical(sig[pointSize:]) { + return fmt.Errorf("signature is not canonical") + } if err := s.UnmarshalBinary(sig[pointSize:]); err != nil { return err } @@ -89,6 +103,14 @@ func VerifyWithChecks(g kyber.Group, pub, msg, sig []byte) error { if err != nil { return fmt.Errorf("schnorr: error unmarshalling public key") } + if p, ok := public.(pointCanCheckCanonicalAndSmallOrder); ok { + if !p.IsCanonical(pub) { + return fmt.Errorf("public key is not canonical") + } + if p.HasSmallOrder() { + return fmt.Errorf("public key has small order") + } + } // recompute hash(public || R || msg) h, err := hash(g, public, R, msg) if err != nil {