numerics 0.1.0
Loading...
Searching...
No Matches
impl.hpp
Go to the documentation of this file.
1/// @file spectral/backends/stdsimd/impl.hpp
2/// @brief std::experimental::simd butterfly for FFT.
3///
4/// Applies \f$(u,v)\leftarrow(u+vw,u-vw)\f$ through
5/// \c std::experimental::simd when enabled.
6#pragma once
7#ifdef NUMERICS_HAS_STD_SIMD
8 #include "../seq/impl.hpp"
9 #include "spectral/fft.hpp"
10 #include <cmath>
11 #include <experimental/simd>
12 #include <stdexcept>
13 #include <vector>
14
15namespace stdx = std::experimental;
16
17namespace backends {
18namespace stdsimd {
19
20static constexpr double TWO_PI = 6.283185307179586476925286766559;
21
22// FFTPlanImpl
23
24struct FFTPlanImpl {
25 int n;
26 bool invert;
27 std::vector<std::vector<num::cplx>> twiddles;
28
29 FFTPlanImpl(int n_, bool inv)
30 : n(n_),
31 invert(inv) {
32 if (n_ == 0 || (n_ & (n_ - 1)))
33 throw std::invalid_argument("FFTPlan: length must be a power of two");
34 for (int len = 2; len <= n_; len <<= 1) {
35 double ang = TWO_PI / static_cast<double>(len) * (inv ? 1.0 : -1.0);
36 num::cplx wlen{std::cos(ang), std::sin(ang)};
37 std::vector<num::cplx> tw(len / 2);
38 num::cplx w{1.0, 0.0};
39 for (int j = 0; j < len / 2; ++j) {
40 tw[j] = w;
41 w *= wlen;
42 }
43 twiddles.push_back(std::move(tw));
44 }
45 }
46
47 void execute(num::CVector& a) const {
48 using vd = stdx::simd<double, stdx::simd_abi::native<double>>;
49 constexpr int W = static_cast<int>(vd::size());
50
52 num::cplx* data = a.data();
53
54 int stage = 0;
55 for (int len = 2; len <= n; len <<= 1, ++stage) {
56 int hlen = len / 2;
57 const num::cplx* tw = twiddles[stage].data();
58
59 for (int i = 0; i < n; i += len) {
60 num::cplx* up = data + i;
61 num::cplx* vp = data + i + hlen;
62
63 int j = 0;
64 for (; j + W <= hlen; j += W) {
65 // Gather: split AoS complex into separate real/imag
66 // vectors.
67 vd ur([&](int k) -> double { return up[j + k].real(); });
68 vd ui([&](int k) -> double { return up[j + k].imag(); });
69 vd vr([&](int k) -> double { return vp[j + k].real(); });
70 vd vi([&](int k) -> double { return vp[j + k].imag(); });
71 vd wr([&](int k) -> double { return tw[j + k].real(); });
72 vd wi([&](int k) -> double { return tw[j + k].imag(); });
73
74 // Complex multiply: t = v * w
75 vd tr = vr * wr - vi * wi;
76 vd ti = vr * wi + vi * wr;
77
78 // Butterfly + scatter store
79 for (int k = 0; k < W; ++k) {
80 up[j + k] = {ur[k] + tr[k], ui[k] + ti[k]};
81 vp[j + k] = {ur[k] - tr[k], ui[k] - ti[k]};
82 }
83 }
84 // scalar tail
85 for (; j < hlen; ++j) {
86 num::cplx t = vp[j] * tw[j];
87 num::cplx uu = up[j];
88 up[j] = uu + t;
89 vp[j] = uu - t;
90 }
91 }
92 }
93 }
94};
95
96inline void fft(const num::CVector& in, num::CVector& out) {
97 int n = static_cast<int>(in.size());
98 for (int i = 0; i < n; ++i)
99 out[i] = in[i];
100 FFTPlanImpl plan(n, false);
101 plan.execute(out);
102}
103
104inline void ifft(const num::CVector& in, num::CVector& out) {
105 int n = static_cast<int>(in.size());
106 for (int i = 0; i < n; ++i)
107 out[i] = in[i];
108 FFTPlanImpl plan(n, true);
109 plan.execute(out);
110}
111
112inline void rfft(const num::Vector& in, num::CVector& out) {
113 int n = static_cast<int>(in.size());
114 num::CVector tmp(static_cast<num::idx>(n), num::cplx{0, 0});
115 for (int i = 0; i < n; ++i)
116 tmp[i] = {in[i], 0.0};
117 FFTPlanImpl plan(n, false);
118 plan.execute(tmp);
119 for (int k = 0; k < n / 2 + 1; ++k)
120 out[k] = tmp[k];
121}
122
123inline void irfft(const num::CVector& in, int n, num::Vector& out) {
124 num::CVector tmp(static_cast<num::idx>(n), num::cplx{0, 0});
125 for (int k = 0; k < n / 2 + 1; ++k)
126 tmp[k] = in[k];
127 for (int k = 1; k < (n - 1) / 2 + 1; ++k)
128 tmp[n - k] = std::conj(in[k]);
129 FFTPlanImpl plan(n, true);
130 plan.execute(tmp);
131 for (int i = 0; i < n; ++i)
132 out[i] = tmp[i].real();
133}
134
135} // namespace stdsimd
136} // namespace backends
137
138#endif // NUMERICS_HAS_STD_SIMD
Dense owning vector.
Definition vector.hpp:16
constexpr idx size() const noexcept
Definition vector.hpp:83
FFT interface with backend dispatch.
void bit_reverse(num::CVector &a)
Definition impl.hpp:15
void ifft(const CVector &in, CVector &out, FFTBackend b=default_fft_backend)
Definition fft.cpp:40
void fft(const CVector &in, CVector &out, FFTBackend b=default_fft_backend)
Definition fft.cpp:15
void irfft(const CVector &in, int n, Vector &out, FFTBackend b=default_fft_backend)
Definition fft.cpp:89
void rfft(const Vector &in, CVector &out, FFTBackend b=default_fft_backend)
Definition fft.cpp:65
double real
Definition types.hpp:10
std::size_t idx
Definition types.hpp:11
std::complex< real > cplx
Definition types.hpp:12