diff --git a/cng/rc4.go b/cng/rc4.go index e0d4507..f484a3e 100644 --- a/cng/rc4.go +++ b/cng/rc4.go @@ -47,6 +47,10 @@ func (c *RC4Cipher) XORKeyStream(dst, src []byte) { if c.kh == 0 || len(src) == 0 { return } + // rc4.Cipher.XORKeyStream throws an out of bounds panic if + // dst is smaller than src. Replicate the same behavior here. + _ = dst[len(src)-1] + if subtle.InexactOverlap(dst[:len(src)], src) { panic("crypto/rc4: invalid buffer overlap") } diff --git a/cng/rc4_test.go b/cng/rc4_test.go index 3920b15..8e70bfa 100644 --- a/cng/rc4_test.go +++ b/cng/rc4_test.go @@ -8,10 +8,12 @@ package cng_test import ( "bytes" + "crypto/cipher" "fmt" "testing" "github.com/microsoft/go-crypto-winnative/cng" + "github.com/microsoft/go-crypto-winnative/internal/cryptotest" ) type rc4Test struct { @@ -138,6 +140,13 @@ func TestRC4Block(t *testing.T) { } } +func TestRC4Stream(t *testing.T) { + cryptotest.TestStream(t, func() cipher.Stream { + c, _ := cng.NewRC4Cipher(golden[0].key) + return c + }) +} + func benchmarkRC4(b *testing.B, size int64) { buf := make([]byte, size) c, err := cng.NewRC4Cipher(golden[0].key) diff --git a/internal/cryptotest/block.go b/internal/cryptotest/block.go index 9aea689..a1c3bd2 100644 --- a/internal/cryptotest/block.go +++ b/internal/cryptotest/block.go @@ -10,8 +10,6 @@ import ( "testing" ) -// This file is a copy of https://github.com/golang/go/blob/9e9b1f57c26a6d13fdaebef67136718b8042cdba/src/crypto/internal/cryptotest/block.go. - type MakeBlock func(key []byte) (cipher.Block, error) // TestBlock performs a set of tests on cipher.Block implementations, checking diff --git a/internal/cryptotest/stream.go b/internal/cryptotest/stream.go new file mode 100644 index 0000000..fb9c553 --- /dev/null +++ b/internal/cryptotest/stream.go @@ -0,0 +1,241 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cryptotest + +import ( + "bytes" + "crypto/cipher" + "crypto/subtle" + "fmt" + "strings" + "testing" +) + +// Each test is executed with each of the buffer lengths in bufLens. +var ( + bufLens = []int{0, 1, 3, 4, 8, 10, 15, 16, 20, 32, 50, 4096, 5000} + bufCap = 10000 +) + +// MakeStream returns a cipher.Stream instance. +// +// Multiple calls to MakeStream must return equivalent instances, +// so for example the key and/or IV must be fixed. +type MakeStream func() cipher.Stream + +// TestStream performs a set of tests on cipher.Stream implementations, +// checking the documented requirements of XORKeyStream. +func TestStream(t *testing.T, ms MakeStream) { + + t.Run("XORSemantics", func(t *testing.T) { + if strings.Contains(t.Name(), "TestCFBStream") { + // This is ugly, but so is CFB's abuse of cipher.Stream. + // Don't want to make it easier for anyone else to do that. + t.Skip("CFB implements cipher.Stream but does not follow XOR semantics") + } + + // Test that XORKeyStream inverts itself for encryption/decryption. + t.Run("Roundtrip", func(t *testing.T) { + + for _, length := range bufLens { + t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) { + rng := newRandReader(t) + + plaintext := make([]byte, length) + rng.Read(plaintext) + + ciphertext := make([]byte, length) + decrypted := make([]byte, length) + + ms().XORKeyStream(ciphertext, plaintext) // Encrypt plaintext + ms().XORKeyStream(decrypted, ciphertext) // Decrypt ciphertext + if !bytes.Equal(decrypted, plaintext) { + t.Errorf("plaintext is different after an encrypt/decrypt cycle; got %s, want %s", truncateHex(decrypted), truncateHex(plaintext)) + } + }) + } + }) + + // Test that XORKeyStream behaves the same as directly XORing + // plaintext with the stream. + t.Run("DirectXOR", func(t *testing.T) { + + for _, length := range bufLens { + t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) { + rng := newRandReader(t) + + plaintext := make([]byte, length) + rng.Read(plaintext) + + // Encrypting all zeros should reveal the stream itself + stream, directXOR := make([]byte, length), make([]byte, length) + ms().XORKeyStream(stream, stream) + // Encrypt plaintext by directly XORing the stream + subtle.XORBytes(directXOR, stream, plaintext) + + // Encrypt plaintext with XORKeyStream + ciphertext := make([]byte, length) + ms().XORKeyStream(ciphertext, plaintext) + if !bytes.Equal(ciphertext, directXOR) { + t.Errorf("xor semantics were not preserved; got %s, want %s", truncateHex(ciphertext), truncateHex(directXOR)) + } + }) + } + }) + }) + + t.Run("AlterInput", func(t *testing.T) { + rng := newRandReader(t) + src, dst, before := make([]byte, bufCap), make([]byte, bufCap), make([]byte, bufCap) + rng.Read(src) + + for _, length := range bufLens { + + t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) { + copy(before, src) + + ms().XORKeyStream(dst[:length], src[:length]) + if !bytes.Equal(src, before) { + t.Errorf("XORKeyStream modified src; got %s, want %s", truncateHex(src), truncateHex(before)) + } + }) + } + }) + + t.Run("Aliasing", func(t *testing.T) { + rng := newRandReader(t) + + buff, expectedOutput := make([]byte, bufCap), make([]byte, bufCap) + + for _, length := range bufLens { + // Record what output is when src and dst are different + rng.Read(buff) + ms().XORKeyStream(expectedOutput[:length], buff[:length]) + + // Check that the same output is generated when src=dst alias to the same + // memory + ms().XORKeyStream(buff[:length], buff[:length]) + if !bytes.Equal(buff[:length], expectedOutput[:length]) { + t.Errorf("block cipher produced different output when dst = src; got %x, want %x", buff[:length], expectedOutput[:length]) + } + } + }) + + t.Run("OutOfBoundsWrite", func(t *testing.T) { // Issue 21104 + rng := newRandReader(t) + + plaintext := make([]byte, bufCap) + rng.Read(plaintext) + ciphertext := make([]byte, bufCap) + + for _, length := range bufLens { + copy(ciphertext, plaintext) // Reset ciphertext buffer + + t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) { + mustPanic(t, "output smaller than input", func() { ms().XORKeyStream(ciphertext[:length], plaintext) }) + + if !bytes.Equal(ciphertext[length:], plaintext[length:]) { + t.Errorf("XORKeyStream did out of bounds write; got %s, want %s", truncateHex(ciphertext[length:]), truncateHex(plaintext[length:])) + } + }) + } + }) + + t.Run("BufferOverlap", func(t *testing.T) { + rng := newRandReader(t) + + buff := make([]byte, bufCap) + rng.Read(buff) + + for _, length := range bufLens { + if length == 0 || length == 1 { + continue + } + + t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) { + // Make src and dst slices point to same array with inexact overlap + src := buff[:length] + dst := buff[1 : length+1] + mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) }) + + // Only overlap on one byte + src = buff[:length] + dst = buff[length-1 : 2*length-1] + mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) }) + + // src comes after dst with one byte overlap + src = buff[length-1 : 2*length-1] + dst = buff[:length] + mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) }) + }) + } + }) + + t.Run("KeepState", func(t *testing.T) { + rng := newRandReader(t) + + plaintext := make([]byte, bufCap) + rng.Read(plaintext) + ciphertext := make([]byte, bufCap) + + // Make one long call to XORKeyStream + ms().XORKeyStream(ciphertext, plaintext) + + for _, step := range bufLens { + if step == 0 { + continue + } + stepMsg := fmt.Sprintf("step %d: ", step) + + dst := make([]byte, bufCap) + + // Make a bunch of small calls to (stateful) XORKeyStream + stream := ms() + i := 0 + for i+step < len(plaintext) { + stream.XORKeyStream(dst[i:], plaintext[i:i+step]) + i += step + } + stream.XORKeyStream(dst[i:], plaintext[i:]) + + if !bytes.Equal(dst, ciphertext) { + t.Errorf(stepMsg+"successive XORKeyStream calls returned a different result than a single one; got %s, want %s", truncateHex(dst), truncateHex(ciphertext)) + } + } + }) +} + +// TestStreamFromBlock creates a Stream from a cipher.Block used in a +// cipher.BlockMode. It addresses Issue 68377 by checking for a panic when the +// BlockMode uses an IV with incorrect length. +// For a valid IV, it also runs all TestStream tests on the resulting stream. +func TestStreamFromBlock(t *testing.T, block cipher.Block, blockMode func(b cipher.Block, iv []byte) cipher.Stream) { + + t.Run("WrongIVLen", func(t *testing.T) { + t.Skip("see Issue 68377") + + rng := newRandReader(t) + iv := make([]byte, block.BlockSize()+1) + rng.Read(iv) + mustPanic(t, "IV length must equal block size", func() { blockMode(block, iv) }) + }) + + t.Run("BlockModeStream", func(t *testing.T) { + rng := newRandReader(t) + iv := make([]byte, block.BlockSize()) + rng.Read(iv) + + TestStream(t, func() cipher.Stream { return blockMode(block, iv) }) + }) +} + +func truncateHex(b []byte) string { + numVals := 50 + + if len(b) <= numVals { + return fmt.Sprintf("%x", b) + } + return fmt.Sprintf("%x...", b[:numVals]) +}