numerics 0.1.0
Loading...
Searching...
No Matches
lu.cpp
Go to the documentation of this file.
1/// @file linalg/factorization/lu.cpp
2/// @brief LU dispatcher + utility functions (lu_solve, lu_det, lu_inv).
3///
4/// Backend routing:
5/// Backend::lapack -> backends::lapack::lu (LAPACKE_dgetrf, blocked BLAS-3)
6/// everything else -> backends::seq::lu (Doolittle, partial pivoting)
7///
8/// Adding an omp backend: create backends/omp/lu.cpp, include its impl.hpp
9/// here, and add the case Backend::omp below.
10
12#include "backends/seq/impl.hpp"
13#include "backends/lapack/impl.hpp"
14
15namespace num {
16
17LUResult lu(const Matrix& A, Backend backend) {
18 switch (backend) {
19 case Backend::lapack:
20 return backends::lapack::lu(A);
21 default:
22 return backends::seq::lu(A);
23 }
24}
25
26// lu_solve() -- apply P, then forward/backward substitution
27
28void lu_solve(const LUResult& f, const Vector& b, Vector& x) {
29 const idx n = f.LU.rows();
30 const Matrix& M = f.LU;
31
32 Vector y = b;
33
34 for (idx k = 0; k < n; ++k)
35 if (f.piv[k] != k)
36 std::swap(y[k], y[f.piv[k]]);
37
38 for (idx i = 1; i < n; ++i)
39 for (idx j = 0; j < i; ++j)
40 y[i] -= M(i, j) * y[j];
41
42 for (idx i = n; i-- > 0;) {
43 for (idx j = i + 1; j < n; ++j)
44 y[i] -= M(i, j) * y[j];
45 y[i] /= M(i, i);
46 }
47
48 x = std::move(y);
49}
50
51void lu_solve(const LUResult& f, const Matrix& B, Matrix& X) {
52 const idx nrhs = B.cols();
53 const idx n = B.rows();
54 Vector col(n), xcol(n);
55 for (idx j = 0; j < nrhs; ++j) {
56 for (idx i = 0; i < n; ++i)
57 col[i] = B(i, j);
58 lu_solve(f, col, xcol);
59 for (idx i = 0; i < n; ++i)
60 X(i, j) = xcol[i];
61 }
62}
63
64// lu_det() -- determinant from diagonal of U and pivot parity
65
66real lu_det(const LUResult& f) {
67 const idx n = f.LU.rows();
68 real det = real(1);
69 for (idx i = 0; i < n; ++i)
70 det *= f.LU(i, i);
71 idx swaps = 0;
72 for (idx k = 0; k < n; ++k)
73 if (f.piv[k] != k)
74 ++swaps;
75 return (swaps % 2 == 0) ? det : -det;
76}
77
78// lu_inv() -- A^{-1} by solving A * X = I column by column
79
81 const idx n = f.LU.rows();
82 Matrix inv(n, n, real(0));
83 Vector e(n, real(0)), col(n);
84 for (idx j = 0; j < n; ++j) {
85 e[j] = real(1);
86 lu_solve(f, e, col);
87 for (idx i = 0; i < n; ++i)
88 inv(i, j) = col[i];
89 e[j] = real(0);
90 }
91 return inv;
92}
93
94} // namespace num
Dense row-major matrix with optional GPU storage.
Definition matrix.hpp:12
constexpr idx rows() const noexcept
Definition matrix.hpp:24
constexpr idx cols() const noexcept
Definition matrix.hpp:25
LU factorization with partial pivoting.
LUResult lu(const Matrix &A)
Definition lu.cpp:18
LUResult lu(const Matrix &A)
Definition lu.cpp:9
double real
Definition types.hpp:10
Backend
Selects which backend handles a linalg operation.
Definition policy.hpp:19
@ lapack
LAPACKE – industry-standard factorizations, SVD, eigen.
std::size_t idx
Definition types.hpp:11
real lu_det(const LUResult &f)
Determinant of A from its LU factorization.
Definition lu.cpp:66
constexpr real e
Definition math.hpp:43
Matrix lu_inv(const LUResult &f)
Inverse of A from its LU factorization.
Definition lu.cpp:80
void lu_solve(const LUResult &f, const Vector &b, Vector &x)
Solve A*x = b using a precomputed LU factorization.
Definition lu.cpp:28
LUResult lu(const Matrix &A, Backend backend=lapack_backend)
LU factorization of a square matrix A with partial pivoting.
Definition lu.cpp:17
Result of an LU factorization with partial pivoting (PA = LU)
Definition lu.hpp:19
Matrix LU
Definition lu.hpp:20
std::vector< idx > piv
Definition lu.hpp:21