numerics 0.1.0
Loading...
Searching...
No Matches
impl.hpp
Go to the documentation of this file.
1/// @file spectral/backends/opt/impl.hpp
2/// @brief Handwritten AVX2 / NEON butterfly for the FFT.
3///
4/// Applies \f$(u,v)\leftarrow(u+vw,u-vw)\f$ with AVX2, NEON, or scalar code.
5#pragma once
6#include "../seq/impl.hpp"
7#include "spectral/fft.hpp"
8#include <cmath>
9#include <stdexcept>
10#include <vector>
11
12#ifdef NUMERICS_HAS_AVX2
13 #include <immintrin.h>
14#endif
15#ifdef NUMERICS_HAS_NEON
16 #include <arm_neon.h>
17#endif
18
19namespace backends {
20namespace opt {
21
22static constexpr double TWO_PI = 6.283185307179586476925286766559;
23
24// Platform butterfly implementations
25
26#ifdef NUMERICS_HAS_AVX2
27
28/// Butterfly on hlen complex pairs using AVX2 (2 pairs per iteration).
29static inline void butterfly(num::cplx* __restrict__ u,
30 num::cplx* __restrict__ v,
31 const num::cplx* __restrict__ tw,
32 int hlen) {
33 auto* ud = reinterpret_cast<double*>(u);
34 auto* vd = reinterpret_cast<double*>(v);
35 auto* wd = reinterpret_cast<const double*>(tw);
36
37 int j = 0;
38 for (; j + 1 < hlen; j += 2) {
39 __m256d U = _mm256_loadu_pd(ud + 2 * j);
40 __m256d V = _mm256_loadu_pd(vd + 2 * j);
41 __m256d W = _mm256_loadu_pd(wd + 2 * j);
42 // complex multiply V * W
43 __m256d Wre = _mm256_unpacklo_pd(W, W); // [wr0,wr0,wr1,wr1]
44 __m256d Wim = _mm256_unpackhi_pd(W, W); // [wi0,wi0,wi1,wi1]
45 __m256d Vsw = _mm256_permute_pd(V, 0x5); // [vi0,vr0,vi1,vr1]
46 __m256d T = _mm256_addsub_pd(_mm256_mul_pd(V, Wre), _mm256_mul_pd(Vsw, Wim));
47 _mm256_storeu_pd(ud + 2 * j, _mm256_add_pd(U, T));
48 _mm256_storeu_pd(vd + 2 * j, _mm256_sub_pd(U, T));
49 }
50 // scalar tail for odd hlen
51 for (; j < hlen; ++j) {
52 num::cplx t = v[j] * tw[j];
53 num::cplx uu = u[j];
54 u[j] = uu + t;
55 v[j] = uu - t;
56 }
57}
58
59#elif defined(NUMERICS_HAS_NEON)
60
61/// Butterfly on hlen complex pairs using NEON vld2/vst2 (SoA deinterleave).
62static inline void butterfly(num::cplx* __restrict__ u,
63 num::cplx* __restrict__ v,
64 const num::cplx* __restrict__ tw,
65 int hlen) {
66 auto* ud = reinterpret_cast<double*>(u);
67 auto* vd = reinterpret_cast<double*>(v);
68 auto* wd = reinterpret_cast<const double*>(tw);
69
70 int j = 0;
71 for (; j + 1 < hlen; j += 2) {
72 // Deinterleaved load: .val[0] = [re0,re1], .val[1] = [im0,im1]
73 float64x2x2_t U = vld2q_f64(ud + 2 * j);
74 float64x2x2_t V = vld2q_f64(vd + 2 * j);
75 float64x2x2_t W = vld2q_f64(wd + 2 * j);
76
77 // T = V * W (complex multiply, component-wise on SoA data)
78 // Tr = Vr*Wr - Vi*Wi
79 float64x2_t Tr = vfmsq_f64(vmulq_f64(V.val[0], W.val[0]), V.val[1], W.val[1]);
80 // Ti = Vr*Wi + Vi*Wr
81 float64x2_t Ti = vfmaq_f64(vmulq_f64(V.val[0], W.val[1]), V.val[1], W.val[0]);
82
83 float64x2x2_t Ru, Rv;
84 Ru.val[0] = vaddq_f64(U.val[0], Tr);
85 Ru.val[1] = vaddq_f64(U.val[1], Ti);
86 Rv.val[0] = vsubq_f64(U.val[0], Tr);
87 Rv.val[1] = vsubq_f64(U.val[1], Ti);
88 vst2q_f64(ud + 2 * j, Ru);
89 vst2q_f64(vd + 2 * j, Rv);
90 }
91 // scalar tail
92 for (; j < hlen; ++j) {
93 num::cplx t = v[j] * tw[j];
94 num::cplx uu = u[j];
95 u[j] = uu + t;
96 v[j] = uu - t;
97 }
98}
99
100#else
101
102/// Scalar fallback when no SIMD ISA is available.
103static inline void butterfly(num::cplx* u, num::cplx* v, const num::cplx* tw, int hlen) {
104 for (int j = 0; j < hlen; ++j) {
105 num::cplx t = v[j] * tw[j];
106 num::cplx uu = u[j];
107 u[j] = uu + t;
108 v[j] = uu - t;
109 }
110}
111
112#endif // NUMERICS_HAS_AVX2 / NEON
113
114// FFTPlanImpl
115
116/// Precomputed twiddle factors + SIMD butterfly execution.
118 int n;
119 bool invert;
120 std::vector<std::vector<num::cplx>> twiddles;
121
122 FFTPlanImpl(int n_, bool inv)
123 : n(n_),
124 invert(inv) {
125 if (n_ == 0 || (n_ & (n_ - 1)))
126 throw std::invalid_argument("FFTPlan: length must be a power of two");
127 for (int len = 2; len <= n_; len <<= 1) {
128 double ang = TWO_PI / static_cast<double>(len) * (inv ? 1.0 : -1.0);
129 num::cplx wlen{std::cos(ang), std::sin(ang)};
130 std::vector<num::cplx> tw(len / 2);
131 num::cplx w{1.0, 0.0};
132 for (int j = 0; j < len / 2; ++j) {
133 tw[j] = w;
134 w *= wlen;
135 }
136 twiddles.push_back(std::move(tw));
137 }
138 }
139
140 void execute(num::CVector& a) const {
142 num::cplx* data = a.data();
143 int stage = 0;
144 for (int len = 2; len <= n; len <<= 1, ++stage) {
145 int hlen = len / 2;
146 const num::cplx* tw = twiddles[stage].data();
147 for (int i = 0; i < n; i += len)
148 butterfly(data + i, data + i + hlen, tw, hlen);
149 }
150 }
151};
152
153inline void fft(const num::CVector& in, num::CVector& out) {
154 int n = static_cast<int>(in.size());
155 for (int i = 0; i < n; ++i)
156 out[i] = in[i];
157 FFTPlanImpl plan(n, false);
158 plan.execute(out);
159}
160
161inline void ifft(const num::CVector& in, num::CVector& out) {
162 int n = static_cast<int>(in.size());
163 for (int i = 0; i < n; ++i)
164 out[i] = in[i];
165 FFTPlanImpl plan(n, true);
166 plan.execute(out);
167}
168
169inline void rfft(const num::Vector& in, num::CVector& out) {
170 int n = static_cast<int>(in.size());
171 num::CVector tmp(static_cast<num::idx>(n), num::cplx{0, 0});
172 for (int i = 0; i < n; ++i)
173 tmp[i] = {in[i], 0.0};
174 FFTPlanImpl plan(n, false);
175 plan.execute(tmp);
176 for (int k = 0; k < n / 2 + 1; ++k)
177 out[k] = tmp[k];
178}
179
180inline void irfft(const num::CVector& in, int n, num::Vector& out) {
181 num::CVector tmp(static_cast<num::idx>(n), num::cplx{0, 0});
182 for (int k = 0; k < n / 2 + 1; ++k)
183 tmp[k] = in[k];
184 for (int k = 1; k < (n - 1) / 2 + 1; ++k)
185 tmp[n - k] = std::conj(in[k]);
186 FFTPlanImpl plan(n, true);
187 plan.execute(tmp);
188 for (int i = 0; i < n; ++i)
189 out[i] = tmp[i].real();
190}
191
192} // namespace opt
193} // namespace backends
Dense owning vector.
Definition vector.hpp:16
constexpr idx size() const noexcept
Definition vector.hpp:83
FFT interface with backend dispatch.
void irfft(const num::CVector &in, int n, num::Vector &out)
Definition impl.hpp:180
void ifft(const num::CVector &in, num::CVector &out)
Definition impl.hpp:161
void fft(const num::CVector &in, num::CVector &out)
Definition impl.hpp:153
void rfft(const num::Vector &in, num::CVector &out)
Definition impl.hpp:169
void bit_reverse(num::CVector &a)
Definition impl.hpp:15
std::size_t idx
Definition types.hpp:11
std::complex< real > cplx
Definition types.hpp:12
Precomputed twiddle factors + SIMD butterfly execution.
Definition impl.hpp:117
void execute(num::CVector &a) const
Definition impl.hpp:140
std::vector< std::vector< num::cplx > > twiddles
Definition impl.hpp:120
FFTPlanImpl(int n_, bool inv)
Definition impl.hpp:122