implement barnes-hut particle repulsion using octree

This commit is contained in:
2026-02-22 23:29:56 +01:00
parent f07f2772c4
commit e43e505110
5 changed files with 302 additions and 8 deletions

View File

@ -5,6 +5,7 @@
#define PRINT_TIMINGS #define PRINT_TIMINGS
// #define WEB // #define WEB
#define BARNES_HUT
// Window // Window
constexpr int INITIAL_WIDTH = 800; constexpr int INITIAL_WIDTH = 800;
@ -35,8 +36,10 @@ constexpr float MASS = 1.0;
constexpr float SPRING_CONSTANT = 5.0; constexpr float SPRING_CONSTANT = 5.0;
constexpr float DAMPENING_CONSTANT = 1.0; constexpr float DAMPENING_CONSTANT = 1.0;
constexpr float REST_LENGTH = 2.0; constexpr float REST_LENGTH = 2.0;
constexpr float REPULSION_FORCE = 0.1; constexpr float REPULSION_FORCE = 2.0;
constexpr float REPULSION_RANGE = 5.0 * REST_LENGTH; constexpr float REPULSION_RANGE = 5.0 * REST_LENGTH;
constexpr float THETA = 1.0; // Barnes-Hut [0.5, ~]
constexpr float SOFTENING = 0.01; // Barnes-Hut [0.01, 1.0]
constexpr int REPULSION_GRID_REFRESH = 5; // Updates between grid rebuilds constexpr int REPULSION_GRID_REFRESH = 5; // Updates between grid rebuilds
constexpr float VERLET_DAMPENING = 0.05; // [0, 1] constexpr float VERLET_DAMPENING = 0.05; // [0, 1]

53
include/octree.hpp Normal file
View File

@ -0,0 +1,53 @@
#ifndef __OCTREE_HPP_
#define __OCTREE_HPP_
#include <raylib.h>
#include <raymath.h>
#include <vector>
class OctreeNode {
public:
Vector3 mass_center;
float mass_total;
Vector3 box_min; // area start
Vector3 box_max; // area end
int children[8];
int mass_id;
bool leaf;
public:
OctreeNode()
: mass_center(Vector3Zero()), mass_total(0.0),
children(-1, -1, -1, -1, -1, -1, -1, -1), mass_id(-1), leaf(true) {}
~OctreeNode() {}
};
class Octree {
public:
std::vector<OctreeNode> nodes;
public:
Octree() {}
Octree(const Octree &copy) = delete;
Octree &operator=(const Octree &copy) = delete;
Octree(Octree &&move) = delete;
Octree &operator=(Octree &&move) = delete;
~Octree() {}
public:
auto CreateNode(const Vector3 &box_min, const Vector3 &box_max) -> int;
auto GetOctant(int node_idx, const Vector3 &pos) -> int;
auto GetChildBounds(int node_idx, int octant) -> std::pair<Vector3, Vector3>;
auto Insert(int node_idx, int mass_id, const Vector3 &pos, float mass)
-> void;
auto CalculateForce(int node_idx, const Vector3 &pos) -> Vector3;
};
#endif

View File

@ -9,6 +9,10 @@
#include "config.hpp" #include "config.hpp"
#include "puzzle.hpp" #include "puzzle.hpp"
#ifdef BARNES_HUT
#include "octree.hpp"
#endif
class Mass { class Mass {
public: public:
const float mass; const float mass;
@ -53,22 +57,32 @@ public:
class MassSpringSystem { class MassSpringSystem {
private: private:
// Uniform grid
std::vector<Mass *> mass_pointers; std::vector<Mass *> mass_pointers;
#ifdef BARNES_HUT
// Barnes-Hut
Octree octree;
#else
// Uniform grid
std::vector<int> mass_indices; std::vector<int> mass_indices;
std::vector<int64_t> cell_ids; std::vector<int64_t> cell_ids;
int last_build; int last_build;
int last_masses_count; int last_masses_count;
int last_springs_count; int last_springs_count;
#endif
public: public:
// This is the main ownership of all the states/masses/springs. // This is the main ownership of all the states/masses/springs.
// Everything is stored multiple times but idc. // TODO: Everything is stored multiple times but idc (currently).
std::unordered_map<State, Mass> masses; std::unordered_map<State, Mass> masses;
std::unordered_map<std::pair<State, State>, Spring> springs; std::unordered_map<std::pair<State, State>, Spring> springs;
public: public:
MassSpringSystem() : last_build(REPULSION_GRID_REFRESH) {}; MassSpringSystem() {
#ifndef BARNES_HUT
last_build = REPULSION_GRID_REFRESH;
#endif
};
MassSpringSystem(const MassSpringSystem &copy) = delete; MassSpringSystem(const MassSpringSystem &copy) = delete;
MassSpringSystem &operator=(const MassSpringSystem &copy) = delete; MassSpringSystem &operator=(const MassSpringSystem &copy) = delete;
@ -78,7 +92,11 @@ public:
~MassSpringSystem() {}; ~MassSpringSystem() {};
private: private:
#ifdef BARNES_HUT
auto BuildOctree() -> void;
#else
auto BuildUniformGrid() -> void; auto BuildUniformGrid() -> void;
#endif
public: public:
auto AddMass(float mass, bool fixed, const State &state) -> void; auto AddMass(float mass, bool fixed, const State &state) -> void;
@ -100,7 +118,9 @@ public:
auto VerletUpdate(float delta_time) -> void; auto VerletUpdate(float delta_time) -> void;
#ifndef BARNES_HUT
auto InvalidateGrid() -> void; auto InvalidateGrid() -> void;
#endif
}; };
#endif #endif

148
src/octree.cpp Normal file
View File

@ -0,0 +1,148 @@
#include "octree.hpp"
#include "config.hpp"
#include <raymath.h>
auto Octree::CreateNode(const Vector3 &box_min, const Vector3 &box_max) -> int {
OctreeNode node;
node.box_min = box_min;
node.box_max = box_max;
nodes.push_back(node);
return nodes.size() - 1;
}
auto Octree::GetOctant(int node_idx, const Vector3 &pos) -> int {
OctreeNode &node = nodes[node_idx];
Vector3 center = Vector3((node.box_min.x + node.box_max.x) / 2.0,
(node.box_min.y + node.box_max.y) / 2.0,
(node.box_min.z + node.box_max.z) / 2.0);
// The octant is encoded as a 3-bit integer "zyx". The node area is split
// along all 3 axes, if a position is right of an axis, this bit is set to 1.
// If a position is right of the x-axis and y-axis and left of the z-axis, the
// encoded octant is "011".
int octant = 0;
if (pos.x >= center.x) {
octant |= 1;
}
if (pos.y >= center.y) {
octant |= 2;
}
if (pos.z >= center.z) {
octant |= 4;
}
return octant;
}
auto Octree::GetChildBounds(int node_idx, int octant)
-> std::pair<Vector3, Vector3> {
OctreeNode &node = nodes[node_idx];
Vector3 center = Vector3((node.box_min.x + node.box_max.x) / 2.0,
(node.box_min.y + node.box_max.y) / 2.0,
(node.box_min.z + node.box_max.z) / 2.0);
Vector3 min;
Vector3 max;
// If (octant & 1), the octant is to the right of the node region's x-axis
// (see GetOctant). This means the left bound is the x-axis and the right
// bound the node's region max.
min.x = (octant & 1) ? center.x : node.box_min.x;
max.x = (octant & 1) ? node.box_max.x : center.x;
min.y = (octant & 2) ? center.y : node.box_min.y;
max.y = (octant & 2) ? node.box_max.y : center.y;
min.z = (octant & 4) ? center.z : node.box_min.z;
max.z = (octant & 4) ? node.box_max.z : center.z;
return std::make_pair(min, max);
}
auto Octree::Insert(int node_idx, int mass_id, const Vector3 &pos, float mass)
-> void {
OctreeNode &node = nodes[node_idx];
if (node.leaf && node.mass_id == -1) {
// We can place the particle in the empty leaf
node.mass_id = mass_id;
node.mass_center = pos;
node.mass_total = mass;
return;
}
if (node.leaf) {
// The leaf is occupied, we need to subdivide
int existing_id = node.mass_id;
Vector3 existing_pos = node.mass_center;
float existing_mass = node.mass_total;
node.mass_id = -1;
node.leaf = false;
// Re-add the existing mass into a new empty leaf (see above)
int oct = GetOctant(node_idx, existing_pos);
if (node.children[oct] == -1) {
auto [min, max] = GetChildBounds(node_idx, oct);
node.children[oct] = CreateNode(min, max);
}
Insert(node.children[oct], existing_id, existing_pos, existing_mass);
}
// Insert the new mass
int oct = GetOctant(node_idx, pos);
if (nodes[node_idx].children[oct] == -1) {
auto [min, max] = GetChildBounds(node_idx, oct);
nodes[node_idx].children[oct] = CreateNode(min, max);
}
Insert(nodes[node_idx].children[oct], mass_id, pos, mass);
// Update the center of mass
node = nodes[node_idx];
float new_mass = node.mass_total + mass;
node.mass_center.x =
(node.mass_center.x * node.mass_total + pos.x) / new_mass;
node.mass_center.y =
(node.mass_center.y * node.mass_total + pos.y) / new_mass;
node.mass_center.z =
(node.mass_center.z * node.mass_total + pos.z) / new_mass;
node.mass_total = new_mass;
}
auto Octree::CalculateForce(int node_idx, const Vector3 &pos) -> Vector3 {
if (node_idx < 0) {
return Vector3Zero();
}
OctreeNode &node = nodes[node_idx];
if (node.mass_total == 0.0f) {
return Vector3Zero();
}
Vector3 diff = Vector3Subtract(pos, node.mass_center);
float dist_sq = diff.x * diff.x + diff.y * diff.y + diff.z * diff.z;
// Softening
dist_sq += SOFTENING;
float size = node.box_max.x - node.box_min.x;
// Barnes-Hut
if (node.leaf || (size * size / dist_sq) < (THETA * THETA)) {
float dist = std::sqrt(dist_sq);
float force_mag = REPULSION_FORCE * node.mass_total / dist_sq;
return Vector3Scale(diff, force_mag / dist);
}
// Collect child forces
Vector3 force = Vector3Zero();
for (int i = 0; i < 8; ++i) {
if (node.children[i] >= 0) {
Vector3 child_force = CalculateForce(node.children[i], pos);
force = Vector3Add(force, child_force);
}
}
return force;
}

View File

@ -2,13 +2,18 @@
#include "config.hpp" #include "config.hpp"
#include <algorithm> #include <algorithm>
#include <numeric> #include <cfloat>
#include <raylib.h> #include <raylib.h>
#include <raymath.h> #include <raymath.h>
#include <tracy/Tracy.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#ifndef BARNES_HUT
#include <numeric>
#endif
auto Mass::ClearForce() -> void { force = Vector3Zero(); } auto Mass::ClearForce() -> void { force = Vector3Zero(); }
auto Mass::CalculateVelocity(const float delta_time) -> void { auto Mass::CalculateVelocity(const float delta_time) -> void {
@ -121,7 +126,9 @@ auto MassSpringSystem::AddSpring(const State &massA, const State &massB,
auto MassSpringSystem::Clear() -> void { auto MassSpringSystem::Clear() -> void {
masses.clear(); masses.clear();
springs.clear(); springs.clear();
#ifndef BARNES_HUT
InvalidateGrid(); InvalidateGrid();
#endif
} }
auto MassSpringSystem::ClearForces() -> void { auto MassSpringSystem::ClearForces() -> void {
@ -136,8 +143,52 @@ auto MassSpringSystem::CalculateSpringForces() -> void {
} }
} }
#ifdef BARNES_HUT
auto MassSpringSystem::BuildOctree() -> void {
octree.nodes.clear();
octree.nodes.reserve(masses.size() * 2);
// Compute bounding box around all masses
Vector3 min = Vector3(FLT_MAX, FLT_MAX, FLT_MAX);
Vector3 max = Vector3(-FLT_MAX, -FLT_MAX, -FLT_MAX);
for (const auto &[state, mass] : masses) {
min.x = std::min(min.x, mass.position.x);
max.x = std::max(max.x, mass.position.x);
min.y = std::min(min.y, mass.position.y);
max.y = std::max(max.y, mass.position.y);
min.z = std::min(min.z, mass.position.z);
max.z = std::max(max.z, mass.position.z);
}
// Pad the bounding box
float pad = 1.0;
min = Vector3Subtract(min, Vector3Scale(Vector3One(), pad));
max = Vector3Add(max, Vector3Scale(Vector3One(), pad));
// Make it cubic (so subdivisions are balanced)
float max_extent = std::max({max.x - min.x, max.y - min.y, max.z - min.z});
max = Vector3Add(min, Vector3Scale(Vector3One(), max_extent));
// Root node spans the entire area
int root = octree.CreateNode(min, max);
// Use a vector of pointers to the masses, because we can't parallelize the
// range-based for loop over the masses unordered_map using OpenMP.
mass_pointers.clear();
mass_pointers.reserve(masses.size());
for (auto &[state, mass] : masses) {
mass_pointers.push_back(&mass);
}
for (int i = 0; i < mass_pointers.size(); ++i) {
octree.Insert(root, i, mass_pointers[i]->position, mass_pointers[i]->mass);
}
}
#else
auto MassSpringSystem::BuildUniformGrid() -> void { auto MassSpringSystem::BuildUniformGrid() -> void {
// Collect pointers to all masses // Use a vector of pointers to masses, because we can't parallelize the
// range-based for loop over the masses unordered_map using OpenMP.
mass_pointers.clear(); mass_pointers.clear();
mass_pointers.reserve(masses.size()); mass_pointers.reserve(masses.size());
for (auto &[state, mass] : masses) { for (auto &[state, mass] : masses) {
@ -173,8 +224,24 @@ auto MassSpringSystem::BuildUniformGrid() -> void {
cell_ids[i] = cell_id(mass_pointers[mass_indices[i]]->position); cell_ids[i] = cell_id(mass_pointers[mass_indices[i]]->position);
} }
} }
#endif
auto MassSpringSystem::CalculateRepulsionForces() -> void { auto MassSpringSystem::CalculateRepulsionForces() -> void {
ZoneScoped;
#ifdef BARNES_HUT
BuildOctree();
// Calculate forces using Barnes-Hut
#pragma omp parallel for schedule(dynamic, 256)
for (int i = 0; i < mass_pointers.size(); ++i) {
int root = 0;
Vector3 force = octree.CalculateForce(root, mass_pointers[i]->position);
mass_pointers[i]->force = Vector3Add(mass_pointers[i]->force, force);
}
#else
// Refresh grid if necessary // Refresh grid if necessary
if (last_build >= REPULSION_GRID_REFRESH || if (last_build >= REPULSION_GRID_REFRESH ||
masses.size() != last_masses_count || masses.size() != last_masses_count ||
@ -186,8 +253,8 @@ auto MassSpringSystem::CalculateRepulsionForces() -> void {
} }
last_build++; last_build++;
// TODO: Use Barnes-Hut + Octree // Calculate forces using uniform grid
#pragma omp parallel for #pragma omp parallel for schedule(dynamic, 256)
// Search the neighboring cells for each mass to calculate repulsion forces // Search the neighboring cells for each mass to calculate repulsion forces
for (int i = 0; i < masses.size(); ++i) { for (int i = 0; i < masses.size(); ++i) {
Mass *mass = mass_pointers[mass_indices[i]]; Mass *mass = mass_pointers[mass_indices[i]];
@ -241,6 +308,7 @@ auto MassSpringSystem::CalculateRepulsionForces() -> void {
mass->force = Vector3Add(mass->force, force); mass->force = Vector3Add(mass->force, force);
} }
#endif
} }
auto MassSpringSystem::VerletUpdate(float delta_time) -> void { auto MassSpringSystem::VerletUpdate(float delta_time) -> void {
@ -249,6 +317,7 @@ auto MassSpringSystem::VerletUpdate(float delta_time) -> void {
} }
} }
#ifndef BARNES_HUT
auto MassSpringSystem::InvalidateGrid() -> void { auto MassSpringSystem::InvalidateGrid() -> void {
mass_pointers.clear(); mass_pointers.clear();
mass_indices.clear(); mass_indices.clear();
@ -257,3 +326,4 @@ auto MassSpringSystem::InvalidateGrid() -> void {
last_masses_count = 0; last_masses_count = 0;
last_springs_count = 0; last_springs_count = 0;
} }
#endif