19static_assert(
sizeof(
real) == 8,
"SIMD kernels require real == double");
21#ifdef NUMERICS_HAS_AVX2
23static inline void avx_tile_4x4(
const Matrix& A,
33 __m256d c0 = _mm256_loadu_pd(Crow + 0 * N + jr);
34 __m256d c1 = _mm256_loadu_pd(Crow + 1 * N + jr);
35 __m256d c2 = _mm256_loadu_pd(Crow + 2 * N + jr);
36 __m256d c3 = _mm256_loadu_pd(Crow + 3 * N + jr);
38 for (
idx k = kk; k < k_lim; ++k) {
39 __m256d b = _mm256_loadu_pd(B.
data() + k * N + jr);
40 c0 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 0, k)), b, c0);
41 c1 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 1, k)), b, c1);
42 c2 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 2, k)), b, c2);
43 c3 = _mm256_fmadd_pd(_mm256_set1_pd(A(ir + 3, k)), b, c3);
46 _mm256_storeu_pd(Crow + 0 * N + jr, c0);
47 _mm256_storeu_pd(Crow + 1 * N + jr, c1);
48 _mm256_storeu_pd(Crow + 2 * N + jr, c2);
49 _mm256_storeu_pd(Crow + 3 * N + jr, c3);
54 std::fill_n(C.
data(), M * N,
real(0));
56 for (
idx ii = 0; ii < M; ii += block_size) {
57 const idx i_lim = std::min(ii + block_size, M);
58 for (
idx jj = 0; jj < N; jj += block_size) {
59 const idx j_lim = std::min(jj + block_size, N);
60 for (
idx kk = 0; kk < K; kk += block_size) {
61 const idx k_lim = std::min(kk + block_size, K);
63 for (; ir + 4 <= i_lim; ir += 4) {
65 for (; jr + 4 <= j_lim; jr += 4)
66 avx_tile_4x4(A, B, C, ir, jr, kk, k_lim);
67 for (; jr < j_lim; ++jr) {
68 real c0 = C(ir + 0, jr), c1 = C(ir + 1, jr);
69 real c2 = C(ir + 2, jr), c3 = C(ir + 3, jr);
70 for (
idx k = kk; k < k_lim; ++k) {
72 c0 += A(ir + 0, k) * b;
73 c1 += A(ir + 1, k) * b;
74 c2 += A(ir + 2, k) * b;
75 c3 += A(ir + 3, k) * b;
83 for (; ir < i_lim; ++ir) {
84 for (
idx k = kk; k < k_lim; ++k) {
85 const real a_ik = A(ir, k);
86 for (
idx j = jj; j < j_lim; ++j)
87 C(ir, j) += a_ik * B(k, j);
97 for (
idx i = 0; i < M; ++i) {
98 __m256d acc = _mm256_setzero_pd();
100 for (; j + 4 <= N; j += 4) {
101 __m256d a = _mm256_loadu_pd(A.
data() + i * N + j);
102 __m256d xv = _mm256_loadu_pd(x.
data() + j);
103 acc = _mm256_fmadd_pd(a, xv, acc);
105 __m128d lo = _mm256_castpd256_pd128(acc);
106 __m128d hi = _mm256_extractf128_pd(acc, 1);
107 __m128d sum = _mm_add_pd(lo, hi);
108 sum = _mm_hadd_pd(sum, sum);
109 real result = _mm_cvtsd_f64(sum);
111 result += A(i, j) * x[j];
118#ifdef NUMERICS_HAS_NEON
120static inline void neon_tile_4x4(
const Matrix& A,
130 float64x2_t c0lo = vld1q_f64(Crow + 0 * N + jr);
131 float64x2_t c0hi = vld1q_f64(Crow + 0 * N + jr + 2);
132 float64x2_t c1lo = vld1q_f64(Crow + 1 * N + jr);
133 float64x2_t c1hi = vld1q_f64(Crow + 1 * N + jr + 2);
134 float64x2_t c2lo = vld1q_f64(Crow + 2 * N + jr);
135 float64x2_t c2hi = vld1q_f64(Crow + 2 * N + jr + 2);
136 float64x2_t c3lo = vld1q_f64(Crow + 3 * N + jr);
137 float64x2_t c3hi = vld1q_f64(Crow + 3 * N + jr + 2);
139 for (
idx k = kk; k < k_lim; ++k) {
140 const real* Brow = B.
data() + k * N + jr;
141 float64x2_t blo = vld1q_f64(Brow), bhi = vld1q_f64(Brow + 2);
142 float64x2_t a0 = vdupq_n_f64(A(ir + 0, k)), a1 = vdupq_n_f64(A(ir + 1, k));
143 float64x2_t a2 = vdupq_n_f64(A(ir + 2, k)), a3 = vdupq_n_f64(A(ir + 3, k));
144 c0lo = vfmaq_f64(c0lo, a0, blo);
145 c0hi = vfmaq_f64(c0hi, a0, bhi);
146 c1lo = vfmaq_f64(c1lo, a1, blo);
147 c1hi = vfmaq_f64(c1hi, a1, bhi);
148 c2lo = vfmaq_f64(c2lo, a2, blo);
149 c2hi = vfmaq_f64(c2hi, a2, bhi);
150 c3lo = vfmaq_f64(c3lo, a3, blo);
151 c3hi = vfmaq_f64(c3hi, a3, bhi);
154 vst1q_f64(Crow + 0 * N + jr, c0lo);
155 vst1q_f64(Crow + 0 * N + jr + 2, c0hi);
156 vst1q_f64(Crow + 1 * N + jr, c1lo);
157 vst1q_f64(Crow + 1 * N + jr + 2, c1hi);
158 vst1q_f64(Crow + 2 * N + jr, c2lo);
159 vst1q_f64(Crow + 2 * N + jr + 2, c2hi);
160 vst1q_f64(Crow + 3 * N + jr, c3lo);
161 vst1q_f64(Crow + 3 * N + jr + 2, c3hi);
166 std::fill_n(C.
data(), M * N,
real(0));
168 for (
idx ii = 0; ii < M; ii += block_size) {
169 const idx i_lim = std::min(ii + block_size, M);
170 for (
idx jj = 0; jj < N; jj += block_size) {
171 const idx j_lim = std::min(jj + block_size, N);
172 for (
idx kk = 0; kk < K; kk += block_size) {
173 const idx k_lim = std::min(kk + block_size, K);
175 for (; ir + 4 <= i_lim; ir += 4) {
177 for (; jr + 4 <= j_lim; jr += 4)
178 neon_tile_4x4(A, B, C, ir, jr, kk, k_lim);
179 for (; jr < j_lim; ++jr) {
180 real c0 = C(ir + 0, jr), c1 = C(ir + 1, jr);
181 real c2 = C(ir + 2, jr), c3 = C(ir + 3, jr);
182 for (
idx k = kk; k < k_lim; ++k) {
184 c0 += A(ir + 0, k) * b;
185 c1 += A(ir + 1, k) * b;
186 c2 += A(ir + 2, k) * b;
187 c3 += A(ir + 3, k) * b;
195 for (; ir < i_lim; ++ir) {
196 for (
idx k = kk; k < k_lim; ++k) {
197 const real a_ik = A(ir, k);
198 for (
idx j = jj; j < j_lim; ++j)
199 C(ir, j) += a_ik * B(k, j);
209 for (
idx i = 0; i < M; ++i) {
210 float64x2_t acc = vdupq_n_f64(0.0);
212 for (; j + 2 <= N; j += 2) {
213 float64x2_t a = vld1q_f64(A.
data() + i * N + j);
214 float64x2_t xv = vld1q_f64(x.
data() + j);
215 acc = vfmaq_f64(acc, a, xv);
217 real result = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
219 result += A(i, j) * x[j];
227#if defined(NUMERICS_HAS_AVX2)
228 matmul_avx(A, B, C, block_size);
229#elif defined(NUMERICS_HAS_NEON)
230 matmul_neon(A, B, C, block_size);
237#if defined(NUMERICS_HAS_AVX2)
239#elif defined(NUMERICS_HAS_NEON)
240 matvec_neon(A, x, y);