diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..5e6dec7 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,61 @@ +name: Test + +on: + push: + +env: + CARGO_TERM_COLOR: always + +jobs: + cargo-benches: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + override: true + - name: Install FFTW + run: sudo apt install -y libfftw3-dev + - name: Compile benches + run: cargo bench --no-run + + cargo-tests: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + + steps: + - uses: actions/checkout@v2 + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + override: true + - name: Test debug + run: cargo test + - name: Test debug serialization + run: cargo test --features=serde + - name: Test debug no-std + run: cargo test --no-default-features + + cargo-tests-nightly: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + + steps: + - uses: actions/checkout@v2 + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + override: true + - name: Test debug nightly + run: cargo test --features=nightly diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..d52e61f --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,46 @@ +# Check formatting using rustfmt +# and lint with clippy +name: Rustfmt and Clippy check + +on: + push: + +jobs: + rustfmt: + name: rustfmt + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + components: rustfmt + override: true + + - name: Run cargo fmt + uses: actions-rs/cargo@v1 + with: + command: fmt + args: --all -- --check + + clippy: + name: clippy + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + components: clippy + override: true + + - uses: actions-rs/clippy-check@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + args: --all-targets --features=serde -- --no-deps -D warnings diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..c6675ff --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "concrete-fft" +version = "0.1.0" +edition = "2021" +authors = ["sarah el kazdadi "] +description = "Concrete-FFT is a pure Rust high performance fast Fourier transform library." +readme = "README.md" +repository = "https://github.com/zama-ai/concrete-fft" +license = "BSD-3-Clause-Clear" +homepage = "https://zama.ai/" +keywords = ["fft"] + +[dependencies] +num-complex = "0.4" +dyn-stack = { version = "0.8", default-features = false } +aligned-vec = { version = "0.5", default-features = false } +serde = { version = "1.0", optional = true, default-features = false } + +[features] +default = ["std"] +nightly = [] +std = [] +serde = ["dep:serde", "num-complex/serde"] + +[dev-dependencies] +criterion = "0.3" +rustfft = "6.0" +fftw-sys = { version = "0.6", default-features = false, features = ["system"] } +rand = "0.8" +bincode = "1.3" + +[[bench]] +name = "fft" +harness = false + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--html-in-header", "katex-header.html", "--cfg", "docsrs"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f1ec11c --- /dev/null +++ b/LICENSE @@ -0,0 +1,33 @@ +BSD 3-Clause Clear License + +Copyright © 2022 ZAMA. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this +list of conditions and the following disclaimer in the documentation and/or other +materials provided with the distribution. + +3. Neither the name of ZAMA nor the names of its contributors may be used to endorse +or promote products derived from this software without specific prior written permission. + +NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE*. +THIS SOFTWARE IS PROVIDED BY THE ZAMA AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +ZAMA OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*In addition to the rights carried by this license, ZAMA grants to the user a non-exclusive, +free and non-commercial license on all patents filed in its name relating to the open-source +code (the "Patents") for the sole purpose of evaluation, development, research, prototyping +and experimentation. diff --git a/README.md b/README.md new file mode 100644 index 0000000..f3002c6 --- /dev/null +++ b/README.md @@ -0,0 +1,80 @@ +Concrete-FFT is a pure Rust high performance fast Fourier transform library +that processes vectors of sizes that are powers of two. It was made to be used +as a backend in Zama's `concrete` library. + +This library provides two FFT modules: + - The ordered module FFT applies a forward/inverse FFT that takes its input in standard + order, and outputs the result in standard order. For more detail on what the FFT + computes, check the ordered module-level documentation. + - The unordered module FFT applies a forward FFT that takes its input in standard order, + and outputs the result in a certain permuted order that may depend on the FFT plan. On the + other hand, the inverse FFT takes its input in that same permuted order and outputs its result + in standard order. This is useful for cases where the order of the coefficients in the + Fourier domain is not important. An example is using the Fourier transform for vector + convolution. The only operations that are performed in the Fourier domain are elementwise, and + so the order of the coefficients does not affect the results. + +## Features + + - `std` (default): This enables runtime arch detection for accelerated SIMD + instructions, and an FFT plan that measures the various implementations to + choose the fastest one at runtime. + - `nightly`: This enables unstable Rust features to further speed up the FFT, + by enabling AVX512F instructions on CPUs that support them. This feature + requires a nightly Rust + toolchain. + - `serde`: This enables serialization and deserialization functions for the + unordered plan. These allow for data in the Fourier domain to be serialized + from the permuted order to the standard order, and deserialized from the + standard order to the permuted order. This is needed since the inverse + transform must be used with the same plan that computed/deserialized the + forward transform (or more specifically, a plan with the same internal base + FFT size). + +## Example + +```rust +use concrete_fft::c64; +use concrete_fft::ordered::{Plan, Method}; +use dyn_stack::{DynStack, GlobalMemBuffer, ReborrowMut}; +use num_complex::ComplexFloat; +use std::time::Duration; + +const N: usize = 4; +let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); +let mut scratch_memory = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); +let mut stack = DynStack::new(&mut scratch_memory); + +let data = [ + c64::new(1.0, 0.0), + c64::new(2.0, 0.0), + c64::new(3.0, 0.0), + c64::new(4.0, 0.0), +]; + +let mut transformed_fwd = data; +plan.fwd(&mut transformed_fwd, stack.rb_mut()); + +let mut transformed_inv = transformed_fwd; +plan.inv(&mut transformed_inv, stack.rb_mut()); + +for (actual, expected) in transformed_inv.iter().map(|z| z / N as f64).zip(data) { + assert!((expected - actual).abs() < 1e-9); +} +``` + +## Links + + - [Zama](https://www.zama.ai/) + - [Concrete](https://github.com/zama-ai/concrete) + +## License + +This software is distributed under the BSD-3-Clause-Clear license with an +exemption that gives rights to use our patents for research, evaluation and +prototyping purposes, as well as for your personal projects. + +If you want to use Concrete in a commercial product however, you will need to +purchase a separate commercial licence. + +If you have any questions, please contact us at `hello@zama.ai.` diff --git a/benches/fft.rs b/benches/fft.rs new file mode 100644 index 0000000..7d93b6b --- /dev/null +++ b/benches/fft.rs @@ -0,0 +1,190 @@ +use concrete_fft::c64; +use core::ptr::NonNull; +use criterion::{criterion_group, criterion_main, Criterion}; +use dyn_stack::{DynStack, ReborrowMut, StackReq}; + +struct FftwAlloc { + bytes: NonNull, +} + +impl Drop for FftwAlloc { + fn drop(&mut self) { + unsafe { + fftw_sys::fftw_free(self.bytes.as_ptr()); + } + } +} + +impl FftwAlloc { + pub fn new(size_bytes: usize) -> FftwAlloc { + unsafe { + let bytes = fftw_sys::fftw_malloc(size_bytes); + if bytes.is_null() { + use std::alloc::{handle_alloc_error, Layout}; + handle_alloc_error(Layout::from_size_align_unchecked(size_bytes, 1)); + } + FftwAlloc { + bytes: NonNull::new_unchecked(bytes), + } + } + } +} + +pub struct PlanInterleavedC64 { + plan: fftw_sys::fftw_plan, + n: usize, +} + +impl Drop for PlanInterleavedC64 { + fn drop(&mut self) { + unsafe { + fftw_sys::fftw_destroy_plan(self.plan); + } + } +} + +pub enum Sign { + Forward, + Backward, +} + +impl PlanInterleavedC64 { + pub fn new(n: usize, sign: Sign) -> Self { + let size_bytes = n.checked_mul(core::mem::size_of::()).unwrap(); + let src = FftwAlloc::new(size_bytes); + let dst = FftwAlloc::new(size_bytes); + unsafe { + let p = fftw_sys::fftw_plan_dft_1d( + n.try_into().unwrap(), + src.bytes.as_ptr() as _, + dst.bytes.as_ptr() as _, + match sign { + Sign::Forward => fftw_sys::FFTW_FORWARD as _, + Sign::Backward => fftw_sys::FFTW_BACKWARD as _, + }, + fftw_sys::FFTW_MEASURE, + ); + PlanInterleavedC64 { plan: p, n } + } + } + + pub fn print(&self) { + unsafe { + fftw_sys::fftw_print_plan(self.plan); + } + } + + pub fn execute(&self, src: &mut [c64], dst: &mut [c64]) { + assert_eq!(src.len(), self.n); + assert_eq!(dst.len(), self.n); + let src = src.as_mut_ptr(); + let dst = dst.as_mut_ptr(); + unsafe { + use fftw_sys::{fftw_alignment_of, fftw_execute_dft}; + assert_eq!(fftw_alignment_of(src as _), 0); + assert_eq!(fftw_alignment_of(dst as _), 0); + fftw_execute_dft(self.plan, src as _, dst as _); + } + } +} + +pub fn criterion_benchmark(c: &mut Criterion) { + for n in [ + 1 << 8, + 1 << 9, + 1 << 10, + 1 << 11, + 1 << 12, + 1 << 13, + 1 << 14, + 1 << 15, + 1 << 16, + ] { + let mut mem = dyn_stack::GlobalMemBuffer::new( + StackReq::new_aligned::(n, 64) // scratch + .and( + StackReq::new_aligned::(2 * n, 64).or(StackReq::new_aligned::(n, 64)), // src | twiddles + ) + .and(StackReq::new_aligned::(n, 64)), // dst + ); + let mut stack = DynStack::new(&mut mem); + let z = c64::new(0.0, 0.0); + + { + use rustfft::FftPlannerAvx; + let mut planner = FftPlannerAvx::::new().unwrap(); + + let fwd_rustfft = planner.plan_fft_forward(n); + let mut scratch = []; + + let fwd_fftw = PlanInterleavedC64::new(n, Sign::Forward); + + let bench_duration = std::time::Duration::from_millis(10); + let ordered = concrete_fft::ordered::Plan::new( + n, + concrete_fft::ordered::Method::Measure(bench_duration), + ); + let unordered = concrete_fft::unordered::Plan::new( + n, + concrete_fft::unordered::Method::Measure(bench_duration), + ); + + { + let (mut dst, stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + let (mut src, _) = stack.make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("rustfft-fwd-{}", n), |b| { + b.iter(|| { + fwd_rustfft.process_outofplace_with_scratch( + &mut src, + &mut dst, + &mut scratch, + ) + }) + }); + + c.bench_function(&format!("fftw-fwd-{}", n), |b| { + b.iter(|| { + fwd_fftw.execute(&mut src, &mut dst); + }) + }); + } + { + let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("concrete-fwd-{}", n), |b| { + b.iter(|| ordered.fwd(&mut *dst, stack.rb_mut())) + }); + } + { + let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("unordered-fwd-{}", n), |b| { + b.iter(|| unordered.fwd(&mut dst, stack.rb_mut())); + }); + } + { + let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("unordered-inv-{}", n), |b| { + b.iter(|| unordered.inv(&mut dst, stack.rb_mut())); + }); + } + } + + // memcpy + { + let (mut dst, stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + let (src, _) = stack.make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("memcpy-{}", n), |b| { + b.iter(|| unsafe { + std::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr(), n); + }) + }); + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/benches/lib.rs b/benches/lib.rs new file mode 100644 index 0000000..2ee2f80 --- /dev/null +++ b/benches/lib.rs @@ -0,0 +1,3 @@ +#![allow(dead_code)] + +mod fft; diff --git a/katex-header.html b/katex-header.html new file mode 100644 index 0000000..be4a727 --- /dev/null +++ b/katex-header.html @@ -0,0 +1,15 @@ + + + + diff --git a/src/dif16.rs b/src/dif16.rs new file mode 100644 index 0000000..3e3601d --- /dev/null +++ b/src/dif16.rs @@ -0,0 +1,1024 @@ +// Copyright (c) 2019 OK Ojisan(Takuya OKAHISA) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +use crate::c64; +use crate::dif2::end_2; +use crate::dif4::end_4; +use crate::dif8::end_8; +use crate::fft_simd::{twid, twid_t, FftSimd64, FftSimd64Ext, FftSimd64X2, FftSimd64X4, Scalar}; +use crate::x86_feature_detected; + +#[inline(always)] +unsafe fn core_( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let m = n / 16; + let big_n = n * s; + let big_n0 = 0; + let big_n1 = big_n / 16; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + let big_n8 = big_n1 * 8; + let big_n9 = big_n1 * 9; + let big_na = big_n1 * 10; + let big_nb = big_n1 * 11; + let big_nc = big_n1 * 12; + let big_nd = big_n1 * 13; + let big_ne = big_n1 * 14; + let big_nf = big_n1 * 15; + + for p in 0..m { + let sp = s * p; + let s16p = 16 * sp; + + let w1p = I::splat(twid_t(16, big_n, 0x1, w, sp)); + let w2p = I::splat(twid_t(16, big_n, 0x2, w, sp)); + let w3p = I::splat(twid_t(16, big_n, 0x3, w, sp)); + let w4p = I::splat(twid_t(16, big_n, 0x4, w, sp)); + let w5p = I::splat(twid_t(16, big_n, 0x5, w, sp)); + let w6p = I::splat(twid_t(16, big_n, 0x6, w, sp)); + let w7p = I::splat(twid_t(16, big_n, 0x7, w, sp)); + let w8p = I::splat(twid_t(16, big_n, 0x8, w, sp)); + let w9p = I::splat(twid_t(16, big_n, 0x9, w, sp)); + let wap = I::splat(twid_t(16, big_n, 0xa, w, sp)); + let wbp = I::splat(twid_t(16, big_n, 0xb, w, sp)); + let wcp = I::splat(twid_t(16, big_n, 0xc, w, sp)); + let wdp = I::splat(twid_t(16, big_n, 0xd, w, sp)); + let wep = I::splat(twid_t(16, big_n, 0xe, w, sp)); + let wfp = I::splat(twid_t(16, big_n, 0xf, w, sp)); + + let mut q = 0; + while q < s { + let xq_sp = x.add(q + sp); + let yq_s16p = y.add(q + s16p); + + let x0 = I::load(xq_sp.add(big_n0)); + let x1 = I::load(xq_sp.add(big_n1)); + let x2 = I::load(xq_sp.add(big_n2)); + let x3 = I::load(xq_sp.add(big_n3)); + let x4 = I::load(xq_sp.add(big_n4)); + let x5 = I::load(xq_sp.add(big_n5)); + let x6 = I::load(xq_sp.add(big_n6)); + let x7 = I::load(xq_sp.add(big_n7)); + let x8 = I::load(xq_sp.add(big_n8)); + let x9 = I::load(xq_sp.add(big_n9)); + let xa = I::load(xq_sp.add(big_na)); + let xb = I::load(xq_sp.add(big_nb)); + let xc = I::load(xq_sp.add(big_nc)); + let xd = I::load(xq_sp.add(big_nd)); + let xe = I::load(xq_sp.add(big_ne)); + let xf = I::load(xq_sp.add(big_nf)); + + let a08 = I::add(x0, x8); + let s08 = I::sub(x0, x8); + let a4c = I::add(x4, xc); + let s4c = I::sub(x4, xc); + let a2a = I::add(x2, xa); + let s2a = I::sub(x2, xa); + let a6e = I::add(x6, xe); + let s6e = I::sub(x6, xe); + let a19 = I::add(x1, x9); + let s19 = I::sub(x1, x9); + let a5d = I::add(x5, xd); + let s5d = I::sub(x5, xd); + let a3b = I::add(x3, xb); + let s3b = I::sub(x3, xb); + let a7f = I::add(x7, xf); + let s7f = I::sub(x7, xf); + + let js4c = I::xpj(fwd, s4c); + let js6e = I::xpj(fwd, s6e); + let js5d = I::xpj(fwd, s5d); + let js7f = I::xpj(fwd, s7f); + + let a08p1a4c = I::add(a08, a4c); + let s08mjs4c = I::sub(s08, js4c); + let a08m1a4c = I::sub(a08, a4c); + let s08pjs4c = I::add(s08, js4c); + let a2ap1a6e = I::add(a2a, a6e); + let s2amjs6e = I::sub(s2a, js6e); + let a2am1a6e = I::sub(a2a, a6e); + let s2apjs6e = I::add(s2a, js6e); + let a19p1a5d = I::add(a19, a5d); + let s19mjs5d = I::sub(s19, js5d); + let a19m1a5d = I::sub(a19, a5d); + let s19pjs5d = I::add(s19, js5d); + let a3bp1a7f = I::add(a3b, a7f); + let s3bmjs7f = I::sub(s3b, js7f); + let a3bm1a7f = I::sub(a3b, a7f); + let s3bpjs7f = I::add(s3b, js7f); + + let w8_s2amjs6e = I::xw8(fwd, s2amjs6e); + let j_a2am1a6e = I::xpj(fwd, a2am1a6e); + let v8_s2apjs6e = I::xv8(fwd, s2apjs6e); + + let a08p1a4c_p1_a2ap1a6e = I::add(a08p1a4c, a2ap1a6e); + let s08mjs4c_pw_s2amjs6e = I::add(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_mj_a2am1a6e = I::sub(a08m1a4c, j_a2am1a6e); + let s08pjs4c_mv_s2apjs6e = I::sub(s08pjs4c, v8_s2apjs6e); + let a08p1a4c_m1_a2ap1a6e = I::sub(a08p1a4c, a2ap1a6e); + let s08mjs4c_mw_s2amjs6e = I::sub(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_pj_a2am1a6e = I::add(a08m1a4c, j_a2am1a6e); + let s08pjs4c_pv_s2apjs6e = I::add(s08pjs4c, v8_s2apjs6e); + + let w8_s3bmjs7f = I::xw8(fwd, s3bmjs7f); + let j_a3bm1a7f = I::xpj(fwd, a3bm1a7f); + let v8_s3bpjs7f = I::xv8(fwd, s3bpjs7f); + + let a19p1a5d_p1_a3bp1a7f = I::add(a19p1a5d, a3bp1a7f); + let s19mjs5d_pw_s3bmjs7f = I::add(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_mj_a3bm1a7f = I::sub(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_mv_s3bpjs7f = I::sub(s19pjs5d, v8_s3bpjs7f); + let a19p1a5d_m1_a3bp1a7f = I::sub(a19p1a5d, a3bp1a7f); + let s19mjs5d_mw_s3bmjs7f = I::sub(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_pj_a3bm1a7f = I::add(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_pv_s3bpjs7f = I::add(s19pjs5d, v8_s3bpjs7f); + + let h1_s19mjs5d_pw_s3bmjs7f = I::xh1(fwd, s19mjs5d_pw_s3bmjs7f); + let w8_a19m1a5d_mj_a3bm1a7f = I::xw8(fwd, a19m1a5d_mj_a3bm1a7f); + let h3_s19pjs5d_mv_s3bpjs7f = I::xh3(fwd, s19pjs5d_mv_s3bpjs7f); + let j_a19p1a5d_m1_a3bp1a7f = I::xpj(fwd, a19p1a5d_m1_a3bp1a7f); + let hd_s19mjs5d_mw_s3bmjs7f = I::xhd(fwd, s19mjs5d_mw_s3bmjs7f); + let v8_a19m1a5d_pj_a3bm1a7f = I::xv8(fwd, a19m1a5d_pj_a3bm1a7f); + let hf_s19pjs5d_pv_s3bpjs7f = I::xhf(fwd, s19pjs5d_pv_s3bpjs7f); + + I::store( + yq_s16p.add(0), + I::add(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f), + ); + I::store( + yq_s16p.add(s), + I::mul(w1p, I::add(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f)), + ); + I::store( + yq_s16p.add(s * 0x2), + I::mul(w2p, I::add(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f)), + ); + I::store( + yq_s16p.add(s * 0x3), + I::mul(w3p, I::add(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f)), + ); + I::store( + yq_s16p.add(s * 0x4), + I::mul(w4p, I::sub(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f)), + ); + I::store( + yq_s16p.add(s * 0x5), + I::mul(w5p, I::sub(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f)), + ); + I::store( + yq_s16p.add(s * 0x6), + I::mul(w6p, I::sub(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f)), + ); + I::store( + yq_s16p.add(s * 0x7), + I::mul(w7p, I::sub(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f)), + ); + + I::store( + yq_s16p.add(s * 0x8), + I::mul(w8p, I::sub(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f)), + ); + I::store( + yq_s16p.add(s * 0x9), + I::mul(w9p, I::sub(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f)), + ); + I::store( + yq_s16p.add(s * 0xa), + I::mul(wap, I::sub(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f)), + ); + I::store( + yq_s16p.add(s * 0xb), + I::mul(wbp, I::sub(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f)), + ); + I::store( + yq_s16p.add(s * 0xc), + I::mul(wcp, I::add(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f)), + ); + I::store( + yq_s16p.add(s * 0xd), + I::mul(wdp, I::add(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f)), + ); + I::store( + yq_s16p.add(s * 0xe), + I::mul(wep, I::add(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f)), + ); + I::store( + yq_s16p.add(s * 0xf), + I::mul(wfp, I::add(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f)), + ); + + q += I::COMPLEX_PER_REG; + } + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x2( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 16; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + let big_n8 = big_n1 * 8; + let big_n9 = big_n1 * 9; + let big_na = big_n1 * 10; + let big_nb = big_n1 * 11; + let big_nc = big_n1 * 12; + let big_nd = big_n1 * 13; + let big_ne = big_n1 * 14; + let big_nf = big_n1 * 15; + + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_16p = y.add(16 * p); + + let x0 = I::load(x_p.add(big_n0)); + let x1 = I::load(x_p.add(big_n1)); + let x2 = I::load(x_p.add(big_n2)); + let x3 = I::load(x_p.add(big_n3)); + let x4 = I::load(x_p.add(big_n4)); + let x5 = I::load(x_p.add(big_n5)); + let x6 = I::load(x_p.add(big_n6)); + let x7 = I::load(x_p.add(big_n7)); + let x8 = I::load(x_p.add(big_n8)); + let x9 = I::load(x_p.add(big_n9)); + let xa = I::load(x_p.add(big_na)); + let xb = I::load(x_p.add(big_nb)); + let xc = I::load(x_p.add(big_nc)); + let xd = I::load(x_p.add(big_nd)); + let xe = I::load(x_p.add(big_ne)); + let xf = I::load(x_p.add(big_nf)); + + let a08 = I::add(x0, x8); + let s08 = I::sub(x0, x8); + let a4c = I::add(x4, xc); + let s4c = I::sub(x4, xc); + let a2a = I::add(x2, xa); + let s2a = I::sub(x2, xa); + let a6e = I::add(x6, xe); + let s6e = I::sub(x6, xe); + let a19 = I::add(x1, x9); + let s19 = I::sub(x1, x9); + let a5d = I::add(x5, xd); + let s5d = I::sub(x5, xd); + let a3b = I::add(x3, xb); + let s3b = I::sub(x3, xb); + let a7f = I::add(x7, xf); + let s7f = I::sub(x7, xf); + + let js4c = I::xpj(fwd, s4c); + let js6e = I::xpj(fwd, s6e); + let js5d = I::xpj(fwd, s5d); + let js7f = I::xpj(fwd, s7f); + + let a08p1a4c = I::add(a08, a4c); + let s08mjs4c = I::sub(s08, js4c); + let a08m1a4c = I::sub(a08, a4c); + let s08pjs4c = I::add(s08, js4c); + let a2ap1a6e = I::add(a2a, a6e); + let s2amjs6e = I::sub(s2a, js6e); + let a2am1a6e = I::sub(a2a, a6e); + let s2apjs6e = I::add(s2a, js6e); + let a19p1a5d = I::add(a19, a5d); + let s19mjs5d = I::sub(s19, js5d); + let a19m1a5d = I::sub(a19, a5d); + let s19pjs5d = I::add(s19, js5d); + let a3bp1a7f = I::add(a3b, a7f); + let s3bmjs7f = I::sub(s3b, js7f); + let a3bm1a7f = I::sub(a3b, a7f); + let s3bpjs7f = I::add(s3b, js7f); + + let w8_s2amjs6e = I::xw8(fwd, s2amjs6e); + let j_a2am1a6e = I::xpj(fwd, a2am1a6e); + let v8_s2apjs6e = I::xv8(fwd, s2apjs6e); + + let a08p1a4c_p1_a2ap1a6e = I::add(a08p1a4c, a2ap1a6e); + let s08mjs4c_pw_s2amjs6e = I::add(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_mj_a2am1a6e = I::sub(a08m1a4c, j_a2am1a6e); + let s08pjs4c_mv_s2apjs6e = I::sub(s08pjs4c, v8_s2apjs6e); + let a08p1a4c_m1_a2ap1a6e = I::sub(a08p1a4c, a2ap1a6e); + let s08mjs4c_mw_s2amjs6e = I::sub(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_pj_a2am1a6e = I::add(a08m1a4c, j_a2am1a6e); + let s08pjs4c_pv_s2apjs6e = I::add(s08pjs4c, v8_s2apjs6e); + + let w8_s3bmjs7f = I::xw8(fwd, s3bmjs7f); + let j_a3bm1a7f = I::xpj(fwd, a3bm1a7f); + let v8_s3bpjs7f = I::xv8(fwd, s3bpjs7f); + + let a19p1a5d_p1_a3bp1a7f = I::add(a19p1a5d, a3bp1a7f); + let s19mjs5d_pw_s3bmjs7f = I::add(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_mj_a3bm1a7f = I::sub(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_mv_s3bpjs7f = I::sub(s19pjs5d, v8_s3bpjs7f); + let a19p1a5d_m1_a3bp1a7f = I::sub(a19p1a5d, a3bp1a7f); + let s19mjs5d_mw_s3bmjs7f = I::sub(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_pj_a3bm1a7f = I::add(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_pv_s3bpjs7f = I::add(s19pjs5d, v8_s3bpjs7f); + + let h1_s19mjs5d_pw_s3bmjs7f = I::xh1(fwd, s19mjs5d_pw_s3bmjs7f); + let w8_a19m1a5d_mj_a3bm1a7f = I::xw8(fwd, a19m1a5d_mj_a3bm1a7f); + let h3_s19pjs5d_mv_s3bpjs7f = I::xh3(fwd, s19pjs5d_mv_s3bpjs7f); + let j_a19p1a5d_m1_a3bp1a7f = I::xpj(fwd, a19p1a5d_m1_a3bp1a7f); + let hd_s19mjs5d_mw_s3bmjs7f = I::xhd(fwd, s19mjs5d_mw_s3bmjs7f); + let v8_a19m1a5d_pj_a3bm1a7f = I::xv8(fwd, a19m1a5d_pj_a3bm1a7f); + let hf_s19pjs5d_pv_s3bpjs7f = I::xhf(fwd, s19pjs5d_pv_s3bpjs7f); + + let w1p = I::load(twid(16, big_n, 1, w, p)); + let w2p = I::load(twid(16, big_n, 2, w, p)); + let w3p = I::load(twid(16, big_n, 3, w, p)); + let w4p = I::load(twid(16, big_n, 4, w, p)); + let w5p = I::load(twid(16, big_n, 5, w, p)); + let w6p = I::load(twid(16, big_n, 6, w, p)); + let w7p = I::load(twid(16, big_n, 7, w, p)); + let w8p = I::load(twid(16, big_n, 8, w, p)); + let w9p = I::load(twid(16, big_n, 9, w, p)); + let wap = I::load(twid(16, big_n, 10, w, p)); + let wbp = I::load(twid(16, big_n, 11, w, p)); + let wcp = I::load(twid(16, big_n, 12, w, p)); + let wdp = I::load(twid(16, big_n, 13, w, p)); + let wep = I::load(twid(16, big_n, 14, w, p)); + let wfp = I::load(twid(16, big_n, 15, w, p)); + + let aa = I::add(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f); + let bb = I::mul(w1p, I::add(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f)); + let cc = I::mul(w2p, I::add(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f)); + let dd = I::mul(w3p, I::add(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f)); + let ee = I::mul(w4p, I::sub(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f)); + let ff = I::mul(w5p, I::sub(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f)); + let gg = I::mul(w6p, I::sub(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f)); + let hh = I::mul(w7p, I::sub(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f)); + + let ii = I::mul(w8p, I::sub(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f)); + let jj = I::mul(w9p, I::sub(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f)); + let kk = I::mul(wap, I::sub(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f)); + let ll = I::mul(wbp, I::sub(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f)); + let mm = I::mul(wcp, I::add(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f)); + let nn = I::mul(wdp, I::add(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f)); + let oo = I::mul(wep, I::add(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f)); + let pp = I::mul(wfp, I::add(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f)); + + { + let ab = I::catlo(aa, bb); + I::store(y_16p.add(0x00), ab); + let cd = I::catlo(cc, dd); + I::store(y_16p.add(0x02), cd); + let ef = I::catlo(ee, ff); + I::store(y_16p.add(0x04), ef); + let gh = I::catlo(gg, hh); + I::store(y_16p.add(0x06), gh); + let ij = I::catlo(ii, jj); + I::store(y_16p.add(0x08), ij); + let kl = I::catlo(kk, ll); + I::store(y_16p.add(0x0a), kl); + let mn = I::catlo(mm, nn); + I::store(y_16p.add(0x0c), mn); + let op = I::catlo(oo, pp); + I::store(y_16p.add(0x0e), op); + } + { + let ab = I::cathi(aa, bb); + I::store(y_16p.add(0x10), ab); + let cd = I::cathi(cc, dd); + I::store(y_16p.add(0x12), cd); + let ef = I::cathi(ee, ff); + I::store(y_16p.add(0x14), ef); + let gh = I::cathi(gg, hh); + I::store(y_16p.add(0x16), gh); + let ij = I::cathi(ii, jj); + I::store(y_16p.add(0x18), ij); + let kl = I::cathi(kk, ll); + I::store(y_16p.add(0x1a), kl); + let mn = I::cathi(mm, nn); + I::store(y_16p.add(0x1c), mn); + let op = I::cathi(oo, pp); + I::store(y_16p.add(0x1e), op); + } + + p += 2; + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x4( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + if n == 32 { + return core_x2::(fwd, n, s, x, y, w); + } + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 16; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + let big_n8 = big_n1 * 8; + let big_n9 = big_n1 * 9; + let big_na = big_n1 * 10; + let big_nb = big_n1 * 11; + let big_nc = big_n1 * 12; + let big_nd = big_n1 * 13; + let big_ne = big_n1 * 14; + let big_nf = big_n1 * 15; + + debug_assert_eq!(big_n1 % 4, 0); + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_16p = y.add(16 * p); + + let x0 = I4::load(x_p.add(big_n0)); + let x1 = I4::load(x_p.add(big_n1)); + let x2 = I4::load(x_p.add(big_n2)); + let x3 = I4::load(x_p.add(big_n3)); + let x4 = I4::load(x_p.add(big_n4)); + let x5 = I4::load(x_p.add(big_n5)); + let x6 = I4::load(x_p.add(big_n6)); + let x7 = I4::load(x_p.add(big_n7)); + let x8 = I4::load(x_p.add(big_n8)); + let x9 = I4::load(x_p.add(big_n9)); + let xa = I4::load(x_p.add(big_na)); + let xb = I4::load(x_p.add(big_nb)); + let xc = I4::load(x_p.add(big_nc)); + let xd = I4::load(x_p.add(big_nd)); + let xe = I4::load(x_p.add(big_ne)); + let xf = I4::load(x_p.add(big_nf)); + + let a08 = I4::add(x0, x8); + let s08 = I4::sub(x0, x8); + let a4c = I4::add(x4, xc); + let s4c = I4::sub(x4, xc); + let a2a = I4::add(x2, xa); + let s2a = I4::sub(x2, xa); + let a6e = I4::add(x6, xe); + let s6e = I4::sub(x6, xe); + let a19 = I4::add(x1, x9); + let s19 = I4::sub(x1, x9); + let a5d = I4::add(x5, xd); + let s5d = I4::sub(x5, xd); + let a3b = I4::add(x3, xb); + let s3b = I4::sub(x3, xb); + let a7f = I4::add(x7, xf); + let s7f = I4::sub(x7, xf); + + let js4c = I4::xpj(fwd, s4c); + let js6e = I4::xpj(fwd, s6e); + let js5d = I4::xpj(fwd, s5d); + let js7f = I4::xpj(fwd, s7f); + + let a08p1a4c = I4::add(a08, a4c); + let s08mjs4c = I4::sub(s08, js4c); + let a08m1a4c = I4::sub(a08, a4c); + let s08pjs4c = I4::add(s08, js4c); + let a2ap1a6e = I4::add(a2a, a6e); + let s2amjs6e = I4::sub(s2a, js6e); + let a2am1a6e = I4::sub(a2a, a6e); + let s2apjs6e = I4::add(s2a, js6e); + let a19p1a5d = I4::add(a19, a5d); + let s19mjs5d = I4::sub(s19, js5d); + let a19m1a5d = I4::sub(a19, a5d); + let s19pjs5d = I4::add(s19, js5d); + let a3bp1a7f = I4::add(a3b, a7f); + let s3bmjs7f = I4::sub(s3b, js7f); + let a3bm1a7f = I4::sub(a3b, a7f); + let s3bpjs7f = I4::add(s3b, js7f); + + let w8_s2amjs6e = I4::xw8(fwd, s2amjs6e); + let j_a2am1a6e = I4::xpj(fwd, a2am1a6e); + let v8_s2apjs6e = I4::xv8(fwd, s2apjs6e); + + let a08p1a4c_p1_a2ap1a6e = I4::add(a08p1a4c, a2ap1a6e); + let s08mjs4c_pw_s2amjs6e = I4::add(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_mj_a2am1a6e = I4::sub(a08m1a4c, j_a2am1a6e); + let s08pjs4c_mv_s2apjs6e = I4::sub(s08pjs4c, v8_s2apjs6e); + let a08p1a4c_m1_a2ap1a6e = I4::sub(a08p1a4c, a2ap1a6e); + let s08mjs4c_mw_s2amjs6e = I4::sub(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_pj_a2am1a6e = I4::add(a08m1a4c, j_a2am1a6e); + let s08pjs4c_pv_s2apjs6e = I4::add(s08pjs4c, v8_s2apjs6e); + + let w8_s3bmjs7f = I4::xw8(fwd, s3bmjs7f); + let j_a3bm1a7f = I4::xpj(fwd, a3bm1a7f); + let v8_s3bpjs7f = I4::xv8(fwd, s3bpjs7f); + + let a19p1a5d_p1_a3bp1a7f = I4::add(a19p1a5d, a3bp1a7f); + let s19mjs5d_pw_s3bmjs7f = I4::add(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_mj_a3bm1a7f = I4::sub(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_mv_s3bpjs7f = I4::sub(s19pjs5d, v8_s3bpjs7f); + let a19p1a5d_m1_a3bp1a7f = I4::sub(a19p1a5d, a3bp1a7f); + let s19mjs5d_mw_s3bmjs7f = I4::sub(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_pj_a3bm1a7f = I4::add(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_pv_s3bpjs7f = I4::add(s19pjs5d, v8_s3bpjs7f); + + let h1_s19mjs5d_pw_s3bmjs7f = I4::xh1(fwd, s19mjs5d_pw_s3bmjs7f); + let w8_a19m1a5d_mj_a3bm1a7f = I4::xw8(fwd, a19m1a5d_mj_a3bm1a7f); + let h3_s19pjs5d_mv_s3bpjs7f = I4::xh3(fwd, s19pjs5d_mv_s3bpjs7f); + let j_a19p1a5d_m1_a3bp1a7f = I4::xpj(fwd, a19p1a5d_m1_a3bp1a7f); + let hd_s19mjs5d_mw_s3bmjs7f = I4::xhd(fwd, s19mjs5d_mw_s3bmjs7f); + let v8_a19m1a5d_pj_a3bm1a7f = I4::xv8(fwd, a19m1a5d_pj_a3bm1a7f); + let hf_s19pjs5d_pv_s3bpjs7f = I4::xhf(fwd, s19pjs5d_pv_s3bpjs7f); + + let w1p = I4::load(twid(16, big_n, 1, w, p)); + let w2p = I4::load(twid(16, big_n, 2, w, p)); + let w3p = I4::load(twid(16, big_n, 3, w, p)); + let w4p = I4::load(twid(16, big_n, 4, w, p)); + let w5p = I4::load(twid(16, big_n, 5, w, p)); + let w6p = I4::load(twid(16, big_n, 6, w, p)); + let w7p = I4::load(twid(16, big_n, 7, w, p)); + let w8p = I4::load(twid(16, big_n, 8, w, p)); + let w9p = I4::load(twid(16, big_n, 9, w, p)); + let wap = I4::load(twid(16, big_n, 10, w, p)); + let wbp = I4::load(twid(16, big_n, 11, w, p)); + let wcp = I4::load(twid(16, big_n, 12, w, p)); + let wdp = I4::load(twid(16, big_n, 13, w, p)); + let wep = I4::load(twid(16, big_n, 14, w, p)); + let wfp = I4::load(twid(16, big_n, 15, w, p)); + + let a_ = I4::add(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f); + let b_ = I4::mul(w1p, I4::add(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f)); + let c_ = I4::mul(w2p, I4::add(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f)); + let d_ = I4::mul(w3p, I4::add(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f)); + let e_ = I4::mul(w4p, I4::sub(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f)); + let f_ = I4::mul(w5p, I4::sub(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f)); + let g_ = I4::mul(w6p, I4::sub(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f)); + let h_ = I4::mul(w7p, I4::sub(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f)); + + let i_ = I4::mul(w8p, I4::sub(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f)); + let j_ = I4::mul(w9p, I4::sub(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f)); + let k_ = I4::mul(wap, I4::sub(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f)); + let l_ = I4::mul(wbp, I4::sub(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f)); + let m_ = I4::mul(wcp, I4::add(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f)); + let n_ = I4::mul(wdp, I4::add(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f)); + let o_ = I4::mul(wep, I4::add(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f)); + let p_ = I4::mul(wfp, I4::add(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f)); + + let (abcd0, abcd1, abcd2, abcd3) = I4::transpose(a_, b_, c_, d_); + let (efgh0, efgh1, efgh2, efgh3) = I4::transpose(e_, f_, g_, h_); + let (ijkl0, ijkl1, ijkl2, ijkl3) = I4::transpose(i_, j_, k_, l_); + let (mnop0, mnop1, mnop2, mnop3) = I4::transpose(m_, n_, o_, p_); + + I4::store(y_16p.add(0x00), abcd0); + I4::store(y_16p.add(0x04), efgh0); + I4::store(y_16p.add(0x08), ijkl0); + I4::store(y_16p.add(0x0c), mnop0); + + I4::store(y_16p.add(0x10), abcd1); + I4::store(y_16p.add(0x14), efgh1); + I4::store(y_16p.add(0x18), ijkl1); + I4::store(y_16p.add(0x1c), mnop1); + + I4::store(y_16p.add(0x20), abcd2); + I4::store(y_16p.add(0x24), efgh2); + I4::store(y_16p.add(0x28), ijkl2); + I4::store(y_16p.add(0x2c), mnop2); + + I4::store(y_16p.add(0x30), abcd3); + I4::store(y_16p.add(0x34), efgh3); + I4::store(y_16p.add(0x38), ijkl3); + I4::store(y_16p.add(0x3c), mnop3); + + p += 4; + } +} + +#[inline(always)] +pub(crate) unsafe fn end16( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + eo: bool, +) { + debug_assert_eq!(n, 16); + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let z = if eo { y } else { x }; + + let mut q = 0; + while q < s { + let xq = x.add(q); + let zq = z.add(q); + + let x0 = I::load(xq.add(s * 0x0)); + let x1 = I::load(xq.add(s * 0x1)); + let x2 = I::load(xq.add(s * 0x2)); + let x3 = I::load(xq.add(s * 0x3)); + let x4 = I::load(xq.add(s * 0x4)); + let x5 = I::load(xq.add(s * 0x5)); + let x6 = I::load(xq.add(s * 0x6)); + let x7 = I::load(xq.add(s * 0x7)); + let x8 = I::load(xq.add(s * 0x8)); + let x9 = I::load(xq.add(s * 0x9)); + let xa = I::load(xq.add(s * 0xa)); + let xb = I::load(xq.add(s * 0xb)); + let xc = I::load(xq.add(s * 0xc)); + let xd = I::load(xq.add(s * 0xd)); + let xe = I::load(xq.add(s * 0xe)); + let xf = I::load(xq.add(s * 0xf)); + + let a08 = I::add(x0, x8); + let s08 = I::sub(x0, x8); + let a4c = I::add(x4, xc); + let s4c = I::sub(x4, xc); + let a2a = I::add(x2, xa); + let s2a = I::sub(x2, xa); + let a6e = I::add(x6, xe); + let s6e = I::sub(x6, xe); + let a19 = I::add(x1, x9); + let s19 = I::sub(x1, x9); + let a5d = I::add(x5, xd); + let s5d = I::sub(x5, xd); + let a3b = I::add(x3, xb); + let s3b = I::sub(x3, xb); + let a7f = I::add(x7, xf); + let s7f = I::sub(x7, xf); + + let js4c = I::xpj(fwd, s4c); + let js6e = I::xpj(fwd, s6e); + let js5d = I::xpj(fwd, s5d); + let js7f = I::xpj(fwd, s7f); + + let a08p1a4c = I::add(a08, a4c); + let s08mjs4c = I::sub(s08, js4c); + let a08m1a4c = I::sub(a08, a4c); + let s08pjs4c = I::add(s08, js4c); + let a2ap1a6e = I::add(a2a, a6e); + let s2amjs6e = I::sub(s2a, js6e); + let a2am1a6e = I::sub(a2a, a6e); + let s2apjs6e = I::add(s2a, js6e); + let a19p1a5d = I::add(a19, a5d); + let s19mjs5d = I::sub(s19, js5d); + let a19m1a5d = I::sub(a19, a5d); + let s19pjs5d = I::add(s19, js5d); + let a3bp1a7f = I::add(a3b, a7f); + let s3bmjs7f = I::sub(s3b, js7f); + let a3bm1a7f = I::sub(a3b, a7f); + let s3bpjs7f = I::add(s3b, js7f); + + let w8_s2amjs6e = I::xw8(fwd, s2amjs6e); + let j_a2am1a6e = I::xpj(fwd, a2am1a6e); + let v8_s2apjs6e = I::xv8(fwd, s2apjs6e); + + let a08p1a4c_p1_a2ap1a6e = I::add(a08p1a4c, a2ap1a6e); + let s08mjs4c_pw_s2amjs6e = I::add(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_mj_a2am1a6e = I::sub(a08m1a4c, j_a2am1a6e); + let s08pjs4c_mv_s2apjs6e = I::sub(s08pjs4c, v8_s2apjs6e); + let a08p1a4c_m1_a2ap1a6e = I::sub(a08p1a4c, a2ap1a6e); + let s08mjs4c_mw_s2amjs6e = I::sub(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_pj_a2am1a6e = I::add(a08m1a4c, j_a2am1a6e); + let s08pjs4c_pv_s2apjs6e = I::add(s08pjs4c, v8_s2apjs6e); + + let w8_s3bmjs7f = I::xw8(fwd, s3bmjs7f); + let j_a3bm1a7f = I::xpj(fwd, a3bm1a7f); + let v8_s3bpjs7f = I::xv8(fwd, s3bpjs7f); + + let a19p1a5d_p1_a3bp1a7f = I::add(a19p1a5d, a3bp1a7f); + let s19mjs5d_pw_s3bmjs7f = I::add(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_mj_a3bm1a7f = I::sub(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_mv_s3bpjs7f = I::sub(s19pjs5d, v8_s3bpjs7f); + let a19p1a5d_m1_a3bp1a7f = I::sub(a19p1a5d, a3bp1a7f); + let s19mjs5d_mw_s3bmjs7f = I::sub(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_pj_a3bm1a7f = I::add(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_pv_s3bpjs7f = I::add(s19pjs5d, v8_s3bpjs7f); + + let h1_s19mjs5d_pw_s3bmjs7f = I::xh1(fwd, s19mjs5d_pw_s3bmjs7f); + let w8_a19m1a5d_mj_a3bm1a7f = I::xw8(fwd, a19m1a5d_mj_a3bm1a7f); + let h3_s19pjs5d_mv_s3bpjs7f = I::xh3(fwd, s19pjs5d_mv_s3bpjs7f); + let j_a19p1a5d_m1_a3bp1a7f = I::xpj(fwd, a19p1a5d_m1_a3bp1a7f); + let hd_s19mjs5d_mw_s3bmjs7f = I::xhd(fwd, s19mjs5d_mw_s3bmjs7f); + let v8_a19m1a5d_pj_a3bm1a7f = I::xv8(fwd, a19m1a5d_pj_a3bm1a7f); + let hf_s19pjs5d_pv_s3bpjs7f = I::xhf(fwd, s19pjs5d_pv_s3bpjs7f); + + I::store( + zq.add(0), + I::add(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f), + ); + I::store( + zq.add(s), + I::add(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f), + ); + I::store( + zq.add(s * 0x2), + I::add(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f), + ); + I::store( + zq.add(s * 0x3), + I::add(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f), + ); + I::store( + zq.add(s * 0x4), + I::sub(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f), + ); + I::store( + zq.add(s * 0x5), + I::sub(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f), + ); + I::store( + zq.add(s * 0x6), + I::sub(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f), + ); + I::store( + zq.add(s * 0x7), + I::sub(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f), + ); + + I::store( + zq.add(s * 0x8), + I::sub(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f), + ); + I::store( + zq.add(s * 0x9), + I::sub(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f), + ); + I::store( + zq.add(s * 0xa), + I::sub(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f), + ); + I::store( + zq.add(s * 0xb), + I::sub(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f), + ); + I::store( + zq.add(s * 0xc), + I::add(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f), + ); + I::store( + zq.add(s * 0xd), + I::add(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f), + ); + I::store( + zq.add(s * 0xe), + I::add(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f), + ); + I::store( + zq.add(s * 0xf), + I::add(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f), + ); + + q += I::COMPLEX_PER_REG; + } +} + +macro_rules! dif16_impl { + ( + $( + $(#[$attr: meta])* + pub static $fft: ident = Fft { + core_1: $core1______: expr, + native: $xn: ty, + x1: $x1: ty, + $(target: $target: tt,)? + }; + )* + ) => { + $( + #[allow(missing_copy_implementations)] + #[allow(non_camel_case_types)] + #[allow(dead_code)] + $(#[$attr])* + struct $fft { + __private: (), + } + #[allow(unused_variables)] + #[allow(dead_code)] + $(#[$attr])* + impl $fft { + $(#[target_feature(enable = $target)])? + unsafe fn fft_00(x: *mut c64, y: *mut c64, w: *const c64) {} + $(#[target_feature(enable = $target)])? + unsafe fn fft_01(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$x1>(FWD, 1 << 1, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_02(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$x1>(FWD, 1 << 2, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_03(x: *mut c64, y: *mut c64, w: *const c64) { + end_8::<$x1>(FWD, 1 << 3, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_04(x: *mut c64, y: *mut c64, w: *const c64) { + end16::<$x1>(FWD, 1 << 4, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_05(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 5, 1 << 0, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 4, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_06(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 6, 1 << 0, x, y, w); + end_4::<$xn>(FWD, 1 << 2, 1 << 4, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_07(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 7, 1 << 0, x, y, w); + end_8::<$xn>(FWD, 1 << 3, 1 << 4, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_08(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 8, 1 << 0, x, y, w); + end16::<$xn>(FWD, 1 << 4, 1 << 4, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_09(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 9, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 4, y, x, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 8, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_10(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 10, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 06, 1 << 4, y, x, w); + end_4::<$xn>(FWD, 1 << 02, 1 << 8, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_11(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 11, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 07, 1 << 04, y, x, w); + end_8::<$xn>(FWD, 1 << 03, 1 << 08, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_12(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 12, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 08, 1 << 04, y, x, w); + end16::<$xn>(FWD, 1 << 04, 1 << 08, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_13(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 13, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 09, 1 << 04, y, x, w); + core_::<$xn>(FWD, 1 << 05, 1 << 08, x, y, w); + end_2::<$xn>(FWD, 1 << 01, 1 << 12, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_14(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 14, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 04, y, x, w); + core_::<$xn>(FWD, 1 << 06, 1 << 08, x, y, w); + end_4::<$xn>(FWD, 1 << 02, 1 << 12, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_15(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 15, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 04, y, x, w); + core_::<$xn>(FWD, 1 << 07, 1 << 08, x, y, w); + end_8::<$xn>(FWD, 1 << 03, 1 << 12, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_16(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 16, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 12, 1 << 04, y, x, w); + core_::<$xn>(FWD, 1 << 08, 1 << 08, x, y, w); + end16::<$xn>(FWD, 1 << 04, 1 << 12, y, x, true); + } + } + $(#[$attr])* + pub(crate) static $fft: crate::FftImpl = crate::FftImpl { + fwd: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + inv: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + }; + )* + }; +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::x86::*; + +dif16_impl! { + pub static DIF16_SCALAR = Fft { + core_1: core_::, + native: Scalar, + x1: Scalar, + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIF16_AVX = Fft { + core_1: core_x2::, + native: AvxX2, + x1: AvxX1, + target: "avx", + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIF16_FMA = Fft { + core_1: core_x2::, + native: FmaX2, + x1: FmaX1, + target: "fma", + }; + + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + pub static DIF16_AVX512 = Fft { + core_1: core_x4::, + native: Avx512X4, + x1: Avx512X1, + target: "avx512f", + }; +} + +pub(crate) fn runtime_fft() -> crate::FftImpl { + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + if x86_feature_detected!("avx512f") { + return DIF16_AVX512; + } + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if x86_feature_detected!("fma") { + return DIF16_FMA; + } else if x86_feature_detected!("avx") { + return DIF16_AVX; + } + + DIF16_SCALAR +} diff --git a/src/dif2.rs b/src/dif2.rs new file mode 100644 index 0000000..92bab99 --- /dev/null +++ b/src/dif2.rs @@ -0,0 +1,424 @@ +// Copyright (c) 2019 OK Ojisan(Takuya OKAHISA) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +use crate::c64; +use crate::fft_simd::{twid, twid_t, FftSimd64, FftSimd64X2, Scalar}; +use crate::x86_feature_detected; + +#[inline(always)] +unsafe fn core_( + _fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let m = n / 2; + let big_n = n * s; + let big_n0 = 0; + let big_n1 = big_n / 2; + + for p in 0..m { + let sp = s * p; + let s2p = 2 * sp; + let w1p = I::splat(twid_t(2, big_n, 1, w, sp)); + + let mut q = 0; + while q < s { + let xq_sp = x.add(q + sp); + let yq_s2p = y.add(q + s2p); + + let a = I::load(xq_sp.add(big_n0)); + let b = I::load(xq_sp.add(big_n1)); + + I::store(yq_s2p.add(s * 0), I::add(a, b)); + I::store(yq_s2p.add(s * 1), I::mul(w1p, I::sub(a, b))); + + q += I::COMPLEX_PER_REG; + } + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x2( + _fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 2; + + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_2p = y.add(2 * p); + + let a = I::load(x_p.add(big_n0)); + let b = I::load(x_p.add(big_n1)); + + let w1p = I::load(twid(2, big_n, 1, w, p)); + + let aa = I::add(a, b); + let bb = I::mul(w1p, I::sub(a, b)); + + { + let ab = I::catlo(aa, bb); + I::store(y_2p.add(0), ab); + } + { + let ab = I::cathi(aa, bb); + I::store(y_2p.add(2), ab); + } + + p += 2; + } +} + +#[inline(always)] +pub unsafe fn end_2( + _fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + eo: bool, +) { + debug_assert_eq!(n, 2); + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + let z = if eo { y } else { x }; + + let mut q = 0; + while q < s { + let xq = x.add(q); + let zq = z.add(q); + + let a = I::load(xq.add(0)); + let b = I::load(xq.add(s)); + + I::store(zq.add(0), I::add(a, b)); + I::store(zq.add(s), I::sub(a, b)); + + q += I::COMPLEX_PER_REG; + } +} + +macro_rules! dif2_impl { + ( + $( + $(#[$attr: meta])* + pub static $fft: ident = Fft { + core_1: $core1______: expr, + native: $xn: ty, + x1: $x1: ty, + $(target: $target: tt,)? + }; + )* + ) => { + $( + #[allow(missing_copy_implementations)] + #[allow(non_camel_case_types)] + #[allow(dead_code)] + $(#[$attr])* + struct $fft { + __private: (), + } + #[allow(unused_variables)] + #[allow(dead_code)] + $(#[$attr])* + impl $fft { + $(#[target_feature(enable = $target)])? + unsafe fn fft_00(x: *mut c64, y: *mut c64, w: *const c64) {} + $(#[target_feature(enable = $target)])? + unsafe fn fft_01(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$x1>(FWD, 1 << 1, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_02(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 2, 1 << 0, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 1, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_03(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 3, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 2, 1 << 1, y, x, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 2, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_04(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 4, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 1, y, x, w); + core_::<$xn>(FWD, 1 << 2, 1 << 2, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 3, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_05(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 5, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 1, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 2, 1 << 3, y, x, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 4, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_06(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 6, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 1, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 2, 1 << 4, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 5, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_07(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 7, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 1, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 2, 1 << 5, y, x, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 6, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_08(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 8, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 1, y, x, w); + core_::<$xn>(FWD, 1 << 6, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 5, y, x, w); + core_::<$xn>(FWD, 1 << 2, 1 << 6, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 7, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_09(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 9, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 8, 1 << 1, y, x, w); + core_::<$xn>(FWD, 1 << 7, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 5, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 6, x, y, w); + core_::<$xn>(FWD, 1 << 2, 1 << 7, y, x, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 8, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_10(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 10, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 9, 1 << 1, y, x, w); + core_::<$xn>(FWD, 1 << 8, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 6, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 5, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 6, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 7, y, x, w); + core_::<$xn>(FWD, 1 << 2, 1 << 8, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 9, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_11(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 11, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 01, y, x, w); + core_::<$xn>(FWD, 1 << 9, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 8, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 7, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 2, 1 << 09, y, x, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 10, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_12(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 12, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 01, y, x, w); + core_::<$xn>(FWD, 1 << 10, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 9, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 8, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 6, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 2, 1 << 10, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 11, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_13(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 13, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 12, 1 << 01, y, x, w); + core_::<$xn>(FWD, 1 << 11, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 9, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 8, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 7, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 10, x, y, w); + core_::<$xn>(FWD, 1 << 2, 1 << 11, y, x, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 12, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_14(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 14, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 13, 1 << 01, y, x, w); + core_::<$xn>(FWD, 1 << 12, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 10, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 9, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 8, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 6, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 10, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 11, y, x, w); + core_::<$xn>(FWD, 1 << 2, 1 << 12, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 13, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_15(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 15, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 14, 1 << 01, y, x, w); + core_::<$xn>(FWD, 1 << 13, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 12, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 11, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 9, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 8, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 7, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 10, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 11, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 12, x, y, w); + core_::<$xn>(FWD, 1 << 2, 1 << 13, y, x, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 14, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_16(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 16, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 15, 1 << 01, y, x, w); + core_::<$xn>(FWD, 1 << 14, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 13, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 12, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 10, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 9, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 8, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 6, 1 << 10, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 11, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 12, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 13, y, x, w); + core_::<$xn>(FWD, 1 << 2, 1 << 14, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 15, y, x, true); + } + } + $(#[$attr])* + pub(crate) static $fft: crate::FftImpl = crate::FftImpl { + fwd: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + inv: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + }; + )* + }; +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::x86::*; + +dif2_impl! { + pub static DIF2_SCALAR = Fft { + core_1: core_::, + native: Scalar, + x1: Scalar, + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIF2_AVX = Fft { + core_1: core_x2::, + native: AvxX2, + x1: AvxX1, + target: "avx", + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIF2_FMA = Fft { + core_1: core_x2::, + native: FmaX2, + x1: FmaX1, + target: "fma", + }; +} + +pub(crate) fn runtime_fft() -> crate::FftImpl { + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if x86_feature_detected!("fma") { + return DIF2_FMA; + } else if x86_feature_detected!("avx") { + return DIF2_AVX; + } + + DIF2_SCALAR +} diff --git a/src/dif4.rs b/src/dif4.rs new file mode 100644 index 0000000..47d4d63 --- /dev/null +++ b/src/dif4.rs @@ -0,0 +1,470 @@ +// Copyright (c) 2019 OK Ojisan(Takuya OKAHISA) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +use crate::c64; +use crate::dif2::end_2; +use crate::fft_simd::{twid, twid_t, FftSimd64, FftSimd64Ext, FftSimd64X2, FftSimd64X4, Scalar}; +use crate::x86_feature_detected; + +#[inline(always)] +unsafe fn core_( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let m = n / 4; + let big_n = n * s; + let big_n0 = 0; + let big_n1 = big_n / 4; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + + for p in 0..m { + let sp = s * p; + let s4p = 4 * sp; + let w1p = I::splat(twid_t(4, big_n, 1, w, sp)); + let w2p = I::splat(twid_t(4, big_n, 2, w, sp)); + let w3p = I::splat(twid_t(4, big_n, 3, w, sp)); + + let mut q = 0; + while q < s { + let xq_sp = x.add(q + sp); + let yq_s4p = y.add(q + s4p); + + let a = I::load(xq_sp.add(big_n0)); + let c = I::load(xq_sp.add(big_n2)); + let apc = I::add(a, c); + let amc = I::sub(a, c); + + let b = I::load(xq_sp.add(big_n1)); + let d = I::load(xq_sp.add(big_n3)); + let bpd = I::add(b, d); + let jbmd = I::xpj(fwd, I::sub(b, d)); + + I::store(yq_s4p.add(s * 0), I::add(apc, bpd)); + I::store(yq_s4p.add(s * 1), I::mul(w1p, I::sub(amc, jbmd))); + I::store(yq_s4p.add(s * 2), I::mul(w2p, I::sub(apc, bpd))); + I::store(yq_s4p.add(s * 3), I::mul(w3p, I::add(amc, jbmd))); + + q += I::COMPLEX_PER_REG; + } + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x2( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 4; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + + debug_assert_eq!(big_n1 % 2, 0); + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_4p = y.add(4 * p); + + let a = I::load(x_p.add(big_n0)); + let c = I::load(x_p.add(big_n2)); + let apc = I::add(a, c); + let amc = I::sub(a, c); + + let b = I::load(x_p.add(big_n1)); + let d = I::load(x_p.add(big_n3)); + let bpd = I::add(b, d); + let jbmd = I::xpj(fwd, I::sub(b, d)); + + let w1p = I::load(twid(4, big_n, 1, w, p)); + let w2p = I::load(twid(4, big_n, 2, w, p)); + let w3p = I::load(twid(4, big_n, 3, w, p)); + + let aa = I::add(apc, bpd); + let bb = I::mul(w1p, I::sub(amc, jbmd)); + let cc = I::mul(w2p, I::sub(apc, bpd)); + let dd = I::mul(w3p, I::add(amc, jbmd)); + + { + let ab = I::catlo(aa, bb); + I::store(y_4p.add(0), ab); + let cd = I::catlo(cc, dd); + I::store(y_4p.add(2), cd); + } + { + let ab = I::cathi(aa, bb); + I::store(y_4p.add(4), ab); + let cd = I::cathi(cc, dd); + I::store(y_4p.add(6), cd); + } + + p += 2; + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x4( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + if n == 8 { + return core_x2::(fwd, n, s, x, y, w); + } + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 4; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + + debug_assert_eq!(big_n1 % 4, 0); + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_4p = y.add(4 * p); + + let a = I4::load(x_p.add(big_n0)); + let c = I4::load(x_p.add(big_n2)); + let apc = I4::add(a, c); + let amc = I4::sub(a, c); + + let b = I4::load(x_p.add(big_n1)); + let d = I4::load(x_p.add(big_n3)); + let bpd = I4::add(b, d); + let jbmd = I4::xpj(fwd, I4::sub(b, d)); + + let w1p = I4::load(twid(4, big_n, 1, w, p)); + let w2p = I4::load(twid(4, big_n, 2, w, p)); + let w3p = I4::load(twid(4, big_n, 3, w, p)); + + let aaaa = I4::add(apc, bpd); + let bbbb = I4::mul(w1p, I4::sub(amc, jbmd)); + let cccc = I4::mul(w2p, I4::sub(apc, bpd)); + let dddd = I4::mul(w3p, I4::add(amc, jbmd)); + + let (abcd0, abcd1, abcd2, abcd3) = I4::transpose(aaaa, bbbb, cccc, dddd); + I4::store(y_4p.add(0), abcd0); + I4::store(y_4p.add(4), abcd1); + I4::store(y_4p.add(8), abcd2); + I4::store(y_4p.add(12), abcd3); + + p += 4; + } +} + +#[inline(always)] +pub unsafe fn end_4( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + eo: bool, +) { + debug_assert_eq!(n, 4); + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + let z = if eo { y } else { x }; + + let mut q = 0; + while q < s { + let xq = x.add(q); + let zq = z.add(q); + + let a = I::load(xq.add(0)); + let b = I::load(xq.add(s)); + let c = I::load(xq.add(s * 2)); + let d = I::load(xq.add(s * 3)); + + let apc = I::add(a, c); + let amc = I::sub(a, c); + let bpd = I::add(b, d); + let jbmd = I::xpj(fwd, I::sub(b, d)); + + I::store(zq.add(s * 0), I::add(apc, bpd)); + I::store(zq.add(s * 1), I::sub(amc, jbmd)); + I::store(zq.add(s * 2), I::sub(apc, bpd)); + I::store(zq.add(s * 3), I::add(amc, jbmd)); + + q += I::COMPLEX_PER_REG; + } +} + +macro_rules! dif4_impl { + ( + $( + $(#[$attr: meta])* + pub static $fft: ident = Fft { + core_1: $core1______: expr, + native: $xn: ty, + x1: $x1: ty, + $(target: $target: tt,)? + }; + )* + ) => { + $( + #[allow(missing_copy_implementations)] + #[allow(non_camel_case_types)] + #[allow(dead_code)] + $(#[$attr])* + struct $fft { + __private: (), + } + #[allow(unused_variables)] + #[allow(dead_code)] + $(#[$attr])* + impl $fft { + $(#[target_feature(enable = $target)])? + unsafe fn fft_00(x: *mut c64, y: *mut c64, w: *const c64) {} + $(#[target_feature(enable = $target)])? + unsafe fn fft_01(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$x1>(FWD, 1 << 1, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_02(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$x1>(FWD, 1 << 2, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_03(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 3, 1 << 0, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 2, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_04(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 4, 1 << 0, x, y, w); + end_4::<$xn>(FWD, 1 << 2, 1 << 2, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_05(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 5, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 2, y, x, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 4, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_06(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 6, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 2, y, x, w); + end_4::<$xn>(FWD, 1 << 2, 1 << 4, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_07(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 7, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 2, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 4, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 6, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_08(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 8, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 2, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 4, x, y, w); + end_4::<$xn>(FWD, 1 << 2, 1 << 6, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_09(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 9, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 2, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 6, y, x, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 8, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_10(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 10, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 08, 1 << 2, y, x, w); + core_::<$xn>(FWD, 1 << 06, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 04, 1 << 6, y, x, w); + end_4::<$xn>(FWD, 1 << 02, 1 << 8, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_11(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 11, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 09, 1 << 02, y, x, w); + core_::<$xn>(FWD, 1 << 07, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 05, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 03, 1 << 08, x, y, w); + end_2::<$xn>(FWD, 1 << 01, 1 << 10, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_12(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 12, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 02, y, x, w); + core_::<$xn>(FWD, 1 << 08, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 06, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 04, 1 << 08, x, y, w); + end_4::<$xn>(FWD, 1 << 02, 1 << 10, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_13(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 13, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 02, y, x, w); + core_::<$xn>(FWD, 1 << 09, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 07, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 05, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 03, 1 << 10, y, x, w); + end_2::<$xn>(FWD, 1 << 01, 1 << 12, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_14(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 14, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 12, 1 << 02, y, x, w); + core_::<$xn>(FWD, 1 << 10, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 08, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 06, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 04, 1 << 10, y, x, w); + end_4::<$xn>(FWD, 1 << 02, 1 << 12, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_15(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 15, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 13, 1 << 02, y, x, w); + core_::<$xn>(FWD, 1 << 11, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 09, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 07, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 05, 1 << 10, y, x, w); + core_::<$xn>(FWD, 1 << 03, 1 << 12, x, y, w); + end_2::<$xn>(FWD, 1 << 01, 1 << 14, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_16(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 16, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 14, 1 << 02, y, x, w); + core_::<$xn>(FWD, 1 << 12, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 08, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 06, 1 << 10, y, x, w); + core_::<$xn>(FWD, 1 << 04, 1 << 12, x, y, w); + end_4::<$xn>(FWD, 1 << 02, 1 << 14, y, x, true); + } + } + $(#[$attr])* + pub(crate) static $fft: crate::FftImpl = crate::FftImpl { + fwd: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + inv: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + }; + )* + }; +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::x86::*; + +dif4_impl! { + pub static DIF4_SCALAR = Fft { + core_1: core_::, + native: Scalar, + x1: Scalar, + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIF4_AVX = Fft { + core_1: core_x2::, + native: AvxX2, + x1: AvxX1, + target: "avx", + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIF4_FMA = Fft { + core_1: core_x2::, + native: FmaX2, + x1: FmaX1, + target: "fma", + }; + + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + pub static DIF4_AVX512 = Fft { + core_1: core_x4::, + native: Avx512X4, + x1: Avx512X1, + target: "avx512f", + }; +} + +pub(crate) fn runtime_fft() -> crate::FftImpl { + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + if x86_feature_detected!("avx512f") { + return DIF4_AVX512; + } + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if x86_feature_detected!("fma") { + return DIF4_FMA; + } else if x86_feature_detected!("avx") { + return DIF4_AVX; + } + + DIF4_SCALAR +} diff --git a/src/dif8.rs b/src/dif8.rs new file mode 100644 index 0000000..4068349 --- /dev/null +++ b/src/dif8.rs @@ -0,0 +1,572 @@ +// Copyright (c) 2019 OK Ojisan(Takuya OKAHISA) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +use crate::c64; +use crate::dif2::end_2; +use crate::dif4::end_4; +use crate::fft_simd::{twid, twid_t, FftSimd64, FftSimd64Ext, FftSimd64X2, FftSimd64X4, Scalar}; +use crate::x86_feature_detected; + +#[inline(always)] +#[rustfmt::skip] +unsafe fn core_( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let m = n / 8; + let big_n = n * s; + let big_n0 = 0; + let big_n1 = big_n / 8; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + + for p in 0..m { + let sp = s * p; + let s8p = 8 * sp; + let w1p = I::splat(twid_t(8, big_n, 1, w, sp)); + let w2p = I::splat(twid_t(8, big_n, 2, w, sp)); + let w3p = I::splat(twid_t(8, big_n, 3, w, sp)); + let w4p = I::splat(twid_t(8, big_n, 4, w, sp)); + let w5p = I::splat(twid_t(8, big_n, 5, w, sp)); + let w6p = I::splat(twid_t(8, big_n, 6, w, sp)); + let w7p = I::splat(twid_t(8, big_n, 7, w, sp)); + + let mut q = 0; + while q < s { + let xq_sp = x.add(q + sp); + let yq_s8p = y.add(q + s8p); + + let x0 = I::load(xq_sp.add(big_n0)); + let x1 = I::load(xq_sp.add(big_n1)); + let x2 = I::load(xq_sp.add(big_n2)); + let x3 = I::load(xq_sp.add(big_n3)); + let x4 = I::load(xq_sp.add(big_n4)); + let x5 = I::load(xq_sp.add(big_n5)); + let x6 = I::load(xq_sp.add(big_n6)); + let x7 = I::load(xq_sp.add(big_n7)); + + let a04 = I::add(x0, x4); + let s04 = I::sub(x0, x4); + let a26 = I::add(x2, x6); + let js26 = I::xpj(fwd, I::sub(x2, x6)); + let a15 = I::add(x1, x5); + let s15 = I::sub(x1, x5); + let a37 = I::add(x3, x7); + let js37 = I::xpj(fwd, I::sub(x3, x7)); + let a04_p1_a26 = I::add(a04, a26); + let s04_mj_s26 = I::sub(s04, js26); + let a04_m1_a26 = I::sub(a04, a26); + let s04_pj_s26 = I::add(s04, js26); + let a15_p1_a37 = I::add(a15, a37); + let w8_s15_mj_s37 = I::xw8(fwd, I::sub(s15, js37)); + let j_a15_m1_a37 = I::xpj(fwd, I::sub(a15, a37)); + let v8_s15_pj_s37 = I::xv8(fwd, I::add(s15, js37)); + + I::store(yq_s8p.add(s * 0), I::add(a04_p1_a26, a15_p1_a37)); + I::store(yq_s8p.add(s * 1), I::mul(w1p, I::add(s04_mj_s26, w8_s15_mj_s37))); + I::store(yq_s8p.add(s * 2), I::mul(w2p, I::sub(a04_m1_a26, j_a15_m1_a37))); + I::store(yq_s8p.add(s * 3), I::mul(w3p, I::sub(s04_pj_s26, v8_s15_pj_s37))); + I::store(yq_s8p.add(s * 4), I::mul(w4p, I::sub(a04_p1_a26, a15_p1_a37))); + I::store(yq_s8p.add(s * 5), I::mul(w5p, I::sub(s04_mj_s26, w8_s15_mj_s37))); + I::store(yq_s8p.add(s * 6), I::mul(w6p, I::add(a04_m1_a26, j_a15_m1_a37))); + I::store(yq_s8p.add(s * 7), I::mul(w7p, I::add(s04_pj_s26, v8_s15_pj_s37))); + + q += I::COMPLEX_PER_REG; + } + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x2( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 8; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + + debug_assert_eq!(big_n1 % 2, 0); + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_8p = y.add(8 * p); + + let x0 = I::load(x_p.add(big_n0)); + let x1 = I::load(x_p.add(big_n1)); + let x2 = I::load(x_p.add(big_n2)); + let x3 = I::load(x_p.add(big_n3)); + let x4 = I::load(x_p.add(big_n4)); + let x5 = I::load(x_p.add(big_n5)); + let x6 = I::load(x_p.add(big_n6)); + let x7 = I::load(x_p.add(big_n7)); + + let a04 = I::add(x0, x4); + let s04 = I::sub(x0, x4); + let a26 = I::add(x2, x6); + let js26 = I::xpj(fwd, I::sub(x2, x6)); + let a15 = I::add(x1, x5); + let s15 = I::sub(x1, x5); + let a37 = I::add(x3, x7); + let js37 = I::xpj(fwd, I::sub(x3, x7)); + + let a04_p1_a26 = I::add(a04, a26); + let s04_mj_s26 = I::sub(s04, js26); + let a04_m1_a26 = I::sub(a04, a26); + let s04_pj_s26 = I::add(s04, js26); + let a15_p1_a37 = I::add(a15, a37); + let w8_s15_mj_s37 = I::xw8(fwd, I::sub(s15, js37)); + let j_a15_m1_a37 = I::xpj(fwd, I::sub(a15, a37)); + let v8_s15_pj_s37 = I::xv8(fwd, I::add(s15, js37)); + + let w1p = I::load(twid(8, big_n, 1, w, p)); + let w2p = I::load(twid(8, big_n, 2, w, p)); + let w3p = I::load(twid(8, big_n, 3, w, p)); + let w4p = I::load(twid(8, big_n, 4, w, p)); + let w5p = I::load(twid(8, big_n, 5, w, p)); + let w6p = I::load(twid(8, big_n, 6, w, p)); + let w7p = I::load(twid(8, big_n, 7, w, p)); + + let aa = I::add(a04_p1_a26, a15_p1_a37); + let bb = I::mul(w1p, I::add(s04_mj_s26, w8_s15_mj_s37)); + let cc = I::mul(w2p, I::sub(a04_m1_a26, j_a15_m1_a37)); + let dd = I::mul(w3p, I::sub(s04_pj_s26, v8_s15_pj_s37)); + let ee = I::mul(w4p, I::sub(a04_p1_a26, a15_p1_a37)); + let ff = I::mul(w5p, I::sub(s04_mj_s26, w8_s15_mj_s37)); + let gg = I::mul(w6p, I::add(a04_m1_a26, j_a15_m1_a37)); + let hh = I::mul(w7p, I::add(s04_pj_s26, v8_s15_pj_s37)); + + { + let ab = I::catlo(aa, bb); + I::store(y_8p.add(0), ab); + let cd = I::catlo(cc, dd); + I::store(y_8p.add(2), cd); + let ef = I::catlo(ee, ff); + I::store(y_8p.add(4), ef); + let gh = I::catlo(gg, hh); + I::store(y_8p.add(6), gh); + } + { + let ab = I::cathi(aa, bb); + I::store(y_8p.add(8), ab); + let cd = I::cathi(cc, dd); + I::store(y_8p.add(10), cd); + let ef = I::cathi(ee, ff); + I::store(y_8p.add(12), ef); + let gh = I::cathi(gg, hh); + I::store(y_8p.add(14), gh); + } + + p += 2; + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x4( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + if n == 16 { + return core_x2::(fwd, n, s, x, y, w); + } + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 8; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + + debug_assert_eq!(big_n1 % 4, 0); + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_8p = y.add(8 * p); + + let x0 = I4::load(x_p.add(big_n0)); + let x1 = I4::load(x_p.add(big_n1)); + let x2 = I4::load(x_p.add(big_n2)); + let x3 = I4::load(x_p.add(big_n3)); + let x4 = I4::load(x_p.add(big_n4)); + let x5 = I4::load(x_p.add(big_n5)); + let x6 = I4::load(x_p.add(big_n6)); + let x7 = I4::load(x_p.add(big_n7)); + + let a04 = I4::add(x0, x4); + let s04 = I4::sub(x0, x4); + let a26 = I4::add(x2, x6); + let js26 = I4::xpj(fwd, I4::sub(x2, x6)); + let a15 = I4::add(x1, x5); + let s15 = I4::sub(x1, x5); + let a37 = I4::add(x3, x7); + let js37 = I4::xpj(fwd, I4::sub(x3, x7)); + + let a04_p1_a26 = I4::add(a04, a26); + let s04_mj_s26 = I4::sub(s04, js26); + let a04_m1_a26 = I4::sub(a04, a26); + let s04_pj_s26 = I4::add(s04, js26); + let a15_p1_a37 = I4::add(a15, a37); + let w8_s15_mj_s37 = I4::xw8(fwd, I4::sub(s15, js37)); + let j_a15_m1_a37 = I4::xpj(fwd, I4::sub(a15, a37)); + let v8_s15_pj_s37 = I4::xv8(fwd, I4::add(s15, js37)); + + let w1p = I4::load(twid(8, big_n, 1, w, p)); + let w2p = I4::load(twid(8, big_n, 2, w, p)); + let w3p = I4::load(twid(8, big_n, 3, w, p)); + let w4p = I4::load(twid(8, big_n, 4, w, p)); + let w5p = I4::load(twid(8, big_n, 5, w, p)); + let w6p = I4::load(twid(8, big_n, 6, w, p)); + let w7p = I4::load(twid(8, big_n, 7, w, p)); + + let a = I4::add(a04_p1_a26, a15_p1_a37); + let b = I4::mul(w1p, I4::add(s04_mj_s26, w8_s15_mj_s37)); + let c = I4::mul(w2p, I4::sub(a04_m1_a26, j_a15_m1_a37)); + let d = I4::mul(w3p, I4::sub(s04_pj_s26, v8_s15_pj_s37)); + let e = I4::mul(w4p, I4::sub(a04_p1_a26, a15_p1_a37)); + let f = I4::mul(w5p, I4::sub(s04_mj_s26, w8_s15_mj_s37)); + let g = I4::mul(w6p, I4::add(a04_m1_a26, j_a15_m1_a37)); + let h = I4::mul(w7p, I4::add(s04_pj_s26, v8_s15_pj_s37)); + + let (abcd0, abcd1, abcd2, abcd3) = I4::transpose(a, b, c, d); + let (efgh0, efgh1, efgh2, efgh3) = I4::transpose(e, f, g, h); + I4::store(y_8p.add(0), abcd0); + I4::store(y_8p.add(4), efgh0); + I4::store(y_8p.add(8), abcd1); + I4::store(y_8p.add(12), efgh1); + I4::store(y_8p.add(16), abcd2); + I4::store(y_8p.add(20), efgh2); + I4::store(y_8p.add(24), abcd3); + I4::store(y_8p.add(28), efgh3); + + p += 4; + } +} + +#[inline(always)] +pub(crate) unsafe fn end_8( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + eo: bool, +) { + debug_assert_eq!(n, 8); + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let z = if eo { y } else { x }; + + let mut q = 0; + while q < s { + let xq = x.add(q); + let zq = z.add(q); + + let x0 = I::load(xq.add(s * 0)); + let x1 = I::load(xq.add(s * 1)); + let x2 = I::load(xq.add(s * 2)); + let x3 = I::load(xq.add(s * 3)); + let x4 = I::load(xq.add(s * 4)); + let x5 = I::load(xq.add(s * 5)); + let x6 = I::load(xq.add(s * 6)); + let x7 = I::load(xq.add(s * 7)); + + let a04 = I::add(x0, x4); + let s04 = I::sub(x0, x4); + let a26 = I::add(x2, x6); + let js26 = I::xpj(fwd, I::sub(x2, x6)); + let a15 = I::add(x1, x5); + let s15 = I::sub(x1, x5); + let a37 = I::add(x3, x7); + let js37 = I::xpj(fwd, I::sub(x3, x7)); + + let a04_p1_a26 = I::add(a04, a26); + let s04_mj_s26 = I::sub(s04, js26); + let a04_m1_a26 = I::sub(a04, a26); + let s04_pj_s26 = I::add(s04, js26); + let a15_p1_a37 = I::add(a15, a37); + let w8_s15_mj_s37 = I::xw8(fwd, I::sub(s15, js37)); + let j_a15_m1_a37 = I::xpj(fwd, I::sub(a15, a37)); + let v8_s15_pj_s37 = I::xv8(fwd, I::add(s15, js37)); + + I::store(zq.add(0), I::add(a04_p1_a26, a15_p1_a37)); + I::store(zq.add(s), I::add(s04_mj_s26, w8_s15_mj_s37)); + I::store(zq.add(s * 2), I::sub(a04_m1_a26, j_a15_m1_a37)); + I::store(zq.add(s * 3), I::sub(s04_pj_s26, v8_s15_pj_s37)); + I::store(zq.add(s * 4), I::sub(a04_p1_a26, a15_p1_a37)); + I::store(zq.add(s * 5), I::sub(s04_mj_s26, w8_s15_mj_s37)); + I::store(zq.add(s * 6), I::add(a04_m1_a26, j_a15_m1_a37)); + I::store(zq.add(s * 7), I::add(s04_pj_s26, v8_s15_pj_s37)); + + q += I::COMPLEX_PER_REG; + } +} + +macro_rules! dif8_impl { + ( + $( + $(#[$attr: meta])* + pub static $fft: ident = Fft { + core_1: $core1______: expr, + native: $xn: ty, + x1: $x1: ty, + $(target: $target: tt,)? + }; + )* + ) => { + $( + #[allow(missing_copy_implementations)] + #[allow(non_camel_case_types)] + #[allow(dead_code)] + $(#[$attr])* + struct $fft { + __private: (), + } + #[allow(unused_variables)] + #[allow(dead_code)] + $(#[$attr])* + impl $fft { + $(#[target_feature(enable = $target)])? + unsafe fn fft_00(x: *mut c64, y: *mut c64, w: *const c64) {} + $(#[target_feature(enable = $target)])? + unsafe fn fft_01(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$x1>(FWD, 1 << 1, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_02(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$x1>(FWD, 1 << 2, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_03(x: *mut c64, y: *mut c64, w: *const c64) { + end_8::<$x1>(FWD, 1 << 3, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_04(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 4, 1 << 0, x, y, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 3, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_05(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 5, 1 << 0, x, y, w); + end_4::<$xn>(FWD, 1 << 2, 1 << 3, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_06(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 6, 1 << 0, x, y, w); + end_8::<$xn>(FWD, 1 << 3, 1 << 3, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_07(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 7, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 3, y, x, w); + end_2::<$xn>(FWD, 1 << 1, 1 << 6, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_08(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 8, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 3, y, x, w); + end_4::<$xn>(FWD, 1 << 2, 1 << 6, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_09(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 9, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 3, y, x, w); + end_8::<$xn>(FWD, 1 << 3, 1 << 6, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_10(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 10, 1 << 0, x, y, w); + core_::<$xn>(FWD, 1 << 07, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 04, 1 << 6, x, y, w); + end_2::<$xn>(FWD, 1 << 01, 1 << 9, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_11(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 11, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 08, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 05, 1 << 06, x, y, w); + end_4::<$xn>(FWD, 1 << 02, 1 << 09, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_12(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 12, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 09, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 06, 1 << 06, x, y, w); + end_8::<$xn>(FWD, 1 << 03, 1 << 09, y, x, true); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_13(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 13, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 07, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 04, 1 << 09, y, x, w); + end_2::<$xn>(FWD, 1 << 01, 1 << 12, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_14(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 14, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 08, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 05, 1 << 09, y, x, w); + end_4::<$xn>(FWD, 1 << 02, 1 << 12, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_15(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 15, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 12, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 09, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 06, 1 << 09, y, x, w); + end_8::<$xn>(FWD, 1 << 03, 1 << 12, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_16(x: *mut c64, y: *mut c64, w: *const c64) { + $core1______(FWD, 1 << 16, 1 << 00, x, y, w); + core_::<$xn>(FWD, 1 << 13, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 10, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 07, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 04, 1 << 12, x, y, w); + end_2::<$xn>(FWD, 1 << 01, 1 << 15, y, x, true); + } + } + $(#[$attr])* + pub(crate) static $fft: crate::FftImpl = crate::FftImpl { + fwd: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + inv: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + }; + )* + }; +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::x86::*; + +dif8_impl! { + pub static DIF8_SCALAR = Fft { + core_1: core_::, + native: Scalar, + x1: Scalar, + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIF8_AVX = Fft { + core_1: core_x2::, + native: AvxX2, + x1: AvxX1, + target: "avx", + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIF8_FMA = Fft { + core_1: core_x2::, + native: FmaX2, + x1: FmaX1, + target: "fma", + }; + + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + pub static DIF8_AVX512 = Fft { + core_1: core_x4::, + native: Avx512X4, + x1: Avx512X1, + target: "avx512f", + }; +} + +pub(crate) fn runtime_fft() -> crate::FftImpl { + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + if x86_feature_detected!("avx512f") { + return DIF8_AVX512; + } + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if x86_feature_detected!("fma") { + return DIF8_FMA; + } else if x86_feature_detected!("avx") { + return DIF8_AVX; + } + + DIF8_SCALAR +} diff --git a/src/dit16.rs b/src/dit16.rs new file mode 100644 index 0000000..92c5d1f --- /dev/null +++ b/src/dit16.rs @@ -0,0 +1,1116 @@ +// Copyright (c) 2019 OK Ojisan(Takuya OKAHISA) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +use crate::c64; +use crate::dit4::{end_2, end_4}; +use crate::dit8::end_8; +use crate::fft_simd::{twid, twid_t, FftSimd64, FftSimd64Ext, FftSimd64X2, FftSimd64X4, Scalar}; +use crate::x86_feature_detected; + +#[inline(always)] +unsafe fn core_( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let m = n / 16; + let big_n = n * s; + let big_n0 = 0; + let big_n1 = big_n / 16; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + let big_n8 = big_n1 * 8; + let big_n9 = big_n1 * 9; + let big_na = big_n1 * 10; + let big_nb = big_n1 * 11; + let big_nc = big_n1 * 12; + let big_nd = big_n1 * 13; + let big_ne = big_n1 * 14; + let big_nf = big_n1 * 15; + + for p in 0..m { + let sp = s * p; + let s16p = 16 * sp; + + let w1p = I::splat(twid_t(16, big_n, 0x1, w, sp)); + let w2p = I::splat(twid_t(16, big_n, 0x2, w, sp)); + let w3p = I::splat(twid_t(16, big_n, 0x3, w, sp)); + let w4p = I::splat(twid_t(16, big_n, 0x4, w, sp)); + let w5p = I::splat(twid_t(16, big_n, 0x5, w, sp)); + let w6p = I::splat(twid_t(16, big_n, 0x6, w, sp)); + let w7p = I::splat(twid_t(16, big_n, 0x7, w, sp)); + let w8p = I::splat(twid_t(16, big_n, 0x8, w, sp)); + let w9p = I::splat(twid_t(16, big_n, 0x9, w, sp)); + let wap = I::splat(twid_t(16, big_n, 0xa, w, sp)); + let wbp = I::splat(twid_t(16, big_n, 0xb, w, sp)); + let wcp = I::splat(twid_t(16, big_n, 0xc, w, sp)); + let wdp = I::splat(twid_t(16, big_n, 0xd, w, sp)); + let wep = I::splat(twid_t(16, big_n, 0xe, w, sp)); + let wfp = I::splat(twid_t(16, big_n, 0xf, w, sp)); + + let mut q = 0; + while q < s { + let xq_sp = x.add(q + sp); + let yq_s16p = y.add(q + s16p); + + let y0 = I::load(yq_s16p.add(0)); + let y1 = I::mul(w1p, I::load(yq_s16p.add(s))); + let y2 = I::mul(w2p, I::load(yq_s16p.add(s * 0x2))); + let y3 = I::mul(w3p, I::load(yq_s16p.add(s * 0x3))); + let y4 = I::mul(w4p, I::load(yq_s16p.add(s * 0x4))); + let y5 = I::mul(w5p, I::load(yq_s16p.add(s * 0x5))); + let y6 = I::mul(w6p, I::load(yq_s16p.add(s * 0x6))); + let y7 = I::mul(w7p, I::load(yq_s16p.add(s * 0x7))); + let y8 = I::mul(w8p, I::load(yq_s16p.add(s * 0x8))); + let y9 = I::mul(w9p, I::load(yq_s16p.add(s * 0x9))); + let ya = I::mul(wap, I::load(yq_s16p.add(s * 0xa))); + let yb = I::mul(wbp, I::load(yq_s16p.add(s * 0xb))); + let yc = I::mul(wcp, I::load(yq_s16p.add(s * 0xc))); + let yd = I::mul(wdp, I::load(yq_s16p.add(s * 0xd))); + let ye = I::mul(wep, I::load(yq_s16p.add(s * 0xe))); + let yf = I::mul(wfp, I::load(yq_s16p.add(s * 0xf))); + + let a08 = I::add(y0, y8); + let s08 = I::sub(y0, y8); + let a4c = I::add(y4, yc); + let s4c = I::sub(y4, yc); + let a2a = I::add(y2, ya); + let s2a = I::sub(y2, ya); + let a6e = I::add(y6, ye); + let s6e = I::sub(y6, ye); + let a19 = I::add(y1, y9); + let s19 = I::sub(y1, y9); + let a5d = I::add(y5, yd); + let s5d = I::sub(y5, yd); + let a3b = I::add(y3, yb); + let s3b = I::sub(y3, yb); + let a7f = I::add(y7, yf); + let s7f = I::sub(y7, yf); + + let js4c = I::xpj(fwd, s4c); + let js6e = I::xpj(fwd, s6e); + let js5d = I::xpj(fwd, s5d); + let js7f = I::xpj(fwd, s7f); + + let a08p1a4c = I::add(a08, a4c); + let s08mjs4c = I::sub(s08, js4c); + let a08m1a4c = I::sub(a08, a4c); + let s08pjs4c = I::add(s08, js4c); + let a2ap1a6e = I::add(a2a, a6e); + let s2amjs6e = I::sub(s2a, js6e); + let a2am1a6e = I::sub(a2a, a6e); + let s2apjs6e = I::add(s2a, js6e); + let a19p1a5d = I::add(a19, a5d); + let s19mjs5d = I::sub(s19, js5d); + let a19m1a5d = I::sub(a19, a5d); + let s19pjs5d = I::add(s19, js5d); + let a3bp1a7f = I::add(a3b, a7f); + let s3bmjs7f = I::sub(s3b, js7f); + let a3bm1a7f = I::sub(a3b, a7f); + let s3bpjs7f = I::add(s3b, js7f); + + let w8_s2amjs6e = I::xw8(fwd, s2amjs6e); + let j_a2am1a6e = I::xpj(fwd, a2am1a6e); + let v8_s2apjs6e = I::xv8(fwd, s2apjs6e); + + let a08p1a4c_p1_a2ap1a6e = I::add(a08p1a4c, a2ap1a6e); + let s08mjs4c_pw_s2amjs6e = I::add(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_mj_a2am1a6e = I::sub(a08m1a4c, j_a2am1a6e); + let s08pjs4c_mv_s2apjs6e = I::sub(s08pjs4c, v8_s2apjs6e); + let a08p1a4c_m1_a2ap1a6e = I::sub(a08p1a4c, a2ap1a6e); + let s08mjs4c_mw_s2amjs6e = I::sub(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_pj_a2am1a6e = I::add(a08m1a4c, j_a2am1a6e); + let s08pjs4c_pv_s2apjs6e = I::add(s08pjs4c, v8_s2apjs6e); + + let w8_s3bmjs7f = I::xw8(fwd, s3bmjs7f); + let j_a3bm1a7f = I::xpj(fwd, a3bm1a7f); + let v8_s3bpjs7f = I::xv8(fwd, s3bpjs7f); + + let a19p1a5d_p1_a3bp1a7f = I::add(a19p1a5d, a3bp1a7f); + let s19mjs5d_pw_s3bmjs7f = I::add(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_mj_a3bm1a7f = I::sub(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_mv_s3bpjs7f = I::sub(s19pjs5d, v8_s3bpjs7f); + let a19p1a5d_m1_a3bp1a7f = I::sub(a19p1a5d, a3bp1a7f); + let s19mjs5d_mw_s3bmjs7f = I::sub(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_pj_a3bm1a7f = I::add(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_pv_s3bpjs7f = I::add(s19pjs5d, v8_s3bpjs7f); + + I::store( + xq_sp.add(big_n0), + I::add(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f), + ); + I::store( + xq_sp.add(big_n8), + I::sub(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f), + ); + + let h1_s19mjs5d_pw_s3bmjs7f = I::xh1(fwd, s19mjs5d_pw_s3bmjs7f); + I::store( + xq_sp.add(big_n1), + I::add(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f), + ); + I::store( + xq_sp.add(big_n9), + I::sub(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f), + ); + + let w8_a19m1a5d_mj_a3bm1a7f = I::xw8(fwd, a19m1a5d_mj_a3bm1a7f); + I::store( + xq_sp.add(big_n2), + I::add(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f), + ); + I::store( + xq_sp.add(big_na), + I::sub(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f), + ); + + let h3_s19pjs5d_mv_s3bpjs7f = I::xh3(fwd, s19pjs5d_mv_s3bpjs7f); + I::store( + xq_sp.add(big_n3), + I::add(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f), + ); + I::store( + xq_sp.add(big_nb), + I::sub(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f), + ); + + let j_a19p1a5d_m1_a3bp1a7f = I::xpj(fwd, a19p1a5d_m1_a3bp1a7f); + I::store( + xq_sp.add(big_n4), + I::sub(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f), + ); + I::store( + xq_sp.add(big_nc), + I::add(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f), + ); + + let hd_s19mjs5d_mw_s3bmjs7f = I::xhd(fwd, s19mjs5d_mw_s3bmjs7f); + I::store( + xq_sp.add(big_n5), + I::sub(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f), + ); + I::store( + xq_sp.add(big_nd), + I::add(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f), + ); + + let v8_a19m1a5d_pj_a3bm1a7f = I::xv8(fwd, a19m1a5d_pj_a3bm1a7f); + I::store( + xq_sp.add(big_n6), + I::sub(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f), + ); + I::store( + xq_sp.add(big_ne), + I::add(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f), + ); + + let hf_s19pjs5d_pv_s3bpjs7f = I::xhf(fwd, s19pjs5d_pv_s3bpjs7f); + I::store( + xq_sp.add(big_n7), + I::sub(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f), + ); + I::store( + xq_sp.add(big_nf), + I::add(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f), + ); + + q += I::COMPLEX_PER_REG; + } + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x2( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 16; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + let big_n8 = big_n1 * 8; + let big_n9 = big_n1 * 9; + let big_na = big_n1 * 10; + let big_nb = big_n1 * 11; + let big_nc = big_n1 * 12; + let big_nd = big_n1 * 13; + let big_ne = big_n1 * 14; + let big_nf = big_n1 * 15; + + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_16p = y.add(16 * p); + + let w1p = I::load(twid(16, big_n, 1, w, p)); + let w2p = I::load(twid(16, big_n, 2, w, p)); + let w3p = I::load(twid(16, big_n, 3, w, p)); + let w4p = I::load(twid(16, big_n, 4, w, p)); + let w5p = I::load(twid(16, big_n, 5, w, p)); + let w6p = I::load(twid(16, big_n, 6, w, p)); + let w7p = I::load(twid(16, big_n, 7, w, p)); + let w8p = I::load(twid(16, big_n, 8, w, p)); + let w9p = I::load(twid(16, big_n, 9, w, p)); + let wap = I::load(twid(16, big_n, 10, w, p)); + let wbp = I::load(twid(16, big_n, 11, w, p)); + let wcp = I::load(twid(16, big_n, 12, w, p)); + let wdp = I::load(twid(16, big_n, 13, w, p)); + let wep = I::load(twid(16, big_n, 14, w, p)); + let wfp = I::load(twid(16, big_n, 15, w, p)); + + let ab_0 = I::load(y_16p.add(0x00)); + let cd_0 = I::load(y_16p.add(0x02)); + let ef_0 = I::load(y_16p.add(0x04)); + let gh_0 = I::load(y_16p.add(0x06)); + let ij_0 = I::load(y_16p.add(0x08)); + let kl_0 = I::load(y_16p.add(0x0a)); + let mn_0 = I::load(y_16p.add(0x0c)); + let op_0 = I::load(y_16p.add(0x0e)); + let ab_1 = I::load(y_16p.add(0x10)); + let cd_1 = I::load(y_16p.add(0x12)); + let ef_1 = I::load(y_16p.add(0x14)); + let gh_1 = I::load(y_16p.add(0x16)); + let ij_1 = I::load(y_16p.add(0x18)); + let kl_1 = I::load(y_16p.add(0x1a)); + let mn_1 = I::load(y_16p.add(0x1c)); + let op_1 = I::load(y_16p.add(0x1e)); + + let y0 = I::catlo(ab_0, ab_1); + let y1 = I::mul(w1p, I::cathi(ab_0, ab_1)); + let y2 = I::mul(w2p, I::catlo(cd_0, cd_1)); + let y3 = I::mul(w3p, I::cathi(cd_0, cd_1)); + let y4 = I::mul(w4p, I::catlo(ef_0, ef_1)); + let y5 = I::mul(w5p, I::cathi(ef_0, ef_1)); + let y6 = I::mul(w6p, I::catlo(gh_0, gh_1)); + let y7 = I::mul(w7p, I::cathi(gh_0, gh_1)); + + let y8 = I::mul(w8p, I::catlo(ij_0, ij_1)); + let y9 = I::mul(w9p, I::cathi(ij_0, ij_1)); + let ya = I::mul(wap, I::catlo(kl_0, kl_1)); + let yb = I::mul(wbp, I::cathi(kl_0, kl_1)); + let yc = I::mul(wcp, I::catlo(mn_0, mn_1)); + let yd = I::mul(wdp, I::cathi(mn_0, mn_1)); + let ye = I::mul(wep, I::catlo(op_0, op_1)); + let yf = I::mul(wfp, I::cathi(op_0, op_1)); + + let a08 = I::add(y0, y8); + let s08 = I::sub(y0, y8); + let a4c = I::add(y4, yc); + let s4c = I::sub(y4, yc); + let a2a = I::add(y2, ya); + let s2a = I::sub(y2, ya); + let a6e = I::add(y6, ye); + let s6e = I::sub(y6, ye); + let a19 = I::add(y1, y9); + let s19 = I::sub(y1, y9); + let a5d = I::add(y5, yd); + let s5d = I::sub(y5, yd); + let a3b = I::add(y3, yb); + let s3b = I::sub(y3, yb); + let a7f = I::add(y7, yf); + let s7f = I::sub(y7, yf); + + let js4c = I::xpj(fwd, s4c); + let js6e = I::xpj(fwd, s6e); + let js5d = I::xpj(fwd, s5d); + let js7f = I::xpj(fwd, s7f); + + let a08p1a4c = I::add(a08, a4c); + let s08mjs4c = I::sub(s08, js4c); + let a08m1a4c = I::sub(a08, a4c); + let s08pjs4c = I::add(s08, js4c); + let a2ap1a6e = I::add(a2a, a6e); + let s2amjs6e = I::sub(s2a, js6e); + let a2am1a6e = I::sub(a2a, a6e); + let s2apjs6e = I::add(s2a, js6e); + let a19p1a5d = I::add(a19, a5d); + let s19mjs5d = I::sub(s19, js5d); + let a19m1a5d = I::sub(a19, a5d); + let s19pjs5d = I::add(s19, js5d); + let a3bp1a7f = I::add(a3b, a7f); + let s3bmjs7f = I::sub(s3b, js7f); + let a3bm1a7f = I::sub(a3b, a7f); + let s3bpjs7f = I::add(s3b, js7f); + + let w8_s2amjs6e = I::xw8(fwd, s2amjs6e); + let j_a2am1a6e = I::xpj(fwd, a2am1a6e); + let v8_s2apjs6e = I::xv8(fwd, s2apjs6e); + + let a08p1a4c_p1_a2ap1a6e = I::add(a08p1a4c, a2ap1a6e); + let s08mjs4c_pw_s2amjs6e = I::add(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_mj_a2am1a6e = I::sub(a08m1a4c, j_a2am1a6e); + let s08pjs4c_mv_s2apjs6e = I::sub(s08pjs4c, v8_s2apjs6e); + let a08p1a4c_m1_a2ap1a6e = I::sub(a08p1a4c, a2ap1a6e); + let s08mjs4c_mw_s2amjs6e = I::sub(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_pj_a2am1a6e = I::add(a08m1a4c, j_a2am1a6e); + let s08pjs4c_pv_s2apjs6e = I::add(s08pjs4c, v8_s2apjs6e); + + let w8_s3bmjs7f = I::xw8(fwd, s3bmjs7f); + let j_a3bm1a7f = I::xpj(fwd, a3bm1a7f); + let v8_s3bpjs7f = I::xv8(fwd, s3bpjs7f); + + let a19p1a5d_p1_a3bp1a7f = I::add(a19p1a5d, a3bp1a7f); + let s19mjs5d_pw_s3bmjs7f = I::add(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_mj_a3bm1a7f = I::sub(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_mv_s3bpjs7f = I::sub(s19pjs5d, v8_s3bpjs7f); + let a19p1a5d_m1_a3bp1a7f = I::sub(a19p1a5d, a3bp1a7f); + let s19mjs5d_mw_s3bmjs7f = I::sub(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_pj_a3bm1a7f = I::add(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_pv_s3bpjs7f = I::add(s19pjs5d, v8_s3bpjs7f); + + I::store( + x_p.add(big_n0), + I::add(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f), + ); + I::store( + x_p.add(big_n8), + I::sub(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f), + ); + + let h1_s19mjs5d_pw_s3bmjs7f = I::xh1(fwd, s19mjs5d_pw_s3bmjs7f); + I::store( + x_p.add(big_n1), + I::add(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f), + ); + I::store( + x_p.add(big_n9), + I::sub(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f), + ); + + let w8_a19m1a5d_mj_a3bm1a7f = I::xw8(fwd, a19m1a5d_mj_a3bm1a7f); + I::store( + x_p.add(big_n2), + I::add(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f), + ); + I::store( + x_p.add(big_na), + I::sub(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f), + ); + + let h3_s19pjs5d_mv_s3bpjs7f = I::xh3(fwd, s19pjs5d_mv_s3bpjs7f); + I::store( + x_p.add(big_n3), + I::add(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f), + ); + I::store( + x_p.add(big_nb), + I::sub(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f), + ); + + let j_a19p1a5d_m1_a3bp1a7f = I::xpj(fwd, a19p1a5d_m1_a3bp1a7f); + I::store( + x_p.add(big_n4), + I::sub(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f), + ); + I::store( + x_p.add(big_nc), + I::add(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f), + ); + + let hd_s19mjs5d_mw_s3bmjs7f = I::xhd(fwd, s19mjs5d_mw_s3bmjs7f); + I::store( + x_p.add(big_n5), + I::sub(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f), + ); + I::store( + x_p.add(big_nd), + I::add(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f), + ); + + let v8_a19m1a5d_pj_a3bm1a7f = I::xv8(fwd, a19m1a5d_pj_a3bm1a7f); + I::store( + x_p.add(big_n6), + I::sub(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f), + ); + I::store( + x_p.add(big_ne), + I::add(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f), + ); + + let hf_s19pjs5d_pv_s3bpjs7f = I::xhf(fwd, s19pjs5d_pv_s3bpjs7f); + I::store( + x_p.add(big_n7), + I::sub(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f), + ); + I::store( + x_p.add(big_nf), + I::add(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f), + ); + + p += 2; + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x4( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + if n == 32 { + return core_x2::(fwd, n, s, x, y, w); + } + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 16; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + let big_n8 = big_n1 * 8; + let big_n9 = big_n1 * 9; + let big_na = big_n1 * 10; + let big_nb = big_n1 * 11; + let big_nc = big_n1 * 12; + let big_nd = big_n1 * 13; + let big_ne = big_n1 * 14; + let big_nf = big_n1 * 15; + + debug_assert_eq!(big_n1 % 4, 0); + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_16p = y.add(16 * p); + + let w1p = I4::load(twid(16, big_n, 1, w, p)); + let w2p = I4::load(twid(16, big_n, 2, w, p)); + let w3p = I4::load(twid(16, big_n, 3, w, p)); + let w4p = I4::load(twid(16, big_n, 4, w, p)); + let w5p = I4::load(twid(16, big_n, 5, w, p)); + let w6p = I4::load(twid(16, big_n, 6, w, p)); + let w7p = I4::load(twid(16, big_n, 7, w, p)); + let w8p = I4::load(twid(16, big_n, 8, w, p)); + let w9p = I4::load(twid(16, big_n, 9, w, p)); + let wap = I4::load(twid(16, big_n, 10, w, p)); + let wbp = I4::load(twid(16, big_n, 11, w, p)); + let wcp = I4::load(twid(16, big_n, 12, w, p)); + let wdp = I4::load(twid(16, big_n, 13, w, p)); + let wep = I4::load(twid(16, big_n, 14, w, p)); + let wfp = I4::load(twid(16, big_n, 15, w, p)); + + let abcd0 = I4::load(y_16p.add(0x00)); + let efgh0 = I4::load(y_16p.add(0x04)); + let ijkl0 = I4::load(y_16p.add(0x08)); + let mnop0 = I4::load(y_16p.add(0x0c)); + + let abcd1 = I4::load(y_16p.add(0x10)); + let efgh1 = I4::load(y_16p.add(0x14)); + let ijkl1 = I4::load(y_16p.add(0x18)); + let mnop1 = I4::load(y_16p.add(0x1c)); + + let abcd2 = I4::load(y_16p.add(0x20)); + let efgh2 = I4::load(y_16p.add(0x24)); + let ijkl2 = I4::load(y_16p.add(0x28)); + let mnop2 = I4::load(y_16p.add(0x2c)); + + let abcd3 = I4::load(y_16p.add(0x30)); + let efgh3 = I4::load(y_16p.add(0x34)); + let ijkl3 = I4::load(y_16p.add(0x38)); + let mnop3 = I4::load(y_16p.add(0x3c)); + + let (a_, b_, c_, d_) = I4::transpose(abcd0, abcd1, abcd2, abcd3); + let (e_, f_, g_, h_) = I4::transpose(efgh0, efgh1, efgh2, efgh3); + let (i_, j_, k_, l_) = I4::transpose(ijkl0, ijkl1, ijkl2, ijkl3); + let (m_, n_, o_, p_) = I4::transpose(mnop0, mnop1, mnop2, mnop3); + + let y0 = a_; + let y1 = I4::mul(w1p, b_); + let y2 = I4::mul(w2p, c_); + let y3 = I4::mul(w3p, d_); + let y4 = I4::mul(w4p, e_); + let y5 = I4::mul(w5p, f_); + let y6 = I4::mul(w6p, g_); + let y7 = I4::mul(w7p, h_); + + let y8 = I4::mul(w8p, i_); + let y9 = I4::mul(w9p, j_); + let ya = I4::mul(wap, k_); + let yb = I4::mul(wbp, l_); + let yc = I4::mul(wcp, m_); + let yd = I4::mul(wdp, n_); + let ye = I4::mul(wep, o_); + let yf = I4::mul(wfp, p_); + + let a08 = I4::add(y0, y8); + let s08 = I4::sub(y0, y8); + let a4c = I4::add(y4, yc); + let s4c = I4::sub(y4, yc); + let a2a = I4::add(y2, ya); + let s2a = I4::sub(y2, ya); + let a6e = I4::add(y6, ye); + let s6e = I4::sub(y6, ye); + let a19 = I4::add(y1, y9); + let s19 = I4::sub(y1, y9); + let a5d = I4::add(y5, yd); + let s5d = I4::sub(y5, yd); + let a3b = I4::add(y3, yb); + let s3b = I4::sub(y3, yb); + let a7f = I4::add(y7, yf); + let s7f = I4::sub(y7, yf); + + let js4c = I4::xpj(fwd, s4c); + let js6e = I4::xpj(fwd, s6e); + let js5d = I4::xpj(fwd, s5d); + let js7f = I4::xpj(fwd, s7f); + + let a08p1a4c = I4::add(a08, a4c); + let s08mjs4c = I4::sub(s08, js4c); + let a08m1a4c = I4::sub(a08, a4c); + let s08pjs4c = I4::add(s08, js4c); + let a2ap1a6e = I4::add(a2a, a6e); + let s2amjs6e = I4::sub(s2a, js6e); + let a2am1a6e = I4::sub(a2a, a6e); + let s2apjs6e = I4::add(s2a, js6e); + let a19p1a5d = I4::add(a19, a5d); + let s19mjs5d = I4::sub(s19, js5d); + let a19m1a5d = I4::sub(a19, a5d); + let s19pjs5d = I4::add(s19, js5d); + let a3bp1a7f = I4::add(a3b, a7f); + let s3bmjs7f = I4::sub(s3b, js7f); + let a3bm1a7f = I4::sub(a3b, a7f); + let s3bpjs7f = I4::add(s3b, js7f); + + let w8_s2amjs6e = I4::xw8(fwd, s2amjs6e); + let j_a2am1a6e = I4::xpj(fwd, a2am1a6e); + let v8_s2apjs6e = I4::xv8(fwd, s2apjs6e); + + let a08p1a4c_p1_a2ap1a6e = I4::add(a08p1a4c, a2ap1a6e); + let s08mjs4c_pw_s2amjs6e = I4::add(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_mj_a2am1a6e = I4::sub(a08m1a4c, j_a2am1a6e); + let s08pjs4c_mv_s2apjs6e = I4::sub(s08pjs4c, v8_s2apjs6e); + let a08p1a4c_m1_a2ap1a6e = I4::sub(a08p1a4c, a2ap1a6e); + let s08mjs4c_mw_s2amjs6e = I4::sub(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_pj_a2am1a6e = I4::add(a08m1a4c, j_a2am1a6e); + let s08pjs4c_pv_s2apjs6e = I4::add(s08pjs4c, v8_s2apjs6e); + + let w8_s3bmjs7f = I4::xw8(fwd, s3bmjs7f); + let j_a3bm1a7f = I4::xpj(fwd, a3bm1a7f); + let v8_s3bpjs7f = I4::xv8(fwd, s3bpjs7f); + + let a19p1a5d_p1_a3bp1a7f = I4::add(a19p1a5d, a3bp1a7f); + let s19mjs5d_pw_s3bmjs7f = I4::add(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_mj_a3bm1a7f = I4::sub(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_mv_s3bpjs7f = I4::sub(s19pjs5d, v8_s3bpjs7f); + let a19p1a5d_m1_a3bp1a7f = I4::sub(a19p1a5d, a3bp1a7f); + let s19mjs5d_mw_s3bmjs7f = I4::sub(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_pj_a3bm1a7f = I4::add(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_pv_s3bpjs7f = I4::add(s19pjs5d, v8_s3bpjs7f); + + I4::store( + x_p.add(big_n0), + I4::add(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f), + ); + I4::store( + x_p.add(big_n8), + I4::sub(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f), + ); + + let h1_s19mjs5d_pw_s3bmjs7f = I4::xh1(fwd, s19mjs5d_pw_s3bmjs7f); + I4::store( + x_p.add(big_n1), + I4::add(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f), + ); + I4::store( + x_p.add(big_n9), + I4::sub(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f), + ); + + let w8_a19m1a5d_mj_a3bm1a7f = I4::xw8(fwd, a19m1a5d_mj_a3bm1a7f); + I4::store( + x_p.add(big_n2), + I4::add(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f), + ); + I4::store( + x_p.add(big_na), + I4::sub(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f), + ); + + let h3_s19pjs5d_mv_s3bpjs7f = I4::xh3(fwd, s19pjs5d_mv_s3bpjs7f); + I4::store( + x_p.add(big_n3), + I4::add(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f), + ); + I4::store( + x_p.add(big_nb), + I4::sub(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f), + ); + + let j_a19p1a5d_m1_a3bp1a7f = I4::xpj(fwd, a19p1a5d_m1_a3bp1a7f); + I4::store( + x_p.add(big_n4), + I4::sub(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f), + ); + I4::store( + x_p.add(big_nc), + I4::add(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f), + ); + + let hd_s19mjs5d_mw_s3bmjs7f = I4::xhd(fwd, s19mjs5d_mw_s3bmjs7f); + I4::store( + x_p.add(big_n5), + I4::sub(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f), + ); + I4::store( + x_p.add(big_nd), + I4::add(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f), + ); + + let v8_a19m1a5d_pj_a3bm1a7f = I4::xv8(fwd, a19m1a5d_pj_a3bm1a7f); + I4::store( + x_p.add(big_n6), + I4::sub(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f), + ); + I4::store( + x_p.add(big_ne), + I4::add(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f), + ); + + let hf_s19pjs5d_pv_s3bpjs7f = I4::xhf(fwd, s19pjs5d_pv_s3bpjs7f); + I4::store( + x_p.add(big_n7), + I4::sub(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f), + ); + I4::store( + x_p.add(big_nf), + I4::add(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f), + ); + + p += 4; + } +} + +#[inline(always)] +pub(crate) unsafe fn end16( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + eo: bool, +) { + debug_assert_eq!(n, 16); + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let z = if eo { y } else { x }; + + let mut q = 0; + while q < s { + let xq = x.add(q); + let zq = z.add(q); + + let z0 = I::load(zq.add(0x0)); + let z1 = I::load(zq.add(s)); + let z2 = I::load(zq.add(s * 0x2)); + let z3 = I::load(zq.add(s * 0x3)); + let z4 = I::load(zq.add(s * 0x4)); + let z5 = I::load(zq.add(s * 0x5)); + let z6 = I::load(zq.add(s * 0x6)); + let z7 = I::load(zq.add(s * 0x7)); + let z8 = I::load(zq.add(s * 0x8)); + let z9 = I::load(zq.add(s * 0x9)); + let za = I::load(zq.add(s * 0xa)); + let zb = I::load(zq.add(s * 0xb)); + let zc = I::load(zq.add(s * 0xc)); + let zd = I::load(zq.add(s * 0xd)); + let ze = I::load(zq.add(s * 0xe)); + let zf = I::load(zq.add(s * 0xf)); + + let a08 = I::add(z0, z8); + let s08 = I::sub(z0, z8); + let a4c = I::add(z4, zc); + let s4c = I::sub(z4, zc); + let a2a = I::add(z2, za); + let s2a = I::sub(z2, za); + let a6e = I::add(z6, ze); + let s6e = I::sub(z6, ze); + let a19 = I::add(z1, z9); + let s19 = I::sub(z1, z9); + let a5d = I::add(z5, zd); + let s5d = I::sub(z5, zd); + let a3b = I::add(z3, zb); + let s3b = I::sub(z3, zb); + let a7f = I::add(z7, zf); + let s7f = I::sub(z7, zf); + + let js4c = I::xpj(fwd, s4c); + let js6e = I::xpj(fwd, s6e); + let js5d = I::xpj(fwd, s5d); + let js7f = I::xpj(fwd, s7f); + + let a08p1a4c = I::add(a08, a4c); + let s08mjs4c = I::sub(s08, js4c); + let a08m1a4c = I::sub(a08, a4c); + let s08pjs4c = I::add(s08, js4c); + let a2ap1a6e = I::add(a2a, a6e); + let s2amjs6e = I::sub(s2a, js6e); + let a2am1a6e = I::sub(a2a, a6e); + let s2apjs6e = I::add(s2a, js6e); + let a19p1a5d = I::add(a19, a5d); + let s19mjs5d = I::sub(s19, js5d); + let a19m1a5d = I::sub(a19, a5d); + let s19pjs5d = I::add(s19, js5d); + let a3bp1a7f = I::add(a3b, a7f); + let s3bmjs7f = I::sub(s3b, js7f); + let a3bm1a7f = I::sub(a3b, a7f); + let s3bpjs7f = I::add(s3b, js7f); + + let w8_s2amjs6e = I::xw8(fwd, s2amjs6e); + let j_a2am1a6e = I::xpj(fwd, a2am1a6e); + let v8_s2apjs6e = I::xv8(fwd, s2apjs6e); + + let a08p1a4c_p1_a2ap1a6e = I::add(a08p1a4c, a2ap1a6e); + let s08mjs4c_pw_s2amjs6e = I::add(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_mj_a2am1a6e = I::sub(a08m1a4c, j_a2am1a6e); + let s08pjs4c_mv_s2apjs6e = I::sub(s08pjs4c, v8_s2apjs6e); + let a08p1a4c_m1_a2ap1a6e = I::sub(a08p1a4c, a2ap1a6e); + let s08mjs4c_mw_s2amjs6e = I::sub(s08mjs4c, w8_s2amjs6e); + let a08m1a4c_pj_a2am1a6e = I::add(a08m1a4c, j_a2am1a6e); + let s08pjs4c_pv_s2apjs6e = I::add(s08pjs4c, v8_s2apjs6e); + + let w8_s3bmjs7f = I::xw8(fwd, s3bmjs7f); + let j_a3bm1a7f = I::xpj(fwd, a3bm1a7f); + let v8_s3bpjs7f = I::xv8(fwd, s3bpjs7f); + + let a19p1a5d_p1_a3bp1a7f = I::add(a19p1a5d, a3bp1a7f); + let s19mjs5d_pw_s3bmjs7f = I::add(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_mj_a3bm1a7f = I::sub(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_mv_s3bpjs7f = I::sub(s19pjs5d, v8_s3bpjs7f); + let a19p1a5d_m1_a3bp1a7f = I::sub(a19p1a5d, a3bp1a7f); + let s19mjs5d_mw_s3bmjs7f = I::sub(s19mjs5d, w8_s3bmjs7f); + let a19m1a5d_pj_a3bm1a7f = I::add(a19m1a5d, j_a3bm1a7f); + let s19pjs5d_pv_s3bpjs7f = I::add(s19pjs5d, v8_s3bpjs7f); + + let h1_s19mjs5d_pw_s3bmjs7f = I::xh1(fwd, s19mjs5d_pw_s3bmjs7f); + let w8_a19m1a5d_mj_a3bm1a7f = I::xw8(fwd, a19m1a5d_mj_a3bm1a7f); + let h3_s19pjs5d_mv_s3bpjs7f = I::xh3(fwd, s19pjs5d_mv_s3bpjs7f); + let j_a19p1a5d_m1_a3bp1a7f = I::xpj(fwd, a19p1a5d_m1_a3bp1a7f); + let hd_s19mjs5d_mw_s3bmjs7f = I::xhd(fwd, s19mjs5d_mw_s3bmjs7f); + let v8_a19m1a5d_pj_a3bm1a7f = I::xv8(fwd, a19m1a5d_pj_a3bm1a7f); + let hf_s19pjs5d_pv_s3bpjs7f = I::xhf(fwd, s19pjs5d_pv_s3bpjs7f); + + I::store( + xq.add(0x0), + I::add(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f), + ); + I::store( + xq.add(s), + I::add(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f), + ); + I::store( + xq.add(s * 0x2), + I::add(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f), + ); + I::store( + xq.add(s * 0x3), + I::add(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f), + ); + I::store( + xq.add(s * 0x4), + I::sub(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f), + ); + I::store( + xq.add(s * 0x5), + I::sub(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f), + ); + I::store( + xq.add(s * 0x6), + I::sub(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f), + ); + I::store( + xq.add(s * 0x7), + I::sub(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f), + ); + + I::store( + xq.add(s * 0x8), + I::sub(a08p1a4c_p1_a2ap1a6e, a19p1a5d_p1_a3bp1a7f), + ); + I::store( + xq.add(s * 0x9), + I::sub(s08mjs4c_pw_s2amjs6e, h1_s19mjs5d_pw_s3bmjs7f), + ); + I::store( + xq.add(s * 0xa), + I::sub(a08m1a4c_mj_a2am1a6e, w8_a19m1a5d_mj_a3bm1a7f), + ); + I::store( + xq.add(s * 0xb), + I::sub(s08pjs4c_mv_s2apjs6e, h3_s19pjs5d_mv_s3bpjs7f), + ); + I::store( + xq.add(s * 0xc), + I::add(a08p1a4c_m1_a2ap1a6e, j_a19p1a5d_m1_a3bp1a7f), + ); + I::store( + xq.add(s * 0xd), + I::add(s08mjs4c_mw_s2amjs6e, hd_s19mjs5d_mw_s3bmjs7f), + ); + I::store( + xq.add(s * 0xe), + I::add(a08m1a4c_pj_a2am1a6e, v8_a19m1a5d_pj_a3bm1a7f), + ); + I::store( + xq.add(s * 0xf), + I::add(s08pjs4c_pv_s2apjs6e, hf_s19pjs5d_pv_s3bpjs7f), + ); + + q += I::COMPLEX_PER_REG; + } +} + +macro_rules! dit16_impl { + ( + $( + $(#[$attr: meta])* + pub static $fft: ident = Fft { + core_1: $core1______: expr, + native: $xn: ty, + x1: $x1: ty, + $(target: $target: tt,)? + }; + )* + ) => { + $( + #[allow(missing_copy_implementations)] + #[allow(non_camel_case_types)] + #[allow(dead_code)] + $(#[$attr])* + struct $fft { + __private: (), + } + #[allow(unused_variables)] + #[allow(dead_code)] + $(#[$attr])* + impl $fft { + $(#[target_feature(enable = $target)])? + unsafe fn fft_00(x: *mut c64, y: *mut c64, w: *const c64) {} + $(#[target_feature(enable = $target)])? + unsafe fn fft_01(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$x1>(FWD, 1 << 1, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_02(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$x1>(FWD, 1 << 2, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_03(x: *mut c64, y: *mut c64, w: *const c64) { + end_8::<$x1>(FWD, 1 << 3, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_04(x: *mut c64, y: *mut c64, w: *const c64) { + end16::<$x1>(FWD, 1 << 4, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_05(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 4, y, x, true); + $core1______(FWD, 1 << 5, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_06(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 2, 1 << 4, y, x, true); + $core1______(FWD, 1 << 6, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_07(x: *mut c64, y: *mut c64, w: *const c64) { + end_8::<$xn>(FWD, 1 << 3, 1 << 4, y, x, true); + $core1______(FWD, 1 << 7, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_08(x: *mut c64, y: *mut c64, w: *const c64) { + end16::<$xn>(FWD, 1 << 4, 1 << 4, y, x, true); + $core1______(FWD, 1 << 8, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_09(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 8, x, y, false); + core_::<$xn>(FWD, 1 << 5, 1 << 4, y, x, w); + $core1______(FWD, 1 << 9, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_10(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 02, 1 << 8, x, y, false); + core_::<$xn>(FWD, 1 << 06, 1 << 4, y, x, w); + $core1______(FWD, 1 << 10, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_11(x: *mut c64, y: *mut c64, w: *const c64) { + end_8::<$xn>(FWD, 1 << 03, 1 << 08, x, y, false); + core_::<$xn>(FWD, 1 << 07, 1 << 04, y, x, w); + $core1______(FWD, 1 << 11, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_12(x: *mut c64, y: *mut c64, w: *const c64) { + end16::<$xn>(FWD, 1 << 04, 1 << 08, x, y, false); + core_::<$xn>(FWD, 1 << 08, 1 << 04, y, x, w); + $core1______(FWD, 1 << 12, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_13(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 01, 1 << 12, y, x, true); + core_::<$xn>(FWD, 1 << 05, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 09, 1 << 04, y, x, w); + $core1______(FWD, 1 << 13, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_14(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 02, 1 << 12, y, x, true); + core_::<$xn>(FWD, 1 << 06, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 04, y, x, w); + $core1______(FWD, 1 << 14, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_15(x: *mut c64, y: *mut c64, w: *const c64) { + end_8::<$xn>(FWD, 1 << 03, 1 << 12, y, x, true); + core_::<$xn>(FWD, 1 << 07, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 04, y, x, w); + $core1______(FWD, 1 << 15, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_16(x: *mut c64, y: *mut c64, w: *const c64) { + end16::<$xn>(FWD, 1 << 04, 1 << 12, y, x, true); + core_::<$xn>(FWD, 1 << 08, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 12, 1 << 04, y, x, w); + $core1______(FWD, 1 << 16, 1 << 00, x, y, w); + } + } + $(#[$attr])* + pub(crate) static $fft: crate::FftImpl = crate::FftImpl { + fwd: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + inv: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + }; + )* + }; +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::x86::*; + +dit16_impl! { + pub static DIT16_SCALAR = Fft { + core_1: core_::, + native: Scalar, + x1: Scalar, + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIT16_AVX = Fft { + core_1: core_x2::, + native: AvxX2, + x1: AvxX1, + target: "avx", + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIT16_FMA = Fft { + core_1: core_x2::, + native: FmaX2, + x1: FmaX1, + target: "fma", + }; + + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + pub static DIT16_AVX512 = Fft { + core_1: core_x4::, + native: Avx512X4, + x1: Avx512X1, + target: "avx512f", + }; +} + +pub(crate) fn runtime_fft() -> crate::FftImpl { + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + if x86_feature_detected!("avx512f") { + return DIT16_AVX512; + } + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if x86_feature_detected!("fma") { + return DIT16_FMA; + } else if x86_feature_detected!("avx") { + return DIT16_AVX; + } + + DIT16_SCALAR +} diff --git a/src/dit2.rs b/src/dit2.rs new file mode 100644 index 0000000..d23f291 --- /dev/null +++ b/src/dit2.rs @@ -0,0 +1,418 @@ +// Copyright (c) 2019 OK Ojisan(Takuya OKAHISA) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +use crate::c64; +use crate::fft_simd::{twid, twid_t, FftSimd64, FftSimd64X2, Scalar}; +use crate::x86_feature_detected; + +#[inline(always)] +unsafe fn core_( + _fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let m = n / 2; + let big_n = n * s; + let big_n0 = 0; + let big_n1 = big_n / 2; + + for p in 0..m { + let sp = s * p; + let s2p = 2 * sp; + let w1p = I::splat(twid_t(2, big_n, 1, w, sp)); + + let mut q = 0; + while q < s { + let xq_sp = x.add(q + sp); + let yq_s2p = y.add(q + s2p); + + let a = I::load(yq_s2p.add(s * 0)); + let b = I::mul(w1p, I::load(yq_s2p.add(s * 1))); + + I::store(xq_sp.add(big_n0), I::add(a, b)); + I::store(xq_sp.add(big_n1), I::sub(a, b)); + + q += I::COMPLEX_PER_REG; + } + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x2( + _fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 2; + + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_2p = y.add(2 * p); + + let w1p = I::load(twid(2, big_n, 1, w, p)); + + let ab0 = I::load(y_2p.add(0)); + let ab1 = I::load(y_2p.add(2)); + + let a = I::catlo(ab0, ab1); + let b = I::mul(w1p, I::cathi(ab0, ab1)); + + I::store(x_p.add(big_n0), I::add(a, b)); + I::store(x_p.add(big_n1), I::sub(a, b)); + + p += 2; + } +} + +#[inline(always)] +pub unsafe fn end_2( + _fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + eo: bool, +) { + debug_assert_eq!(n, 2); + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + let z = if eo { y } else { x }; + + let mut q = 0; + while q < s { + let xq = x.add(q); + let zq = z.add(q); + + let a = I::load(zq.add(0)); + let b = I::load(zq.add(s)); + + I::store(xq.add(0), I::add(a, b)); + I::store(xq.add(s), I::sub(a, b)); + + q += I::COMPLEX_PER_REG; + } +} + +macro_rules! dit2_impl { + ( + $( + $(#[$attr: meta])* + pub static $fft: ident = Fft { + core_1: $core1______: expr, + native: $xn: ty, + x1: $x1: ty, + $(target: $target: tt,)? + }; + )* + ) => { + $( + #[allow(missing_copy_implementations)] + #[allow(non_camel_case_types)] + #[allow(dead_code)] + $(#[$attr])* + struct $fft { + __private: (), + } + #[allow(unused_variables)] + #[allow(dead_code)] + $(#[$attr])* + impl $fft { + $(#[target_feature(enable = $target)])? + unsafe fn fft_00(x: *mut c64, y: *mut c64, w: *const c64) {} + $(#[target_feature(enable = $target)])? + unsafe fn fft_01(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$x1>(FWD, 1 << 1, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_02(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 1, y, x, true); + $core1______(FWD, 1 << 2, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_03(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 2, x, y, false); + core_::<$xn>(FWD, 1 << 2, 1 << 1, y, x, w); + $core1______(FWD, 1 << 3, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_04(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 3, y, x, true); + core_::<$xn>(FWD, 1 << 2, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 1, y, x, w); + $core1______(FWD, 1 << 4, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_05(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 4, x, y, false); + core_::<$xn>(FWD, 1 << 2, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 1, y, x, w); + $core1______(FWD, 1 << 5, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_06(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 5, y, x, true); + core_::<$xn>(FWD, 1 << 2, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 1, y, x, w); + $core1______(FWD, 1 << 6, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_07(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 6, x, y, false); + core_::<$xn>(FWD, 1 << 2, 1 << 5, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 1, y, x, w); + $core1______(FWD, 1 << 7, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_08(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 7, y, x, true); + core_::<$xn>(FWD, 1 << 2, 1 << 6, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 5, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 6, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 1, y, x, w); + $core1______(FWD, 1 << 8, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_09(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 8, x, y, false); + core_::<$xn>(FWD, 1 << 2, 1 << 7, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 6, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 5, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 7, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 8, 1 << 1, y, x, w); + $core1______(FWD, 1 << 9, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_10(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 9, y, x, true); + core_::<$xn>(FWD, 1 << 2, 1 << 8, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 7, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 6, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 5, y, x, w); + core_::<$xn>(FWD, 1 << 6, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 3, y, x, w); + core_::<$xn>(FWD, 1 << 8, 1 << 2, x, y, w); + core_::<$xn>(FWD, 1 << 9, 1 << 1, y, x, w); + $core1______(FWD, 1 << 10, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_11(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 10, x, y, false); + core_::<$xn>(FWD, 1 << 2, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 7, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 8, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 9, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 01, y, x, w); + $core1______(FWD, 1 << 11, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_12(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 11, y, x, true); + core_::<$xn>(FWD, 1 << 2, 1 << 10, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 6, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 8, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 9, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 10, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 01, y, x, w); + $core1______(FWD, 1 << 12, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_13(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 12, x, y, false); + core_::<$xn>(FWD, 1 << 2, 1 << 11, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 10, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 7, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 8, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 9, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 11, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 12, 1 << 01, y, x, w); + $core1______(FWD, 1 << 13, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_14(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 13, y, x, true); + core_::<$xn>(FWD, 1 << 2, 1 << 12, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 11, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 10, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 6, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 8, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 9, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 10, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 12, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 13, 1 << 01, y, x, w); + $core1______(FWD, 1 << 14, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_15(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 14, x, y, false); + core_::<$xn>(FWD, 1 << 2, 1 << 13, y, x, w); + core_::<$xn>(FWD, 1 << 3, 1 << 12, x, y, w); + core_::<$xn>(FWD, 1 << 4, 1 << 11, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 10, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 7, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 8, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 9, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 11, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 12, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 13, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 14, 1 << 01, y, x, w); + $core1______(FWD, 1 << 15, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_16(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 15, y, x, true); + core_::<$xn>(FWD, 1 << 2, 1 << 14, x, y, w); + core_::<$xn>(FWD, 1 << 3, 1 << 13, y, x, w); + core_::<$xn>(FWD, 1 << 4, 1 << 12, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 11, y, x, w); + core_::<$xn>(FWD, 1 << 6, 1 << 10, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 8, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 9, 1 << 07, y, x, w); + core_::<$xn>(FWD, 1 << 10, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 05, y, x, w); + core_::<$xn>(FWD, 1 << 12, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 13, 1 << 03, y, x, w); + core_::<$xn>(FWD, 1 << 14, 1 << 02, x, y, w); + core_::<$xn>(FWD, 1 << 15, 1 << 01, y, x, w); + $core1______(FWD, 1 << 16, 1 << 00, x, y, w); + } + } + $(#[$attr])* + pub(crate) static $fft: crate::FftImpl = crate::FftImpl { + fwd: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + inv: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + }; + )* + }; +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::x86::*; + +dit2_impl! { + pub static DIT2_SCALAR = Fft { + core_1: core_::, + native: Scalar, + x1: Scalar, + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIT2_AVX = Fft { + core_1: core_x2::, + native: AvxX2, + x1: AvxX1, + target: "avx", + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIT2_FMA = Fft { + core_1: core_x2::, + native: FmaX2, + x1: FmaX1, + target: "fma", + }; +} + +pub(crate) fn runtime_fft() -> crate::FftImpl { + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if x86_feature_detected!("fma") { + return DIT2_FMA; + } else if x86_feature_detected!("avx") { + return DIT2_AVX; + } + + DIT2_SCALAR +} diff --git a/src/dit4.rs b/src/dit4.rs new file mode 100644 index 0000000..3f6a2dc --- /dev/null +++ b/src/dit4.rs @@ -0,0 +1,491 @@ +// Copyright (c) 2019 OK Ojisan(Takuya OKAHISA) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +use crate::c64; +use crate::fft_simd::{twid, twid_t, FftSimd64, FftSimd64Ext, FftSimd64X2, FftSimd64X4, Scalar}; +use crate::x86_feature_detected; + +#[inline(always)] +unsafe fn core_( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let m = n / 4; + let big_n = n * s; + let big_n0 = 0; + let big_n1 = big_n / 4; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + + for p in 0..m { + let sp = s * p; + let s4p = 4 * sp; + let w1p = I::splat(twid_t(4, big_n, 1, w, sp)); + let w2p = I::splat(twid_t(4, big_n, 2, w, sp)); + let w3p = I::splat(twid_t(4, big_n, 3, w, sp)); + + let mut q = 0; + while q < s { + let xq_sp = x.add(q + sp); + let yq_s4p = y.add(q + s4p); + + let a = I::load(yq_s4p.add(0)); + let b = I::mul(w1p, I::load(yq_s4p.add(s))); + let c = I::mul(w2p, I::load(yq_s4p.add(s * 2))); + let d = I::mul(w3p, I::load(yq_s4p.add(s * 3))); + + let apc = I::add(a, c); + let amc = I::sub(a, c); + + let bpd = I::add(b, d); + let jbmd = I::xpj(fwd, I::sub(b, d)); + + I::store(xq_sp.add(big_n0), I::add(apc, bpd)); + I::store(xq_sp.add(big_n1), I::sub(amc, jbmd)); + I::store(xq_sp.add(big_n2), I::sub(apc, bpd)); + I::store(xq_sp.add(big_n3), I::add(amc, jbmd)); + + q += I::COMPLEX_PER_REG; + } + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x2( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 4; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + + debug_assert_eq!(big_n1 % 2, 0); + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_4p = y.add(4 * p); + + let w1p = I::load(twid(4, big_n, 1, w, p)); + let w2p = I::load(twid(4, big_n, 2, w, p)); + let w3p = I::load(twid(4, big_n, 3, w, p)); + + let ab0 = I::load(y_4p.add(0)); + let cd0 = I::load(y_4p.add(2)); + let ab1 = I::load(y_4p.add(4)); + let cd1 = I::load(y_4p.add(6)); + + let a = I::catlo(ab0, ab1); + let b = I::mul(w1p, I::cathi(ab0, ab1)); + let c = I::mul(w2p, I::catlo(cd0, cd1)); + let d = I::mul(w3p, I::cathi(cd0, cd1)); + + let apc = I::add(a, c); + let amc = I::sub(a, c); + let bpd = I::add(b, d); + let jbmd = I::xpj(fwd, I::sub(b, d)); + + I::store(x_p.add(big_n0), I::add(apc, bpd)); + I::store(x_p.add(big_n1), I::sub(amc, jbmd)); + I::store(x_p.add(big_n2), I::sub(apc, bpd)); + I::store(x_p.add(big_n3), I::add(amc, jbmd)); + + p += 2; + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x4( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + if n == 8 { + return core_x2::(fwd, n, s, x, y, w); + } + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 4; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + + debug_assert_eq!(big_n1 % 4, 0); + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_4p = y.add(4 * p); + + let w1p = I4::load(twid(4, big_n, 1, w, p)); + let w2p = I4::load(twid(4, big_n, 2, w, p)); + let w3p = I4::load(twid(4, big_n, 3, w, p)); + + let abcd0 = I4::load(y_4p.add(0)); + let abcd1 = I4::load(y_4p.add(4)); + let abcd2 = I4::load(y_4p.add(8)); + let abcd3 = I4::load(y_4p.add(12)); + + let (a, b, c, d) = I4::transpose(abcd0, abcd1, abcd2, abcd3); + + let a = a; + let b = I4::mul(w1p, b); + let c = I4::mul(w2p, c); + let d = I4::mul(w3p, d); + + let apc = I4::add(a, c); + let amc = I4::sub(a, c); + let bpd = I4::add(b, d); + let jbmd = I4::xpj(fwd, I4::sub(b, d)); + + I4::store(x_p.add(big_n0), I4::add(apc, bpd)); + I4::store(x_p.add(big_n1), I4::sub(amc, jbmd)); + I4::store(x_p.add(big_n2), I4::sub(apc, bpd)); + I4::store(x_p.add(big_n3), I4::add(amc, jbmd)); + + p += 4; + } +} + +#[inline(always)] +pub unsafe fn end_4( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + eo: bool, +) { + debug_assert_eq!(n, 4); + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + let z = if eo { y } else { x }; + + let mut q = 0; + while q < s { + let xq = x.add(q); + let zq = z.add(q); + + let a = I::load(zq.add(0)); + let b = I::load(zq.add(s)); + let c = I::load(zq.add(s * 2)); + let d = I::load(zq.add(s * 3)); + + let apc = I::add(a, c); + let amc = I::sub(a, c); + let bpd = I::add(b, d); + let jbmd = I::xpj(fwd, I::sub(b, d)); + + I::store(xq.add(0), I::add(apc, bpd)); + I::store(xq.add(s), I::sub(amc, jbmd)); + I::store(xq.add(s * 2), I::sub(apc, bpd)); + I::store(xq.add(s * 3), I::add(amc, jbmd)); + + q += I::COMPLEX_PER_REG; + } +} + +#[inline(always)] +pub unsafe fn end_2( + _fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + eo: bool, +) { + debug_assert_eq!(n, 2); + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + let z = if eo { y } else { x }; + + let mut q = 0; + while q < s { + let xq = x.add(q); + let zq = z.add(q); + + let a = I::load(zq.add(0)); + let b = I::load(zq.add(s)); + + I::store(xq.add(0), I::add(a, b)); + I::store(xq.add(s), I::sub(a, b)); + + q += I::COMPLEX_PER_REG; + } +} + +macro_rules! dit4_impl { + ( + $( + $(#[$attr: meta])* + pub static $fft: ident = Fft { + core_1: $core1______: expr, + native: $xn: ty, + x1: $x1: ty, + $(target: $target: tt,)? + }; + )* + ) => { + $( + #[allow(missing_copy_implementations)] + #[allow(non_camel_case_types)] + #[allow(dead_code)] + $(#[$attr])* + struct $fft { + __private: (), + } + #[allow(unused_variables)] + #[allow(dead_code)] + $(#[$attr])* + impl $fft { + $(#[target_feature(enable = $target)])? + unsafe fn fft_00(x: *mut c64, y: *mut c64, w: *const c64) {} + $(#[target_feature(enable = $target)])? + unsafe fn fft_01(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$x1>(FWD, 1 << 1, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_02(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$x1>(FWD, 1 << 2, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_03(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 2, y, x, true); + $core1______(FWD, 1 << 3, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_04(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 2, 1 << 2, y, x, true); + $core1______(FWD, 1 << 4, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_05(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 4, x, y, false); + core_::<$xn>(FWD, 1 << 3, 1 << 2, y, x, w); + $core1______(FWD, 1 << 5, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_06(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 2, 1 << 4, x, y, false); + core_::<$xn>(FWD, 1 << 4, 1 << 2, y, x, w); + $core1______(FWD, 1 << 6, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_07(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 6, y, x, true); + core_::<$xn>(FWD, 1 << 3, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 5, 1 << 2, y, x, w); + $core1______(FWD, 1 << 7, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_08(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 2, 1 << 6, y, x, true); + core_::<$xn>(FWD, 1 << 4, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 6, 1 << 2, y, x, w); + $core1______(FWD, 1 << 8, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_09(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 8, x, y, false); + core_::<$xn>(FWD, 1 << 3, 1 << 6, y, x, w); + core_::<$xn>(FWD, 1 << 5, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 7, 1 << 2, y, x, w); + $core1______(FWD, 1 << 9, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_10(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 02, 1 << 8, x, y, false); + core_::<$xn>(FWD, 1 << 04, 1 << 6, y, x, w); + core_::<$xn>(FWD, 1 << 06, 1 << 4, x, y, w); + core_::<$xn>(FWD, 1 << 08, 1 << 2, y, x, w); + $core1______(FWD, 1 << 10, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_11(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 01, 1 << 10, y, x, true); + core_::<$xn>(FWD, 1 << 03, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 05, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 07, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 09, 1 << 02, y, x, w); + $core1______(FWD, 1 << 11, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_12(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 02, 1 << 10, y, x, true); + core_::<$xn>(FWD, 1 << 04, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 06, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 08, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 02, y, x, w); + $core1______(FWD, 1 << 12, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_13(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 01, 1 << 12, x, y, false); + core_::<$xn>(FWD, 1 << 03, 1 << 10, y, x, w); + core_::<$xn>(FWD, 1 << 05, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 07, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 09, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 02, y, x, w); + $core1______(FWD, 1 << 13, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_14(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 02, 1 << 12, x, y, false); + core_::<$xn>(FWD, 1 << 04, 1 << 10, y, x, w); + core_::<$xn>(FWD, 1 << 06, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 08, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 10, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 12, 1 << 02, y, x, w); + $core1______(FWD, 1 << 14, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_15(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 01, 1 << 14, y, x, true); + core_::<$xn>(FWD, 1 << 03, 1 << 12, x, y, w); + core_::<$xn>(FWD, 1 << 05, 1 << 10, y, x, w); + core_::<$xn>(FWD, 1 << 07, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 09, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 11, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 13, 1 << 02, y, x, w); + $core1______(FWD, 1 << 15, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_16(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 02, 1 << 14, y, x, true); + core_::<$xn>(FWD, 1 << 04, 1 << 12, x, y, w); + core_::<$xn>(FWD, 1 << 06, 1 << 10, y, x, w); + core_::<$xn>(FWD, 1 << 08, 1 << 08, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 06, y, x, w); + core_::<$xn>(FWD, 1 << 12, 1 << 04, x, y, w); + core_::<$xn>(FWD, 1 << 14, 1 << 02, y, x, w); + $core1______(FWD, 1 << 16, 1 << 00, x, y, w); + } + } + $(#[$attr])* + pub(crate) static $fft: crate::FftImpl = crate::FftImpl { + fwd: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + inv: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + }; + )* + }; +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::x86::*; + +dit4_impl! { + pub static DIT4_SCALAR = Fft { + core_1: core_::, + native: Scalar, + x1: Scalar, + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIT4_AVX = Fft { + core_1: core_x2::, + native: AvxX2, + x1: AvxX1, + target: "avx", + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIT4_FMA = Fft { + core_1: core_x2::, + native: FmaX2, + x1: FmaX1, + target: "fma", + }; + + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + pub static DIT4_AVX512 = Fft { + core_1: core_x4::, + native: Avx512X4, + x1: Avx512X1, + target: "avx512f", + }; +} + +pub(crate) fn runtime_fft() -> crate::FftImpl { + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + if x86_feature_detected!("avx512f") { + return DIT4_AVX512; + } + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if x86_feature_detected!("fma") { + return DIT4_FMA; + } else if x86_feature_detected!("avx") { + return DIT4_AVX; + } + + DIT4_SCALAR +} diff --git a/src/dit8.rs b/src/dit8.rs new file mode 100644 index 0000000..11b2a6a --- /dev/null +++ b/src/dit8.rs @@ -0,0 +1,560 @@ +// Copyright (c) 2019 OK Ojisan(Takuya OKAHISA) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +use crate::c64; +use crate::dit4::{end_2, end_4}; +use crate::fft_simd::{twid, twid_t, FftSimd64, FftSimd64Ext, FftSimd64X2, FftSimd64X4, Scalar}; +use crate::x86_feature_detected; + +#[inline(always)] +#[rustfmt::skip] +unsafe fn core_( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let m = n / 8; + let big_n = n * s; + let big_n0 = 0; + let big_n1 = big_n / 8; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + + for p in 0..m { + let sp = s * p; + let s8p = 8 * sp; + let w1p = I::splat(twid_t(8, big_n, 1, w, sp)); + let w2p = I::splat(twid_t(8, big_n, 2, w, sp)); + let w3p = I::splat(twid_t(8, big_n, 3, w, sp)); + let w4p = I::splat(twid_t(8, big_n, 4, w, sp)); + let w5p = I::splat(twid_t(8, big_n, 5, w, sp)); + let w6p = I::splat(twid_t(8, big_n, 6, w, sp)); + let w7p = I::splat(twid_t(8, big_n, 7, w, sp)); + + let mut q = 0; + while q < s { + let xq_sp = x.add(q + sp); + let yq_s8p = y.add(q + s8p); + + let y0 = I::load(yq_s8p.add(0)); + let y1 = I::mul(w1p, I::load(yq_s8p.add(s))); + let y2 = I::mul(w2p, I::load(yq_s8p.add(s * 2))); + let y3 = I::mul(w3p, I::load(yq_s8p.add(s * 3))); + let y4 = I::mul(w4p, I::load(yq_s8p.add(s * 4))); + let y5 = I::mul(w5p, I::load(yq_s8p.add(s * 5))); + let y6 = I::mul(w6p, I::load(yq_s8p.add(s * 6))); + let y7 = I::mul(w7p, I::load(yq_s8p.add(s * 7))); + let a04 = I::add(y0, y4); + let s04 = I::sub(y0, y4); + let a26 = I::add(y2, y6); + let js26 = I::xpj(fwd, I::sub(y2, y6)); + let a15 = I::add(y1, y5); + let s15 = I::sub(y1, y5); + let a37 = I::add(y3, y7); + let js37 = I::xpj(fwd, I::sub(y3, y7)); + + let a04_p1_a26 = I::add(a04, a26); + let a15_p1_a37 = I::add(a15, a37); + I::store(xq_sp.add(big_n0), I::add(a04_p1_a26, a15_p1_a37)); + I::store(xq_sp.add(big_n4), I::sub(a04_p1_a26, a15_p1_a37)); + + let s04_mj_s26 = I::sub(s04, js26); + let w8_s15_mj_s37 = I::xw8(fwd, I::sub(s15, js37)); + I::store(xq_sp.add(big_n1), I::add(s04_mj_s26, w8_s15_mj_s37)); + I::store(xq_sp.add(big_n5), I::sub(s04_mj_s26, w8_s15_mj_s37)); + + let a04_m1_a26 = I::sub(a04, a26); + let j_a15_m1_a37 = I::xpj(fwd, I::sub(a15, a37)); + I::store(xq_sp.add(big_n2), I::sub(a04_m1_a26, j_a15_m1_a37)); + I::store(xq_sp.add(big_n6), I::add(a04_m1_a26, j_a15_m1_a37)); + + let s04_pj_s26 = I::add(s04, js26); + let v8_s15_pj_s37 = I::xv8(fwd, I::add(s15, js37)); + I::store(xq_sp.add(big_n3), I::sub(s04_pj_s26, v8_s15_pj_s37)); + I::store(xq_sp.add(big_n7), I::add(s04_pj_s26, v8_s15_pj_s37)); + + q += I::COMPLEX_PER_REG; + } + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x2( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + debug_assert_eq!(I::COMPLEX_PER_REG, 2); + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 8; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_8p = y.add(8 * p); + + let w1p = I::load(twid(8, big_n, 1, w, p)); + let w2p = I::load(twid(8, big_n, 2, w, p)); + let w3p = I::load(twid(8, big_n, 3, w, p)); + let w4p = I::load(twid(8, big_n, 4, w, p)); + let w5p = I::load(twid(8, big_n, 5, w, p)); + let w6p = I::load(twid(8, big_n, 6, w, p)); + let w7p = I::load(twid(8, big_n, 7, w, p)); + let ab_0 = I::load(y_8p.add(0)); + let cd_0 = I::load(y_8p.add(2)); + let ef_0 = I::load(y_8p.add(4)); + let gh_0 = I::load(y_8p.add(6)); + let ab_1 = I::load(y_8p.add(8)); + let cd_1 = I::load(y_8p.add(10)); + let ef_1 = I::load(y_8p.add(12)); + let gh_1 = I::load(y_8p.add(14)); + let y0 = I::catlo(ab_0, ab_1); + let y1 = I::mul(w1p, I::cathi(ab_0, ab_1)); + let y2 = I::mul(w2p, I::catlo(cd_0, cd_1)); + let y3 = I::mul(w3p, I::cathi(cd_0, cd_1)); + let y4 = I::mul(w4p, I::catlo(ef_0, ef_1)); + let y5 = I::mul(w5p, I::cathi(ef_0, ef_1)); + let y6 = I::mul(w6p, I::catlo(gh_0, gh_1)); + let y7 = I::mul(w7p, I::cathi(gh_0, gh_1)); + + let a04 = I::add(y0, y4); + let s04 = I::sub(y0, y4); + let a26 = I::add(y2, y6); + let js26 = I::xpj(fwd, I::sub(y2, y6)); + let a15 = I::add(y1, y5); + let s15 = I::sub(y1, y5); + let a37 = I::add(y3, y7); + let js37 = I::xpj(fwd, I::sub(y3, y7)); + + let a04_p1_a26 = I::add(a04, a26); + let a15_p1_a37 = I::add(a15, a37); + I::store(x_p.add(big_n0), I::add(a04_p1_a26, a15_p1_a37)); + I::store(x_p.add(big_n4), I::sub(a04_p1_a26, a15_p1_a37)); + + let s04_mj_s26 = I::sub(s04, js26); + let w8_s15_mj_s37 = I::xw8(fwd, I::sub(s15, js37)); + I::store(x_p.add(big_n1), I::add(s04_mj_s26, w8_s15_mj_s37)); + I::store(x_p.add(big_n5), I::sub(s04_mj_s26, w8_s15_mj_s37)); + + let a04_m1_a26 = I::sub(a04, a26); + let j_a15_m1_a37 = I::xpj(fwd, I::sub(a15, a37)); + I::store(x_p.add(big_n2), I::sub(a04_m1_a26, j_a15_m1_a37)); + I::store(x_p.add(big_n6), I::add(a04_m1_a26, j_a15_m1_a37)); + + let s04_pj_s26 = I::add(s04, js26); + let v8_s15_pj_s37 = I::xv8(fwd, I::add(s15, js37)); + I::store(x_p.add(big_n3), I::sub(s04_pj_s26, v8_s15_pj_s37)); + I::store(x_p.add(big_n7), I::add(s04_pj_s26, v8_s15_pj_s37)); + + p += 2; + } +} + +#[inline(always)] +#[allow(dead_code)] +unsafe fn core_x4( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + w: *const c64, +) { + debug_assert_eq!(s, 1); + if n == 16 { + return core_x2::(fwd, n, s, x, y, w); + } + + let big_n = n; + let big_n0 = 0; + let big_n1 = big_n / 8; + let big_n2 = big_n1 * 2; + let big_n3 = big_n1 * 3; + let big_n4 = big_n1 * 4; + let big_n5 = big_n1 * 5; + let big_n6 = big_n1 * 6; + let big_n7 = big_n1 * 7; + + let mut p = 0; + while p < big_n1 { + let x_p = x.add(p); + let y_8p = y.add(8 * p); + + let w1p = I4::load(twid(8, big_n, 1, w, p)); + let w2p = I4::load(twid(8, big_n, 2, w, p)); + let w3p = I4::load(twid(8, big_n, 3, w, p)); + let w4p = I4::load(twid(8, big_n, 4, w, p)); + let w5p = I4::load(twid(8, big_n, 5, w, p)); + let w6p = I4::load(twid(8, big_n, 6, w, p)); + let w7p = I4::load(twid(8, big_n, 7, w, p)); + + let abcd_0 = I4::load(y_8p.add(0)); + let efgh_0 = I4::load(y_8p.add(4)); + let abcd_1 = I4::load(y_8p.add(8)); + let efgh_1 = I4::load(y_8p.add(12)); + let abcd_2 = I4::load(y_8p.add(16)); + let efgh_2 = I4::load(y_8p.add(20)); + let abcd_3 = I4::load(y_8p.add(24)); + let efgh_3 = I4::load(y_8p.add(28)); + + let (a, b, c, d) = I4::transpose(abcd_0, abcd_1, abcd_2, abcd_3); + let (e, f, g, h) = I4::transpose(efgh_0, efgh_1, efgh_2, efgh_3); + + let y0 = a; + let y1 = I4::mul(w1p, b); + let y2 = I4::mul(w2p, c); + let y3 = I4::mul(w3p, d); + let y4 = I4::mul(w4p, e); + let y5 = I4::mul(w5p, f); + let y6 = I4::mul(w6p, g); + let y7 = I4::mul(w7p, h); + + let a04 = I4::add(y0, y4); + let s04 = I4::sub(y0, y4); + let a26 = I4::add(y2, y6); + let js26 = I4::xpj(fwd, I4::sub(y2, y6)); + let a15 = I4::add(y1, y5); + let s15 = I4::sub(y1, y5); + let a37 = I4::add(y3, y7); + let js37 = I4::xpj(fwd, I4::sub(y3, y7)); + + let a04_p1_a26 = I4::add(a04, a26); + let a15_p1_a37 = I4::add(a15, a37); + I4::store(x_p.add(big_n0), I4::add(a04_p1_a26, a15_p1_a37)); + I4::store(x_p.add(big_n4), I4::sub(a04_p1_a26, a15_p1_a37)); + + let s04_mj_s26 = I4::sub(s04, js26); + let w8_s15_mj_s37 = I4::xw8(fwd, I4::sub(s15, js37)); + I4::store(x_p.add(big_n1), I4::add(s04_mj_s26, w8_s15_mj_s37)); + I4::store(x_p.add(big_n5), I4::sub(s04_mj_s26, w8_s15_mj_s37)); + + let a04_m1_a26 = I4::sub(a04, a26); + let j_a15_m1_a37 = I4::xpj(fwd, I4::sub(a15, a37)); + I4::store(x_p.add(big_n2), I4::sub(a04_m1_a26, j_a15_m1_a37)); + I4::store(x_p.add(big_n6), I4::add(a04_m1_a26, j_a15_m1_a37)); + + let s04_pj_s26 = I4::add(s04, js26); + let v8_s15_pj_s37 = I4::xv8(fwd, I4::add(s15, js37)); + I4::store(x_p.add(big_n3), I4::sub(s04_pj_s26, v8_s15_pj_s37)); + I4::store(x_p.add(big_n7), I4::add(s04_pj_s26, v8_s15_pj_s37)); + + p += 4; + } +} + +#[inline(always)] +pub(crate) unsafe fn end_8( + fwd: bool, + n: usize, + s: usize, + x: *mut c64, + y: *mut c64, + eo: bool, +) { + debug_assert_eq!(n, 8); + debug_assert_eq!(s % I::COMPLEX_PER_REG, 0); + + let z = if eo { y } else { x }; + + let mut q = 0; + while q < s { + let xq = x.add(q); + let zq = z.add(q); + + let z0 = I::load(zq.add(0)); + let z1 = I::load(zq.add(s)); + let z2 = I::load(zq.add(s * 2)); + let z3 = I::load(zq.add(s * 3)); + let z4 = I::load(zq.add(s * 4)); + let z5 = I::load(zq.add(s * 5)); + let z6 = I::load(zq.add(s * 6)); + let z7 = I::load(zq.add(s * 7)); + let a04 = I::add(z0, z4); + let s04 = I::sub(z0, z4); + let a26 = I::add(z2, z6); + let js26 = I::xpj(fwd, I::sub(z2, z6)); + let a15 = I::add(z1, z5); + let s15 = I::sub(z1, z5); + let a37 = I::add(z3, z7); + let js37 = I::xpj(fwd, I::sub(z3, z7)); + let a04_p1_a26 = I::add(a04, a26); + let s04_mj_s26 = I::sub(s04, js26); + let a04_m1_a26 = I::sub(a04, a26); + let s04_pj_s26 = I::add(s04, js26); + let a15_p1_a37 = I::add(a15, a37); + let w8_s15_mj_s37 = I::xw8(fwd, I::sub(s15, js37)); + let j_a15_m1_a37 = I::xpj(fwd, I::sub(a15, a37)); + let v8_s15_pj_s37 = I::xv8(fwd, I::add(s15, js37)); + I::store(xq.add(0), I::add(a04_p1_a26, a15_p1_a37)); + I::store(xq.add(s), I::add(s04_mj_s26, w8_s15_mj_s37)); + I::store(xq.add(s * 2), I::sub(a04_m1_a26, j_a15_m1_a37)); + I::store(xq.add(s * 3), I::sub(s04_pj_s26, v8_s15_pj_s37)); + I::store(xq.add(s * 4), I::sub(a04_p1_a26, a15_p1_a37)); + I::store(xq.add(s * 5), I::sub(s04_mj_s26, w8_s15_mj_s37)); + I::store(xq.add(s * 6), I::add(a04_m1_a26, j_a15_m1_a37)); + I::store(xq.add(s * 7), I::add(s04_pj_s26, v8_s15_pj_s37)); + + q += I::COMPLEX_PER_REG; + } +} + +macro_rules! dit8_impl { + ( + $( + $(#[$attr: meta])* + pub static $fft: ident = Fft { + core_1: $core1______: expr, + native: $xn: ty, + x1: $x1: ty, + $(target: $target: tt,)? + }; + )* + ) => { + $( + #[allow(missing_copy_implementations)] + #[allow(non_camel_case_types)] + #[allow(dead_code)] + $(#[$attr])* + struct $fft { + __private: (), + } + #[allow(unused_variables)] + #[allow(dead_code)] + $(#[$attr])* + impl $fft { + $(#[target_feature(enable = $target)])? + unsafe fn fft_00(x: *mut c64, y: *mut c64, w: *const c64) {} + $(#[target_feature(enable = $target)])? + unsafe fn fft_01(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$x1>(FWD, 1 << 1, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_02(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$x1>(FWD, 1 << 2, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_03(x: *mut c64, y: *mut c64, w: *const c64) { + end_8::<$x1>(FWD, 1 << 3, 1 << 0, x, y, false); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_04(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 3, y, x, true); + $core1______(FWD, 1 << 4, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_05(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 2, 1 << 3, y, x, true); + $core1______(FWD, 1 << 5, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_06(x: *mut c64, y: *mut c64, w: *const c64) { + end_8::<$xn>(FWD, 1 << 3, 1 << 3, y, x, true); + $core1______(FWD, 1 << 6, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_07(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 1, 1 << 6, x, y, false); + core_::<$xn>(FWD, 1 << 4, 1 << 3, y, x, w); + $core1______(FWD, 1 << 7, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_08(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 2, 1 << 6, x, y, false); + core_::<$xn>(FWD, 1 << 5, 1 << 3, y, x, w); + $core1______(FWD, 1 << 8, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_09(x: *mut c64, y: *mut c64, w: *const c64) { + end_8::<$xn>(FWD, 1 << 3, 1 << 6, x, y, false); + core_::<$xn>(FWD, 1 << 6, 1 << 3, y, x, w); + $core1______(FWD, 1 << 9, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_10(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 01, 1 << 9, y, x, true); + core_::<$xn>(FWD, 1 << 04, 1 << 6, x, y, w); + core_::<$xn>(FWD, 1 << 07, 1 << 3, y, x, w); + $core1______(FWD, 1 << 10, 1 << 0, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_11(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 02, 1 << 09, y, x, true); + core_::<$xn>(FWD, 1 << 05, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 08, 1 << 03, y, x, w); + $core1______(FWD, 1 << 11, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_12(x: *mut c64, y: *mut c64, w: *const c64) { + end_8::<$xn>(FWD, 1 << 03, 1 << 09, y, x, true); + core_::<$xn>(FWD, 1 << 06, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 09, 1 << 03, y, x, w); + $core1______(FWD, 1 << 12, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_13(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 01, 1 << 12, x, y, false); + core_::<$xn>(FWD, 1 << 04, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 07, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 10, 1 << 03, y, x, w); + $core1______(FWD, 1 << 13, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_14(x: *mut c64, y: *mut c64, w: *const c64) { + end_4::<$xn>(FWD, 1 << 02, 1 << 12, x, y, false); + core_::<$xn>(FWD, 1 << 05, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 08, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 11, 1 << 03, y, x, w); + $core1______(FWD, 1 << 14, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_15(x: *mut c64, y: *mut c64, w: *const c64) { + end_8::<$xn>(FWD, 1 << 03, 1 << 12, x, y, false); + core_::<$xn>(FWD, 1 << 06, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 09, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 12, 1 << 03, y, x, w); + $core1______(FWD, 1 << 15, 1 << 00, x, y, w); + } + $(#[target_feature(enable = $target)])? + unsafe fn fft_16(x: *mut c64, y: *mut c64, w: *const c64) { + end_2::<$xn>(FWD, 1 << 01, 1 << 15, y, x, true); + core_::<$xn>(FWD, 1 << 04, 1 << 12, x, y, w); + core_::<$xn>(FWD, 1 << 07, 1 << 09, y, x, w); + core_::<$xn>(FWD, 1 << 10, 1 << 06, x, y, w); + core_::<$xn>(FWD, 1 << 13, 1 << 03, y, x, w); + $core1______(FWD, 1 << 16, 1 << 00, x, y, w); + } + } + $(#[$attr])* + pub(crate) static $fft: crate::FftImpl = crate::FftImpl { + fwd: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + inv: [ + <$fft>::fft_00::, + <$fft>::fft_01::, + <$fft>::fft_02::, + <$fft>::fft_03::, + <$fft>::fft_04::, + <$fft>::fft_05::, + <$fft>::fft_06::, + <$fft>::fft_07::, + <$fft>::fft_08::, + <$fft>::fft_09::, + <$fft>::fft_10::, + <$fft>::fft_11::, + <$fft>::fft_12::, + <$fft>::fft_13::, + <$fft>::fft_14::, + <$fft>::fft_15::, + <$fft>::fft_16::, + ], + }; + )* + }; +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::x86::*; + +dit8_impl! { + pub static DIT8_SCALAR = Fft { + core_1: core_::, + native: Scalar, + x1: Scalar, + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIT8_AVX = Fft { + core_1: core_x2::, + native: AvxX2, + x1: AvxX1, + target: "avx", + }; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + pub static DIT8_FMA = Fft { + core_1: core_x2::, + native: FmaX2, + x1: FmaX1, + target: "fma", + }; + + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + pub static DIT8_AVX512 = Fft { + core_1: core_x4::, + native: Avx512X4, + x1: Avx512X1, + target: "avx512f", + }; +} + +pub(crate) fn runtime_fft() -> crate::FftImpl { + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + if x86_feature_detected!("avx512f") { + return DIT8_AVX512; + } + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if x86_feature_detected!("fma") { + return DIT8_FMA; + } else if x86_feature_detected!("avx") { + return DIT8_AVX; + } + + DIT8_SCALAR +} diff --git a/src/fft_simd.rs b/src/fft_simd.rs new file mode 100644 index 0000000..3d9ffdc --- /dev/null +++ b/src/fft_simd.rs @@ -0,0 +1,263 @@ +use crate::c64; +use core::fmt::Debug; +use core::mem::transmute; + +// cos(-pi/8) +pub const H1X: f64 = 0.9238795325112867f64; +// sin(-pi/8) +pub const H1Y: f64 = -0.38268343236508984f64; + +pub trait FftSimd64 { + type Reg: Copy + Debug; + const COMPLEX_PER_REG: usize; + + unsafe fn splat_re_im(ptr: *const f64) -> Self::Reg; + unsafe fn splat(ptr: *const c64) -> Self::Reg; + unsafe fn load(ptr: *const c64) -> Self::Reg; + unsafe fn store(ptr: *mut c64, z: Self::Reg); + unsafe fn xor(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn swap_re_im(xy: Self::Reg) -> Self::Reg; + unsafe fn add(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn sub(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn cwise_mul(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn mul(a: Self::Reg, b: Self::Reg) -> Self::Reg; +} + +pub trait FftSimd64Ext: FftSimd64 { + #[inline(always)] + unsafe fn conj(xy: Self::Reg) -> Self::Reg { + let mask = Self::splat(&c64 { re: 0.0, im: -0.0 }); + Self::xor(xy, mask) + } + + #[inline(always)] + unsafe fn xpj(fwd: bool, xy: Self::Reg) -> Self::Reg { + if fwd { + Self::swap_re_im(Self::conj(xy)) + } else { + Self::conj(Self::swap_re_im(xy)) + } + } + + #[inline(always)] + unsafe fn xmj(fwd: bool, xy: Self::Reg) -> Self::Reg { + Self::xpj(!fwd, xy) + } + + #[inline(always)] + unsafe fn xv8(fwd: bool, xy: Self::Reg) -> Self::Reg { + let r = Self::splat_re_im(&core::f64::consts::FRAC_1_SQRT_2); + Self::cwise_mul(r, Self::add(xy, Self::xpj(fwd, xy))) + } + + #[inline(always)] + unsafe fn xw8(fwd: bool, xy: Self::Reg) -> Self::Reg { + Self::xv8(!fwd, xy) + } + + #[inline(always)] + unsafe fn xh1(fwd: bool, xy: Self::Reg) -> Self::Reg { + if fwd { + Self::mul(Self::splat(&c64 { re: H1X, im: H1Y }), xy) + } else { + Self::mul(Self::splat(&c64 { re: H1X, im: -H1Y }), xy) + } + } + + #[inline(always)] + unsafe fn xh3(fwd: bool, xy: Self::Reg) -> Self::Reg { + if fwd { + Self::mul(Self::splat(&c64 { re: -H1Y, im: -H1X }), xy) + } else { + Self::mul(Self::splat(&c64 { re: -H1Y, im: H1X }), xy) + } + } + + #[inline(always)] + unsafe fn xhf(fwd: bool, xy: Self::Reg) -> Self::Reg { + Self::xh1(!fwd, xy) + } + + #[inline(always)] + unsafe fn xhd(fwd: bool, xy: Self::Reg) -> Self::Reg { + Self::xh3(!fwd, xy) + } +} + +impl FftSimd64Ext for T {} + +pub trait FftSimd64X2: FftSimd64 { + unsafe fn catlo(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn cathi(a: Self::Reg, b: Self::Reg) -> Self::Reg; +} + +pub trait FftSimd64X4: FftSimd64 { + unsafe fn transpose( + a: Self::Reg, + b: Self::Reg, + c: Self::Reg, + d: Self::Reg, + ) -> (Self::Reg, Self::Reg, Self::Reg, Self::Reg); +} + +#[derive(Copy, Clone, Debug)] +pub struct Scalar; + +impl FftSimd64 for Scalar { + type Reg = c64; + + const COMPLEX_PER_REG: usize = 1; + + #[inline(always)] + unsafe fn splat_re_im(ptr: *const f64) -> Self::Reg { + c64 { re: *ptr, im: *ptr } + } + + #[inline(always)] + unsafe fn splat(ptr: *const c64) -> Self::Reg { + *ptr + } + + #[inline(always)] + unsafe fn load(ptr: *const c64) -> Self::Reg { + *ptr + } + + #[inline(always)] + unsafe fn store(ptr: *mut c64, z: Self::Reg) { + *ptr = z; + } + + #[inline(always)] + unsafe fn xor(a: Self::Reg, b: Self::Reg) -> Self::Reg { + transmute(transmute::(a) ^ transmute::(b)) + } + + #[inline(always)] + unsafe fn swap_re_im(xy: Self::Reg) -> Self::Reg { + Self::Reg { + re: xy.im, + im: xy.re, + } + } + + #[inline(always)] + unsafe fn add(a: Self::Reg, b: Self::Reg) -> Self::Reg { + a + b + } + + #[inline(always)] + unsafe fn sub(a: Self::Reg, b: Self::Reg) -> Self::Reg { + a - b + } + + #[inline(always)] + unsafe fn mul(a: Self::Reg, b: Self::Reg) -> Self::Reg { + a * b + } + + #[inline(always)] + unsafe fn cwise_mul(a: Self::Reg, b: Self::Reg) -> Self::Reg { + Self::Reg { + re: a.re * b.re, + im: a.im * b.im, + } + } +} + +#[inline(always)] +pub unsafe fn twid(r: usize, big_n: usize, k: usize, w: *const c64, p: usize) -> &'static c64 { + &*w.add(p + (k - 1) * (big_n / r)) +} + +#[inline(always)] +pub unsafe fn twid_t(r: usize, big_n: usize, k: usize, w: *const c64, p: usize) -> &'static c64 { + &*w.add(r * p + (big_n + k)) +} + +// https://stackoverflow.com/a/42792940 +pub fn sincospi64(mut a: f64) -> (f64, f64) { + let fma = f64::mul_add; + + // must be evaluated with IEEE-754 semantics + let az = a * 0.0; + + // for |a| >= 2**53, cospi(a) = 1.0, but cospi(Inf) = NaN + a = if a.abs() < 9007199254740992.0f64 { + a + } else { + az + }; + + // reduce argument to primary approximation interval (-0.25, 0.25) + let mut r = (a + a).round(); + let i = r as i64; + let t = f64::mul_add(-0.5, r, a); + + // compute core approximations + let s = t * t; + + // approximate cos(pi*x) for x in [-0.25,0.25] + + r = -1.0369917389758117e-4; + r = fma(r, s, 1.9294935641298806e-3); + r = fma(r, s, -2.5806887942825395e-2); + r = fma(r, s, 2.3533063028328211e-1); + r = fma(r, s, -1.3352627688538006e+0); + r = fma(r, s, 4.0587121264167623e+0); + r = fma(r, s, -4.9348022005446790e+0); + let mut c = fma(r, s, 1.0000000000000000e+0); + + // approximate sin(pi*x) for x in [-0.25,0.25] + r = 4.6151442520157035e-4; + r = fma(r, s, -7.3700183130883555e-3); + r = fma(r, s, 8.2145868949323936e-2); + r = fma(r, s, -5.9926452893214921e-1); + r = fma(r, s, 2.5501640398732688e+0); + r = fma(r, s, -5.1677127800499516e+0); + let s = s * t; + r *= s; + + let mut s = fma(t, 3.1415926535897931e+0, r); + // map results according to quadrant + + if (i & 2) != 0 { + s = 0.0 - s; // must be evaluated with IEEE-754 semantics + c = 0.0 - c; // must be evaluated with IEEE-754 semantics + } + if (i & 1) != 0 { + let t = 0.0 - s; // must be evaluated with IEEE-754 semantics + s = c; + c = t; + } + // IEEE-754: sinPi(+n) is +0 and sinPi(-n) is -0 for positive integers n + if a == a.floor() { + s = az + } + (s, c) +} + +pub fn init_wt(r: usize, n: usize, w: &mut [c64], w_inv: &mut [c64]) { + if n < r { + return; + } + + let nr = n / r; + let theta = -2.0 / n as f64; + + for wi in w.iter_mut() { + wi.re = f64::NAN; + wi.im = f64::NAN; + } + + for p in 0..nr { + for k in 1..r { + let (s, c) = sincospi64(theta * (k * p) as f64); + let z = c64::new(c, s); + w[p + (k - 1) * nr] = z; + w[n + r * p + k] = z; + w_inv[p + (k - 1) * nr] = z.conj(); + w_inv[n + r * p + k] = z.conj(); + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100755 index 0000000..cf76154 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,121 @@ +//! Concrete-FFT is a pure Rust high performance fast Fourier transform library that processes +//! vectors of sizes that are powers of two. +//! +//! This library provides two FFT modules: +//! - The ordered module FFT applies a forward/inverse FFT that takes its input in standard +//! order, and outputs the result in standard order. For more detail on what the FFT +//! computes, check the ordered module-level documentation. +//! - The unordered module FFT applies a forward FFT that takes its input in standard order, +//! and outputs the result in a certain permuted order that may depend on the FFT plan. On the +//! other hand, the inverse FFT takes its input in that same permuted order and outputs its result +//! in standard order. This is useful for cases where the order of the coefficients in the +//! Fourier domain is not important. An example is using the Fourier transform for vector +//! convolution. The only operations that are performed in the Fourier domain are elementwise, and +//! so the order of the coefficients does not affect the results. +//! +//! # Features +//! +//! - `std` (default): This enables runtime arch detection for accelerated SIMD instructions, and +//! an FFT plan that measures the various implementations to choose the fastest one at runtime. +//! - `nightly`: This enables unstable Rust features to further speed up the FFT, by enabling +//! AVX512F instructions on CPUs that support them. This feature requires a nightly Rust +//! toolchain. +//! - `serde`: This enables serialization and deserialization functions for the unordered plan. +//! These allow for data in the Fourier domain to be serialized from the permuted order to the +//! standard order, and deserialized from the standard order to the permuted order. +//! This is needed since the inverse transform must be used with the same plan that +//! computed/deserialized the forward transform (or more specifically, a plan with the same +//! internal base FFT size). +//! +//! # Example +//! +#![cfg_attr(feature = "std", doc = "```")] +#![cfg_attr(not(feature = "std"), doc = "```ignore")] +//! use concrete_fft::c64; +//! use concrete_fft::ordered::{Plan, Method}; +//! use dyn_stack::{DynStack, GlobalMemBuffer, ReborrowMut}; +//! use num_complex::ComplexFloat; +//! use std::time::Duration; +//! +//! const N: usize = 4; +//! let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); +//! let mut scratch_memory = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); +//! let mut stack = DynStack::new(&mut scratch_memory); +//! +//! let data = [ +//! c64::new(1.0, 0.0), +//! c64::new(2.0, 0.0), +//! c64::new(3.0, 0.0), +//! c64::new(4.0, 0.0), +//! ]; +//! +//! let mut transformed_fwd = data; +//! plan.fwd(&mut transformed_fwd, stack.rb_mut()); +//! +//! let mut transformed_inv = transformed_fwd; +//! plan.inv(&mut transformed_inv, stack.rb_mut()); +//! +//! for (actual, expected) in transformed_inv.iter().map(|z| z / N as f64).zip(data) { +//! assert!((expected - actual).abs() < 1e-9); +//! } +//! ``` + +#![cfg_attr(not(feature = "std"), no_std)] +#![allow( + clippy::erasing_op, + clippy::identity_op, + clippy::zero_prefixed_literal, + clippy::excessive_precision, + clippy::type_complexity, + clippy::too_many_arguments +)] +#![cfg_attr(feature = "nightly", feature(stdsimd, avx512_target_feature))] +#![cfg_attr(docsrs, feature(doc_cfg))] + +use num_complex::Complex64; + +mod fft_simd; + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +mod x86; + +pub(crate) mod dif16; +pub(crate) mod dif2; +pub(crate) mod dif4; +pub(crate) mod dif8; + +pub(crate) mod dit16; +pub(crate) mod dit2; +pub(crate) mod dit4; +pub(crate) mod dit8; + +pub mod ordered; +pub mod unordered; + +/// 64-bit complex floating point type. +#[allow(non_camel_case_types)] +pub type c64 = Complex64; + +type FnArray = [unsafe fn(*mut c64, *mut c64, *const c64); 17]; + +#[derive(Copy, Clone)] +struct FftImpl { + fwd: FnArray, + inv: FnArray, +} + +#[cfg(feature = "std")] +macro_rules! x86_feature_detected { + ($tt: tt) => { + is_x86_feature_detected!($tt) + }; +} + +#[cfg(not(feature = "std"))] +macro_rules! x86_feature_detected { + ($tt: tt) => { + cfg!(target_arch = $tt) + }; +} + +pub(crate) use x86_feature_detected; diff --git a/src/ordered.rs b/src/ordered.rs new file mode 100644 index 0000000..4cd24da --- /dev/null +++ b/src/ordered.rs @@ -0,0 +1,502 @@ +//! Ordered FFT module. +//! +//! This FFT is currently based on hte Stockham algorithm, and was ported from the +//! [OTFFT](http://wwwa.pikara.ne.jp/okojisan/otfft-en/) C++ library by Takuya OKAHISA. +//! +//! This module computes the forward or inverse FFT in standard ordering. +//! This means that given a buffer of complex numbers $[x_0, \dots, x_{n-1}]$, +//! the forward FFT $[X_0, \dots, X_{n-1}]$ is given by +//! $$X_p = \sum_{q = 0}^{n-1} \exp\left(-\frac{i 2\pi pq}{n}\right),$$ +//! and the inverse FFT $[Y_0, \dots, Y_{n-1}]$ is given by +//! $$Y_p = \sum_{q = 0}^{n-1} \exp\left(\frac{i 2\pi pq}{n}\right).$$ + +use crate::*; +use aligned_vec::avec; +use aligned_vec::ABox; +use aligned_vec::CACHELINE_ALIGN; + +#[cfg(feature = "std")] +use core::time::Duration; +use dyn_stack::{DynStack, SizeOverflow, StackReq}; +#[cfg(feature = "std")] +use dyn_stack::{GlobalMemBuffer, ReborrowMut}; + +/// Internal FFT algorithm. +/// +/// The FFT can use a decimation-in-frequency (DIF) or decimation-in-time (DIT) approach. +/// And the FFT radix can be any of 2, 4, 8, 16. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum FftAlgo { + /// Decimation in frequency with radix 2 + Dif2, + /// Decimation in time with radix 2 + Dit2, + /// Decimation in frequency with radix 4 + Dif4, + /// Decimation in time with radix 4 + Dit4, + /// Decimation in frequency with radix 8 + Dif8, + /// Decimation in time with radix 8 + Dit8, + /// Decimation in frequency with radix 16 + Dif16, + /// Decimation in time with radix 16 + Dit16, +} + +/// Method for selecting the ordered FFT plan. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum Method { + /// Select the FFT plan by manually providing the underlying algorithm. + UserProvided(FftAlgo), + /// Select the FFT plan by measuring the running time of all the possible plans and selecting + /// the fastest one. The provided duration specifies how long the benchmark of each plan should + /// last. + #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] + Measure(Duration), +} + +#[cfg(feature = "std")] +fn measure_n_runs( + n_runs: u128, + algo: FftAlgo, + buf: &mut [c64], + twiddles: &[c64], + stack: DynStack, +) -> Duration { + let n = buf.len(); + let (mut scratch, _) = stack.make_aligned_uninit::(n, CACHELINE_ALIGN); + let scratch = scratch.as_mut_ptr() as *mut c64; + let [fwd, _] = get_fn_ptr(algo, n); + + use std::time::Instant; + let now = Instant::now(); + + for _ in 0..n_runs { + unsafe { + fwd(buf.as_mut_ptr(), scratch, twiddles.as_ptr()); + } + } + + now.elapsed() +} + +#[cfg(feature = "std")] +fn duration_div_f64(duration: Duration, n: f64) -> Duration { + Duration::from_secs_f64(duration.as_secs_f64() / n as f64) +} + +#[cfg(feature = "std")] +pub(crate) fn measure_fastest_scratch(n: usize) -> StackReq { + let align = CACHELINE_ALIGN; + StackReq::new_aligned::(2 * n, align) // twiddles + .and(StackReq::new_aligned::(n, align)) // buffer + .and(StackReq::new_aligned::(n, align)) +} + +#[cfg(feature = "std")] +pub(crate) fn measure_fastest( + min_bench_duration_per_algo: Duration, + n: usize, + stack: DynStack, +) -> (FftAlgo, Duration) { + const N_ALGOS: usize = 8; + const MIN_DURATION: Duration = Duration::from_millis(1); + + assert!(n.is_power_of_two()); + + let align = CACHELINE_ALIGN; + + let f = |_| c64::default(); + + let (twiddles, stack) = stack.make_aligned_with::(2 * n, align, f); + let (mut buf, mut stack) = stack.make_aligned_with::(n, align, f); + + { + // initialize scratch to load it in the cpu cache + drop(stack.rb_mut().make_aligned_with::(n, align, f)); + } + + let mut avg_durations = [Duration::ZERO; N_ALGOS]; + + let discriminant_to_algo = |i: usize| -> FftAlgo { + match i { + 0 => FftAlgo::Dif2, + 1 => FftAlgo::Dit2, + 2 => FftAlgo::Dif4, + 3 => FftAlgo::Dit4, + 4 => FftAlgo::Dif8, + 5 => FftAlgo::Dit8, + 6 => FftAlgo::Dif16, + 7 => FftAlgo::Dit16, + _ => unreachable!(), + } + }; + + for (i, avg) in (0..N_ALGOS).zip(&mut avg_durations) { + let algo = discriminant_to_algo(i); + + let (init_n_runs, approx_duration) = { + let mut n_runs: u128 = 1; + + loop { + let duration = measure_n_runs(n_runs, algo, &mut buf, &twiddles, stack.rb_mut()); + + if duration < MIN_DURATION { + n_runs *= 2; + } else { + break (n_runs, duration_div_f64(duration, n_runs as f64)); + } + } + }; + + let n_runs = (min_bench_duration_per_algo.as_secs_f64() / approx_duration.as_secs_f64()) + .ceil() as u128; + *avg = if n_runs <= init_n_runs { + approx_duration + } else { + let duration = measure_n_runs(n_runs, algo, &mut buf, &twiddles, stack.rb_mut()); + duration_div_f64(duration, n_runs as f64) + }; + } + + let best_time = avg_durations.iter().min().unwrap(); + let best_index = avg_durations + .iter() + .position(|elem| elem == best_time) + .unwrap(); + (discriminant_to_algo(best_index), *best_time) +} + +/// Ordered FFT plan. +/// +/// This type holds a forward and inverse FFT plan and twiddling factors for a specific size. +/// The size must be a power of two, and can be as large as `2^16` (inclusive). +#[derive(Clone)] +pub struct Plan { + fwd: unsafe fn(*mut c64, *mut c64, *const c64), + inv: unsafe fn(*mut c64, *mut c64, *const c64), + twiddles: ABox<[c64]>, + twiddles_inv: ABox<[c64]>, + algo: FftAlgo, +} + +impl core::fmt::Debug for Plan { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Plan") + .field("algo", &self.algo) + .field("fft_size", &self.fft_size()) + .finish() + } +} + +pub(crate) fn get_fn_ptr( + algo: FftAlgo, + n: usize, +) -> [unsafe fn(*mut c64, *mut c64, *const c64); 2] { + use FftAlgo::*; + + let fft = match algo { + Dif2 => dif2::runtime_fft(), + Dit2 => dit2::runtime_fft(), + Dif4 => dif4::runtime_fft(), + Dit4 => dit4::runtime_fft(), + Dif8 => dif8::runtime_fft(), + Dit8 => dit8::runtime_fft(), + Dif16 => dif16::runtime_fft(), + Dit16 => dit16::runtime_fft(), + }; + + let idx = n.trailing_zeros() as usize; + + [fft.fwd[idx], fft.inv[idx]] +} + +impl Plan { + /// Returns a new FFT plan for the given vector size, selected by the provided method. + /// + /// # Panics + /// + /// - Panics if `n` is not a power of two. + /// - Panics if `n` is greater than `2^16`. + /// + /// # Example + /// + #[cfg_attr(feature = "std", doc = " ```")] + #[cfg_attr(not(feature = "std"), doc = " ```ignore")] + /// use concrete_fft::ordered::{Method, Plan}; + /// use core::time::Duration; + /// + /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); + /// ``` + pub fn new(n: usize, method: Method) -> Self { + assert!(n.is_power_of_two()); + assert!(n.trailing_zeros() < 17); + + let algo = match method { + Method::UserProvided(algo) => algo, + #[cfg(feature = "std")] + Method::Measure(duration) => { + measure_fastest( + duration, + n, + DynStack::new(&mut GlobalMemBuffer::new(measure_fastest_scratch(n))), + ) + .0 + } + }; + + let [fwd, inv] = get_fn_ptr(algo, n); + + let mut twiddles = avec![c64::default(); 2 * n].into_boxed_slice(); + let mut twiddles_inv = avec![c64::default(); 2 * n].into_boxed_slice(); + use FftAlgo::*; + let r = match algo { + Dif2 | Dit2 => 2, + Dif4 | Dit4 => 4, + Dif8 | Dit8 => 8, + Dif16 | Dit16 => 16, + }; + fft_simd::init_wt(r, n, &mut twiddles, &mut twiddles_inv); + Self { + fwd, + inv, + twiddles, + algo, + twiddles_inv, + } + } + + /// Returns the vector size of the FFT. + /// + /// # Example + /// + #[cfg_attr(feature = "std", doc = " ```")] + #[cfg_attr(not(feature = "std"), doc = " ```ignore")] + /// use concrete_fft::ordered::{Method, Plan}; + /// use core::time::Duration; + /// + /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); + /// assert_eq!(plan.fft_size(), 4); + /// ``` + pub fn fft_size(&self) -> usize { + self.twiddles.len() / 2 + } + + /// Returns the algorithm that's internally used by the FFT. + /// + /// # Example + /// + /// ``` + /// use concrete_fft::ordered::{FftAlgo, Method, Plan}; + /// + /// let plan = Plan::new(4, Method::UserProvided(FftAlgo::Dif2)); + /// assert_eq!(plan.algo(), FftAlgo::Dif2); + /// ``` + pub fn algo(&self) -> FftAlgo { + self.algo + } + + /// Returns the size and alignment of the scratch memory needed to perform an FFT. + /// + /// # Example + /// + #[cfg_attr(feature = "std", doc = " ```")] + #[cfg_attr(not(feature = "std"), doc = " ```ignore")] + /// use concrete_fft::ordered::{Method, Plan}; + /// use core::time::Duration; + /// + /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); + /// let scratch = plan.fft_scratch().unwrap(); + /// ``` + pub fn fft_scratch(&self) -> Result { + StackReq::try_new_aligned::(self.fft_size(), CACHELINE_ALIGN) + } + + /// Performs a forward FFT in place, using the provided stack as scratch space. + /// + /// # Example + /// + #[cfg_attr(feature = "std", doc = " ```")] + #[cfg_attr(not(feature = "std"), doc = " ```ignore")] + /// use concrete_fft::c64; + /// use concrete_fft::ordered::{Method, Plan}; + /// use dyn_stack::{DynStack, GlobalMemBuffer}; + /// use core::time::Duration; + /// + /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); + /// + /// let mut memory = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); + /// let stack = DynStack::new(&mut memory); + /// + /// let mut buf = [c64::default(); 4]; + /// plan.fwd(&mut buf, stack); + /// ``` + pub fn fwd(&self, buf: &mut [c64], stack: DynStack) { + let n = self.fft_size(); + assert_eq!(n, buf.len()); + let (mut scratch, _) = stack.make_aligned_uninit::(n, CACHELINE_ALIGN); + let buf = buf.as_mut_ptr(); + let scratch = scratch.as_mut_ptr(); + unsafe { (self.fwd)(buf, scratch as *mut c64, self.twiddles.as_ptr()) } + } + + /// Performs an inverse FFT in place, using the provided stack as scratch space. + /// + /// # Example + /// + #[cfg_attr(feature = "std", doc = " ```")] + #[cfg_attr(not(feature = "std"), doc = " ```ignore")] + /// use concrete_fft::c64; + /// use concrete_fft::ordered::{Method, Plan}; + /// use dyn_stack::{DynStack, GlobalMemBuffer, ReborrowMut}; + /// use core::time::Duration; + /// + /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); + /// + /// let mut memory = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); + /// let mut stack = DynStack::new(&mut memory); + /// + /// let mut buf = [c64::default(); 4]; + /// plan.fwd(&mut buf, stack.rb_mut()); + /// plan.inv(&mut buf, stack); + /// ``` + pub fn inv(&self, buf: &mut [c64], stack: DynStack) { + let n = self.fft_size(); + assert_eq!(n, buf.len()); + let (mut scratch, _) = stack.make_aligned_uninit::(n, CACHELINE_ALIGN); + let buf = buf.as_mut_ptr(); + let scratch = scratch.as_mut_ptr(); + unsafe { (self.inv)(buf, scratch as *mut c64, self.twiddles_inv.as_ptr()) } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dif16::*; + use crate::dif2::*; + use crate::dif4::*; + use crate::dif8::*; + use crate::dit16::*; + use crate::dit2::*; + use crate::dit4::*; + use crate::dit8::*; + use crate::fft_simd::init_wt; + use crate::x86_feature_detected; + use num_complex::ComplexFloat; + use rand::random; + use rustfft::FftPlanner; + + extern crate alloc; + use alloc::vec; + + #[test] + fn test_fft() { + unsafe { + for (can_run, r, arr) in [ + (true, 2, &DIT2_SCALAR), + (true, 4, &DIT4_SCALAR), + (true, 8, &DIT8_SCALAR), + (true, 16, &DIT16_SCALAR), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("avx"), 2, &DIT2_AVX), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("avx"), 4, &DIT4_AVX), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("avx"), 8, &DIT8_AVX), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("avx"), 16, &DIT16_AVX), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("fma"), 2, &DIT2_FMA), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("fma"), 4, &DIT4_FMA), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("fma"), 8, &DIT8_FMA), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("fma"), 16, &DIT16_FMA), + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + (x86_feature_detected!("avx512f"), 4, &DIT4_AVX512), + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + (x86_feature_detected!("avx512f"), 8, &DIT8_AVX512), + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + (x86_feature_detected!("avx512f"), 16, &DIT16_AVX512), + (true, 2, &DIF2_SCALAR), + (true, 4, &DIF4_SCALAR), + (true, 8, &DIF8_SCALAR), + (true, 16, &DIF16_SCALAR), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("avx"), 2, &DIF2_AVX), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("avx"), 4, &DIF4_AVX), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("avx"), 8, &DIF8_AVX), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("avx"), 16, &DIF16_AVX), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("fma"), 2, &DIF2_FMA), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("fma"), 4, &DIF4_FMA), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("fma"), 8, &DIF8_FMA), + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + (x86_feature_detected!("fma"), 16, &DIF16_FMA), + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + (x86_feature_detected!("avx512f"), 4, &DIF4_AVX512), + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + (x86_feature_detected!("avx512f"), 8, &DIF8_AVX512), + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + (x86_feature_detected!("avx512f"), 16, &DIF16_AVX512), + ] { + if can_run { + for exp in 0..17 { + let n: usize = 1 << exp; + let fwd = arr.fwd[n.trailing_zeros() as usize]; + let inv = arr.inv[n.trailing_zeros() as usize]; + + let mut scratch = vec![c64::default(); n]; + let mut twiddles = vec![c64::default(); 2 * n]; + let mut twiddles_inv = vec![c64::default(); 2 * n]; + + init_wt(r, n, &mut twiddles, &mut twiddles_inv); + + let mut x = vec![c64::default(); n]; + for z in &mut x { + *z = c64::new(random(), random()); + } + + let orig = x.clone(); + + fwd(x.as_mut_ptr(), scratch.as_mut_ptr(), twiddles.as_ptr()); + + // compare with rustfft + { + let mut planner = FftPlanner::new(); + let plan = planner.plan_fft_forward(n); + let mut y = orig.clone(); + plan.process(&mut y); + + for (z_expected, z_actual) in y.iter().zip(&x) { + assert!((*z_expected - *z_actual).abs() < 1e-12); + } + } + + inv(x.as_mut_ptr(), scratch.as_mut_ptr(), twiddles_inv.as_ptr()); + + for z in &mut x { + *z /= n as f64; + } + + for (z_expected, z_actual) in orig.iter().zip(&x) { + assert!((*z_expected - *z_actual).abs() < 1e-14); + } + } + } + } + } + } +} diff --git a/src/unordered.rs b/src/unordered.rs new file mode 100644 index 0000000..962a76b --- /dev/null +++ b/src/unordered.rs @@ -0,0 +1,1259 @@ +//! Unordered FFT module. +//! +//! This module computes the forward or inverse FFT in a similar fashion to the ordered module, +//! with two crucial differences. +//! Given an FFT plan, the forward transform takes its inputs in standard order, and outputs the +//! forward FFT terms in an unspecified order. And the backward transform takes its inputs in the +//! aforementioned order, and outputs the inverse FFT in the standard order. + +use crate::fft_simd::{init_wt, sincospi64, FftSimd64, FftSimd64Ext}; +use crate::x86_feature_detected; +use crate::{c64, ordered::FftAlgo}; +use aligned_vec::{avec, ABox, CACHELINE_ALIGN}; +#[cfg(feature = "std")] +use core::time::Duration; +use dyn_stack::{DynStack, SizeOverflow, StackReq}; +#[cfg(feature = "std")] +use dyn_stack::{GlobalMemBuffer, ReborrowMut}; + +#[inline(always)] +unsafe fn fwd_butterfly_x2(z0: I::Reg, z1: I::Reg, w1: I::Reg) -> (I::Reg, I::Reg) { + (I::add(z0, z1), I::mul(w1, I::sub(z0, z1))) +} + +#[inline(always)] +unsafe fn inv_butterfly_x2(z0: I::Reg, z1: I::Reg, w1: I::Reg) -> (I::Reg, I::Reg) { + let z1 = I::mul(w1, z1); + (I::add(z0, z1), I::sub(z0, z1)) +} + +#[inline(always)] +unsafe fn fwd_butterfly_x4( + z0: I::Reg, + z1: I::Reg, + z2: I::Reg, + z3: I::Reg, + w1: I::Reg, + w2: I::Reg, + w3: I::Reg, +) -> (I::Reg, I::Reg, I::Reg, I::Reg) { + let z0p2 = I::add(z0, z2); + let z0m2 = I::sub(z0, z2); + let z1p3 = I::add(z1, z3); + let jz1m3 = I::xpj(true, I::sub(z1, z3)); + + ( + I::add(z0p2, z1p3), + I::mul(w1, I::sub(z0m2, jz1m3)), + I::mul(w2, I::sub(z0p2, z1p3)), + I::mul(w3, I::add(z0m2, jz1m3)), + ) +} + +#[inline(always)] +unsafe fn inv_butterfly_x4( + z0: I::Reg, + z1: I::Reg, + z2: I::Reg, + z3: I::Reg, + w1: I::Reg, + w2: I::Reg, + w3: I::Reg, +) -> (I::Reg, I::Reg, I::Reg, I::Reg) { + let z0 = z0; + let z1 = I::mul(w1, z1); + let z2 = I::mul(w2, z2); + let z3 = I::mul(w3, z3); + + let z0p2 = I::add(z0, z2); + let z0m2 = I::sub(z0, z2); + let z1p3 = I::add(z1, z3); + let jz1m3 = I::xpj(false, I::sub(z1, z3)); + + ( + I::add(z0p2, z1p3), + I::sub(z0m2, jz1m3), + I::sub(z0p2, z1p3), + I::add(z0m2, jz1m3), + ) +} + +#[inline(always)] +unsafe fn fwd_butterfly_x8( + z0: I::Reg, + z1: I::Reg, + z2: I::Reg, + z3: I::Reg, + z4: I::Reg, + z5: I::Reg, + z6: I::Reg, + z7: I::Reg, + w1: I::Reg, + w2: I::Reg, + w3: I::Reg, + w4: I::Reg, + w5: I::Reg, + w6: I::Reg, + w7: I::Reg, +) -> ( + I::Reg, + I::Reg, + I::Reg, + I::Reg, + I::Reg, + I::Reg, + I::Reg, + I::Reg, +) { + let z0p4 = I::add(z0, z4); + let z0m4 = I::sub(z0, z4); + let z2p6 = I::add(z2, z6); + let jz2m6 = I::xpj(true, I::sub(z2, z6)); + + let z1p5 = I::add(z1, z5); + let z1m5 = I::sub(z1, z5); + let z3p7 = I::add(z3, z7); + let jz3m7 = I::xpj(true, I::sub(z3, z7)); + + // z0 + z2 + z4 + z6 + let t0 = I::add(z0p4, z2p6); + // z1 + z3 + z5 + z7 + let t1 = I::add(z1p5, z3p7); + // z0 + w4z2 + z4 + w4z6 + let t2 = I::sub(z0p4, z2p6); + // w2z1 + w6z3 + w2z5 + w6z7 + let t3 = I::xpj(true, I::sub(z1p5, z3p7)); + // z0 + w2z2 + z4 + w6z6 + let t4 = I::sub(z0m4, jz2m6); + // w1z1 + w3z3 + w5z5 + w7z7 + let t5 = I::xw8(true, I::sub(z1m5, jz3m7)); + // z0 + w2z2 + w4z4 + w6z6 + let t6 = I::add(z0m4, jz2m6); + // w7z1 + w1z3 + w3z5 + w5z7 + let t7 = I::xv8(true, I::add(z1m5, jz3m7)); + + ( + I::add(t0, t1), + I::mul(w1, I::add(t4, t5)), + I::mul(w2, I::sub(t2, t3)), + I::mul(w3, I::sub(t6, t7)), + I::mul(w4, I::sub(t0, t1)), + I::mul(w5, I::sub(t4, t5)), + I::mul(w6, I::add(t2, t3)), + I::mul(w7, I::add(t6, t7)), + ) +} + +#[inline(always)] +unsafe fn inv_butterfly_x8( + z0: I::Reg, + z1: I::Reg, + z2: I::Reg, + z3: I::Reg, + z4: I::Reg, + z5: I::Reg, + z6: I::Reg, + z7: I::Reg, + w1: I::Reg, + w2: I::Reg, + w3: I::Reg, + w4: I::Reg, + w5: I::Reg, + w6: I::Reg, + w7: I::Reg, +) -> ( + I::Reg, + I::Reg, + I::Reg, + I::Reg, + I::Reg, + I::Reg, + I::Reg, + I::Reg, +) { + let z0 = z0; + let z1 = I::mul(w1, z1); + let z2 = I::mul(w2, z2); + let z3 = I::mul(w3, z3); + let z4 = I::mul(w4, z4); + let z5 = I::mul(w5, z5); + let z6 = I::mul(w6, z6); + let z7 = I::mul(w7, z7); + + let z0p4 = I::add(z0, z4); + let z0m4 = I::sub(z0, z4); + let z2p6 = I::add(z2, z6); + let jz2m6 = I::xpj(false, I::sub(z2, z6)); + + let z1p5 = I::add(z1, z5); + let z1m5 = I::sub(z1, z5); + let z3p7 = I::add(z3, z7); + let jz3m7 = I::xpj(false, I::sub(z3, z7)); + + // z0 + z2 + z4 + z6 + let t0 = I::add(z0p4, z2p6); + // z1 + z3 + z5 + z7 + let t1 = I::add(z1p5, z3p7); + // z0 + w4z2 + z4 + w4z6 + let t2 = I::sub(z0p4, z2p6); + // w2z1 + w6z3 + w2z5 + w6z7 + let t3 = I::xpj(false, I::sub(z1p5, z3p7)); + // z0 + w2z2 + z4 + w6z6 + let t4 = I::sub(z0m4, jz2m6); + // w1z1 + w3z3 + w5z5 + w7z7 + let t5 = I::xw8(false, I::sub(z1m5, jz3m7)); + // z0 + w2z2 + w4z4 + w6z6 + let t6 = I::add(z0m4, jz2m6); + // w7z1 + w1z3 + w3z5 + w5z7 + let t7 = I::xv8(false, I::add(z1m5, jz3m7)); + + ( + I::add(t0, t1), + I::add(t4, t5), + I::sub(t2, t3), + I::sub(t6, t7), + I::sub(t0, t1), + I::sub(t4, t5), + I::add(t2, t3), + I::add(t6, t7), + ) +} + +#[inline(always)] +unsafe fn fwd_process_x2(n: usize, z: *mut c64, w: *const c64) { + let m = n / 2; + let z0 = z.add(m * 0); + let z1 = z.add(m * 1); + debug_assert_eq!(m % I::COMPLEX_PER_REG, 0); + let mut p = 0; + while p < m { + let w1 = I::load(w.add(p + I::COMPLEX_PER_REG * 0)); + + let z00 = I::load(z0.add(p)); + let z01 = I::load(z1.add(p)); + + let (z00, z01) = fwd_butterfly_x2::(z00, z01, w1); + + I::store(z0.add(p), z00); + I::store(z1.add(p), z01); + + p += I::COMPLEX_PER_REG; + } +} + +#[inline(always)] +unsafe fn inv_process_x2(n: usize, z: *mut c64, w: *const c64) { + let m = n / 2; + let z0 = z.add(m * 0); + let z1 = z.add(m * 1); + debug_assert_eq!(m % I::COMPLEX_PER_REG, 0); + let mut p = 0; + while p < m { + let w1 = I::load(w.add(p + I::COMPLEX_PER_REG * 0)); + + let z00 = I::load(z0.add(p)); + let z01 = I::load(z1.add(p)); + + let (z00, z01) = inv_butterfly_x2::(z00, z01, w1); + + I::store(z0.add(p), z00); + I::store(z1.add(p), z01); + + p += I::COMPLEX_PER_REG; + } +} + +#[inline(always)] +unsafe fn fwd_process_x4(n: usize, z: *mut c64, w: *const c64) { + let m = n / 4; + let z0 = z.add(m * 0); + let z1 = z.add(m * 1); + let z2 = z.add(m * 2); + let z3 = z.add(m * 3); + debug_assert_eq!(m % I::COMPLEX_PER_REG, 0); + let mut p = 0; + while p < m { + let w1 = I::load(w.add(3 * p + I::COMPLEX_PER_REG * 0)); + let w2 = I::load(w.add(3 * p + I::COMPLEX_PER_REG * 1)); + let w3 = I::load(w.add(3 * p + I::COMPLEX_PER_REG * 2)); + + let z00 = I::load(z0.add(p)); + let z01 = I::load(z1.add(p)); + let z02 = I::load(z2.add(p)); + let z03 = I::load(z3.add(p)); + + let (z00, z01, z02, z03) = fwd_butterfly_x4::(z00, z01, z02, z03, w1, w2, w3); + + I::store(z0.add(p), z00); + I::store(z1.add(p), z02); + I::store(z2.add(p), z01); + I::store(z3.add(p), z03); + + p += I::COMPLEX_PER_REG; + } +} + +#[inline(always)] +unsafe fn inv_process_x4(n: usize, z: *mut c64, w: *const c64) { + let m = n / 4; + let z0 = z.add(m * 0); + let z1 = z.add(m * 1); + let z2 = z.add(m * 2); + let z3 = z.add(m * 3); + debug_assert_eq!(m % I::COMPLEX_PER_REG, 0); + let mut p = 0; + while p < m { + let w1 = I::load(w.add(3 * p + I::COMPLEX_PER_REG * 0)); + let w2 = I::load(w.add(3 * p + I::COMPLEX_PER_REG * 1)); + let w3 = I::load(w.add(3 * p + I::COMPLEX_PER_REG * 2)); + + let z00 = I::load(z0.add(p)); + let z01 = I::load(z2.add(p)); + let z02 = I::load(z1.add(p)); + let z03 = I::load(z3.add(p)); + + let (z00, z01, z02, z03) = inv_butterfly_x4::(z00, z01, z02, z03, w1, w2, w3); + + I::store(z0.add(p), z00); + I::store(z1.add(p), z01); + I::store(z2.add(p), z02); + I::store(z3.add(p), z03); + + p += I::COMPLEX_PER_REG; + } +} + +#[inline(always)] +unsafe fn fwd_process_x8(n: usize, z: *mut c64, w: *const c64) { + let m = n / 8; + let z0 = z.add(m * 0); + let z1 = z.add(m * 1); + let z2 = z.add(m * 2); + let z3 = z.add(m * 3); + let z4 = z.add(m * 4); + let z5 = z.add(m * 5); + let z6 = z.add(m * 6); + let z7 = z.add(m * 7); + + debug_assert_eq!(m % I::COMPLEX_PER_REG, 0); + let mut p = 0; + while p < m { + let w1 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 0)); + let w2 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 1)); + let w3 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 2)); + let w4 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 3)); + let w5 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 4)); + let w6 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 5)); + let w7 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 6)); + + let z00 = I::load(z0.add(p)); + let z01 = I::load(z1.add(p)); + let z02 = I::load(z2.add(p)); + let z03 = I::load(z3.add(p)); + let z04 = I::load(z4.add(p)); + let z05 = I::load(z5.add(p)); + let z06 = I::load(z6.add(p)); + let z07 = I::load(z7.add(p)); + + let (z00, z01, z02, z03, z04, z05, z06, z07) = fwd_butterfly_x8::( + z00, z01, z02, z03, z04, z05, z06, z07, w1, w2, w3, w4, w5, w6, w7, + ); + + I::store(z0.add(p), z00); + I::store(z1.add(p), z04); + I::store(z2.add(p), z02); + I::store(z3.add(p), z06); + I::store(z4.add(p), z01); + I::store(z5.add(p), z05); + I::store(z6.add(p), z03); + I::store(z7.add(p), z07); + + p += I::COMPLEX_PER_REG; + } +} + +#[inline(always)] +unsafe fn inv_process_x8(n: usize, z: *mut c64, w: *const c64) { + let m = n / 8; + let z0 = z.add(m * 0); + let z1 = z.add(m * 1); + let z2 = z.add(m * 2); + let z3 = z.add(m * 3); + let z4 = z.add(m * 4); + let z5 = z.add(m * 5); + let z6 = z.add(m * 6); + let z7 = z.add(m * 7); + + debug_assert_eq!(m % I::COMPLEX_PER_REG, 0); + let mut p = 0; + while p < m { + let w1 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 0)); + let w2 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 1)); + let w3 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 2)); + let w4 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 3)); + let w5 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 4)); + let w6 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 5)); + let w7 = I::load(w.add(7 * p + I::COMPLEX_PER_REG * 6)); + + let z00 = I::load(z0.add(p)); + let z01 = I::load(z4.add(p)); + let z02 = I::load(z2.add(p)); + let z03 = I::load(z6.add(p)); + let z04 = I::load(z1.add(p)); + let z05 = I::load(z5.add(p)); + let z06 = I::load(z3.add(p)); + let z07 = I::load(z7.add(p)); + + let (z00, z01, z02, z03, z04, z05, z06, z07) = inv_butterfly_x8::( + z00, z01, z02, z03, z04, z05, z06, z07, w1, w2, w3, w4, w5, w6, w7, + ); + + I::store(z0.add(p), z00); + I::store(z1.add(p), z01); + I::store(z2.add(p), z02); + I::store(z3.add(p), z03); + I::store(z4.add(p), z04); + I::store(z5.add(p), z05); + I::store(z6.add(p), z06); + I::store(z7.add(p), z07); + + p += I::COMPLEX_PER_REG; + } +} + +macro_rules! dispatcher { + ($name: ident, $impl: ident) => { + #[allow(non_camel_case_types)] + struct $name { + __private: (), + } + impl $name { + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + #[target_feature(enable = "avx512f")] + unsafe fn avx512f(n: usize, z: *mut c64, w: *const c64) { + $impl::(n, z, w); + } + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + #[target_feature(enable = "fma")] + unsafe fn fma(n: usize, z: *mut c64, w: *const c64) { + $impl::(n, z, w); + } + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + #[target_feature(enable = "avx")] + unsafe fn avx(n: usize, z: *mut c64, w: *const c64) { + $impl::(n, z, w); + } + } + fn $name() -> unsafe fn(usize, *mut c64, *const c64) { + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + if x86_feature_detected!("avx512f") { + return $name::avx512f; + } + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if x86_feature_detected!("fma") { + return $name::fma; + } else if x86_feature_detected!("avx") { + return $name::avx; + } + + $impl:: + } + }; +} + +dispatcher!(get_fwd_process_x2, fwd_process_x2); +dispatcher!(get_fwd_process_x4, fwd_process_x4); +dispatcher!(get_fwd_process_x8, fwd_process_x8); + +dispatcher!(get_inv_process_x2, inv_process_x2); +dispatcher!(get_inv_process_x4, inv_process_x4); +dispatcher!(get_inv_process_x8, inv_process_x8); + +fn get_complex_per_reg() -> usize { + #[cfg(all(feature = "nightly", any(target_arch = "x86_64", target_arch = "x86")))] + if x86_feature_detected!("avx512f") { + return crate::x86::Avx512X4::COMPLEX_PER_REG; + } + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if x86_feature_detected!("fma") { + return ::COMPLEX_PER_REG; + } else if x86_feature_detected!("avx") { + return ::COMPLEX_PER_REG; + } + + ::COMPLEX_PER_REG +} + +fn init_twiddles( + n: usize, + complex_per_reg: usize, + base_n: usize, + base_r: usize, + w: &mut [c64], + w_inv: &mut [c64], +) { + let theta = 2.0 / n as f64; + if n <= base_n { + init_wt(base_r, n, w, w_inv); + } else { + // FIXME + let r = if n == 2 * base_n { + 2 + } else if n == 4 * base_n { + 4 + } else { + 8 + }; + + let m = n / r; + let (w, w_next) = w.split_at_mut((r - 1) * m); + let (w_inv_next, w_inv) = w_inv.split_at_mut(w_inv.len() - (r - 1) * m); + + let mut p = 0; + while p < m { + for i in 0..complex_per_reg { + for k in 1..r { + let (sk, ck) = sincospi64(theta * (k * (p + i)) as f64); + let idx = (r - 1) * p + (k - 1) * complex_per_reg + i; + w[idx] = c64 { re: ck, im: -sk }; + w_inv[idx] = c64 { re: ck, im: sk }; + } + } + + p += complex_per_reg; + } + + init_twiddles(n / r, complex_per_reg, base_n, base_r, w_next, w_inv_next); + } +} + +#[inline(never)] +unsafe fn fwd_depth( + n: usize, + z: *mut c64, + w: *const c64, + base_fn: unsafe fn(*mut c64, *mut c64, *const c64), + base_n: usize, + base_scratch: *mut c64, + fwd_process_x2: unsafe fn(usize, *mut c64, *const c64), + fwd_process_x4: unsafe fn(usize, *mut c64, *const c64), + fwd_process_x8: unsafe fn(usize, *mut c64, *const c64), +) { + if n == base_n { + base_fn(z, base_scratch, w) + } else { + let r = if n == 2 * base_n { + fwd_process_x2(n, z, w); + 2 + } else if n == 4 * base_n { + fwd_process_x4(n, z, w); + 4 + } else { + fwd_process_x8(n, z, w); + 8 + }; + + let m = n / r; + for i in 0..r { + fwd_depth( + m, + z.add(m * i), + w.add((r - 1) * m), + base_fn, + base_n, + base_scratch, + fwd_process_x2, + fwd_process_x4, + fwd_process_x8, + ); + } + } +} + +#[inline(never)] +unsafe fn inv_depth( + n: usize, + z: *mut c64, + w: *const c64, + base_fn: unsafe fn(*mut c64, *mut c64, *const c64), + base_n: usize, + base_scratch: *mut c64, + inv_process_x2: unsafe fn(usize, *mut c64, *const c64), + inv_process_x4: unsafe fn(usize, *mut c64, *const c64), + inv_process_x8: unsafe fn(usize, *mut c64, *const c64), +) { + if n == base_n { + base_fn(z, base_scratch, w.sub(2 * n)) + } else { + let r = if n == 2 * base_n { + 2 + } else if n == 4 * base_n { + 4 + } else { + 8 + }; + + let m = n / r; + let w = w.sub((r - 1) * m); + for i in 0..r { + inv_depth( + m, + z.add(m * i), + w, + base_fn, + base_n, + base_scratch, + inv_process_x2, + inv_process_x4, + inv_process_x8, + ); + } + + if r == 2 { + inv_process_x2(n, z, w); + } else if r == 4 { + inv_process_x4(n, z, w); + } else { + inv_process_x8(n, z, w); + } + } +} + +/// Unordered FFT plan. +/// +/// This type holds a forward and inverse FFT plan and twiddling factors for a specific size. +/// The size must be a power of two. +#[derive(Clone)] +pub struct Plan { + twiddles: ABox<[c64]>, + twiddles_inv: ABox<[c64]>, + fwd_process_x2: unsafe fn(usize, *mut c64, *const c64), + fwd_process_x4: unsafe fn(usize, *mut c64, *const c64), + fwd_process_x8: unsafe fn(usize, *mut c64, *const c64), + inv_process_x2: unsafe fn(usize, *mut c64, *const c64), + inv_process_x4: unsafe fn(usize, *mut c64, *const c64), + inv_process_x8: unsafe fn(usize, *mut c64, *const c64), + base_n: usize, + base_fn_fwd: unsafe fn(*mut c64, *mut c64, *const c64), + base_fn_inv: unsafe fn(*mut c64, *mut c64, *const c64), + base_algo: FftAlgo, + n: usize, +} + +impl core::fmt::Debug for Plan { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Plan") + .field("base_algo", &self.base_algo) + .field("base_size", &self.base_n) + .field("fft_size", &self.fft_size()) + .finish() + } +} + +/// Method for selecting the unordered FFT plan. +#[derive(Clone, Copy, Debug)] +pub enum Method { + /// Select the FFT plan by manually providing the underlying algorithm. + /// The unordered FFT works by using an internal ordered FFT plan, whose size and algorithm can + /// be specified by the user. + UserProvided { base_algo: FftAlgo, base_n: usize }, + /// Select the FFT plan by measuring the running time of all the possible plans and selecting + /// the fastest one. The provided duration specifies how long the benchmark of each plan should + /// last. + #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] + Measure(Duration), +} + +#[cfg(feature = "std")] +fn measure_fastest_scratch(n: usize) -> StackReq { + if n <= 512 { + crate::ordered::measure_fastest_scratch(n) + } else { + let base_n = 4096; + crate::ordered::measure_fastest_scratch(base_n) + .and(StackReq::new_aligned::(n + base_n, CACHELINE_ALIGN)) // twiddles + .and(StackReq::new_aligned::(n, CACHELINE_ALIGN)) // buf + .and(StackReq::new_aligned::(base_n, CACHELINE_ALIGN)) // scratch + } +} + +#[cfg(feature = "std")] +fn measure_fastest( + mut min_bench_duration_per_algo: Duration, + n: usize, + mut stack: DynStack, +) -> (FftAlgo, usize, Duration) { + const MIN_DURATION: Duration = Duration::from_millis(1); + min_bench_duration_per_algo = min_bench_duration_per_algo.max(MIN_DURATION); + + if n <= 512 { + let (algo, duration) = + crate::ordered::measure_fastest(min_bench_duration_per_algo, n, stack); + (algo, n, duration) + } else { + // bench + + let bases = [512, 1024, 2048, 4096]; + let mut algos: [Option; 4] = [None; 4]; + let mut avg_durations: [Option; 4] = [None; 4]; + let fwd_process_x2 = get_fwd_process_x2(); + let fwd_process_x4 = get_fwd_process_x4(); + let fwd_process_x8 = get_fwd_process_x8(); + + let mut n_algos = 0; + for (i, base_n) in bases.into_iter().enumerate() { + if n < base_n { + break; + } + + n_algos += 1; + + // we'll measure the corresponding plan + let (base_algo, duration) = crate::ordered::measure_fastest( + min_bench_duration_per_algo, + base_n, + stack.rb_mut(), + ); + + algos[i] = Some(base_algo); + + if n == base_n { + avg_durations[i] = Some(duration); + continue; + } + + // get the forward base algo + let base_fn = crate::ordered::get_fn_ptr(base_algo, base_n)[0]; + + let (w, stack) = + stack + .rb_mut() + .make_aligned_with::(n + base_n, CACHELINE_ALIGN, |_| { + Default::default() + }); + let (mut scratch, stack) = + stack.make_aligned_with::(base_n, CACHELINE_ALIGN, |_| Default::default()); + let (mut z, _) = + stack.make_aligned_with::(n, CACHELINE_ALIGN, |_| Default::default()); + + let n_runs = min_bench_duration_per_algo.as_secs_f64() + / (duration.as_secs_f64() * (n / base_n) as f64); + + let n_runs = n_runs.ceil() as u32; + + use std::time::Instant; + let now = Instant::now(); + for _ in 0..n_runs { + unsafe { + fwd_depth( + n, + z.as_mut_ptr(), + w.as_ptr(), + base_fn, + base_n, + scratch.as_mut_ptr(), + fwd_process_x2, + fwd_process_x4, + fwd_process_x8, + ); + } + } + let duration = now.elapsed(); + avg_durations[i] = Some(duration / n_runs); + } + + let best_time = avg_durations[..n_algos].iter().min().unwrap().unwrap(); + let best_index = avg_durations[..n_algos] + .iter() + .position(|elem| elem.unwrap() == best_time) + .unwrap(); + + (algos[best_index].unwrap(), bases[best_index], best_time) + } +} + +impl Plan { + /// Returns a new FFT plan for the given vector size, selected by the provided method. + /// + /// # Panics + /// + /// - Panics if `n` is not a power of two. + /// - If the method is user-provided, panics if `n` is not equal to the base ordered FFT size, + /// and the base FFT size is less than `32`. + /// + /// # Example + /// + #[cfg_attr(feature = "std", doc = " ```")] + #[cfg_attr(not(feature = "std"), doc = " ```ignore")] + /// use concrete_fft::unordered::{Method, Plan}; + /// use core::time::Duration; + /// + /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); + /// ``` + pub fn new(n: usize, method: Method) -> Self { + assert!(n.is_power_of_two()); + + let (base_algo, base_n) = match method { + Method::UserProvided { base_algo, base_n } => { + assert!(base_n.is_power_of_two()); + assert!(base_n <= n); + if base_n != n { + assert!(base_n >= 32); + } + assert!(base_n.trailing_zeros() < 17); + (base_algo, base_n) + } + + #[cfg(feature = "std")] + Method::Measure(duration) => { + let (algo, base_n, _) = measure_fastest( + duration, + n, + DynStack::new(&mut GlobalMemBuffer::new(measure_fastest_scratch(n))), + ); + (algo, base_n) + } + }; + + let [base_fn_fwd, base_fn_inv] = crate::ordered::get_fn_ptr(base_algo, base_n); + + let mut twiddles = avec![c64::default(); n + base_n].into_boxed_slice(); + let mut twiddles_inv = avec![c64::default(); n + base_n].into_boxed_slice(); + + use crate::ordered::FftAlgo::*; + let base_r = match base_algo { + Dif2 | Dit2 => 2, + Dif4 | Dit4 => 4, + Dif8 | Dit8 => 8, + Dif16 | Dit16 => 16, + }; + + init_twiddles( + n, + get_complex_per_reg(), + base_n, + base_r, + &mut twiddles, + &mut twiddles_inv, + ); + + Self { + twiddles, + twiddles_inv, + fwd_process_x2: get_fwd_process_x2(), + fwd_process_x4: get_fwd_process_x4(), + fwd_process_x8: get_fwd_process_x8(), + inv_process_x2: get_inv_process_x2(), + inv_process_x4: get_inv_process_x4(), + inv_process_x8: get_inv_process_x8(), + base_n, + base_fn_fwd, + base_fn_inv, + n, + base_algo, + } + } + + /// Returns the vector size of the FFT. + /// + /// # Example + /// + #[cfg_attr(feature = "std", doc = " ```")] + #[cfg_attr(not(feature = "std"), doc = " ```ignore")] + /// use concrete_fft::unordered::{Method, Plan}; + /// use core::time::Duration; + /// + /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); + /// assert_eq!(plan.fft_size(), 4); + /// ``` + pub fn fft_size(&self) -> usize { + self.n + } + + /// Returns the algorithm and size of the internal ordered FFT plan. + /// + /// # Example + /// + /// ``` + /// use concrete_fft::ordered::FftAlgo; + /// use concrete_fft::unordered::{Method, Plan}; + /// + /// let plan = Plan::new( + /// 4, + /// Method::UserProvided{ + /// base_algo: FftAlgo::Dif2, + /// base_n: 4, + /// }, + /// ); + /// assert_eq!(plan.algo(), (FftAlgo::Dif2, 4)); + /// ``` + pub fn algo(&self) -> (FftAlgo, usize) { + (self.base_algo, self.base_n) + } + + /// Returns the size and alignment of the scratch memory needed to perform an FFT. + /// + /// # Example + /// + #[cfg_attr(feature = "std", doc = " ```")] + #[cfg_attr(not(feature = "std"), doc = " ```ignore")] + /// use concrete_fft::unordered::{Method, Plan}; + /// use core::time::Duration; + /// + /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); + /// let scratch = plan.fft_scratch().unwrap(); + /// ``` + pub fn fft_scratch(&self) -> Result { + StackReq::try_new_aligned::(self.algo().1, CACHELINE_ALIGN) + } + + /// Performs a forward FFT in place, using the provided stack as scratch space. + /// + /// # Note + /// + /// The values in `buf` must be in standard order prior to calling this function. + /// When this function returns, the values in `buf` will contain the terms of the forward + /// transform in permuted order. + /// + /// # Example + /// + #[cfg_attr(feature = "std", doc = " ```")] + #[cfg_attr(not(feature = "std"), doc = " ```ignore")] + /// use concrete_fft::c64; + /// use concrete_fft::unordered::{Method, Plan}; + /// use dyn_stack::{DynStack, GlobalMemBuffer}; + /// use core::time::Duration; + /// + /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); + /// + /// let mut memory = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); + /// let stack = DynStack::new(&mut memory); + /// + /// let mut buf = [c64::default(); 4]; + /// plan.fwd(&mut buf, stack); + /// ``` + pub fn fwd(&self, buf: &mut [c64], stack: DynStack) { + assert_eq!(self.fft_size(), buf.len()); + let (mut scratch, _) = stack.make_aligned_uninit::(self.algo().1, CACHELINE_ALIGN); + unsafe { + fwd_depth( + self.n, + buf.as_mut_ptr(), + self.twiddles.as_ptr(), + self.base_fn_fwd, + self.base_n, + scratch.as_mut_ptr() as *mut c64, + self.fwd_process_x2, + self.fwd_process_x4, + self.fwd_process_x8, + ); + } + } + + /// Performs an inverse FFT in place, using the provided stack as scratch space. + /// + /// # Note + /// + /// The values in `buf` must be in permuted order prior to calling this function. + /// When this function returns, the values in `buf` will contain the terms of the forward + /// transform in standard order. + /// + /// # Example + /// + #[cfg_attr(feature = "std", doc = " ```")] + #[cfg_attr(not(feature = "std"), doc = " ```ignore")] + /// use concrete_fft::c64; + /// use concrete_fft::unordered::{Method, Plan}; + /// use dyn_stack::{DynStack, GlobalMemBuffer, ReborrowMut}; + /// use core::time::Duration; + /// + /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); + /// + /// let mut memory = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); + /// let mut stack = DynStack::new(&mut memory); + /// + /// let mut buf = [c64::default(); 4]; + /// plan.fwd(&mut buf, stack.rb_mut()); + /// plan.inv(&mut buf, stack); + /// ``` + pub fn inv(&self, buf: &mut [c64], stack: DynStack) { + assert_eq!(self.fft_size(), buf.len()); + let (mut scratch, _) = stack.make_aligned_uninit::(self.algo().1, CACHELINE_ALIGN); + unsafe { + inv_depth( + self.n, + buf.as_mut_ptr(), + self.twiddles_inv.as_ptr().add(self.n + self.base_n), + self.base_fn_inv, + self.base_n, + scratch.as_mut_ptr() as *mut c64, + self.inv_process_x2, + self.inv_process_x4, + self.inv_process_x8, + ); + } + } + + /// Serialize a buffer containing data in the Fourier domain that is stored in the + /// plan-specific permuted order, and store the result with the serializer in the standard + /// order. + /// + /// # Panics + /// + /// - Panics if the length of `buf` is not equal to the FFT size. + #[cfg(feature = "serde")] + #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] + pub fn serialize_fourier_buffer( + &self, + serializer: S, + buf: &[c64], + ) -> Result { + use serde::ser::SerializeSeq; + + let n = self.n; + let base_n = self.base_n; + assert_eq!(n, buf.len()); + + let mut seq = serializer.serialize_seq(Some(n))?; + + let nbits = n.trailing_zeros(); + let base_nbits = base_n.trailing_zeros(); + + for i in 0..n { + seq.serialize_element(&buf[bit_rev_twice(nbits, base_nbits, i)])?; + } + + seq.end() + } + + /// Deserialize data in the Fourier domain that is produced by the deserializer in the standard + /// order into a buffer so that it will contain the data in the plan-specific permuted order + /// + /// # Panics + /// + /// - Panics if the length of `buf` is not equal to the FFT size. + #[cfg(feature = "serde")] + #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] + pub fn deserialize_fourier_buffer<'de, D: serde::Deserializer<'de>>( + &self, + deserializer: D, + buf: &mut [c64], + ) -> Result<(), D::Error> { + use serde::de::{SeqAccess, Visitor}; + + let n = self.n; + let base_n = self.base_n; + assert_eq!(n, buf.len()); + + struct SeqVisitor<'a> { + buf: &'a mut [c64], + base_n: usize, + } + + impl<'de, 'a> Visitor<'de> for SeqVisitor<'a> { + type Value = (); + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + write!( + formatter, + "a sequence of {} 64-bit complex numbers", + self.buf.len() + ) + } + + fn visit_seq(self, mut seq: S) -> Result + where + S: SeqAccess<'de>, + { + let n = self.buf.len(); + let nbits = n.trailing_zeros(); + let base_nbits = self.base_n.trailing_zeros(); + + let mut i = 0; + + while let Some(value) = seq.next_element::()? { + if i < n { + self.buf[bit_rev_twice(nbits, base_nbits, i)] = value; + } + + i += 1; + } + + if i != n { + Err(serde::de::Error::invalid_length(i, &self)) + } else { + Ok(()) + } + } + } + + deserializer.deserialize_seq(SeqVisitor { buf, base_n }) + } +} + +#[cfg(any(test, feature = "serde"))] +#[inline] +fn bit_rev(nbits: u32, i: usize) -> usize { + i.reverse_bits() >> (usize::BITS - nbits) +} + +#[cfg(any(test, feature = "serde"))] +#[inline] +fn bit_rev_twice(nbits: u32, base_nbits: u32, i: usize) -> usize { + let i_rev = bit_rev(nbits, i); + let bottom_mask = (1 << base_nbits) - 1; + let bottom_bits = bit_rev(base_nbits, i_rev); + (i_rev & !bottom_mask) | bottom_bits +} + +#[cfg(test)] +mod tests { + use super::*; + use alloc::vec; + use dyn_stack::GlobalMemBuffer; + use dyn_stack::ReborrowMut; + use num_complex::ComplexFloat; + use rand::random; + + extern crate alloc; + + #[test] + fn test_fwd() { + for n in [256, 512, 1024] { + let mut z = vec![c64::default(); n]; + + for z in &mut z { + z.re = random(); + z.im = random(); + } + + let mut z_ref = z.clone(); + let mut planner = rustfft::FftPlanner::new(); + let fwd = planner.plan_fft_forward(n); + fwd.process(&mut z_ref); + + let plan = Plan::new( + n, + Method::UserProvided { + base_algo: FftAlgo::Dif4, + base_n: 32, + }, + ); + let base_n = plan.algo().1; + let mut mem = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); + let stack = DynStack::new(&mut *mem); + plan.fwd(&mut z, stack); + + for i in 0..n { + assert!( + (z[bit_rev_twice(n.trailing_zeros(), base_n.trailing_zeros(), i)] - z_ref[i]) + .abs() + < 1e-12 + ); + } + } + } + + #[test] + fn test_roundtrip() { + for n in [32, 64, 256, 512, 1024] { + let mut z = vec![c64::default(); n]; + + for z in &mut z { + z.re = random(); + z.im = random(); + } + + let orig = z.clone(); + + let plan = Plan::new( + n, + Method::UserProvided { + base_algo: FftAlgo::Dif4, + base_n: 32, + }, + ); + let mut mem = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); + let mut stack = DynStack::new(&mut *mem); + plan.fwd(&mut z, stack.rb_mut()); + plan.inv(&mut z, stack); + + for z in &mut z { + *z /= n as f64; + } + + for (z_actual, z_expected) in z.iter().zip(&orig) { + assert!((z_actual - z_expected).abs() < 1e-12); + } + } + } +} + +#[cfg(all(test, feature = "serde"))] +mod tests_serde { + use super::*; + use dyn_stack::GlobalMemBuffer; + use num_complex::ComplexFloat; + use rand::random; + + #[test] + fn test_serde() { + for n in [64, 128, 256, 512, 1024] { + let mut z = vec![c64::default(); n]; + + for z in &mut z { + z.re = random(); + z.im = random(); + } + + let orig = z.clone(); + + let plan1 = Plan::new( + n, + Method::UserProvided { + base_algo: FftAlgo::Dif4, + base_n: 32, + }, + ); + let plan2 = Plan::new( + n, + Method::UserProvided { + base_algo: FftAlgo::Dif4, + base_n: 64, + }, + ); + + let mut mem = GlobalMemBuffer::new( + plan1 + .fft_scratch() + .unwrap() + .or(plan2.fft_scratch().unwrap()), + ); + let mut stack = DynStack::new(&mut *mem); + + plan1.fwd(&mut z, stack.rb_mut()); + + let mut buf = Vec::::new(); + let mut serializer = bincode::Serializer::new(&mut buf, bincode::options()); + plan1.serialize_fourier_buffer(&mut serializer, &z).unwrap(); + + let mut deserializer = bincode::de::Deserializer::from_slice(&buf, bincode::options()); + plan2 + .deserialize_fourier_buffer(&mut deserializer, &mut z) + .unwrap(); + + plan2.inv(&mut z, stack); + + for z in &mut z { + *z /= n as f64; + } + + for (z_actual, z_expected) in z.iter().zip(&orig) { + assert!((z_actual - z_expected).abs() < 1e-12); + } + } + } +} diff --git a/src/x86.rs b/src/x86.rs new file mode 100644 index 0000000..f5a207f --- /dev/null +++ b/src/x86.rs @@ -0,0 +1,377 @@ +use crate::c64; +use crate::fft_simd::{FftSimd64, FftSimd64X2}; + +#[cfg(target_arch = "x86")] +use core::arch::x86::*; + +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +#[derive(Copy, Clone, Debug)] +pub struct AvxX2; +#[derive(Copy, Clone, Debug)] +pub struct AvxX1; + +#[derive(Copy, Clone, Debug)] +pub struct FmaX2; +#[derive(Copy, Clone, Debug)] +pub struct FmaX1; + +#[cfg(feature = "nightly")] +pub struct Avx512X4; +#[cfg(feature = "nightly")] +pub struct Avx512X2; +#[cfg(feature = "nightly")] +pub struct Avx512X1; + +macro_rules! reimpl { + ( + as $ty: ty: + $( + unsafe fn $name: ident($($arg_name: ident: $arg_ty: ty),* $(,)?) $(-> $ret: ty)?; + )* + ) => { + $( + #[inline(always)] + unsafe fn $name($($arg_name: $arg_ty),*) $(-> $ret)? { + <$ty>::$name($($arg_name),*) + } + )* + }; +} + +impl FftSimd64 for AvxX1 { + type Reg = __m128d; + + const COMPLEX_PER_REG: usize = 1; + + #[inline(always)] + unsafe fn splat_re_im(ptr: *const f64) -> Self::Reg { + _mm_set1_pd(*ptr) + } + + #[inline(always)] + unsafe fn splat(ptr: *const crate::c64) -> Self::Reg { + Self::load(ptr) + } + + #[inline(always)] + unsafe fn load(ptr: *const crate::c64) -> Self::Reg { + _mm_loadu_pd(ptr as _) + } + + #[inline(always)] + unsafe fn store(ptr: *mut crate::c64, z: Self::Reg) { + _mm_storeu_pd(ptr as _, z); + } + + #[inline(always)] + unsafe fn xor(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm_xor_pd(a, b) + } + + #[inline(always)] + unsafe fn swap_re_im(xy: Self::Reg) -> Self::Reg { + _mm_permute_pd::<0b01>(xy) + } + + #[inline(always)] + unsafe fn add(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm_add_pd(a, b) + } + + #[inline(always)] + unsafe fn sub(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm_sub_pd(a, b) + } + + #[inline(always)] + unsafe fn cwise_mul(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm_mul_pd(a, b) + } + + #[inline(always)] + unsafe fn mul(a: Self::Reg, b: Self::Reg) -> Self::Reg { + let ab = a; + let xy = b; + let aa = _mm_unpacklo_pd(ab, ab); + let bb = _mm_unpackhi_pd(ab, ab); + let yx = Self::swap_re_im(xy); + _mm_addsub_pd(_mm_mul_pd(aa, xy), _mm_mul_pd(bb, yx)) + } +} + +impl FftSimd64 for AvxX2 { + type Reg = __m256d; + + const COMPLEX_PER_REG: usize = 2; + + #[inline(always)] + unsafe fn splat_re_im(ptr: *const f64) -> Self::Reg { + _mm256_set1_pd(*ptr) + } + + #[inline(always)] + unsafe fn splat(ptr: *const crate::c64) -> Self::Reg { + let tmp = _mm_loadu_pd(ptr as _); + _mm256_broadcast_pd(&tmp) + } + + #[inline(always)] + unsafe fn load(ptr: *const crate::c64) -> Self::Reg { + _mm256_loadu_pd(ptr as _) + } + + #[inline(always)] + unsafe fn store(ptr: *mut crate::c64, z: Self::Reg) { + _mm256_storeu_pd(ptr as _, z); + } + + #[inline(always)] + unsafe fn xor(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm256_xor_pd(a, b) + } + + #[inline(always)] + unsafe fn swap_re_im(xy: Self::Reg) -> Self::Reg { + _mm256_permute_pd::<0b0101>(xy) + } + + #[inline(always)] + unsafe fn add(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm256_add_pd(a, b) + } + + #[inline(always)] + unsafe fn sub(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm256_sub_pd(a, b) + } + + #[inline(always)] + unsafe fn cwise_mul(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm256_mul_pd(a, b) + } + + #[inline(always)] + unsafe fn mul(a: Self::Reg, b: Self::Reg) -> Self::Reg { + let ab = a; + let xy = b; + let aa = _mm256_unpacklo_pd(ab, ab); + let bb = _mm256_unpackhi_pd(ab, ab); + let yx = Self::swap_re_im(xy); + _mm256_addsub_pd(_mm256_mul_pd(aa, xy), _mm256_mul_pd(bb, yx)) + } +} + +impl FftSimd64 for FmaX1 { + type Reg = __m128d; + + const COMPLEX_PER_REG: usize = 1; + + reimpl! { as AvxX1: + unsafe fn splat_re_im(ptr: *const f64) -> Self::Reg; + unsafe fn splat(ptr: *const c64) -> Self::Reg; + unsafe fn load(ptr: *const c64) -> Self::Reg; + unsafe fn store(ptr: *mut c64, z: Self::Reg); + unsafe fn xor(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn swap_re_im(xy: Self::Reg) -> Self::Reg; + unsafe fn add(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn sub(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn cwise_mul(a: Self::Reg, b: Self::Reg) -> Self::Reg; + } + + #[inline(always)] + unsafe fn mul(a: Self::Reg, b: Self::Reg) -> Self::Reg { + let ab = a; + let xy = b; + let aa = _mm_unpacklo_pd(ab, ab); + let bb = _mm_unpackhi_pd(ab, ab); + let yx = Self::swap_re_im(xy); + _mm_fmaddsub_pd(aa, xy, _mm_mul_pd(bb, yx)) + } +} + +impl FftSimd64 for FmaX2 { + type Reg = __m256d; + + const COMPLEX_PER_REG: usize = 2; + + reimpl! { as AvxX2: + unsafe fn splat_re_im(ptr: *const f64) -> Self::Reg; + unsafe fn splat(ptr: *const c64) -> Self::Reg; + unsafe fn load(ptr: *const c64) -> Self::Reg; + unsafe fn store(ptr: *mut c64, z: Self::Reg); + unsafe fn xor(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn swap_re_im(xy: Self::Reg) -> Self::Reg; + unsafe fn add(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn sub(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn cwise_mul(a: Self::Reg, b: Self::Reg) -> Self::Reg; + } + + #[inline(always)] + unsafe fn mul(a: Self::Reg, b: Self::Reg) -> Self::Reg { + let ab = a; + let xy = b; + let aa = _mm256_unpacklo_pd(ab, ab); + let bb = _mm256_unpackhi_pd(ab, ab); + let yx = Self::swap_re_im(xy); + _mm256_fmaddsub_pd(aa, xy, _mm256_mul_pd(bb, yx)) + } +} + +#[cfg(feature = "nightly")] +impl FftSimd64 for Avx512X1 { + type Reg = __m128d; + + const COMPLEX_PER_REG: usize = 1; + + reimpl! { as FmaX1: + unsafe fn splat_re_im(ptr: *const f64) -> Self::Reg; + unsafe fn splat(ptr: *const c64) -> Self::Reg; + unsafe fn load(ptr: *const c64) -> Self::Reg; + unsafe fn store(ptr: *mut c64, z: Self::Reg); + unsafe fn xor(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn swap_re_im(xy: Self::Reg) -> Self::Reg; + unsafe fn add(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn sub(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn cwise_mul(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn mul(a: Self::Reg, b: Self::Reg) -> Self::Reg; + } +} + +#[cfg(feature = "nightly")] +impl FftSimd64 for Avx512X2 { + type Reg = __m256d; + + const COMPLEX_PER_REG: usize = 2; + + reimpl! { as FmaX2: + unsafe fn splat_re_im(ptr: *const f64) -> Self::Reg; + unsafe fn splat(ptr: *const c64) -> Self::Reg; + unsafe fn load(ptr: *const c64) -> Self::Reg; + unsafe fn store(ptr: *mut c64, z: Self::Reg); + unsafe fn xor(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn swap_re_im(xy: Self::Reg) -> Self::Reg; + unsafe fn add(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn sub(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn cwise_mul(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn mul(a: Self::Reg, b: Self::Reg) -> Self::Reg; + } +} + +#[cfg(feature = "nightly")] +impl FftSimd64 for Avx512X4 { + type Reg = __m512d; + + const COMPLEX_PER_REG: usize = 4; + + #[inline(always)] + unsafe fn splat_re_im(ptr: *const f64) -> Self::Reg { + _mm512_set1_pd(*ptr) + } + + #[inline(always)] + unsafe fn splat(ptr: *const crate::c64) -> Self::Reg { + _mm512_castps_pd(_mm512_broadcast_f32x4(_mm_castpd_ps(_mm_loadu_pd( + ptr as _, + )))) + } + + #[inline(always)] + unsafe fn load(ptr: *const crate::c64) -> Self::Reg { + _mm512_loadu_pd(ptr as _) + } + + #[inline(always)] + unsafe fn store(ptr: *mut crate::c64, z: Self::Reg) { + _mm512_storeu_pd(ptr as _, z); + } + + #[inline(always)] + unsafe fn xor(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm512_castsi512_pd(_mm512_xor_si512( + _mm512_castpd_si512(a), + _mm512_castpd_si512(b), + )) + } + + #[inline(always)] + unsafe fn swap_re_im(xy: Self::Reg) -> Self::Reg { + _mm512_permute_pd::<0b01010101>(xy) + } + + #[inline(always)] + unsafe fn add(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm512_add_pd(a, b) + } + + #[inline(always)] + unsafe fn sub(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm512_sub_pd(a, b) + } + + #[inline(always)] + unsafe fn cwise_mul(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm512_mul_pd(a, b) + } + + #[inline(always)] + unsafe fn mul(a: Self::Reg, b: Self::Reg) -> Self::Reg { + let ab = a; + let xy = b; + let aa = _mm512_unpacklo_pd(ab, ab); + let bb = _mm512_unpackhi_pd(ab, ab); + let yx = Self::swap_re_im(xy); + _mm512_fmaddsub_pd(aa, xy, _mm512_mul_pd(bb, yx)) + } +} + +impl FftSimd64X2 for AvxX2 { + #[inline(always)] + unsafe fn catlo(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm256_permute2f128_pd::<0b00100000>(a, b) + } + + #[inline(always)] + unsafe fn cathi(a: Self::Reg, b: Self::Reg) -> Self::Reg { + _mm256_permute2f128_pd::<0b00110001>(a, b) + } +} + +impl FftSimd64X2 for FmaX2 { + reimpl! { as AvxX2: + unsafe fn catlo(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn cathi(a: Self::Reg, b: Self::Reg) -> Self::Reg; + } +} + +#[cfg(feature = "nightly")] +impl FftSimd64X2 for Avx512X2 { + reimpl! { as AvxX2: + unsafe fn catlo(a: Self::Reg, b: Self::Reg) -> Self::Reg; + unsafe fn cathi(a: Self::Reg, b: Self::Reg) -> Self::Reg; + } +} + +#[cfg(feature = "nightly")] +impl crate::fft_simd::FftSimd64X4 for Avx512X4 { + #[inline(always)] + unsafe fn transpose( + r0: Self::Reg, + r1: Self::Reg, + r2: Self::Reg, + r3: Self::Reg, + ) -> (Self::Reg, Self::Reg, Self::Reg, Self::Reg) { + let t0 = _mm512_shuffle_f64x2::<0b10001000>(r0, r1); + let t1 = _mm512_shuffle_f64x2::<0b11011101>(r0, r1); + let t2 = _mm512_shuffle_f64x2::<0b10001000>(r2, r3); + let t3 = _mm512_shuffle_f64x2::<0b11011101>(r2, r3); + + let s0 = _mm512_shuffle_f64x2::<0b10001000>(t0, t2); + let s1 = _mm512_shuffle_f64x2::<0b11011101>(t0, t2); + let s2 = _mm512_shuffle_f64x2::<0b10001000>(t1, t3); + let s3 = _mm512_shuffle_f64x2::<0b11011101>(t1, t3); + + (s0, s2, s1, s3) + } +}