numerics
Loading...
Searching...
No Matches
lu.cpp
Go to the documentation of this file.
2#include <cmath>
3#include <algorithm>
4
5namespace num {
6
7// lu() -- Doolittle algorithm with partial pivoting
8//
9// Gaussian elimination transforms A into U by subtracting multiples of pivot
10// rows from rows below. The multipliers l_{ik} = A(i,k)/A(k,k) are stored
11// in the lower triangle of the working matrix (they are the entries of L).
12//
13// Partial pivoting: at step k, scan column k from row k downward and swap
14// the largest-magnitude element into the pivot position. This guarantees
15// |L(i,j)| <= 1 for all i > j, bounding the growth factor to 2^{n-1} in
16// the worst case (random matrices grow far less in practice).
17//
18// After n steps the working matrix holds L (below diagonal) and U (diagonal
19// and above) packed together. The diagonal of L is implicitly 1.
20
22 constexpr real singular_tol = 1e-14;
23 const idx n = A.rows();
24 LUResult f;
25 f.LU = A; // working copy
26 f.piv.resize(n);
27 f.singular = false;
28
29 Matrix& M = f.LU;
30
31 for (idx k = 0; k < n; ++k) {
32
33 idx pivot_row = k;
34 real pivot_val = std::abs(M(k, k));
35 for (idx i = k + 1; i < n; ++i) {
36 real v = std::abs(M(i, k));
37 if (v > pivot_val) { pivot_val = v; pivot_row = i; }
38 }
39 f.piv[k] = pivot_row;
40
41 if (pivot_row != k)
42 for (idx j = 0; j < n; ++j)
43 std::swap(M(k, j), M(pivot_row, j));
44
45 if (std::abs(M(k, k)) < singular_tol) {
46 f.singular = true;
47 continue;
48 }
49
50 const real inv_ukk = real(1) / M(k, k);
51 for (idx i = k + 1; i < n; ++i)
52 M(i, k) *= inv_ukk;
53
54 // M[i,j] -= L[i,k] * U[k,j] for i,j > k
55 for (idx i = k + 1; i < n; ++i) {
56 const real lik = M(i, k);
57 for (idx j = k + 1; j < n; ++j)
58 M(i, j) -= lik * M(k, j);
59 }
60 }
61
62 return f;
63}
64
65// lu_solve() -- apply P, then forward/backward substitution
66
67void lu_solve(const LUResult& f, const Vector& b, Vector& x) {
68 const idx n = f.LU.rows();
69 const Matrix& M = f.LU;
70
71 Vector y = b; // working copy; will become x
72
73 for (idx k = 0; k < n; ++k)
74 if (f.piv[k] != k)
75 std::swap(y[k], y[f.piv[k]]);
76
77 for (idx i = 1; i < n; ++i)
78 for (idx j = 0; j < i; ++j)
79 y[i] -= M(i, j) * y[j];
80
81 for (idx i = n; i-- > 0; ) {
82 for (idx j = i + 1; j < n; ++j)
83 y[i] -= M(i, j) * y[j];
84 y[i] /= M(i, i);
85 }
86
87 x = std::move(y);
88}
89
90void lu_solve(const LUResult& f, const Matrix& B, Matrix& X) {
91 const idx nrhs = B.cols();
92 const idx n = B.rows();
93 Vector col(n), xcol(n);
94 for (idx j = 0; j < nrhs; ++j) {
95 for (idx i = 0; i < n; ++i) col[i] = B(i, j);
96 lu_solve(f, col, xcol);
97 for (idx i = 0; i < n; ++i) X(i, j) = xcol[i];
98 }
99}
100
101// lu_det() -- determinant from the diagonal of U and pivot parity
102//
103// det(A) = det(P^{-1}) * det(L) * det(U)
104// = (-1)^{swaps} * 1 * prod(U[i,i])
105//
106// Each transposition recorded in piv contributes a factor of -1.
107// Count only entries where piv[k] != k (an identity "swap" contributes +1).
108
110 const idx n = f.LU.rows();
111 real det = real(1);
112 for (idx i = 0; i < n; ++i)
113 det *= f.LU(i, i);
114 idx swaps = 0;
115 for (idx k = 0; k < n; ++k)
116 if (f.piv[k] != k) ++swaps;
117 return (swaps % 2 == 0) ? det : -det;
118}
119
120// lu_inv() -- A^{-1} by solving A * X = I column by column
121
123 const idx n = f.LU.rows();
124 Matrix inv(n, n, real(0));
125 Vector e(n, real(0)), col(n);
126 for (idx j = 0; j < n; ++j) {
127 e[j] = real(1);
128 lu_solve(f, e, col);
129 for (idx i = 0; i < n; ++i) inv(i, j) = col[i];
130 e[j] = real(0);
131 }
132 return inv;
133}
134
135} // namespace num
Dense row-major matrix with optional GPU storage.
Definition matrix.hpp:12
constexpr idx rows() const noexcept
Definition matrix.hpp:24
LU factorization with partial pivoting.
double real
Definition types.hpp:10
LUResult lu(const Matrix &A)
LU factorization of a square matrix A with partial pivoting.
Definition lu.cpp:21
constexpr T ipow(T x) noexcept
Compute x^N at compile time via repeated squaring.
std::size_t idx
Definition types.hpp:11
real lu_det(const LUResult &f)
Determinant of A from its LU factorization.
Definition lu.cpp:109
constexpr real e
Definition math.hpp:41
Matrix lu_inv(const LUResult &f)
Inverse of A from its LU factorization.
Definition lu.cpp:122
void lu_solve(const LUResult &f, const Vector &b, Vector &x)
Solve A*x = b using a precomputed LU factorization.
Definition lu.cpp:67
Result of an LU factorization with partial pivoting (PA = LU)
Definition lu.hpp:19
Matrix LU
Definition lu.hpp:20