diff --git a/Cargo.lock b/Cargo.lock index a5540d1d..748854e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -145,6 +145,7 @@ dependencies = [ name = "salsa20" version = "0.11.0-pre" dependencies = [ + "cfg-if", "cipher", "hex-literal", ] diff --git a/salsa20/Cargo.toml b/salsa20/Cargo.toml index 5307bf3a..7d611c3d 100644 --- a/salsa20/Cargo.toml +++ b/salsa20/Cargo.toml @@ -13,6 +13,7 @@ keywords = ["crypto", "stream-cipher", "trait", "xsalsa20"] categories = ["cryptography", "no-std"] [dependencies] +cfg-if = "1" cipher = "=0.5.0-pre.4" [dev-dependencies] diff --git a/salsa20/src/backends.rs b/salsa20/src/backends.rs new file mode 100644 index 00000000..49f13ee5 --- /dev/null +++ b/salsa20/src/backends.rs @@ -0,0 +1,20 @@ +use cfg_if::cfg_if; + +cfg_if! { + if #[cfg(salsa20_force_soft)] { + pub(crate) mod soft; + } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { + cfg_if! { + if #[cfg(salsa20_force_sse2)] { + pub(crate) mod sse2; + } else if #[cfg(salsa20_force_soft)] { + pub(crate) mod soft; + } else { + pub(crate) mod sse2; + pub(crate) mod soft; + } + } + } else { + pub(crate) mod soft; + } +} diff --git a/salsa20/src/backends/soft.rs b/salsa20/src/backends/soft.rs new file mode 100644 index 00000000..c7c2a91c --- /dev/null +++ b/salsa20/src/backends/soft.rs @@ -0,0 +1,70 @@ +//! Portable implementation which does not rely on architecture-specific +//! intrinsics. + +use crate::{Block, SalsaCore, Unsigned, STATE_WORDS}; +use cipher::{ + consts::{U1, U64}, + BlockSizeUser, ParBlocksSizeUser, StreamBackend, StreamCipherSeekCore, +}; + +pub(crate) struct Backend<'a, R: Unsigned>(pub(crate) &'a mut SalsaCore); + +impl<'a, R: Unsigned> BlockSizeUser for Backend<'a, R> { + type BlockSize = U64; +} + +impl<'a, R: Unsigned> ParBlocksSizeUser for Backend<'a, R> { + type ParBlocksSize = U1; +} + +impl<'a, R: Unsigned> StreamBackend for Backend<'a, R> { + #[inline(always)] + fn gen_ks_block(&mut self, block: &mut Block) { + let res = run_rounds::(&self.0.state); + + self.0.set_block_pos(self.0.get_block_pos() + 1); + + for (chunk, val) in block.chunks_exact_mut(4).zip(res.iter()) { + chunk.copy_from_slice(&val.to_le_bytes()); + } + } +} + +#[inline] +#[allow(clippy::many_single_char_names)] +pub(crate) fn quarter_round( + a: usize, + b: usize, + c: usize, + d: usize, + state: &mut [u32; STATE_WORDS], +) { + state[b] ^= state[a].wrapping_add(state[d]).rotate_left(7); + state[c] ^= state[b].wrapping_add(state[a]).rotate_left(9); + state[d] ^= state[c].wrapping_add(state[b]).rotate_left(13); + state[a] ^= state[d].wrapping_add(state[c]).rotate_left(18); +} + +#[inline(always)] +fn run_rounds(state: &[u32; STATE_WORDS]) -> [u32; STATE_WORDS] { + let mut res = *state; + + for _ in 0..R::USIZE { + // column rounds + quarter_round(0, 4, 8, 12, &mut res); + quarter_round(5, 9, 13, 1, &mut res); + quarter_round(10, 14, 2, 6, &mut res); + quarter_round(15, 3, 7, 11, &mut res); + + // diagonal rounds + quarter_round(0, 1, 2, 3, &mut res); + quarter_round(5, 6, 7, 4, &mut res); + quarter_round(10, 11, 8, 9, &mut res); + quarter_round(15, 12, 13, 14, &mut res); + } + + for (s1, s0) in res.iter_mut().zip(state.iter()) { + *s1 = s1.wrapping_add(*s0); + } + res +} diff --git a/salsa20/src/backends/sse2.rs b/salsa20/src/backends/sse2.rs new file mode 100644 index 00000000..3e0199a8 --- /dev/null +++ b/salsa20/src/backends/sse2.rs @@ -0,0 +1,156 @@ +use crate::{Block, StreamClosure, Unsigned, STATE_WORDS}; +use cipher::{ + consts::{U1, U64}, + BlockSizeUser, ParBlocksSizeUser, StreamBackend, +}; +use core::marker::PhantomData; + +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +#[inline] +#[target_feature(enable = "sse2")] +pub(crate) unsafe fn inner(state: &mut [u32; STATE_WORDS], f: F) +where + R: Unsigned, + F: StreamClosure, +{ + let state_ptr = state.as_ptr() as *const __m128i; + let mut backend = Backend:: { + v: [ + _mm_loadu_si128(state_ptr.add(0)), + _mm_loadu_si128(state_ptr.add(1)), + _mm_loadu_si128(state_ptr.add(2)), + _mm_loadu_si128(state_ptr.add(3)), + ], + _pd: PhantomData, + }; + + f.call(&mut backend); + + state[8] = _mm_cvtsi128_si32(backend.v[2]) as u32; +} + +struct Backend { + v: [__m128i; 4], + _pd: PhantomData, +} + +impl BlockSizeUser for Backend { + type BlockSize = U64; +} + +impl ParBlocksSizeUser for Backend { + type ParBlocksSize = U1; +} + +impl StreamBackend for Backend { + #[inline(always)] + fn gen_ks_block(&mut self, block: &mut Block) { + unsafe { + let res = rounds::(&self.v); + + self.v[2] = _mm_add_epi32(self.v[2], _mm_set_epi32(0, 0, 0, 1)); + let block_ptr = block.as_mut_ptr() as *mut __m128i; + + for (i, v) in res.iter().enumerate() { + _mm_storeu_si128(block_ptr.add(i), *v); + } + } + } +} + +#[inline] +#[target_feature(enable = "sse2")] +unsafe fn rounds(v: &[__m128i; 4]) -> [__m128i; 4] { + let mut res = *v; + + for _ in 0..R::USIZE { + double_round(&mut res); + } + + for i in 0..4 { + res[i] = _mm_add_epi32(res[i], v[i]); + } + + transpose(&mut res); + res[1] = _mm_shuffle_epi32(res[1], 0b_10_01_00_11); + res[2] = _mm_shuffle_epi32(res[2], 0b_01_00_11_10); + res[3] = _mm_shuffle_epi32(res[3], 0b_00_11_10_01); + transpose(&mut res); + + res +} + +/// The Salsa20 doubleround function for SSE2. +/// +/// https://users.rust-lang.org/t/can-the-compiler-infer-sse-instructions/59976 +#[inline] +#[target_feature(enable = "sse2")] +unsafe fn double_round([a, b, c, d]: &mut [__m128i; 4]) { + let mut t_sum: __m128i; + let mut t_rotl: __m128i; + + // Operate on "columns" + t_sum = _mm_add_epi32(*a, *d); + t_rotl = _mm_xor_si128(_mm_slli_epi32(t_sum, 7), _mm_srli_epi32(t_sum, 25)); + *b = _mm_xor_si128(*b, t_rotl); + + t_sum = _mm_add_epi32(*b, *a); + t_rotl = _mm_xor_si128(_mm_slli_epi32(t_sum, 9), _mm_srli_epi32(t_sum, 23)); + *c = _mm_xor_si128(*c, t_rotl); + + t_sum = _mm_add_epi32(*c, *b); + t_rotl = _mm_xor_si128(_mm_slli_epi32(t_sum, 13), _mm_srli_epi32(t_sum, 19)); + *d = _mm_xor_si128(*d, t_rotl); + + t_sum = _mm_add_epi32(*d, *c); + t_rotl = _mm_xor_si128(_mm_slli_epi32(t_sum, 18), _mm_srli_epi32(t_sum, 14)); + *a = _mm_xor_si128(*a, t_rotl); + + // Rearrange data. + *b = _mm_shuffle_epi32(*b, 0b_10_01_00_11); + *c = _mm_shuffle_epi32(*c, 0b_01_00_11_10); + *d = _mm_shuffle_epi32(*d, 0b_00_11_10_01); + + // Operate on "rows". + t_sum = _mm_add_epi32(*a, *b); + t_rotl = _mm_xor_si128(_mm_slli_epi32(t_sum, 7), _mm_srli_epi32(t_sum, 25)); + *d = _mm_xor_si128(*d, t_rotl); + + t_sum = _mm_add_epi32(*d, *a); + t_rotl = _mm_xor_si128(_mm_slli_epi32(t_sum, 9), _mm_srli_epi32(t_sum, 23)); + *c = _mm_xor_si128(*c, t_rotl); + + t_sum = _mm_add_epi32(*c, *d); + t_rotl = _mm_xor_si128(_mm_slli_epi32(t_sum, 13), _mm_srli_epi32(t_sum, 19)); + *b = _mm_xor_si128(*b, t_rotl); + + t_sum = _mm_add_epi32(*b, *c); + t_rotl = _mm_xor_si128(_mm_slli_epi32(t_sum, 18), _mm_srli_epi32(t_sum, 14)); + *a = _mm_xor_si128(*a, t_rotl); + + // Rearrange data. + *b = _mm_shuffle_epi32(*b, 0b_00_11_10_01); + *c = _mm_shuffle_epi32(*c, 0b_01_00_11_10); + *d = _mm_shuffle_epi32(*d, 0b_10_01_00_11); +} + +/// Transpose an integer 4 by 4 matrix in SSE2. +/// +/// https://randombit.net/bitbashing/posts/integer_matrix_transpose_in_sse2.html +#[inline] +#[target_feature(enable = "sse2")] +unsafe fn transpose([a, b, c, d]: &mut [__m128i; 4]) { + let t0 = _mm_unpacklo_epi32(*a, *b); + let t1 = _mm_unpacklo_epi32(*c, *d); + let t2 = _mm_unpackhi_epi32(*a, *b); + let t3 = _mm_unpackhi_epi32(*c, *d); + + *a = _mm_unpacklo_epi64(t0, t1); + *b = _mm_unpackhi_epi64(t0, t1); + *c = _mm_unpacklo_epi64(t2, t3); + *d = _mm_unpackhi_epi64(t2, t3); +} diff --git a/salsa20/src/lib.rs b/salsa20/src/lib.rs index 45bc9f0d..4e068220 100644 --- a/salsa20/src/lib.rs +++ b/salsa20/src/lib.rs @@ -61,6 +61,21 @@ //! assert_eq!(buffer, ciphertext); //! ``` //! +//! # Configuration Flags +//! +//! You can modify crate using the following configuration flags: +//! +//! - `salsa20_force_soft`: force software backend. +//! - `salsa20_force_sse2`: force SSE2 backend on x86/x86_64 targets. +//! Requires enabled SSE2 target feature. Ignored on non-x86(-64) targets. +//! +//! Salsa20 will run the SSE2 backend in x86(-64) targets unless `salsa20_force_soft` is set. +//! +//! The flags can be enabled using `RUSTFLAGS` environmental variable +//! (e.g. `RUSTFLAGS="--cfg salsa20_force_sse2"`) or by modifying `.cargo/config`. +//! +//! You SHOULD NOT enable several `force` flags simultaneously. +//! //! [Salsa]: https://en.wikipedia.org/wiki/Salsa20 #![no_std] @@ -72,20 +87,21 @@ )] #![warn(missing_docs, rust_2018_idioms, trivial_casts, unused_qualifications)] +use cfg_if::cfg_if; pub use cipher; use cipher::{ array::{typenum::Unsigned, Array}, - consts::{U1, U10, U24, U32, U4, U6, U64, U8}, - Block, BlockSizeUser, IvSizeUser, KeyIvInit, KeySizeUser, ParBlocksSizeUser, StreamBackend, - StreamCipherCore, StreamCipherCoreWrapper, StreamCipherSeekCore, StreamClosure, + consts::{U10, U24, U32, U4, U6, U64, U8}, + Block, BlockSizeUser, IvSizeUser, KeyIvInit, KeySizeUser, StreamCipherCore, + StreamCipherCoreWrapper, StreamCipherSeekCore, StreamClosure, }; use core::marker::PhantomData; #[cfg(feature = "zeroize")] use cipher::zeroize::{Zeroize, ZeroizeOnDrop}; -//mod backends; +mod backends; mod xsalsa; pub use xsalsa::{hsalsa, XSalsa12, XSalsa20, XSalsa8, XSalsaCore}; @@ -175,6 +191,19 @@ impl KeyIvInit for SalsaCore { state[15] = CONSTANTS[3]; + cfg_if! { + if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { + #[cfg(not(salsa20_force_soft))] { + state = [ + state[0], state[5], state[10], state[15], + state[4], state[9], state[14], state[3], + state[8], state[13], state[2], state[7], + state[12], state[1], state[6], state[11], + ]; + } + } + } + Self { state, rounds: PhantomData, @@ -189,7 +218,23 @@ impl StreamCipherCore for SalsaCore { rem.try_into().ok() } fn process_with_backend(&mut self, f: impl StreamClosure) { - f.call(&mut Backend(self)); + cfg_if! { + if #[cfg(salsa20_force_soft)] { + f.call(&mut backends::soft::Backend(self)); + } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { + cfg_if! { + if #[cfg(not(salsa20_force_soft))] { + unsafe { + backends::sse2::inner::(&mut self.state, f); + } + } else { + f.call(&mut backends::soft::Backend(self)); + } + } + } else { + f.call(&mut backends::soft::Backend(self)); + } + } } } @@ -198,13 +243,12 @@ impl StreamCipherSeekCore for SalsaCore { #[inline(always)] fn get_block_pos(&self) -> u64 { - (self.state[8] as u64) + ((self.state[9] as u64) << 32) + self.state[8] as u64 } #[inline(always)] fn set_block_pos(&mut self, pos: u64) { - self.state[8] = (pos & 0xffff_ffff) as u32; - self.state[9] = ((pos >> 32) & 0xffff_ffff) as u32; + self.state[8] = pos as u32; } } @@ -219,64 +263,3 @@ impl Drop for SalsaCore { #[cfg(feature = "zeroize")] #[cfg_attr(docsrs, doc(cfg(feature = "zeroize")))] impl ZeroizeOnDrop for SalsaCore {} - -struct Backend<'a, R: Unsigned>(&'a mut SalsaCore); - -impl<'a, R: Unsigned> BlockSizeUser for Backend<'a, R> { - type BlockSize = U64; -} - -impl<'a, R: Unsigned> ParBlocksSizeUser for Backend<'a, R> { - type ParBlocksSize = U1; -} - -impl<'a, R: Unsigned> StreamBackend for Backend<'a, R> { - #[inline(always)] - fn gen_ks_block(&mut self, block: &mut Block) { - let res = run_rounds::(&self.0.state); - self.0.set_block_pos(self.0.get_block_pos() + 1); - - for (chunk, val) in block.chunks_exact_mut(4).zip(res.iter()) { - chunk.copy_from_slice(&val.to_le_bytes()); - } - } -} - -#[inline] -#[allow(clippy::many_single_char_names)] -pub(crate) fn quarter_round( - a: usize, - b: usize, - c: usize, - d: usize, - state: &mut [u32; STATE_WORDS], -) { - state[b] ^= state[a].wrapping_add(state[d]).rotate_left(7); - state[c] ^= state[b].wrapping_add(state[a]).rotate_left(9); - state[d] ^= state[c].wrapping_add(state[b]).rotate_left(13); - state[a] ^= state[d].wrapping_add(state[c]).rotate_left(18); -} - -#[inline(always)] -fn run_rounds(state: &[u32; STATE_WORDS]) -> [u32; STATE_WORDS] { - let mut res = *state; - - for _ in 0..R::USIZE { - // column rounds - quarter_round(0, 4, 8, 12, &mut res); - quarter_round(5, 9, 13, 1, &mut res); - quarter_round(10, 14, 2, 6, &mut res); - quarter_round(15, 3, 7, 11, &mut res); - - // diagonal rounds - quarter_round(0, 1, 2, 3, &mut res); - quarter_round(5, 6, 7, 4, &mut res); - quarter_round(10, 11, 8, 9, &mut res); - quarter_round(15, 12, 13, 14, &mut res); - } - - for (s1, s0) in res.iter_mut().zip(state.iter()) { - *s1 = s1.wrapping_add(*s0); - } - res -} diff --git a/salsa20/src/xsalsa.rs b/salsa20/src/xsalsa.rs index 6316972b..fc8659a7 100644 --- a/salsa20/src/xsalsa.rs +++ b/salsa20/src/xsalsa.rs @@ -1,6 +1,6 @@ //! XSalsa20 is an extended nonce variant of Salsa20 -use super::{quarter_round, Key, Nonce, SalsaCore, Unsigned, XNonce, CONSTANTS}; +use super::{Key, Nonce, SalsaCore, Unsigned, XNonce, CONSTANTS, STATE_WORDS}; use cipher::{ array::Array, consts::{U10, U16, U24, U32, U4, U6, U64}, @@ -136,3 +136,18 @@ pub fn hsalsa(key: &Key, input: &Array) -> Array output } + +/// The Salsa20 quarter round function +// for simplicity this function is copied from the software backend +pub(crate) fn quarter_round( + a: usize, + b: usize, + c: usize, + d: usize, + state: &mut [u32; STATE_WORDS], +) { + state[b] ^= state[a].wrapping_add(state[d]).rotate_left(7); + state[c] ^= state[b].wrapping_add(state[a]).rotate_left(9); + state[d] ^= state[c].wrapping_add(state[b]).rotate_left(13); + state[a] ^= state[d].wrapping_add(state[c]).rotate_left(18); +} diff --git a/salsa20/tests/mod.rs b/salsa20/tests/mod.rs index 46dc2681..9388aedb 100644 --- a/salsa20/tests/mod.rs +++ b/salsa20/tests/mod.rs @@ -175,23 +175,3 @@ fn xsalsa20_encrypt_hello_world() { assert_eq!(buf, EXPECTED_XSALSA20_HELLO_WORLD); } - -#[test] -fn salsa20_regression_2024_03() { - use salsa20::{ - cipher::{typenum::U4, StreamCipherCore}, - SalsaCore, - }; - - type Salsa20_8 = SalsaCore; - - let mut x : [u8; 64] = hex!("8dcf83fa131d44aaa4241dc58a86d0851d5cb1815e05cc0b8da1f4a39b2ef6a5db2f2bec267136a57a78930da84da1e1984baeb30aca20642c4da8a4cb42fb4f"); - let t2: [u32; 16] = [ - 2123785505, 879699904, 959334342, 2115744216, 477309436, 1153321713, 2181596049, 488300870, - 1113186107, 4152514392, 2202170644, 2028366353, 2177718219, 2842602797, 3038675742, - 1716559436, - ]; - Salsa20_8::from_raw_state(t2).write_keystream_block((&mut x).into()); - - assert_eq!(x, hex!("66a3d4a32f86eb8eaefe5aa25cb5ff1aac91177dd03f114979d042f15658a505035b90d1559f1dd0c2ceaf3014129729fdd697cf94d16116588b271cd03d9b42")); -}