implement barnes-hut particle repulsion using octree

This commit is contained in:
2026-02-22 23:29:56 +01:00
parent f07f2772c4
commit e43e505110
5 changed files with 302 additions and 8 deletions

View File

@ -2,13 +2,18 @@
#include "config.hpp"
#include <algorithm>
#include <numeric>
#include <cfloat>
#include <raylib.h>
#include <raymath.h>
#include <tracy/Tracy.hpp>
#include <unordered_map>
#include <utility>
#include <vector>
#ifndef BARNES_HUT
#include <numeric>
#endif
auto Mass::ClearForce() -> void { force = Vector3Zero(); }
auto Mass::CalculateVelocity(const float delta_time) -> void {
@ -121,7 +126,9 @@ auto MassSpringSystem::AddSpring(const State &massA, const State &massB,
auto MassSpringSystem::Clear() -> void {
masses.clear();
springs.clear();
#ifndef BARNES_HUT
InvalidateGrid();
#endif
}
auto MassSpringSystem::ClearForces() -> void {
@ -136,8 +143,52 @@ auto MassSpringSystem::CalculateSpringForces() -> void {
}
}
#ifdef BARNES_HUT
auto MassSpringSystem::BuildOctree() -> void {
octree.nodes.clear();
octree.nodes.reserve(masses.size() * 2);
// Compute bounding box around all masses
Vector3 min = Vector3(FLT_MAX, FLT_MAX, FLT_MAX);
Vector3 max = Vector3(-FLT_MAX, -FLT_MAX, -FLT_MAX);
for (const auto &[state, mass] : masses) {
min.x = std::min(min.x, mass.position.x);
max.x = std::max(max.x, mass.position.x);
min.y = std::min(min.y, mass.position.y);
max.y = std::max(max.y, mass.position.y);
min.z = std::min(min.z, mass.position.z);
max.z = std::max(max.z, mass.position.z);
}
// Pad the bounding box
float pad = 1.0;
min = Vector3Subtract(min, Vector3Scale(Vector3One(), pad));
max = Vector3Add(max, Vector3Scale(Vector3One(), pad));
// Make it cubic (so subdivisions are balanced)
float max_extent = std::max({max.x - min.x, max.y - min.y, max.z - min.z});
max = Vector3Add(min, Vector3Scale(Vector3One(), max_extent));
// Root node spans the entire area
int root = octree.CreateNode(min, max);
// Use a vector of pointers to the masses, because we can't parallelize the
// range-based for loop over the masses unordered_map using OpenMP.
mass_pointers.clear();
mass_pointers.reserve(masses.size());
for (auto &[state, mass] : masses) {
mass_pointers.push_back(&mass);
}
for (int i = 0; i < mass_pointers.size(); ++i) {
octree.Insert(root, i, mass_pointers[i]->position, mass_pointers[i]->mass);
}
}
#else
auto MassSpringSystem::BuildUniformGrid() -> void {
// Collect pointers to all masses
// Use a vector of pointers to masses, because we can't parallelize the
// range-based for loop over the masses unordered_map using OpenMP.
mass_pointers.clear();
mass_pointers.reserve(masses.size());
for (auto &[state, mass] : masses) {
@ -173,8 +224,24 @@ auto MassSpringSystem::BuildUniformGrid() -> void {
cell_ids[i] = cell_id(mass_pointers[mass_indices[i]]->position);
}
}
#endif
auto MassSpringSystem::CalculateRepulsionForces() -> void {
ZoneScoped;
#ifdef BARNES_HUT
BuildOctree();
// Calculate forces using Barnes-Hut
#pragma omp parallel for schedule(dynamic, 256)
for (int i = 0; i < mass_pointers.size(); ++i) {
int root = 0;
Vector3 force = octree.CalculateForce(root, mass_pointers[i]->position);
mass_pointers[i]->force = Vector3Add(mass_pointers[i]->force, force);
}
#else
// Refresh grid if necessary
if (last_build >= REPULSION_GRID_REFRESH ||
masses.size() != last_masses_count ||
@ -186,8 +253,8 @@ auto MassSpringSystem::CalculateRepulsionForces() -> void {
}
last_build++;
// TODO: Use Barnes-Hut + Octree
#pragma omp parallel for
// Calculate forces using uniform grid
#pragma omp parallel for schedule(dynamic, 256)
// Search the neighboring cells for each mass to calculate repulsion forces
for (int i = 0; i < masses.size(); ++i) {
Mass *mass = mass_pointers[mass_indices[i]];
@ -241,6 +308,7 @@ auto MassSpringSystem::CalculateRepulsionForces() -> void {
mass->force = Vector3Add(mass->force, force);
}
#endif
}
auto MassSpringSystem::VerletUpdate(float delta_time) -> void {
@ -249,6 +317,7 @@ auto MassSpringSystem::VerletUpdate(float delta_time) -> void {
}
}
#ifndef BARNES_HUT
auto MassSpringSystem::InvalidateGrid() -> void {
mass_pointers.clear();
mass_indices.clear();
@ -257,3 +326,4 @@ auto MassSpringSystem::InvalidateGrid() -> void {
last_masses_count = 0;
last_springs_count = 0;
}
#endif