numerics 0.1.0
Loading...
Searching...
No Matches
cell_list.hpp
Go to the documentation of this file.
1/// @file cell_list.hpp
2/// @brief Cache-coherent 2D cell list for O(1) amortized neighbour queries
3#pragma once
4
5#include <algorithm>
6#include <cassert>
7#include <cmath>
8#include <utility>
9#include <vector>
10
11namespace num {
12
13struct IntRange {
14 const int* first;
15 const int* last;
16 const int* begin() const noexcept { return first; }
17 const int* end() const noexcept { return last; }
18 int size() const noexcept { return static_cast<int>(last - first); }
19 bool empty() const noexcept { return first == last; }
20};
21
22template<typename Scalar>
24 public:
25 CellList2D(Scalar cell_size, Scalar xmin, Scalar xmax, Scalar ymin, Scalar ymax)
26 : cs_(cell_size),
27 xmin_(xmin),
28 ymin_(ymin) {
29 // Padding cells avoid boundary checks in the hot query loops.
30 nx_ = static_cast<int>(std::ceil((xmax - xmin) / cs_)) + 2;
31 ny_ = static_cast<int>(std::ceil((ymax - ymin) / cs_)) + 2;
32 const int total = nx_ * ny_;
33 start_.assign(total + 1, 0);
34 count_.assign(total, 0);
35 }
36
37 /// @brief Rebuild by counting-sort over cell ids.
38 template<typename PosAccessor>
39 void build(PosAccessor&& get_pos, int n) {
40 sorted_.resize(n);
41 const int total = nx_ * ny_;
42
43 std::fill(count_.begin(), count_.end(), 0);
44 for (int i = 0; i < n; ++i)
45 ++count_[cell_id_of(get_pos(i))];
46
47 start_[0] = 0;
48 for (int c = 0; c < total; ++c)
49 start_[c + 1] = start_[c] + count_[c];
50
51 std::fill(count_.begin(), count_.end(), 0);
52 for (int i = 0; i < n; ++i) {
53 const int cid = cell_id_of(get_pos(i));
54 sorted_[start_[cid] + count_[cid]] = i;
55 ++count_[cid];
56 }
57 }
58
59 /// @brief Call f(j) for candidate particles near (px, py).
60 template<typename F>
61 void query(Scalar px, Scalar py, F&& f) const {
62 const int cx = cell_x(px);
63 const int cy = cell_y(py);
64 for (int dy = -1; dy <= 1; ++dy) {
65 const int qy = cy + dy;
66 if (qy < 0 || qy >= ny_)
67 continue;
68 for (int dx = -1; dx <= 1; ++dx) {
69 const int qx = cx + dx;
70 if (qx < 0 || qx >= nx_)
71 continue;
72 const int cid = qy * nx_ + qx;
73 for (int k = start_[cid]; k < start_[cid + 1]; ++k)
74 f(sorted_[k]);
75 }
76 }
77 }
78
79 /// @brief Visit each candidate pair once.
80 template<typename F>
81 void iterate_pairs(F&& f) const {
82 // Half-shell offsets cover all neighboring cell pairs once.
83 static constexpr int FDX[4] = {+1, 0, +1, -1};
84 static constexpr int FDY[4] = {0, +1, +1, +1};
85
86 for (int cy = 0; cy < ny_; ++cy) {
87 for (int cx = 0; cx < nx_; ++cx) {
88 const int cid = cy * nx_ + cx;
89 const int beg = start_[cid];
90 const int end = start_[cid + 1];
91 if (beg == end)
92 continue;
93
94 for (int a = beg; a < end; ++a) {
95 for (int b = a + 1; b < end; ++b) {
96 f(sorted_[a], sorted_[b]);
97 }
98 }
99
100 for (int d = 0; d < 4; ++d) {
101 const int ncx = cx + FDX[d];
102 const int ncy = cy + FDY[d];
103 if (ncx < 0 || ncx >= nx_ || ncy < 0 || ncy >= ny_)
104 continue;
105 const int ncid = ncy * nx_ + ncx;
106 const int nbeg = start_[ncid];
107 const int nend = start_[ncid + 1];
108 if (nbeg == nend)
109 continue;
110 for (int a = beg; a < end; ++a) {
111 for (int b = nbeg; b < nend; ++b) {
112 f(sorted_[a], sorted_[b]);
113 }
114 }
115 }
116 }
117 }
118 }
119
120 IntRange cell_particles(int cx, int cy) const noexcept {
121 const int cid = cy * nx_ + cx;
122 return {sorted_.data() + start_[cid], sorted_.data() + start_[cid + 1]};
123 }
124
125 int nx() const noexcept { return nx_; }
126 int ny() const noexcept { return ny_; }
127 int n_particles() const noexcept { return static_cast<int>(sorted_.size()); }
128
129 private:
130 Scalar cs_ = 0, xmin_ = 0, ymin_ = 0;
131 int nx_ = 0, ny_ = 0;
132
133 std::vector<int> sorted_;
134 std::vector<int> start_;
135 std::vector<int> count_;
136
137 int cell_x(Scalar x) const noexcept {
138 const int cx = static_cast<int>(std::floor((x - xmin_) / cs_)) + 1;
139 return cx < 0 ? 0 : (cx >= nx_ ? nx_ - 1 : cx);
140 }
141 int cell_y(Scalar y) const noexcept {
142 const int cy = static_cast<int>(std::floor((y - ymin_) / cs_)) + 1;
143 return cy < 0 ? 0 : (cy >= ny_ ? ny_ - 1 : cy);
144 }
145 int cell_id_of(std::pair<Scalar, Scalar> p) const noexcept {
146 return cell_y(p.second) * nx_ + cell_x(p.first);
147 }
148};
149
150} // namespace num
int nx() const noexcept
void query(Scalar px, Scalar py, F &&f) const
Call f(j) for candidate particles near (px, py).
Definition cell_list.hpp:61
void iterate_pairs(F &&f) const
Visit each candidate pair once.
Definition cell_list.hpp:81
CellList2D(Scalar cell_size, Scalar xmin, Scalar xmax, Scalar ymin, Scalar ymax)
Definition cell_list.hpp:25
int n_particles() const noexcept
IntRange cell_particles(int cx, int cy) const noexcept
int ny() const noexcept
void build(PosAccessor &&get_pos, int n)
Rebuild by counting-sort over cell ids.
Definition cell_list.hpp:39
const int * begin() const noexcept
Definition cell_list.hpp:16
bool empty() const noexcept
Definition cell_list.hpp:19
const int * end() const noexcept
Definition cell_list.hpp:17
const int * last
Definition cell_list.hpp:15
int size() const noexcept
Definition cell_list.hpp:18
const int * first
Definition cell_list.hpp:14