numerics 0.1.0
Loading...
Searching...
No Matches
impl.hpp
Go to the documentation of this file.
1/// @file spectral/backends/opt/impl.hpp
2/// @brief Handwritten AVX2 / NEON butterfly for the FFT.
3///
4/// Processes 2 complex butterflies per SIMD iteration using 256-bit (AVX2) or
5/// 128-bit (NEON) registers. Falls back to the seq scalar butterfly when
6/// neither ISA extension is available at compile time.
7///
8/// AVX2 complex butterfly (2 pairs per __m256d):
9/// u = [ur0, ui0, ur1, ui1]
10/// v = [vr0, vi0, vr1, vi1]
11/// w = [wr0, wi0, wr1, wi1]
12///
13/// w_re = unpacklo(w, w) -> [wr0, wr0, wr1, wr1]
14/// w_im = unpackhi(w, w) -> [wi0, wi0, wi1, wi1]
15/// v_sw = permute(v, 0b0101) -> [vi0, vr0, vi1, vr1] (swap re/im)
16/// t = addsub(v*w_re, v_sw*w_im)
17/// = [vr0*wr0 - vi0*wi0, vr0*wi0 + vi0*wr0,
18/// vr1*wr1 - vi1*wi1, vr1*wi1 + vi1*wr1]
19///
20/// NEON deinterleaves with vld2q_f64 (SoA load), multiplies component-wise,
21/// and re-interleaves with vst2q_f64.
22#pragma once
23#include "spectral/fft.hpp"
24#include "../seq/impl.hpp"
25#include <cmath>
26#include <stdexcept>
27#include <vector>
28
29#ifdef NUMERICS_HAS_AVX2
30 #include <immintrin.h>
31#endif
32#ifdef NUMERICS_HAS_NEON
33 #include <arm_neon.h>
34#endif
35
36namespace backends {
37namespace opt {
38
39static constexpr double TWO_PI = 6.283185307179586476925286766559;
40
41// Platform butterfly implementations
42
43#ifdef NUMERICS_HAS_AVX2
44
45/// Butterfly on hlen complex pairs using AVX2 (2 pairs per iteration).
46static inline void butterfly(num::cplx* __restrict__ u,
47 num::cplx* __restrict__ v,
48 const num::cplx* __restrict__ tw,
49 int hlen) {
50 auto* ud = reinterpret_cast<double*>(u);
51 auto* vd = reinterpret_cast<double*>(v);
52 auto* wd = reinterpret_cast<const double*>(tw);
53
54 int j = 0;
55 for (; j + 1 < hlen; j += 2) {
56 __m256d U = _mm256_loadu_pd(ud + 2 * j);
57 __m256d V = _mm256_loadu_pd(vd + 2 * j);
58 __m256d W = _mm256_loadu_pd(wd + 2 * j);
59 // complex multiply V * W
60 __m256d Wre = _mm256_unpacklo_pd(W, W); // [wr0,wr0,wr1,wr1]
61 __m256d Wim = _mm256_unpackhi_pd(W, W); // [wi0,wi0,wi1,wi1]
62 __m256d Vsw = _mm256_permute_pd(V, 0x5); // [vi0,vr0,vi1,vr1]
63 __m256d T = _mm256_addsub_pd(_mm256_mul_pd(V, Wre),
64 _mm256_mul_pd(Vsw, Wim));
65 _mm256_storeu_pd(ud + 2 * j, _mm256_add_pd(U, T));
66 _mm256_storeu_pd(vd + 2 * j, _mm256_sub_pd(U, T));
67 }
68 // scalar tail for odd hlen
69 for (; j < hlen; ++j) {
70 num::cplx t = v[j] * tw[j];
71 num::cplx uu = u[j];
72 u[j] = uu + t;
73 v[j] = uu - t;
74 }
75}
76
77#elif defined(NUMERICS_HAS_NEON)
78
79/// Butterfly on hlen complex pairs using NEON vld2/vst2 (SoA deinterleave).
80static inline void butterfly(num::cplx* __restrict__ u,
81 num::cplx* __restrict__ v,
82 const num::cplx* __restrict__ tw,
83 int hlen) {
84 auto* ud = reinterpret_cast<double*>(u);
85 auto* vd = reinterpret_cast<double*>(v);
86 auto* wd = reinterpret_cast<const double*>(tw);
87
88 int j = 0;
89 for (; j + 1 < hlen; j += 2) {
90 // Deinterleaved load: .val[0] = [re0,re1], .val[1] = [im0,im1]
91 float64x2x2_t U = vld2q_f64(ud + 2 * j);
92 float64x2x2_t V = vld2q_f64(vd + 2 * j);
93 float64x2x2_t W = vld2q_f64(wd + 2 * j);
94
95 // T = V * W (complex multiply, component-wise on SoA data)
96 // Tr = Vr*Wr - Vi*Wi
97 float64x2_t Tr =
98 vfmsq_f64(vmulq_f64(V.val[0], W.val[0]), V.val[1], W.val[1]);
99 // Ti = Vr*Wi + Vi*Wr
100 float64x2_t Ti =
101 vfmaq_f64(vmulq_f64(V.val[0], W.val[1]), V.val[1], W.val[0]);
102
103 float64x2x2_t Ru, Rv;
104 Ru.val[0] = vaddq_f64(U.val[0], Tr);
105 Ru.val[1] = vaddq_f64(U.val[1], Ti);
106 Rv.val[0] = vsubq_f64(U.val[0], Tr);
107 Rv.val[1] = vsubq_f64(U.val[1], Ti);
108 vst2q_f64(ud + 2 * j, Ru);
109 vst2q_f64(vd + 2 * j, Rv);
110 }
111 // scalar tail
112 for (; j < hlen; ++j) {
113 num::cplx t = v[j] * tw[j];
114 num::cplx uu = u[j];
115 u[j] = uu + t;
116 v[j] = uu - t;
117 }
118}
119
120#else
121
122/// Scalar fallback when no SIMD ISA is available.
123static inline void butterfly(num::cplx* u,
124 num::cplx* v,
125 const num::cplx* tw,
126 int hlen) {
127 for (int j = 0; j < hlen; ++j) {
128 num::cplx t = v[j] * tw[j];
129 num::cplx uu = u[j];
130 u[j] = uu + t;
131 v[j] = uu - t;
132 }
133}
134
135#endif // NUMERICS_HAS_AVX2 / NEON
136
137// FFTPlanImpl
138
139/// Precomputed twiddle factors + SIMD butterfly execution.
141 int n;
142 bool invert;
143 std::vector<std::vector<num::cplx>> twiddles;
144
145 FFTPlanImpl(int n_, bool inv)
146 : n(n_)
147 , invert(inv) {
148 if (n_ == 0 || (n_ & (n_ - 1)))
149 throw std::invalid_argument(
150 "FFTPlan: length must be a power of two");
151 for (int len = 2; len <= n_; len <<= 1) {
152 double ang = TWO_PI / static_cast<double>(len) * (inv ? 1.0 : -1.0);
153 num::cplx wlen{std::cos(ang), std::sin(ang)};
154 std::vector<num::cplx> tw(len / 2);
155 num::cplx w{1.0, 0.0};
156 for (int j = 0; j < len / 2; ++j) {
157 tw[j] = w;
158 w *= wlen;
159 }
160 twiddles.push_back(std::move(tw));
161 }
162 }
163
164 void execute(num::CVector& a) const {
166 num::cplx* data = a.data();
167 int stage = 0;
168 for (int len = 2; len <= n; len <<= 1, ++stage) {
169 int hlen = len / 2;
170 const num::cplx* tw = twiddles[stage].data();
171 for (int i = 0; i < n; i += len)
172 butterfly(data + i, data + i + hlen, tw, hlen);
173 }
174 }
175};
176
177// One-shot functions
178
179inline void fft(const num::CVector& in, num::CVector& out) {
180 int n = static_cast<int>(in.size());
181 for (int i = 0; i < n; ++i)
182 out[i] = in[i];
183 FFTPlanImpl plan(n, false);
184 plan.execute(out);
185}
186
187inline void ifft(const num::CVector& in, num::CVector& out) {
188 int n = static_cast<int>(in.size());
189 for (int i = 0; i < n; ++i)
190 out[i] = in[i];
191 FFTPlanImpl plan(n, true);
192 plan.execute(out);
193}
194
195inline void rfft(const num::Vector& in, num::CVector& out) {
196 int n = static_cast<int>(in.size());
197 num::CVector tmp(static_cast<num::idx>(n), num::cplx{0, 0});
198 for (int i = 0; i < n; ++i)
199 tmp[i] = {in[i], 0.0};
200 FFTPlanImpl plan(n, false);
201 plan.execute(tmp);
202 for (int k = 0; k < n / 2 + 1; ++k)
203 out[k] = tmp[k];
204}
205
206inline void irfft(const num::CVector& in, int n, num::Vector& out) {
207 num::CVector tmp(static_cast<num::idx>(n), num::cplx{0, 0});
208 for (int k = 0; k < n / 2 + 1; ++k)
209 tmp[k] = in[k];
210 for (int k = 1; k < (n - 1) / 2 + 1; ++k)
211 tmp[n - k] = std::conj(in[k]);
212 FFTPlanImpl plan(n, true);
213 plan.execute(tmp);
214 for (int i = 0; i < n; ++i)
215 out[i] = tmp[i].real();
216}
217
218} // namespace opt
219} // namespace backends
Dense vector with optional GPU storage, templated over scalar type T.
Definition vector.hpp:24
constexpr idx size() const noexcept
Definition vector.hpp:80
FFT interface with backend dispatch.
void irfft(const num::CVector &in, int n, num::Vector &out)
Definition impl.hpp:206
void ifft(const num::CVector &in, num::CVector &out)
Definition impl.hpp:187
void fft(const num::CVector &in, num::CVector &out)
Definition impl.hpp:179
void rfft(const num::Vector &in, num::CVector &out)
Definition impl.hpp:195
void bit_reverse(num::CVector &a)
Definition impl.hpp:15
std::size_t idx
Definition types.hpp:11
std::complex< real > cplx
Definition types.hpp:12
Precomputed twiddle factors + SIMD butterfly execution.
Definition impl.hpp:140
void execute(num::CVector &a) const
Definition impl.hpp:164
std::vector< std::vector< num::cplx > > twiddles
Definition impl.hpp:143
FFTPlanImpl(int n_, bool inv)
Definition impl.hpp:145