numerics 0.1.0
Loading...
Searching...
No Matches
matrix.cpp
Go to the documentation of this file.
1/// @file core/backends/opt/matrix.cpp
2/// @brief SIMD backend -- hand-written AVX2/NEON matmul and matvec
3
4#include "core/matrix.hpp"
5#include "../seq/impl.hpp"
6#include <algorithm>
7#include <cassert>
8
9#ifdef NUMERICS_HAS_AVX2
10 #include <immintrin.h>
11#endif
12
13#ifdef NUMERICS_HAS_NEON
14 #include <arm_neon.h>
15#endif
16
18
19static_assert(sizeof(real) == 8, "SIMD kernels require real == double");
20
21#ifdef NUMERICS_HAS_AVX2
22
23static inline void avx_tile_4x4(const Matrix& A,
24 const Matrix& B,
25 Matrix& C,
26 idx ir,
27 idx jr,
28 idx kk,
29 idx k_lim) {
30 const idx N = B.cols();
31 real* Crow = C.data() + ir * N;
32
33 __m256d c0 = _mm256_loadu_pd(Crow + 0 * N + jr);
34 __m256d c1 = _mm256_loadu_pd(Crow + 1 * N + jr);
35 __m256d c2 = _mm256_loadu_pd(Crow + 2 * N + jr);
36 __m256d c3 = _mm256_loadu_pd(Crow + 3 * N + jr);
37
38 for (idx k = kk; k < k_lim; ++k) {
39 __m256d b = _mm256_loadu_pd(B.data() + k * N + jr);
40 c0 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 0, k)), b, c0);
41 c1 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 1, k)), b, c1);
42 c2 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 2, k)), b, c2);
43 c3 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 3, k)), b, c3);
44 }
45
46 _mm256_storeu_pd(Crow + 0 * N + jr, c0);
47 _mm256_storeu_pd(Crow + 1 * N + jr, c1);
48 _mm256_storeu_pd(Crow + 2 * N + jr, c2);
49 _mm256_storeu_pd(Crow + 3 * N + jr, c3);
50}
51
52static void matmul_avx(const Matrix& A, const Matrix& B, Matrix& C, idx block_size) {
53 const idx M = A.rows(), K = A.cols(), N = B.cols();
54 std::fill_n(C.data(), M * N, real(0));
55
56 for (idx ii = 0; ii < M; ii += block_size) {
57 const idx i_lim = std::min(ii + block_size, M);
58 for (idx jj = 0; jj < N; jj += block_size) {
59 const idx j_lim = std::min(jj + block_size, N);
60 for (idx kk = 0; kk < K; kk += block_size) {
61 const idx k_lim = std::min(kk + block_size, K);
62 idx ir = ii;
63 for (; ir + 4 <= i_lim; ir += 4) {
64 idx jr = jj;
65 for (; jr + 4 <= j_lim; jr += 4)
66 avx_tile_4x4(A, B, C, ir, jr, kk, k_lim);
67 for (; jr < j_lim; ++jr) {
68 real c0 = C(ir + 0, jr), c1 = C(ir + 1, jr);
69 real c2 = C(ir + 2, jr), c3 = C(ir + 3, jr);
70 for (idx k = kk; k < k_lim; ++k) {
71 real b = B(k, jr);
72 c0 += A(ir + 0, k) * b;
73 c1 += A(ir + 1, k) * b;
74 c2 += A(ir + 2, k) * b;
75 c3 += A(ir + 3, k) * b;
76 }
77 C(ir + 0, jr) = c0;
78 C(ir + 1, jr) = c1;
79 C(ir + 2, jr) = c2;
80 C(ir + 3, jr) = c3;
81 }
82 }
83 for (; ir < i_lim; ++ir) {
84 for (idx k = kk; k < k_lim; ++k) {
85 const real a_ik = A(ir, k);
86 for (idx j = jj; j < j_lim; ++j)
87 C(ir, j) += a_ik * B(k, j);
88 }
89 }
90 }
91 }
92 }
93}
94
95static void matvec_avx(const Matrix& A, const Vector& x, Vector& y) {
96 const idx M = A.rows(), N = A.cols();
97 for (idx i = 0; i < M; ++i) {
98 __m256d acc = _mm256_setzero_pd();
99 idx j = 0;
100 for (; j + 4 <= N; j += 4) {
101 __m256d a = _mm256_loadu_pd(A.data() + i * N + j);
102 __m256d xv = _mm256_loadu_pd(x.data() + j);
103 acc = _mm256_fmadd_pd(a, xv, acc);
104 }
105 __m128d lo = _mm256_castpd256_pd128(acc);
106 __m128d hi = _mm256_extractf128_pd(acc, 1);
107 __m128d sum = _mm_add_pd(lo, hi);
108 sum = _mm_hadd_pd(sum, sum);
109 real result = _mm_cvtsd_f64(sum);
110 for (; j < N; ++j)
111 result += A(i, j) * x[j];
112 y[i] = result;
113 }
114}
115
116#endif // NUMERICS_HAS_AVX2
117
118#ifdef NUMERICS_HAS_NEON
119
120static inline void neon_tile_4x4(const Matrix& A,
121 const Matrix& B,
122 Matrix& C,
123 idx ir,
124 idx jr,
125 idx kk,
126 idx k_lim) {
127 const idx N = B.cols();
128 real* Crow = C.data() + ir * N;
129
130 float64x2_t c0lo = vld1q_f64(Crow + 0 * N + jr);
131 float64x2_t c0hi = vld1q_f64(Crow + 0 * N + jr + 2);
132 float64x2_t c1lo = vld1q_f64(Crow + 1 * N + jr);
133 float64x2_t c1hi = vld1q_f64(Crow + 1 * N + jr + 2);
134 float64x2_t c2lo = vld1q_f64(Crow + 2 * N + jr);
135 float64x2_t c2hi = vld1q_f64(Crow + 2 * N + jr + 2);
136 float64x2_t c3lo = vld1q_f64(Crow + 3 * N + jr);
137 float64x2_t c3hi = vld1q_f64(Crow + 3 * N + jr + 2);
138
139 for (idx k = kk; k < k_lim; ++k) {
140 const real* Brow = B.data() + k * N + jr;
141 float64x2_t blo = vld1q_f64(Brow), bhi = vld1q_f64(Brow + 2);
142 float64x2_t a0 = vdupq_n_f64(A(ir + 0, k)), a1 = vdupq_n_f64(A(ir + 1, k));
143 float64x2_t a2 = vdupq_n_f64(A(ir + 2, k)), a3 = vdupq_n_f64(A(ir + 3, k));
144 c0lo = vfmaq_f64(c0lo, a0, blo);
145 c0hi = vfmaq_f64(c0hi, a0, bhi);
146 c1lo = vfmaq_f64(c1lo, a1, blo);
147 c1hi = vfmaq_f64(c1hi, a1, bhi);
148 c2lo = vfmaq_f64(c2lo, a2, blo);
149 c2hi = vfmaq_f64(c2hi, a2, bhi);
150 c3lo = vfmaq_f64(c3lo, a3, blo);
151 c3hi = vfmaq_f64(c3hi, a3, bhi);
152 }
153
154 vst1q_f64(Crow + 0 * N + jr, c0lo);
155 vst1q_f64(Crow + 0 * N + jr + 2, c0hi);
156 vst1q_f64(Crow + 1 * N + jr, c1lo);
157 vst1q_f64(Crow + 1 * N + jr + 2, c1hi);
158 vst1q_f64(Crow + 2 * N + jr, c2lo);
159 vst1q_f64(Crow + 2 * N + jr + 2, c2hi);
160 vst1q_f64(Crow + 3 * N + jr, c3lo);
161 vst1q_f64(Crow + 3 * N + jr + 2, c3hi);
162}
163
164static void matmul_neon(const Matrix& A, const Matrix& B, Matrix& C, idx block_size) {
165 const idx M = A.rows(), K = A.cols(), N = B.cols();
166 std::fill_n(C.data(), M * N, real(0));
167
168 for (idx ii = 0; ii < M; ii += block_size) {
169 const idx i_lim = std::min(ii + block_size, M);
170 for (idx jj = 0; jj < N; jj += block_size) {
171 const idx j_lim = std::min(jj + block_size, N);
172 for (idx kk = 0; kk < K; kk += block_size) {
173 const idx k_lim = std::min(kk + block_size, K);
174 idx ir = ii;
175 for (; ir + 4 <= i_lim; ir += 4) {
176 idx jr = jj;
177 for (; jr + 4 <= j_lim; jr += 4)
178 neon_tile_4x4(A, B, C, ir, jr, kk, k_lim);
179 for (; jr < j_lim; ++jr) {
180 real c0 = C(ir + 0, jr), c1 = C(ir + 1, jr);
181 real c2 = C(ir + 2, jr), c3 = C(ir + 3, jr);
182 for (idx k = kk; k < k_lim; ++k) {
183 real b = B(k, jr);
184 c0 += A(ir + 0, k) * b;
185 c1 += A(ir + 1, k) * b;
186 c2 += A(ir + 2, k) * b;
187 c3 += A(ir + 3, k) * b;
188 }
189 C(ir + 0, jr) = c0;
190 C(ir + 1, jr) = c1;
191 C(ir + 2, jr) = c2;
192 C(ir + 3, jr) = c3;
193 }
194 }
195 for (; ir < i_lim; ++ir) {
196 for (idx k = kk; k < k_lim; ++k) {
197 const real a_ik = A(ir, k);
198 for (idx j = jj; j < j_lim; ++j)
199 C(ir, j) += a_ik * B(k, j);
200 }
201 }
202 }
203 }
204 }
205}
206
207static void matvec_neon(const Matrix& A, const Vector& x, Vector& y) {
208 const idx M = A.rows(), N = A.cols();
209 for (idx i = 0; i < M; ++i) {
210 float64x2_t acc = vdupq_n_f64(0.0);
211 idx j = 0;
212 for (; j + 2 <= N; j += 2) {
213 float64x2_t a = vld1q_f64(A.data() + i * N + j);
214 float64x2_t xv = vld1q_f64(x.data() + j);
215 acc = vfmaq_f64(acc, a, xv);
216 }
217 real result = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
218 for (; j < N; ++j)
219 result += A(i, j) * x[j];
220 y[i] = result;
221 }
222}
223
224#endif // NUMERICS_HAS_NEON
225
226void matmul(const Matrix& A, const Matrix& B, Matrix& C, idx block_size) {
227#if defined(NUMERICS_HAS_AVX2)
228 matmul_avx(A, B, C, block_size);
229#elif defined(NUMERICS_HAS_NEON)
230 matmul_neon(A, B, C, block_size);
231#else
232 num::backends::seq::matmul_blocked(A, B, C, block_size);
233#endif
234}
235
236void matvec(const Matrix& A, const Vector& x, Vector& y) {
237#if defined(NUMERICS_HAS_AVX2)
238 matvec_avx(A, x, y);
239#elif defined(NUMERICS_HAS_NEON)
240 matvec_neon(A, x, y);
241#else
243#endif
244}
245
246} // namespace num::backends::simd
constexpr idx rows() const noexcept
Definition matrix.hpp:87
constexpr idx cols() const noexcept
Definition matrix.hpp:88
Dense row-major matrix templated over scalar type T.
void matvec(const Matrix &A, const Vector &x, Vector &y)
Definition matrix.cpp:20
void matmul_blocked(const Matrix &A, const Matrix &B, Matrix &C, idx block_size)
Definition matrix.cpp:28
void matvec(const Matrix &A, const Vector &x, Vector &y)
Definition matrix.cpp:236
void matmul(const Matrix &A, const Matrix &B, Matrix &C, idx block_size)
Definition matrix.cpp:226
double real
Definition types.hpp:10
std::size_t idx
Definition types.hpp:11