numerics
Loading...
Searching...
No Matches
impl.hpp
Go to the documentation of this file.
1/// @file spectral/backends/opt/impl.hpp
2/// @brief Handwritten AVX2 / NEON butterfly for the FFT.
3///
4/// Processes 2 complex butterflies per SIMD iteration using 256-bit (AVX2) or
5/// 128-bit (NEON) registers. Falls back to the seq scalar butterfly when
6/// neither ISA extension is available at compile time.
7///
8/// AVX2 complex butterfly (2 pairs per __m256d):
9/// u = [ur0, ui0, ur1, ui1]
10/// v = [vr0, vi0, vr1, vi1]
11/// w = [wr0, wi0, wr1, wi1]
12///
13/// w_re = unpacklo(w, w) -> [wr0, wr0, wr1, wr1]
14/// w_im = unpackhi(w, w) -> [wi0, wi0, wi1, wi1]
15/// v_sw = permute(v, 0b0101) -> [vi0, vr0, vi1, vr1] (swap re/im)
16/// t = addsub(v*w_re, v_sw*w_im)
17/// = [vr0*wr0 - vi0*wi0, vr0*wi0 + vi0*wr0,
18/// vr1*wr1 - vi1*wi1, vr1*wi1 + vi1*wr1]
19///
20/// NEON deinterleaves with vld2q_f64 (SoA load), multiplies component-wise,
21/// and re-interleaves with vst2q_f64.
22#pragma once
23#include "spectral/fft.hpp"
24#include "../seq/impl.hpp"
25#include <cmath>
26#include <stdexcept>
27#include <vector>
28
29#ifdef NUMERICS_HAS_AVX2
30# include <immintrin.h>
31#endif
32#ifdef NUMERICS_HAS_NEON
33# include <arm_neon.h>
34#endif
35
36namespace backends {
37namespace opt {
38
39static constexpr double TWO_PI = 6.283185307179586476925286766559;
40
41// ---- platform butterfly implementations ------------------------------------
42
43#ifdef NUMERICS_HAS_AVX2
44
45/// Butterfly on hlen complex pairs using AVX2 (2 pairs per iteration).
46static inline void butterfly(num::cplx* __restrict__ u,
47 num::cplx* __restrict__ v,
48 const num::cplx* __restrict__ tw,
49 int hlen)
50{
51 auto* ud = reinterpret_cast<double*>(u);
52 auto* vd = reinterpret_cast<double*>(v);
53 auto* wd = reinterpret_cast<const double*>(tw);
54
55 int j = 0;
56 for (; j + 1 < hlen; j += 2) {
57 __m256d U = _mm256_loadu_pd(ud + 2*j);
58 __m256d V = _mm256_loadu_pd(vd + 2*j);
59 __m256d W = _mm256_loadu_pd(wd + 2*j);
60 // complex multiply V * W
61 __m256d Wre = _mm256_unpacklo_pd(W, W); // [wr0,wr0,wr1,wr1]
62 __m256d Wim = _mm256_unpackhi_pd(W, W); // [wi0,wi0,wi1,wi1]
63 __m256d Vsw = _mm256_permute_pd(V, 0x5); // [vi0,vr0,vi1,vr1]
64 __m256d T = _mm256_addsub_pd(
65 _mm256_mul_pd(V, Wre),
66 _mm256_mul_pd(Vsw, Wim));
67 _mm256_storeu_pd(ud + 2*j, _mm256_add_pd(U, T));
68 _mm256_storeu_pd(vd + 2*j, _mm256_sub_pd(U, T));
69 }
70 // scalar tail for odd hlen
71 for (; j < hlen; ++j) {
72 num::cplx t = v[j] * tw[j];
73 num::cplx uu = u[j];
74 u[j] = uu + t;
75 v[j] = uu - t;
76 }
77}
78
79#elif defined(NUMERICS_HAS_NEON)
80
81/// Butterfly on hlen complex pairs using NEON vld2/vst2 (SoA deinterleave).
82static inline void butterfly(num::cplx* __restrict__ u,
83 num::cplx* __restrict__ v,
84 const num::cplx* __restrict__ tw,
85 int hlen)
86{
87 auto* ud = reinterpret_cast<double*>(u);
88 auto* vd = reinterpret_cast<double*>(v);
89 auto* wd = reinterpret_cast<const double*>(tw);
90
91 int j = 0;
92 for (; j + 1 < hlen; j += 2) {
93 // Deinterleaved load: .val[0] = [re0,re1], .val[1] = [im0,im1]
94 float64x2x2_t U = vld2q_f64(ud + 2*j);
95 float64x2x2_t V = vld2q_f64(vd + 2*j);
96 float64x2x2_t W = vld2q_f64(wd + 2*j);
97
98 // T = V * W (complex multiply, component-wise on SoA data)
99 // Tr = Vr*Wr - Vi*Wi
100 float64x2_t Tr = vfmsq_f64(vmulq_f64(V.val[0], W.val[0]),
101 V.val[1], W.val[1]);
102 // Ti = Vr*Wi + Vi*Wr
103 float64x2_t Ti = vfmaq_f64(vmulq_f64(V.val[0], W.val[1]),
104 V.val[1], W.val[0]);
105
106 float64x2x2_t Ru, Rv;
107 Ru.val[0] = vaddq_f64(U.val[0], Tr);
108 Ru.val[1] = vaddq_f64(U.val[1], Ti);
109 Rv.val[0] = vsubq_f64(U.val[0], Tr);
110 Rv.val[1] = vsubq_f64(U.val[1], Ti);
111 vst2q_f64(ud + 2*j, Ru);
112 vst2q_f64(vd + 2*j, Rv);
113 }
114 // scalar tail
115 for (; j < hlen; ++j) {
116 num::cplx t = v[j] * tw[j];
117 num::cplx uu = u[j];
118 u[j] = uu + t;
119 v[j] = uu - t;
120 }
121}
122
123#else
124
125/// Scalar fallback when no SIMD ISA is available.
126static inline void butterfly(num::cplx* u, num::cplx* v,
127 const num::cplx* tw, int hlen)
128{
129 for (int j = 0; j < hlen; ++j) {
130 num::cplx t = v[j] * tw[j];
131 num::cplx uu = u[j];
132 u[j] = uu + t;
133 v[j] = uu - t;
134 }
135}
136
137#endif // NUMERICS_HAS_AVX2 / NEON
138
139// ---- FFTPlanImpl -----------------------------------------------------------
140
141/// Precomputed twiddle factors + SIMD butterfly execution.
143 int n;
144 bool invert;
145 std::vector<std::vector<num::cplx>> twiddles;
146
147 FFTPlanImpl(int n_, bool inv) : n(n_), invert(inv) {
148 if (n_ == 0 || (n_ & (n_ - 1)))
149 throw std::invalid_argument("FFTPlan: length must be a power of two");
150 for (int len = 2; len <= n_; len <<= 1) {
151 double ang = TWO_PI / static_cast<double>(len) * (inv ? 1.0 : -1.0);
152 num::cplx wlen{std::cos(ang), std::sin(ang)};
153 std::vector<num::cplx> tw(len / 2);
154 num::cplx w{1.0, 0.0};
155 for (int j = 0; j < len / 2; ++j) { tw[j] = w; w *= wlen; }
156 twiddles.push_back(std::move(tw));
157 }
158 }
159
160 void execute(num::CVector& a) const {
162 num::cplx* data = a.data();
163 int stage = 0;
164 for (int len = 2; len <= n; len <<= 1, ++stage) {
165 int hlen = len / 2;
166 const num::cplx* tw = twiddles[stage].data();
167 for (int i = 0; i < n; i += len)
168 butterfly(data + i, data + i + hlen, tw, hlen);
169 }
170 }
171};
172
173// ---- one-shot functions ----------------------------------------------------
174
175inline void fft(const num::CVector& in, num::CVector& out) {
176 int n = static_cast<int>(in.size());
177 for (int i = 0; i < n; ++i) out[i] = in[i];
178 FFTPlanImpl plan(n, false);
179 plan.execute(out);
180}
181
182inline void ifft(const num::CVector& in, num::CVector& out) {
183 int n = static_cast<int>(in.size());
184 for (int i = 0; i < n; ++i) out[i] = in[i];
185 FFTPlanImpl plan(n, true);
186 plan.execute(out);
187}
188
189inline void rfft(const num::Vector& in, num::CVector& out) {
190 int n = static_cast<int>(in.size());
191 num::CVector tmp(static_cast<num::idx>(n), num::cplx{0, 0});
192 for (int i = 0; i < n; ++i) tmp[i] = {in[i], 0.0};
193 FFTPlanImpl plan(n, false);
194 plan.execute(tmp);
195 for (int k = 0; k < n / 2 + 1; ++k) out[k] = tmp[k];
196}
197
198inline void irfft(const num::CVector& in, int n, num::Vector& out) {
199 num::CVector tmp(static_cast<num::idx>(n), num::cplx{0, 0});
200 for (int k = 0; k < n / 2 + 1; ++k) tmp[k] = in[k];
201 for (int k = 1; k < (n - 1) / 2 + 1; ++k)
202 tmp[n - k] = std::conj(in[k]);
203 FFTPlanImpl plan(n, true);
204 plan.execute(tmp);
205 for (int i = 0; i < n; ++i) out[i] = tmp[i].real();
206}
207
208} // namespace opt
209} // namespace backends
Dense vector with optional GPU storage, templated over scalar type T.
Definition vector.hpp:23
constexpr idx size() const noexcept
Definition vector.hpp:77
FFT interface with backend dispatch.
void irfft(const num::CVector &in, int n, num::Vector &out)
Definition impl.hpp:198
void ifft(const num::CVector &in, num::CVector &out)
Definition impl.hpp:182
void fft(const num::CVector &in, num::CVector &out)
Definition impl.hpp:175
void rfft(const num::Vector &in, num::CVector &out)
Definition impl.hpp:189
void bit_reverse(num::CVector &a)
Definition impl.hpp:15
std::size_t idx
Definition types.hpp:11
std::complex< real > cplx
Definition types.hpp:12
Precomputed twiddle factors + SIMD butterfly execution.
Definition impl.hpp:142
void execute(num::CVector &a) const
Definition impl.hpp:160
std::vector< std::vector< num::cplx > > twiddles
Definition impl.hpp:145
FFTPlanImpl(int n_, bool inv)
Definition impl.hpp:147