numerics 0.1.0
Loading...
Searching...
No Matches
svd.cpp
Go to the documentation of this file.
1/// @file svd/svd.cpp
2/// @brief SVD dispatcher + randomized truncated SVD.
3
4#include "linalg/svd/svd.hpp"
5#include "backends/lapack/impl.hpp"
6#include "backends/seq/impl.hpp"
8
9namespace num {
10
11SVDResult svd(const Matrix& A_in, Backend backend, real tol, idx max_sweeps) {
12 switch (backend) {
13 case Backend::lapack:
14 return backends::lapack::svd(A_in);
15 default:
16 return backends::seq::svd(A_in, tol, max_sweeps);
17 }
18}
19
21 idx k,
22 Backend backend,
23 idx oversampling,
24 Rng* rng) {
25 const idx m = A.rows(), n = A.cols();
26 if (k == 0 || k > std::min(m, n))
27 throw std::invalid_argument("svd_truncated: k out of range");
28
29 const idx l = k + oversampling;
30
31 Rng local_rng;
32 if (!rng)
33 rng = &local_rng;
34
35 Matrix Omega(n, l);
36 for (idx j = 0; j < l; ++j)
37 for (idx i = 0; i < n; ++i)
38 Omega(i, j) = rng_normal(rng, 0.0, 1.0);
39
40 Matrix Y(m, l, 0.0);
41 matmul(A, Omega, Y, backend);
42
43 QRResult qr_res = qr(Y);
44 const Matrix& Q = qr_res.Q;
45
46 Matrix B(l, n, 0.0);
47 for (idx i = 0; i < l; ++i)
48 for (idx kk = 0; kk < m; ++kk) {
49 const real q_ki = Q(kk, i);
50 for (idx j = 0; j < n; ++j)
51 B(i, j) += q_ki * A(kk, j);
52 }
53
54 SVDResult small = svd(B, backend);
55
56 Matrix U(m, k, 0.0);
57 for (idx j = 0; j < k; ++j)
58 for (idx i = 0; i < m; ++i)
59 for (idx ii = 0; ii < l; ++ii)
60 U(i, j) += Q(i, ii) * small.U(ii, j);
61
62 Vector S(k);
63 for (idx i = 0; i < k; ++i)
64 S[i] = small.S[i];
65
66 Matrix Vt(k, n, 0.0);
67 for (idx i = 0; i < k; ++i)
68 for (idx j = 0; j < n; ++j)
69 Vt(i, j) = small.Vt(i, j);
70
71 return {U, S, Vt, 0, true};
72}
73
74} // namespace num
constexpr idx rows() const noexcept
Definition matrix.hpp:87
constexpr idx cols() const noexcept
Definition matrix.hpp:88
SVDResult svd(const Matrix &A)
Definition svd.cpp:15
SVDResult svd(const Matrix &A, real tol, idx max_sweeps)
Definition svd.cpp:11
SVDResult svd_truncated(const Matrix &A, idx k, Backend backend=default_backend, idx oversampling=10, Rng *rng=nullptr)
Definition svd.cpp:20
double real
Definition types.hpp:10
Backend
Definition policy.hpp:7
QRResult qr(const Matrix &A, Backend backend=lapack_backend)
Factor as .
Definition qr.cpp:10
std::size_t idx
Definition types.hpp:11
SVDResult svd(const Matrix &A, Backend backend=lapack_backend, real tol=1e-12, idx max_sweeps=100)
Definition svd.cpp:11
real rng_normal(Rng *r, real mean, real stddev)
Normal (Gaussian) sample with given mean and standard deviation.
Definition math.hpp:298
void matmul(const Matrix &A, const Matrix &B, Matrix &C, Backend b=default_backend)
C = A * B.
Definition matrix.cpp:20
QR factorization via Householder reflections.
QR factorization .
Definition qr.hpp:11
Matrix Q
Definition qr.hpp:12
Seeded pseudo-random number generator (Mersenne Twister). Pass a pointer to rng_* functions to draw s...
Definition math.hpp:281
Matrix Vt
Definition svd.hpp:19
Matrix U
Definition svd.hpp:17
Vector S
Definition svd.hpp:18
Dense and randomized truncated SVD.