numerics 0.1.0
Loading...
Searching...
No Matches
sparse.cpp
Go to the documentation of this file.
2#include <algorithm>
3#include <numeric>
4#include <stdexcept>
5
6namespace num {
7
9 idx n_cols,
10 std::vector<real> vals,
11 std::vector<idx> col_idx,
12 std::vector<idx> row_ptr)
13 : n_rows_(n_rows)
14 , n_cols_(n_cols)
15 , vals_(std::move(vals))
16 , col_idx_(std::move(col_idx))
17 , row_ptr_(std::move(row_ptr)) {
18 if (row_ptr_.size() != n_rows_ + 1)
19 throw std::invalid_argument(
20 "SparseMatrix: row_ptr must have length n_rows+1");
21 if (col_idx_.size() != vals_.size())
22 throw std::invalid_argument(
23 "SparseMatrix: col_idx and vals must have equal length");
24}
25
27 idx n_cols,
28 const std::vector<idx>& rows,
29 const std::vector<idx>& cols,
30 const std::vector<real>& vals) {
31 if (rows.size() != cols.size() || rows.size() != vals.size())
32 throw std::invalid_argument(
33 "SparseMatrix::from_triplets: inconsistent input sizes");
34
35 // Count entries per row
36 std::vector<idx> row_count(n_rows, 0);
37 for (idx k = 0; k < rows.size(); ++k) {
38 if (rows[k] >= n_rows || cols[k] >= n_cols)
39 throw std::out_of_range(
40 "SparseMatrix::from_triplets: index out of range");
41 ++row_count[rows[k]];
42 }
43
44 // Build row_ptr
45 std::vector<idx> row_ptr(n_rows + 1, 0);
46 for (idx i = 0; i < n_rows; ++i)
47 row_ptr[i + 1] = row_ptr[i] + row_count[i];
48
50 std::vector<real> out_vals(nnz, 0.0);
51 std::vector<idx> out_col(nnz);
52
53 // Fill entries (stable insertion within each row)
54 std::vector<idx> fill_pos = row_ptr;
55 for (idx k = 0; k < rows.size(); ++k) {
56 idx pos = fill_pos[rows[k]]++;
57 out_col[pos] = cols[k];
58 out_vals[pos] = vals[k];
59 }
60
61 // Sort each row by column and sum duplicates
62 for (idx i = 0; i < n_rows; ++i) {
63 idx start = row_ptr[i], end = row_ptr[i + 1];
64 // Sort by column index
65 std::vector<idx> order(end - start);
66 std::iota(order.begin(), order.end(), 0);
67 std::sort(order.begin(), order.end(), [&](idx a, idx b) {
68 return out_col[start + a] < out_col[start + b];
69 });
70
71 std::vector<real> sv(end - start);
72 std::vector<idx> sc(end - start);
73 for (idx k = 0; k < order.size(); ++k) {
74 sv[k] = out_vals[start + order[k]];
75 sc[k] = out_col[start + order[k]];
76 }
77 for (idx k = 0; k < order.size(); ++k) {
78 out_vals[start + k] = sv[k];
79 out_col[start + k] = sc[k];
80 }
81
82 // Sum duplicates in-place
83 idx write = start;
84 for (idx k = start; k < end;) {
85 idx cur_col = out_col[k];
86 real sum = 0.0;
87 while (k < end && out_col[k] == cur_col)
88 sum += out_vals[k++];
89 out_col[write] = cur_col;
90 out_vals[write++] = sum;
91 }
92 // Compact row_ptr if duplicates were merged
93 row_ptr[i + 1] = write;
94 // Shift remaining rows' data (rare; only matters if duplicates exist)
95 if (write < end) {
96 for (idx k = end; k < nnz; ++k) {
97 out_vals[write + (k - end)] = out_vals[k];
98 out_col[write + (k - end)] = out_col[k];
99 }
100 nnz -= (end - write);
101 out_vals.resize(nnz);
102 out_col.resize(nnz);
103 // Fix subsequent row_ptr entries
104 idx delta = end - write;
105 for (idx r = i + 2; r <= n_rows; ++r)
106 row_ptr[r] -= delta;
107 }
108 }
109
110 return SparseMatrix(n_rows,
111 n_cols,
112 std::move(out_vals),
113 std::move(out_col),
114 std::move(row_ptr));
115}
116
118 for (idx k = row_ptr_[i]; k < row_ptr_[i + 1]; ++k)
119 if (col_idx_[k] == j)
120 return vals_[k];
121 return 0.0;
122}
123
124void sparse_matvec(const SparseMatrix& A, const Vector& x, Vector& y) {
125 if (A.n_cols() != x.size() || A.n_rows() != y.size())
126 throw std::invalid_argument("Dimension mismatch in sparse_matvec");
127 for (idx i = 0; i < A.n_rows(); ++i) {
128 real sum = 0.0;
129 for (idx k = A.row_ptr()[i]; k < A.row_ptr()[i + 1]; ++k)
130 sum += A.values()[k] * x[A.col_idx()[k]];
131 y[i] = sum;
132 }
133}
134
135} // namespace num
constexpr idx size() const noexcept
Definition vector.hpp:80
Sparse matrix in Compressed Sparse Row (CSR) format.
Definition sparse.hpp:15
real operator()(idx i, idx j) const
Element access A(i,j); returns 0 if outside stored pattern – O(nnz/n)
Definition sparse.cpp:117
idx nnz() const
Definition sparse.hpp:39
SparseMatrix(idx n_rows, idx n_cols, std::vector< real > vals, std::vector< idx > col_idx, std::vector< idx > row_ptr)
Construct from raw CSR arrays (takes ownership)
Definition sparse.cpp:8
static SparseMatrix from_triplets(idx n_rows, idx n_cols, const std::vector< idx > &rows, const std::vector< idx > &cols, const std::vector< real > &vals)
Build from coordinate (COO / triplet) lists.
Definition sparse.cpp:26
idx n_cols() const
Definition sparse.hpp:36
const idx * row_ptr() const
Definition sparse.hpp:53
idx n_rows() const
Definition sparse.hpp:33
const idx * col_idx() const
Definition sparse.hpp:50
const real * values() const
Definition sparse.hpp:47
double real
Definition types.hpp:10
std::size_t idx
Definition types.hpp:11
void sparse_matvec(const SparseMatrix &A, const Vector &x, Vector &y)
y = A * x
Definition sparse.cpp:124
Compressed Sparse Row (CSR) matrix and operations.