numerics 0.1.0
Loading...
Searching...
No Matches
cell_list_3d.hpp
Go to the documentation of this file.
1/// @file cell_list_3d.hpp
2/// @brief Cache-coherent 3D cell list for O(1) amortized neighbour queries
3#pragma once
4
5#include <algorithm>
6#include <cassert>
7#include <cmath>
8#include <tuple>
9#include <vector>
10
11namespace num {
12
13template<typename Scalar>
15 public:
16 CellList3D(Scalar cell_size,
17 Scalar xmin,
18 Scalar xmax,
19 Scalar ymin,
20 Scalar ymax,
21 Scalar zmin,
22 Scalar zmax)
23 : cs_(cell_size),
24 xmin_(xmin),
25 ymin_(ymin),
26 zmin_(zmin) {
27 nx_ = static_cast<int>(std::ceil((xmax - xmin) / cs_)) + 2;
28 ny_ = static_cast<int>(std::ceil((ymax - ymin) / cs_)) + 2;
29 nz_ = static_cast<int>(std::ceil((zmax - zmin) / cs_)) + 2;
30 const int total = nx_ * ny_ * nz_;
31 start_.assign(total + 1, 0);
32 count_.assign(total, 0);
33 }
34
35 /// @brief Rebuild by counting-sort over cell ids.
36 template<typename PosAccessor>
37 void build(PosAccessor&& get_pos, int n) {
38 sorted_.resize(n);
39 const int total = nx_ * ny_ * nz_;
40
41 std::fill(count_.begin(), count_.end(), 0);
42 for (int i = 0; i < n; ++i)
43 ++count_[cell_id_of(get_pos(i))];
44
45 start_[0] = 0;
46 for (int c = 0; c < total; ++c)
47 start_[c + 1] = start_[c] + count_[c];
48
49 std::fill(count_.begin(), count_.end(), 0);
50 for (int i = 0; i < n; ++i) {
51 const int cid = cell_id_of(get_pos(i));
52 sorted_[start_[cid] + count_[cid]] = i;
53 ++count_[cid];
54 }
55 }
56
57 /// @brief Call f(j) for candidate particles near (px, py, pz).
58 template<typename F>
59 void query(Scalar px, Scalar py, Scalar pz, F&& f) const {
60 const int cx = cell_x(px);
61 const int cy = cell_y(py);
62 const int cz = cell_z(pz);
63 for (int dz = -1; dz <= 1; ++dz) {
64 const int qz = cz + dz;
65 if (qz < 0 || qz >= nz_)
66 continue;
67 for (int dy = -1; dy <= 1; ++dy) {
68 const int qy = cy + dy;
69 if (qy < 0 || qy >= ny_)
70 continue;
71 for (int dx = -1; dx <= 1; ++dx) {
72 const int qx = cx + dx;
73 if (qx < 0 || qx >= nx_)
74 continue;
75 const int cid = (qz * ny_ + qy) * nx_ + qx;
76 for (int k = start_[cid]; k < start_[cid + 1]; ++k)
77 f(sorted_[k]);
78 }
79 }
80 }
81 }
82
83 /// @brief Visit each candidate pair once.
84 template<typename F>
85 void iterate_pairs(F&& f) const {
86 // Half-shell offsets cover all neighboring cell pairs once.
87 static constexpr int FDX[13] = {-1, 0, 1, -1, 0, 1, -1, 0, 1, -1, 0, 1, 1};
88 static constexpr int FDY[13] = {-1, -1, -1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0};
89 static constexpr int FDZ[13] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0};
90
91 for (int cz = 0; cz < nz_; ++cz) {
92 for (int cy = 0; cy < ny_; ++cy) {
93 for (int cx = 0; cx < nx_; ++cx) {
94 const int cid = (cz * ny_ + cy) * nx_ + cx;
95 const int beg = start_[cid];
96 const int end = start_[cid + 1];
97 if (beg == end)
98 continue;
99
100 // Intra-cell pairs
101 for (int a = beg; a < end; ++a)
102 for (int b = a + 1; b < end; ++b)
103 f(sorted_[a], sorted_[b]);
104
105 // Inter-cell: self x 13 forward neighbours
106 for (int d = 0; d < 13; ++d) {
107 const int ncx = cx + FDX[d];
108 const int ncy = cy + FDY[d];
109 const int ncz = cz + FDZ[d];
110 if (ncx < 0 || ncx >= nx_ || ncy < 0 || ncy >= ny_ || ncz < 0
111 || ncz >= nz_)
112 continue;
113 const int ncid = (ncz * ny_ + ncy) * nx_ + ncx;
114 const int nbeg = start_[ncid];
115 const int nend = start_[ncid + 1];
116 if (nbeg == nend)
117 continue;
118 for (int a = beg; a < end; ++a)
119 for (int b = nbeg; b < nend; ++b)
120 f(sorted_[a], sorted_[b]);
121 }
122 }
123 }
124 }
125 }
126
127 int nx() const noexcept { return nx_; }
128 int ny() const noexcept { return ny_; }
129 int nz() const noexcept { return nz_; }
130 int n_particles() const noexcept { return static_cast<int>(sorted_.size()); }
131
132 private:
133 Scalar cs_ = 0, xmin_ = 0, ymin_ = 0, zmin_ = 0;
134 int nx_ = 0, ny_ = 0, nz_ = 0;
135
136 std::vector<int> sorted_, start_, count_;
137
138 int cell_x(Scalar x) const noexcept {
139 const int cx = static_cast<int>(std::floor((x - xmin_) / cs_)) + 1;
140 return cx < 0 ? 0 : (cx >= nx_ ? nx_ - 1 : cx);
141 }
142 int cell_y(Scalar y) const noexcept {
143 const int cy = static_cast<int>(std::floor((y - ymin_) / cs_)) + 1;
144 return cy < 0 ? 0 : (cy >= ny_ ? ny_ - 1 : cy);
145 }
146 int cell_z(Scalar z) const noexcept {
147 const int cz = static_cast<int>(std::floor((z - zmin_) / cs_)) + 1;
148 return cz < 0 ? 0 : (cz >= nz_ ? nz_ - 1 : cz);
149 }
150 int cell_id_of(std::tuple<Scalar, Scalar, Scalar> p) const noexcept {
151 const auto [x, y, z] = p;
152 return (cell_z(z) * ny_ + cell_y(y)) * nx_ + cell_x(x);
153 }
154};
155
156} // namespace num
void iterate_pairs(F &&f) const
Visit each candidate pair once.
int n_particles() const noexcept
int nx() const noexcept
void build(PosAccessor &&get_pos, int n)
Rebuild by counting-sort over cell ids.
CellList3D(Scalar cell_size, Scalar xmin, Scalar xmax, Scalar ymin, Scalar ymax, Scalar zmin, Scalar zmax)
int nz() const noexcept
int ny() const noexcept
void query(Scalar px, Scalar py, Scalar pz, F &&f) const
Call f(j) for candidate particles near (px, py, pz).