numerics 0.1.0
Loading...
Searching...
No Matches
matrix.cpp
Go to the documentation of this file.
1/// @file core/backends/opt/matrix.cpp
2/// @brief SIMD backend -- hand-written AVX2/NEON matmul and matvec
3///
4/// Compile-time dispatch:
5/// NUMERICS_HAS_AVX2 -> AVX-256 + FMA (x86-64, 4 doubles/register)
6/// NUMERICS_HAS_NEON -> ARM NEON (AArch64, 2 doubles/register)
7/// neither -> falls back to cache-blocked scalar
8///
9/// Both backends use the same register-tile structure:
10/// Outer cache tile: ii -> jj -> kk (B tile stays in L2)
11/// Inner reg tile: 4 rows x 4 cols (AVX: 4 YMM regs; NEON: 8 Q-regs)
12/// Hot k loop: one vector FMA per row, zero loop overhead for j
13
14#include "core/matrix.hpp"
15#include "../seq/impl.hpp"
16#include <algorithm>
17#include <cassert>
18
19#ifdef NUMERICS_HAS_AVX2
20 #include <immintrin.h>
21#endif
22
23#ifdef NUMERICS_HAS_NEON
24 #include <arm_neon.h>
25#endif
26
28
29static_assert(sizeof(real) == 8, "SIMD kernels require real == double");
30
31// AVX-256 backend
32#ifdef NUMERICS_HAS_AVX2
33
34static inline void avx_tile_4x4(const Matrix& A,
35 const Matrix& B,
36 Matrix& C,
37 idx ir,
38 idx jr,
39 idx kk,
40 idx k_lim) {
41 const idx N = B.cols();
42 real* Crow = C.data() + ir * N;
43
44 __m256d c0 = _mm256_loadu_pd(Crow + 0 * N + jr);
45 __m256d c1 = _mm256_loadu_pd(Crow + 1 * N + jr);
46 __m256d c2 = _mm256_loadu_pd(Crow + 2 * N + jr);
47 __m256d c3 = _mm256_loadu_pd(Crow + 3 * N + jr);
48
49 for (idx k = kk; k < k_lim; ++k) {
50 __m256d b = _mm256_loadu_pd(B.data() + k * N + jr);
51 c0 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 0, k)), b, c0);
52 c1 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 1, k)), b, c1);
53 c2 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 2, k)), b, c2);
54 c3 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 3, k)), b, c3);
55 }
56
57 _mm256_storeu_pd(Crow + 0 * N + jr, c0);
58 _mm256_storeu_pd(Crow + 1 * N + jr, c1);
59 _mm256_storeu_pd(Crow + 2 * N + jr, c2);
60 _mm256_storeu_pd(Crow + 3 * N + jr, c3);
61}
62
63static void matmul_avx(const Matrix& A,
64 const Matrix& B,
65 Matrix& C,
66 idx block_size) {
67 const idx M = A.rows(), K = A.cols(), N = B.cols();
68 std::fill_n(C.data(), M * N, real(0));
69
70 for (idx ii = 0; ii < M; ii += block_size) {
71 const idx i_lim = std::min(ii + block_size, M);
72 for (idx jj = 0; jj < N; jj += block_size) {
73 const idx j_lim = std::min(jj + block_size, N);
74 for (idx kk = 0; kk < K; kk += block_size) {
75 const idx k_lim = std::min(kk + block_size, K);
76 idx ir = ii;
77 for (; ir + 4 <= i_lim; ir += 4) {
78 idx jr = jj;
79 for (; jr + 4 <= j_lim; jr += 4)
80 avx_tile_4x4(A, B, C, ir, jr, kk, k_lim);
81 for (; jr < j_lim; ++jr) {
82 real c0 = C(ir + 0, jr), c1 = C(ir + 1, jr);
83 real c2 = C(ir + 2, jr), c3 = C(ir + 3, jr);
84 for (idx k = kk; k < k_lim; ++k) {
85 real b = B(k, jr);
86 c0 += A(ir + 0, k) * b;
87 c1 += A(ir + 1, k) * b;
88 c2 += A(ir + 2, k) * b;
89 c3 += A(ir + 3, k) * b;
90 }
91 C(ir + 0, jr) = c0;
92 C(ir + 1, jr) = c1;
93 C(ir + 2, jr) = c2;
94 C(ir + 3, jr) = c3;
95 }
96 }
97 for (; ir < i_lim; ++ir) {
98 for (idx k = kk; k < k_lim; ++k) {
99 const real a_ik = A(ir, k);
100 for (idx j = jj; j < j_lim; ++j)
101 C(ir, j) += a_ik * B(k, j);
102 }
103 }
104 }
105 }
106 }
107}
108
109static void matvec_avx(const Matrix& A, const Vector& x, Vector& y) {
110 const idx M = A.rows(), N = A.cols();
111 for (idx i = 0; i < M; ++i) {
112 __m256d acc = _mm256_setzero_pd();
113 idx j = 0;
114 for (; j + 4 <= N; j += 4) {
115 __m256d a = _mm256_loadu_pd(A.data() + i * N + j);
116 __m256d xv = _mm256_loadu_pd(x.data() + j);
117 acc = _mm256_fmadd_pd(a, xv, acc);
118 }
119 __m128d lo = _mm256_castpd256_pd128(acc);
120 __m128d hi = _mm256_extractf128_pd(acc, 1);
121 __m128d sum = _mm_add_pd(lo, hi);
122 sum = _mm_hadd_pd(sum, sum);
123 real result = _mm_cvtsd_f64(sum);
124 for (; j < N; ++j)
125 result += A(i, j) * x[j];
126 y[i] = result;
127 }
128}
129
130#endif // NUMERICS_HAS_AVX2
131
132// ARM NEON backend
133#ifdef NUMERICS_HAS_NEON
134
135static inline void neon_tile_4x4(const Matrix& A,
136 const Matrix& B,
137 Matrix& C,
138 idx ir,
139 idx jr,
140 idx kk,
141 idx k_lim) {
142 const idx N = B.cols();
143 real* Crow = C.data() + ir * N;
144
145 float64x2_t c0lo = vld1q_f64(Crow + 0 * N + jr);
146 float64x2_t c0hi = vld1q_f64(Crow + 0 * N + jr + 2);
147 float64x2_t c1lo = vld1q_f64(Crow + 1 * N + jr);
148 float64x2_t c1hi = vld1q_f64(Crow + 1 * N + jr + 2);
149 float64x2_t c2lo = vld1q_f64(Crow + 2 * N + jr);
150 float64x2_t c2hi = vld1q_f64(Crow + 2 * N + jr + 2);
151 float64x2_t c3lo = vld1q_f64(Crow + 3 * N + jr);
152 float64x2_t c3hi = vld1q_f64(Crow + 3 * N + jr + 2);
153
154 for (idx k = kk; k < k_lim; ++k) {
155 const real* Brow = B.data() + k * N + jr;
156 float64x2_t blo = vld1q_f64(Brow), bhi = vld1q_f64(Brow + 2);
157 float64x2_t a0 = vdupq_n_f64(A(ir + 0, k)),
158 a1 = vdupq_n_f64(A(ir + 1, k));
159 float64x2_t a2 = vdupq_n_f64(A(ir + 2, k)),
160 a3 = vdupq_n_f64(A(ir + 3, k));
161 c0lo = vfmaq_f64(c0lo, a0, blo);
162 c0hi = vfmaq_f64(c0hi, a0, bhi);
163 c1lo = vfmaq_f64(c1lo, a1, blo);
164 c1hi = vfmaq_f64(c1hi, a1, bhi);
165 c2lo = vfmaq_f64(c2lo, a2, blo);
166 c2hi = vfmaq_f64(c2hi, a2, bhi);
167 c3lo = vfmaq_f64(c3lo, a3, blo);
168 c3hi = vfmaq_f64(c3hi, a3, bhi);
169 }
170
171 vst1q_f64(Crow + 0 * N + jr, c0lo);
172 vst1q_f64(Crow + 0 * N + jr + 2, c0hi);
173 vst1q_f64(Crow + 1 * N + jr, c1lo);
174 vst1q_f64(Crow + 1 * N + jr + 2, c1hi);
175 vst1q_f64(Crow + 2 * N + jr, c2lo);
176 vst1q_f64(Crow + 2 * N + jr + 2, c2hi);
177 vst1q_f64(Crow + 3 * N + jr, c3lo);
178 vst1q_f64(Crow + 3 * N + jr + 2, c3hi);
179}
180
181static void matmul_neon(const Matrix& A,
182 const Matrix& B,
183 Matrix& C,
184 idx block_size) {
185 const idx M = A.rows(), K = A.cols(), N = B.cols();
186 std::fill_n(C.data(), M * N, real(0));
187
188 for (idx ii = 0; ii < M; ii += block_size) {
189 const idx i_lim = std::min(ii + block_size, M);
190 for (idx jj = 0; jj < N; jj += block_size) {
191 const idx j_lim = std::min(jj + block_size, N);
192 for (idx kk = 0; kk < K; kk += block_size) {
193 const idx k_lim = std::min(kk + block_size, K);
194 idx ir = ii;
195 for (; ir + 4 <= i_lim; ir += 4) {
196 idx jr = jj;
197 for (; jr + 4 <= j_lim; jr += 4)
198 neon_tile_4x4(A, B, C, ir, jr, kk, k_lim);
199 for (; jr < j_lim; ++jr) {
200 real c0 = C(ir + 0, jr), c1 = C(ir + 1, jr);
201 real c2 = C(ir + 2, jr), c3 = C(ir + 3, jr);
202 for (idx k = kk; k < k_lim; ++k) {
203 real b = B(k, jr);
204 c0 += A(ir + 0, k) * b;
205 c1 += A(ir + 1, k) * b;
206 c2 += A(ir + 2, k) * b;
207 c3 += A(ir + 3, k) * b;
208 }
209 C(ir + 0, jr) = c0;
210 C(ir + 1, jr) = c1;
211 C(ir + 2, jr) = c2;
212 C(ir + 3, jr) = c3;
213 }
214 }
215 for (; ir < i_lim; ++ir) {
216 for (idx k = kk; k < k_lim; ++k) {
217 const real a_ik = A(ir, k);
218 for (idx j = jj; j < j_lim; ++j)
219 C(ir, j) += a_ik * B(k, j);
220 }
221 }
222 }
223 }
224 }
225}
226
227static void matvec_neon(const Matrix& A, const Vector& x, Vector& y) {
228 const idx M = A.rows(), N = A.cols();
229 for (idx i = 0; i < M; ++i) {
230 float64x2_t acc = vdupq_n_f64(0.0);
231 idx j = 0;
232 for (; j + 2 <= N; j += 2) {
233 float64x2_t a = vld1q_f64(A.data() + i * N + j);
234 float64x2_t xv = vld1q_f64(x.data() + j);
235 acc = vfmaq_f64(acc, a, xv);
236 }
237 real result = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
238 for (; j < N; ++j)
239 result += A(i, j) * x[j];
240 y[i] = result;
241 }
242}
243
244#endif // NUMERICS_HAS_NEON
245
246// Implementations -- compile-time dispatch to best available SIMD backend
247
248void matmul(const Matrix& A, const Matrix& B, Matrix& C, idx block_size) {
249#if defined(NUMERICS_HAS_AVX2)
250 matmul_avx(A, B, C, block_size);
251#elif defined(NUMERICS_HAS_NEON)
252 matmul_neon(A, B, C, block_size);
253#else
254 num::backends::seq::matmul_blocked(A, B, C, block_size);
255#endif
256}
257
258void matvec(const Matrix& A, const Vector& x, Vector& y) {
259#if defined(NUMERICS_HAS_AVX2)
260 matvec_avx(A, x, y);
261#elif defined(NUMERICS_HAS_NEON)
262 matvec_neon(A, x, y);
263#else
265#endif
266}
267
268} // namespace num::backends::simd
Dense row-major matrix with optional GPU storage.
Definition matrix.hpp:12
real * data()
Definition matrix.hpp:28
constexpr idx rows() const noexcept
Definition matrix.hpp:24
constexpr idx cols() const noexcept
Definition matrix.hpp:25
Matrix operations.
void matvec(const Matrix &A, const Vector &x, Vector &y)
Definition matrix.cpp:24
void matmul_blocked(const Matrix &A, const Matrix &B, Matrix &C, idx block_size)
Definition matrix.cpp:77
void matvec(const Matrix &A, const Vector &x, Vector &y)
Definition matrix.cpp:258
void matmul(const Matrix &A, const Matrix &B, Matrix &C, idx block_size)
Definition matrix.cpp:248
double real
Definition types.hpp:10
std::size_t idx
Definition types.hpp:11