numerics 0.1.0
Loading...
Searching...
No Matches
thomas.cpp
Go to the documentation of this file.
1/// @file linalg/factorization/thomas.cpp
2/// @brief Thomas tridiagonal solver dispatcher.
3
5#include "backends/lapack/impl.hpp"
6#include "backends/seq/impl.hpp"
8#include <stdexcept>
9
10namespace num {
11
12void thomas(const Vector& a,
13 const Vector& b,
14 const Vector& c,
15 const Vector& d,
16 Vector& x,
17 Backend backend) {
18 idx n = b.size();
19 if (a.size() != n - 1 || c.size() != n - 1 || d.size() != n || x.size() != n)
20 throw std::invalid_argument("Dimension mismatch in Thomas solver");
21
22 switch (backend) {
23 case Backend::lapack:
24 backends::lapack::thomas(a, b, c, d, x);
25 return;
26 case Backend::gpu:
27#ifdef NUMERICS_HAS_CUDA
28 {
29 Vector ag = a;
30 ag.to_gpu();
31 Vector bg = b;
32 bg.to_gpu();
33 Vector cg = c;
34 cg.to_gpu();
35 Vector dg = d;
36 dg.to_gpu();
37 x = Vector(n);
38 x.to_gpu();
40 bg.gpu_data(),
41 cg.gpu_data(),
42 dg.gpu_data(),
43 x.gpu_data(),
44 n,
45 1);
46 x.to_cpu();
47 return;
48 }
49#endif
50 [[fallthrough]];
51 default:
52 backends::seq::thomas(a, b, c, d, x);
53 return;
54 }
55}
56
57} // namespace num
real * gpu_data()
Definition vector.hpp:118
constexpr idx size() const noexcept
Definition vector.hpp:83
CUDA kernel wrappers.
void thomas(const Vector &a, const Vector &b, const Vector &c, const Vector &d, Vector &x)
Definition thomas.cpp:15
void thomas(const Vector &a, const Vector &b, const Vector &c, const Vector &d, Vector &x)
Definition thomas.cpp:8
void thomas_batched(const real *a, const real *b, const real *c, const real *d, real *x, idx n, idx batch_size)
Batched Thomas algorithm for tridiagonal systems.
Backend
Definition policy.hpp:7
std::size_t idx
Definition types.hpp:11
void thomas(const Vector &a, const Vector &b, const Vector &c, const Vector &d, Vector &x, Backend backend=lapack_backend)
Definition thomas.cpp:12
BasicVector< real > Vector
Real-valued dense vector with full backend dispatch (CPU + GPU)
Definition vector.hpp:129
SolverResult cg(const Matrix &A, const Vector &b, Vector &x, real tol=1e-10, idx max_iter=1000, Backend backend=default_backend)
Definition cg.cpp:8
Thomas algorithm for tridiagonal systems.