numerics 0.1.0
Loading...
Searching...
No Matches
lu.cpp
Go to the documentation of this file.
1/// @file linalg/factorization/lu.cpp
2/// @brief LU dispatcher + utility functions (lu_solve, lu_det, lu_inv).
3
5#include "backends/lapack/impl.hpp"
6#include "backends/seq/impl.hpp"
7
8namespace num {
9
10LUResult lu(const Matrix& A, Backend backend) {
11 switch (backend) {
12 case Backend::lapack:
13 return backends::lapack::lu(A);
14 default:
15 return backends::seq::lu(A);
16 }
17}
18
19void lu_solve(const LUResult& f, const Vector& b, Vector& x) {
20 const idx n = f.LU.rows();
21 const Matrix& M = f.LU;
22
23 Vector y = b;
24
25 for (idx k = 0; k < n; ++k)
26 if (f.piv[k] != k)
27 std::swap(y[k], y[f.piv[k]]);
28
29 for (idx i = 1; i < n; ++i)
30 for (idx j = 0; j < i; ++j)
31 y[i] -= M(i, j) * y[j];
32
33 for (idx i = n; i-- > 0;) {
34 for (idx j = i + 1; j < n; ++j)
35 y[i] -= M(i, j) * y[j];
36 y[i] /= M(i, i);
37 }
38
39 x = std::move(y);
40}
41
42void lu_solve(const LUResult& f, const Matrix& B, Matrix& X) {
43 const idx nrhs = B.cols();
44 const idx n = B.rows();
45 Vector col(n), xcol(n);
46 for (idx j = 0; j < nrhs; ++j) {
47 for (idx i = 0; i < n; ++i)
48 col[i] = B(i, j);
49 lu_solve(f, col, xcol);
50 for (idx i = 0; i < n; ++i)
51 X(i, j) = xcol[i];
52 }
53}
54
55// lu_det() -- determinant from diagonal of U and pivot parity
56
57real lu_det(const LUResult& f) {
58 const idx n = f.LU.rows();
59 real det = real(1);
60 for (idx i = 0; i < n; ++i)
61 det *= f.LU(i, i);
62 idx swaps = 0;
63 for (idx k = 0; k < n; ++k)
64 if (f.piv[k] != k)
65 ++swaps;
66 return (swaps % 2 == 0) ? det : -det;
67}
68
69// lu_inv() -- A^{-1} by solving A * X = I column by column
70
72 const idx n = f.LU.rows();
73 Matrix inv(n, n, real(0));
74 Vector e(n, real(0)), col(n);
75 for (idx j = 0; j < n; ++j) {
76 e[j] = real(1);
77 lu_solve(f, e, col);
78 for (idx i = 0; i < n; ++i)
79 inv(i, j) = col[i];
80 e[j] = real(0);
81 }
82 return inv;
83}
84
85} // namespace num
constexpr idx rows() const noexcept
Definition matrix.hpp:87
constexpr idx cols() const noexcept
Definition matrix.hpp:88
LU factorization with partial pivoting.
LUResult lu(const Matrix &A)
Definition lu.cpp:18
LUResult lu(const Matrix &A)
Definition lu.cpp:9
double real
Definition types.hpp:10
Backend
Definition policy.hpp:7
std::size_t idx
Definition types.hpp:11
real lu_det(const LUResult &f)
Compute .
Definition lu.cpp:57
constexpr real e
Definition math.hpp:44
Matrix lu_inv(const LUResult &f)
Compute by solving .
Definition lu.cpp:71
void lu_solve(const LUResult &f, const Vector &b, Vector &x)
Solve from a precomputed factorization.
Definition lu.cpp:19
LUResult lu(const Matrix &A, Backend backend=lapack_backend)
Definition lu.cpp:10
Packed factorization .
Definition lu.hpp:12
Matrix LU
Definition lu.hpp:13
std::vector< idx > piv
Definition lu.hpp:14