Skip to content

Commit

Permalink
Add option to multi-thread crop_map and decay_map operations
Browse files Browse the repository at this point in the history
  • Loading branch information
victorreijgwart committed Nov 29, 2024
1 parent 6fb3d81 commit f46ea6c
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

#include <memory>
#include <string>
#include <utility>

#include <wavemap/core/config/config_base.h>
#include <wavemap/core/map/map_base.h>
#include <wavemap/core/utils/thread_pool.h>
#include <wavemap/core/utils/time/stopwatch.h>
#include <wavemap/pipeline/map_operations/map_operation_base.h>

Expand Down Expand Up @@ -48,6 +48,7 @@ class CropMapOperation : public MapOperationBase {
public:
CropMapOperation(const CropMapOperationConfig& config,
MapBase::Ptr occupancy_map,
std::shared_ptr<ThreadPool> thread_pool,
std::shared_ptr<TfTransformer> transformer,
std::string world_frame);

Expand All @@ -57,6 +58,7 @@ class CropMapOperation : public MapOperationBase {

private:
const CropMapOperationConfig config_;
const std::shared_ptr<ThreadPool> thread_pool_;
const std::shared_ptr<TfTransformer> transformer_;
const std::string world_frame_;
ros::Time last_run_timestamp_;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#ifndef WAVEMAP_ROS_MAP_OPERATIONS_DECAY_MAP_OPERATION_H_
#define WAVEMAP_ROS_MAP_OPERATIONS_DECAY_MAP_OPERATION_H_

#include <utility>
#include <memory>

#include <ros/ros.h>
#include <wavemap/core/config/config_base.h>
#include <wavemap/core/map/map_base.h>
#include <wavemap/core/utils/thread_pool.h>
#include <wavemap/core/utils/time/stopwatch.h>
#include <wavemap/pipeline/map_operations/map_operation_base.h>

Expand All @@ -29,16 +30,16 @@ struct DecayMapOperationConfig : public ConfigBase<DecayMapOperationConfig, 2> {
class DecayMapOperation : public MapOperationBase {
public:
DecayMapOperation(const DecayMapOperationConfig& config,
MapBase::Ptr occupancy_map)
: MapOperationBase(std::move(occupancy_map)),
config_(config.checkValid()) {}
MapBase::Ptr occupancy_map,
std::shared_ptr<ThreadPool> thread_pool);

bool shouldRun(const ros::Time& current_time = ros::Time::now());

void run(bool force_run) override;

private:
const DecayMapOperationConfig config_;
const std::shared_ptr<ThreadPool> thread_pool_;
ros::Time last_run_timestamp_;
Stopwatch timer_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ bool CropMapOperationConfig::isValid(bool verbose) const {

CropMapOperation::CropMapOperation(const CropMapOperationConfig& config,
MapBase::Ptr occupancy_map,
std::shared_ptr<ThreadPool> thread_pool,
std::shared_ptr<TfTransformer> transformer,
std::string world_frame)
: MapOperationBase(std::move(occupancy_map)),
config_(config.checkValid()),
thread_pool_(std::move(thread_pool)),
transformer_(std::move(transformer)),
world_frame_(std::move(world_frame)) {}

Expand Down Expand Up @@ -85,13 +87,14 @@ void CropMapOperation::run(bool force_run) {
dynamic_cast<HashedWaveletOctree*>(occupancy_map_.get());
hashed_wavelet_octree) {
crop_to_sphere(T_W_B->getPosition(), config_.radius, *hashed_wavelet_octree,
termination_height_);
termination_height_, thread_pool_);
} else if (auto* hashed_chunked_wavelet_octree =
dynamic_cast<HashedChunkedWaveletOctree*>(
occupancy_map_.get());
hashed_chunked_wavelet_octree) {
crop_to_sphere(T_W_B->getPosition(), config_.radius,
*hashed_chunked_wavelet_octree, termination_height_);
*hashed_chunked_wavelet_octree, termination_height_,
thread_pool_);
} else {
ROS_WARN(
"Map cropping is only supported for hash-based map data structures.");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "wavemap_ros/map_operations/decay_map_operation.h"

#include <memory>
#include <utility>

#include <wavemap/core/map/hashed_chunked_wavelet_octree.h>
#include <wavemap/core/map/hashed_wavelet_octree.h>

Expand All @@ -20,6 +23,13 @@ bool DecayMapOperationConfig::isValid(bool verbose) const {
return all_valid;
}

DecayMapOperation::DecayMapOperation(const DecayMapOperationConfig& config,
MapBase::Ptr occupancy_map,
std::shared_ptr<ThreadPool> thread_pool)
: MapOperationBase(std::move(occupancy_map)),
config_(config.checkValid()),
thread_pool_(std::move(thread_pool)) {}

bool DecayMapOperation::shouldRun(const ros::Time& current_time) {
return config_.once_every < (current_time - last_run_timestamp_).toSec();
}
Expand All @@ -41,12 +51,12 @@ void DecayMapOperation::run(bool force_run) {
if (auto* hashed_wavelet_octree =
dynamic_cast<HashedWaveletOctree*>(occupancy_map_.get());
hashed_wavelet_octree) {
multiply(*hashed_wavelet_octree, config_.decay_rate);
multiply(*hashed_wavelet_octree, config_.decay_rate, thread_pool_);
} else if (auto* hashed_chunked_wavelet_octree =
dynamic_cast<HashedChunkedWaveletOctree*>(
occupancy_map_.get());
hashed_chunked_wavelet_octree) {
multiply(*hashed_chunked_wavelet_octree, config_.decay_rate);
multiply(*hashed_chunked_wavelet_octree, config_.decay_rate, thread_pool_);
} else {
ROS_WARN("Map decay is only supported for hash-based map data structures.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ std::unique_ptr<MapOperationBase> MapRosOperationFactory::create(
case MapRosOperationType::kCropMap:
if (const auto config = CropMapOperationConfig::from(params); config) {
return std::make_unique<CropMapOperation>(
config.value(), std::move(occupancy_map), std::move(transformer),
std::move(world_frame));
config.value(), std::move(occupancy_map), std::move(thread_pool),
std::move(transformer), std::move(world_frame));
} else {
ROS_ERROR("Crop map operation config could not be loaded.");
return nullptr;
}
case MapRosOperationType::kDecayMap:
if (const auto config = DecayMapOperationConfig::from(params); config) {
return std::make_unique<DecayMapOperation>(config.value(),
std::move(occupancy_map));
return std::make_unique<DecayMapOperation>(
config.value(), std::move(occupancy_map), std::move(thread_pool));
} else {
ROS_ERROR("Decay map operation config could not be loaded.");
return nullptr;
Expand Down
36 changes: 31 additions & 5 deletions library/cpp/include/wavemap/core/utils/edit/crop.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#ifndef WAVEMAP_CORE_UTILS_EDIT_CROP_H_
#define WAVEMAP_CORE_UTILS_EDIT_CROP_H_

#include <memory>

#include "wavemap/core/common.h"
#include "wavemap/core/utils/thread_pool.h"

namespace wavemap {
template <typename MapType>
Expand Down Expand Up @@ -88,10 +91,13 @@ void cropNodeRecursive(typename MapType::Block::OctreeType::NodeRefType node,

template <typename MapType>
void crop_to_sphere(const Point3D& t_W_center, FloatingPoint radius,
MapType& map, IndexElement termination_height) {
MapType& map, IndexElement termination_height,
const std::shared_ptr<ThreadPool>& thread_pool = nullptr) {
using NodePtrType = typename MapType::Block::OctreeType::NodePtrType;
const IndexElement tree_height = map.getTreeHeight();
const FloatingPoint min_cell_width = map.getMinCellWidth();

// Check all blocks
for (auto it = map.getHashMap().begin(); it != map.getHashMap().end();) {
// Start by testing at the block level
const Index3D& block_index = it->first;
Expand All @@ -112,13 +118,33 @@ void crop_to_sphere(const Point3D& t_W_center, FloatingPoint radius,
// Since the block overlaps with the sphere's boundary, we need to process
// it at a higher resolution by recursing over its cells
auto& block = it->second;
cropNodeRecursive<MapType>(block.getRootNode(), block_node_index,
block.getRootScale(), t_W_center, radius,
min_cell_width, termination_height);
// Indicate that the block has changed
block.setLastUpdatedStamp();

// Get pointers to the root value and node, which contain the wavelet
// scale and detail coefficients, respectively
FloatingPoint* root_value_ptr = &block.getRootScale();
NodePtrType root_node_ptr = &block.getRootNode();
// Recursively crop all nodes
if (thread_pool) {
thread_pool->add_task([root_node_ptr, root_value_ptr, block_node_index,
t_W_center, radius, min_cell_width,
termination_height]() {
cropNodeRecursive<MapType>(*root_node_ptr, block_node_index,
*root_value_ptr, t_W_center, radius,
min_cell_width, termination_height);
});
} else {
cropNodeRecursive<MapType>(*root_node_ptr, block_node_index,
*root_value_ptr, t_W_center, radius,
min_cell_width, termination_height);
}
// Advance to the next block
++it;
}
// Wait for all parallel jobs to finish
if (thread_pool) {
thread_pool->wait_all();
}
}
} // namespace wavemap

Expand Down
32 changes: 27 additions & 5 deletions library/cpp/include/wavemap/core/utils/edit/multiply.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#ifndef WAVEMAP_CORE_UTILS_EDIT_MULTIPLY_H_
#define WAVEMAP_CORE_UTILS_EDIT_MULTIPLY_H_

#include <memory>

#include "wavemap/core/common.h"
#include "wavemap/core/utils/thread_pool.h"

namespace wavemap {

Expand All @@ -21,12 +24,31 @@ void multiplyNodeRecursive(
}

template <typename MapType>
void multiply(MapType& map, FloatingPoint multiplier) {
map.forEachBlock([multiplier](const Index3D& /*block_index*/, auto& block) {
block.getRootScale() *= multiplier;
multiplyNodeRecursive<MapType>(block.getRootNode(), multiplier);
void multiply(MapType& map, FloatingPoint multiplier,
const std::shared_ptr<ThreadPool>& thread_pool = nullptr) {
using NodePtrType = typename MapType::Block::OctreeType::NodePtrType;

// Process all blocks
for (auto& [block_index, block] : map.getHashMap()) {
// Indicate that the block has changed
block.setLastUpdatedStamp();
});
// Multiply the block's average value (wavelet scale coefficient)
FloatingPoint& root_value = block.getRootScale();
root_value *= multiplier;
// Recursively multiply all node values (wavelet detail coefficients)
NodePtrType root_node_ptr = &block.getRootNode();
if (thread_pool) {
thread_pool->add_task([root_node_ptr, multiplier]() {
multiplyNodeRecursive<MapType>(*root_node_ptr, multiplier);
});
} else {
multiplyNodeRecursive<MapType>(*root_node_ptr, multiplier);
}
}
// Wait for all parallel jobs to finish
if (thread_pool) {
thread_pool->wait_all();
}
}
} // namespace wavemap

Expand Down

0 comments on commit f46ea6c

Please sign in to comment.