numerics 0.1.0
Loading...
Searching...
No Matches
impl.hpp
Go to the documentation of this file.
1/// @file spectral/backends/stdsimd/impl.hpp
2/// @brief std::experimental::simd butterfly for FFT.
3///
4/// Uses the portable C++ SIMD abstraction (<experimental/simd>, GCC 11+).
5/// On AVX2 platforms vd::size() == 4 (4 doubles/register), on NEON == 2.
6///
7/// The gather uses the generator-lambda constructor:
8/// simd<double, abi> ur([](int k){ return a[j+k].real(); });
9/// which the compiler maps to SIMD gather instructions when possible.
10///
11/// The scatter-store back to AoS is the main cost difference vs. the
12/// handwritten backend -- the element-wise write loop is the comparison point.
13///
14/// Only compiled when NUMERICS_HAS_STD_SIMD is defined.
15#pragma once
16#ifdef NUMERICS_HAS_STD_SIMD
17 #include "spectral/fft.hpp"
18 #include "../seq/impl.hpp"
19 #include <experimental/simd>
20 #include <cmath>
21 #include <stdexcept>
22 #include <vector>
23
24namespace stdx = std::experimental;
25
26namespace backends {
27namespace stdsimd {
28
29static constexpr double TWO_PI = 6.283185307179586476925286766559;
30
31// FFTPlanImpl
32
33struct FFTPlanImpl {
34 int n;
35 bool invert;
36 std::vector<std::vector<num::cplx>> twiddles;
37
38 FFTPlanImpl(int n_, bool inv)
39 : n(n_)
40 , invert(inv) {
41 if (n_ == 0 || (n_ & (n_ - 1)))
42 throw std::invalid_argument(
43 "FFTPlan: length must be a power of two");
44 for (int len = 2; len <= n_; len <<= 1) {
45 double ang = TWO_PI / static_cast<double>(len) * (inv ? 1.0 : -1.0);
46 num::cplx wlen{std::cos(ang), std::sin(ang)};
47 std::vector<num::cplx> tw(len / 2);
48 num::cplx w{1.0, 0.0};
49 for (int j = 0; j < len / 2; ++j) {
50 tw[j] = w;
51 w *= wlen;
52 }
53 twiddles.push_back(std::move(tw));
54 }
55 }
56
57 void execute(num::CVector& a) const {
58 using vd = stdx::simd<double, stdx::simd_abi::native<double>>;
59 constexpr int W = static_cast<int>(vd::size());
60
62 num::cplx* data = a.data();
63
64 int stage = 0;
65 for (int len = 2; len <= n; len <<= 1, ++stage) {
66 int hlen = len / 2;
67 const num::cplx* tw = twiddles[stage].data();
68
69 for (int i = 0; i < n; i += len) {
70 num::cplx* up = data + i;
71 num::cplx* vp = data + i + hlen;
72
73 int j = 0;
74 for (; j + W <= hlen; j += W) {
75 // Gather: split AoS complex into separate real/imag
76 // vectors.
77 vd ur([&](int k) -> double { return up[j + k].real(); });
78 vd ui([&](int k) -> double { return up[j + k].imag(); });
79 vd vr([&](int k) -> double { return vp[j + k].real(); });
80 vd vi([&](int k) -> double { return vp[j + k].imag(); });
81 vd wr([&](int k) -> double { return tw[j + k].real(); });
82 vd wi([&](int k) -> double { return tw[j + k].imag(); });
83
84 // Complex multiply: t = v * w
85 vd tr = vr * wr - vi * wi;
86 vd ti = vr * wi + vi * wr;
87
88 // Butterfly + scatter store
89 for (int k = 0; k < W; ++k) {
90 up[j + k] = {ur[k] + tr[k], ui[k] + ti[k]};
91 vp[j + k] = {ur[k] - tr[k], ui[k] - ti[k]};
92 }
93 }
94 // scalar tail
95 for (; j < hlen; ++j) {
96 num::cplx t = vp[j] * tw[j];
97 num::cplx uu = up[j];
98 up[j] = uu + t;
99 vp[j] = uu - t;
100 }
101 }
102 }
103 }
104};
105
106// One-shot functions
107
108inline void fft(const num::CVector& in, num::CVector& out) {
109 int n = static_cast<int>(in.size());
110 for (int i = 0; i < n; ++i)
111 out[i] = in[i];
112 FFTPlanImpl plan(n, false);
113 plan.execute(out);
114}
115
116inline void ifft(const num::CVector& in, num::CVector& out) {
117 int n = static_cast<int>(in.size());
118 for (int i = 0; i < n; ++i)
119 out[i] = in[i];
120 FFTPlanImpl plan(n, true);
121 plan.execute(out);
122}
123
124inline void rfft(const num::Vector& in, num::CVector& out) {
125 int n = static_cast<int>(in.size());
126 num::CVector tmp(static_cast<num::idx>(n), num::cplx{0, 0});
127 for (int i = 0; i < n; ++i)
128 tmp[i] = {in[i], 0.0};
129 FFTPlanImpl plan(n, false);
130 plan.execute(tmp);
131 for (int k = 0; k < n / 2 + 1; ++k)
132 out[k] = tmp[k];
133}
134
135inline void irfft(const num::CVector& in, int n, num::Vector& out) {
136 num::CVector tmp(static_cast<num::idx>(n), num::cplx{0, 0});
137 for (int k = 0; k < n / 2 + 1; ++k)
138 tmp[k] = in[k];
139 for (int k = 1; k < (n - 1) / 2 + 1; ++k)
140 tmp[n - k] = std::conj(in[k]);
141 FFTPlanImpl plan(n, true);
142 plan.execute(tmp);
143 for (int i = 0; i < n; ++i)
144 out[i] = tmp[i].real();
145}
146
147} // namespace stdsimd
148} // namespace backends
149
150#endif // NUMERICS_HAS_STD_SIMD
Dense vector with optional GPU storage, templated over scalar type T.
Definition vector.hpp:24
constexpr idx size() const noexcept
Definition vector.hpp:80
FFT interface with backend dispatch.
void bit_reverse(num::CVector &a)
Definition impl.hpp:15
void ifft(const CVector &in, CVector &out, FFTBackend b=default_fft_backend)
Inverse complex DFT (unnormalised: result = n * true_inverse).
Definition fft.cpp:40
void fft(const CVector &in, CVector &out, FFTBackend b=default_fft_backend)
Forward complex DFT. out must be pre-allocated to in.size().
Definition fft.cpp:15
void irfft(const CVector &in, int n, Vector &out, FFTBackend b=default_fft_backend)
Complex-to-real inverse DFT (unnormalised).
Definition fft.cpp:88
void rfft(const Vector &in, CVector &out, FFTBackend b=default_fft_backend)
Real-to-complex forward DFT. out must be pre-allocated to n/2+1.
Definition fft.cpp:64
double real
Definition types.hpp:10
std::size_t idx
Definition types.hpp:11
std::complex< real > cplx
Definition types.hpp:12