29static_assert(
sizeof(
real) == 8,
"SIMD kernels require real == double");
32#ifdef NUMERICS_HAS_AVX2
34static inline void avx_tile_4x4(
const Matrix& A,
44 __m256d c0 = _mm256_loadu_pd(Crow + 0 * N + jr);
45 __m256d c1 = _mm256_loadu_pd(Crow + 1 * N + jr);
46 __m256d c2 = _mm256_loadu_pd(Crow + 2 * N + jr);
47 __m256d c3 = _mm256_loadu_pd(Crow + 3 * N + jr);
49 for (
idx k = kk; k < k_lim; ++k) {
50 __m256d b = _mm256_loadu_pd(B.
data() + k * N + jr);
51 c0 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 0, k)), b, c0);
52 c1 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 1, k)), b, c1);
53 c2 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 2, k)), b, c2);
54 c3 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 3, k)), b, c3);
57 _mm256_storeu_pd(Crow + 0 * N + jr, c0);
58 _mm256_storeu_pd(Crow + 1 * N + jr, c1);
59 _mm256_storeu_pd(Crow + 2 * N + jr, c2);
60 _mm256_storeu_pd(Crow + 3 * N + jr, c3);
63static void matmul_avx(
const Matrix& A,
68 std::fill_n(C.
data(), M * N,
real(0));
70 for (
idx ii = 0; ii < M; ii += block_size) {
71 const idx i_lim = std::min(ii + block_size, M);
72 for (
idx jj = 0; jj < N; jj += block_size) {
73 const idx j_lim = std::min(jj + block_size, N);
74 for (
idx kk = 0; kk < K; kk += block_size) {
75 const idx k_lim = std::min(kk + block_size, K);
77 for (; ir + 4 <= i_lim; ir += 4) {
79 for (; jr + 4 <= j_lim; jr += 4)
80 avx_tile_4x4(A, B, C, ir, jr, kk, k_lim);
81 for (; jr < j_lim; ++jr) {
82 real c0 = C(ir + 0, jr), c1 = C(ir + 1, jr);
83 real c2 = C(ir + 2, jr), c3 = C(ir + 3, jr);
84 for (
idx k = kk; k < k_lim; ++k) {
86 c0 += A(ir + 0, k) * b;
87 c1 += A(ir + 1, k) * b;
88 c2 += A(ir + 2, k) * b;
89 c3 += A(ir + 3, k) * b;
97 for (; ir < i_lim; ++ir) {
98 for (
idx k = kk; k < k_lim; ++k) {
99 const real a_ik = A(ir, k);
100 for (
idx j = jj; j < j_lim; ++j)
101 C(ir, j) += a_ik * B(k, j);
111 for (
idx i = 0; i < M; ++i) {
112 __m256d acc = _mm256_setzero_pd();
114 for (; j + 4 <= N; j += 4) {
115 __m256d a = _mm256_loadu_pd(A.
data() + i * N + j);
116 __m256d xv = _mm256_loadu_pd(x.
data() + j);
117 acc = _mm256_fmadd_pd(a, xv, acc);
119 __m128d lo = _mm256_castpd256_pd128(acc);
120 __m128d hi = _mm256_extractf128_pd(acc, 1);
121 __m128d sum = _mm_add_pd(lo, hi);
122 sum = _mm_hadd_pd(sum, sum);
123 real result = _mm_cvtsd_f64(sum);
125 result += A(i, j) * x[j];
133#ifdef NUMERICS_HAS_NEON
135static inline void neon_tile_4x4(
const Matrix& A,
145 float64x2_t c0lo = vld1q_f64(Crow + 0 * N + jr);
146 float64x2_t c0hi = vld1q_f64(Crow + 0 * N + jr + 2);
147 float64x2_t c1lo = vld1q_f64(Crow + 1 * N + jr);
148 float64x2_t c1hi = vld1q_f64(Crow + 1 * N + jr + 2);
149 float64x2_t c2lo = vld1q_f64(Crow + 2 * N + jr);
150 float64x2_t c2hi = vld1q_f64(Crow + 2 * N + jr + 2);
151 float64x2_t c3lo = vld1q_f64(Crow + 3 * N + jr);
152 float64x2_t c3hi = vld1q_f64(Crow + 3 * N + jr + 2);
154 for (
idx k = kk; k < k_lim; ++k) {
155 const real* Brow = B.
data() + k * N + jr;
156 float64x2_t blo = vld1q_f64(Brow), bhi = vld1q_f64(Brow + 2);
157 float64x2_t a0 = vdupq_n_f64(A(ir + 0, k)),
158 a1 = vdupq_n_f64(A(ir + 1, k));
159 float64x2_t a2 = vdupq_n_f64(A(ir + 2, k)),
160 a3 = vdupq_n_f64(A(ir + 3, k));
161 c0lo = vfmaq_f64(c0lo, a0, blo);
162 c0hi = vfmaq_f64(c0hi, a0, bhi);
163 c1lo = vfmaq_f64(c1lo, a1, blo);
164 c1hi = vfmaq_f64(c1hi, a1, bhi);
165 c2lo = vfmaq_f64(c2lo, a2, blo);
166 c2hi = vfmaq_f64(c2hi, a2, bhi);
167 c3lo = vfmaq_f64(c3lo, a3, blo);
168 c3hi = vfmaq_f64(c3hi, a3, bhi);
171 vst1q_f64(Crow + 0 * N + jr, c0lo);
172 vst1q_f64(Crow + 0 * N + jr + 2, c0hi);
173 vst1q_f64(Crow + 1 * N + jr, c1lo);
174 vst1q_f64(Crow + 1 * N + jr + 2, c1hi);
175 vst1q_f64(Crow + 2 * N + jr, c2lo);
176 vst1q_f64(Crow + 2 * N + jr + 2, c2hi);
177 vst1q_f64(Crow + 3 * N + jr, c3lo);
178 vst1q_f64(Crow + 3 * N + jr + 2, c3hi);
181static void matmul_neon(
const Matrix& A,
186 std::fill_n(C.
data(), M * N,
real(0));
188 for (
idx ii = 0; ii < M; ii += block_size) {
189 const idx i_lim = std::min(ii + block_size, M);
190 for (
idx jj = 0; jj < N; jj += block_size) {
191 const idx j_lim = std::min(jj + block_size, N);
192 for (
idx kk = 0; kk < K; kk += block_size) {
193 const idx k_lim = std::min(kk + block_size, K);
195 for (; ir + 4 <= i_lim; ir += 4) {
197 for (; jr + 4 <= j_lim; jr += 4)
198 neon_tile_4x4(A, B, C, ir, jr, kk, k_lim);
199 for (; jr < j_lim; ++jr) {
200 real c0 = C(ir + 0, jr), c1 = C(ir + 1, jr);
201 real c2 = C(ir + 2, jr), c3 = C(ir + 3, jr);
202 for (
idx k = kk; k < k_lim; ++k) {
204 c0 += A(ir + 0, k) * b;
205 c1 += A(ir + 1, k) * b;
206 c2 += A(ir + 2, k) * b;
207 c3 += A(ir + 3, k) * b;
215 for (; ir < i_lim; ++ir) {
216 for (
idx k = kk; k < k_lim; ++k) {
217 const real a_ik = A(ir, k);
218 for (
idx j = jj; j < j_lim; ++j)
219 C(ir, j) += a_ik * B(k, j);
229 for (
idx i = 0; i < M; ++i) {
230 float64x2_t acc = vdupq_n_f64(0.0);
232 for (; j + 2 <= N; j += 2) {
233 float64x2_t a = vld1q_f64(A.
data() + i * N + j);
234 float64x2_t xv = vld1q_f64(x.
data() + j);
235 acc = vfmaq_f64(acc, a, xv);
237 real result = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
239 result += A(i, j) * x[j];
249#if defined(NUMERICS_HAS_AVX2)
250 matmul_avx(A, B, C, block_size);
251#elif defined(NUMERICS_HAS_NEON)
252 matmul_neon(A, B, C, block_size);
259#if defined(NUMERICS_HAS_AVX2)
261#elif defined(NUMERICS_HAS_NEON)
262 matvec_neon(A, x, y);