2 Commits
ros2 ... super

22 changed files with 7289 additions and 257 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
Log/
build/
install/
log/

View File

@@ -1,11 +1,7 @@
cmake_minimum_required(VERSION 3.5)
project(point_lio)
SET(CMAKE_BUILD_TYPE "Debug")
ADD_COMPILE_OPTIONS(-std=c++17 )
ADD_COMPILE_OPTIONS(-std=c++17 )
set( CMAKE_CXX_FLAGS "-std=c++17 -O3" )
SET(CMAKE_BUILD_TYPE "Release")
add_definitions(-DROOT_DIR=\"${CMAKE_CURRENT_SOURCE_DIR}/\")
@@ -13,13 +9,13 @@ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fexceptions" )
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -pthread -std=c++0x -std=c++17 -fexceptions")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -march=native -pthread -fexceptions")
#add_compile_definitions(BOOST_BIND_GLOBAL_PLACEHOLDERS)
find_package(OpenMP QUIET)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
# find_package(OpenMP QUIET)
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
find_package(PythonLibs REQUIRED)
find_path(MATPLOTLIB_CPP_INCLUDE_DIRS "matplotlibcpp.h")
@@ -41,6 +37,7 @@ find_package(visualization_msgs REQUIRED)
find_package(Eigen3 REQUIRED)
find_package(PCL REQUIRED)
find_package(TBB REQUIRED)
message(Eigen: ${EIGEN3_INCLUDE_DIR})
@@ -71,7 +68,7 @@ ament_target_dependencies(pointlio_mapping
livox_ros_driver2
visualization_msgs
)
target_link_libraries(pointlio_mapping ${PYTHON_LIBRARIES})
target_link_libraries(pointlio_mapping ${PYTHON_LIBRARIES} TBB::tbb)
target_include_directories(pointlio_mapping PRIVATE ${PYTHON_INCLUDE_DIRS})
# Install the executable

View File

View File

View File

View File

@@ -4,12 +4,12 @@
prop_at_freq_of_imu: True
check_satu: True
init_map_size: 1000
point_filter_num: 3 # Options: 1, 3
point_filter_num: 1 # Options: 1, 3
space_down_sample: True
filter_size_surf: 0.5 # Options: 0.5, 0.3, 0.2, 0.15, 0.1
filter_size_map: 0.5 # Options: 0.5, 0.3, 0.15, 0.1
cube_side_length: 1000.0 # Option: 1000
runtime_pos_log_enable: false # Option: True
runtime_pos_log_enable: true # Option: True
common:
lid_topic: "/livox/lidar"
@@ -45,7 +45,7 @@
acc_cov_input: 0.1 # for IMU as input model
plane_thr: 0.1 # 0.05, the threshold for plane criteria, the smaller, the flatter a plane
match_s: 81.0
ivox_grid_resolution: 2.0
ivox_grid_resolution: 0.5
gravity: [0.0, 0.0, -9.810] # [0.0, 9.810, 0.0] # # [0.0, 0.0, -9.787561] # gvins #
gravity_init: [0.0, 0.0, -9.810] # preknown gravity in the initial IMU frame for unstationary start or in the initial LiDAR frame for using without IMU
extrinsic_T: [ 0.04165, 0.02326, -0.0284 ] # avia # [0.011, 0.02329, -0.04412] # mid360
@@ -54,7 +54,7 @@
0.0, 0.0, 1.0 ]
odometry:
publish_odometry_without_downsample: false
publish_odometry_without_downsample: true
publish:
path_en: true # false: close the path output

View File

@@ -0,0 +1,193 @@
#pragma once
#include <array>
#include <vector>
#include <Eigen/Core>
/// The last number is useless
static const std::array<float, 6> orders_min_dis = {
0.250000, 0.353553, 0.500000, 0.559017, 0.612372, 10
};
static const std::array<float, 6> orders_min_dis2 = {
0.062500, 0.125000, 0.250000, 0.312500, 0.375000, 100
};
alignas(16) static const std::array<Eigen::Vector3i, 60> HKNN_neighbor_voxel =
{
Eigen::Vector3i(0, 0, 0), Eigen::Vector3i(0, -1, 0), Eigen::Vector3i(0, 0, -1), Eigen::Vector3i(-1, 0, 0), Eigen::Vector3i(-1, 0, -1), Eigen::Vector3i(0, -1, -1),
Eigen::Vector3i(-1, -1, 0), Eigen::Vector3i(-1, -1, -1), Eigen::Vector3i(1, 0, 0), Eigen::Vector3i(0, 1, 0), Eigen::Vector3i(0, 0, 1), Eigen::Vector3i(0, -1, 1),
Eigen::Vector3i(1, -1, 0), Eigen::Vector3i(1, 0, -1), Eigen::Vector3i(0, 1, -1), Eigen::Vector3i(-1, 1, 0), Eigen::Vector3i(-1, 0, 1), Eigen::Vector3i(-1, -1, 1),
Eigen::Vector3i(-1, 1, -1), Eigen::Vector3i(1, -1, -1), Eigen::Vector3i(0, 0, -2), Eigen::Vector3i(0, 1, 1), Eigen::Vector3i(1, 1, 0), Eigen::Vector3i(1, 0, 1),
Eigen::Vector3i(-2, 0, 0), Eigen::Vector3i(0, -2, 0), Eigen::Vector3i(-2, 0, -1), Eigen::Vector3i(-1, 0, -2), Eigen::Vector3i(-1, -2, 0), Eigen::Vector3i(-2, -1, 0),
Eigen::Vector3i(0, -2, -1), Eigen::Vector3i(1, 1, -1), Eigen::Vector3i(-1, 1, 1), Eigen::Vector3i(1, -1, 1), Eigen::Vector3i(0, -1, -2), Eigen::Vector3i(-2, -1, -1),
Eigen::Vector3i(-1, -1, -2), Eigen::Vector3i(-1, -2, -1), Eigen::Vector3i(-2, 1, 0), Eigen::Vector3i(1, -2, 0), Eigen::Vector3i(0, -2, 1), Eigen::Vector3i(0, 1, -2),
Eigen::Vector3i(1, 0, -2), Eigen::Vector3i(1, 1, 1), Eigen::Vector3i(-2, 0, 1), Eigen::Vector3i(1, -1, -2), Eigen::Vector3i(1, -2, -1), Eigen::Vector3i(-2, 1, -1),
Eigen::Vector3i(-2, -1, 1), Eigen::Vector3i(-1, 1, -2), Eigen::Vector3i(-1, -2, 1), Eigen::Vector3i(0, -2, -2), Eigen::Vector3i(-2, 0, -2), Eigen::Vector3i(-2, 1, 1),
Eigen::Vector3i(1, 1, -2), Eigen::Vector3i(1, -2, 1), Eigen::Vector3i(-2, -2, 0), Eigen::Vector3i(-2, -1, -2), Eigen::Vector3i(-2, -2, -1), Eigen::Vector3i(-1, -2, -2),
};
static constexpr std::array<uint16_t, 7> flat_search_order_offsets = {{
0, 43, 135, 219, 321, 465, 593,
}};
static constexpr std::array<uint8_t, 593> flat_search_order = {{
// Group0
0, 8, 0, 1, 2, 3, 4, 5, 6, 7,
1, 4, 2, 3, 6, 7,
2, 4, 4, 5, 6, 7,
3, 4, 1, 3, 5, 7,
4, 2, 5, 7,
5, 2, 6, 7,
6, 2, 3, 7,
7, 1, 7,
// Group1
1, 4, 0, 1, 4, 5,
2, 4, 0, 1, 2, 3,
3, 4, 0, 2, 4, 6,
4, 4, 1, 3, 4, 6,
5, 4, 2, 3, 4, 5,
6, 4, 1, 2, 5, 6,
7, 3, 3, 5, 6,
8, 4, 0, 2, 4, 6,
9, 4, 0, 1, 4, 5,
10, 4, 0, 1, 2, 3,
11, 2, 2, 3,
12, 2, 2, 6,
13, 2, 4, 6,
14, 2, 4, 5,
15, 2, 1, 5,
16, 2, 1, 3,
17, 1, 3,
18, 1, 5,
19, 1, 6,
// Group2
4, 2, 0, 2,
5, 2, 0, 1,
6, 2, 0, 4,
7, 4, 0, 1, 2, 4,
11, 2, 0, 1,
12, 2, 0, 4,
13, 2, 0, 2,
14, 2, 0, 1,
15, 2, 0, 4,
16, 2, 0, 2,
17, 3, 0, 1, 2,
18, 3, 0, 1, 4,
19, 3, 0, 2, 4,
21, 2, 0, 1,
22, 2, 0, 4,
23, 2, 0, 2,
31, 2, 0, 4,
32, 2, 0, 1,
33, 2, 0, 2,
43, 1, 0,
// Group3
8, 4, 1, 3, 5, 7,
9, 4, 2, 3, 6, 7,
10, 4, 4, 5, 6, 7,
11, 2, 6, 7,
12, 2, 3, 7,
13, 2, 5, 7,
14, 2, 6, 7,
15, 2, 3, 7,
16, 2, 5, 7,
17, 1, 7,
18, 1, 7,
19, 1, 7,
20, 4, 4, 5, 6, 7,
24, 4, 1, 3, 5, 7,
25, 4, 2, 3, 6, 7,
26, 2, 5, 7,
27, 2, 5, 7,
28, 2, 3, 7,
29, 2, 3, 7,
30, 2, 6, 7,
34, 2, 6, 7,
35, 1, 7,
36, 1, 7,
37, 1, 7,
// Group4
11, 2, 4, 5,
12, 2, 1, 5,
13, 2, 1, 3,
14, 2, 2, 3,
15, 2, 2, 6,
16, 2, 4, 6,
17, 2, 5, 6,
18, 2, 3, 6,
19, 2, 3, 5,
21, 4, 2, 3, 4, 5,
22, 4, 1, 2, 5, 6,
23, 4, 1, 3, 4, 6,
26, 2, 1, 3,
27, 2, 4, 6,
28, 2, 2, 6,
29, 2, 1, 5,
30, 2, 2, 3,
31, 2, 5, 6,
32, 2, 3, 5,
33, 2, 3, 6,
34, 2, 4, 5,
35, 2, 3, 5,
36, 2, 5, 6,
37, 2, 3, 6,
38, 2, 1, 5,
39, 2, 2, 6,
40, 2, 2, 3,
41, 2, 4, 5,
42, 2, 4, 6,
44, 2, 1, 3,
45, 1, 6,
46, 1, 6,
47, 1, 5,
48, 1, 3,
49, 1, 5,
50, 1, 3,
// Group5
17, 1, 4,
18, 1, 2,
19, 1, 1,
21, 2, 6, 7,
22, 2, 3, 7,
23, 2, 5, 7,
31, 3, 1, 2, 7,
32, 3, 2, 4, 7,
33, 3, 1, 4, 7,
35, 1, 1,
36, 1, 4,
37, 1, 2,
38, 2, 3, 7,
39, 2, 3, 7,
40, 2, 6, 7,
41, 2, 6, 7,
42, 2, 5, 7,
43, 3, 1, 2, 4,
44, 2, 5, 7,
45, 2, 4, 7,
46, 2, 2, 7,
47, 2, 1, 7,
48, 2, 1, 7,
49, 2, 4, 7,
50, 2, 2, 7,
51, 2, 6, 7,
52, 2, 5, 7,
53, 1, 1,
54, 1, 4,
55, 1, 2,
56, 2, 3, 7,
57, 1, 7,
58, 1, 7,
59, 1, 7,
}};

View File

@@ -0,0 +1,569 @@
#ifndef OctVoxMap_HPP_
#define OctVoxMap_HPP_
#include <set>
#include <list>
#include <queue>
#include <vector>
#include <memory>
#include <cstring>
#include <iostream>
#include <execution>
#include <filesystem>
#include <unordered_map>
#include <unordered_set>
#include <Eigen/Core>
#include "tsl/robin_map.h"
#include "HKNN_list60_gem.h"
namespace LI2Sup{
template<int K, typename Point>
class KNNHeap {
public:
KNNHeap() : count(0), worst_(0), max_dist2_(0.0f) {
memset(dist2_, 0, sizeof(dist2_));
}
void reset() {
count = 0;
worst_ = 0;
max_dist2_ = 0.0f;
memset(dist2_, 0, sizeof(dist2_));
}
uint8_t count;
uint8_t worst_;
float max_dist2_;
float dist2_[K];
std::array<Point, K> points_;
inline void try_insert(float dist2, const Point& pt) {
const bool not_full = (count < K);
const bool should_insert = not_full || (dist2 < max_dist2_);
if (should_insert) {
const uint8_t insert_idx = not_full ? count : worst_;
dist2_[insert_idx] = dist2;
points_[insert_idx] = pt;
if (not_full) {
count++;
if (dist2 > max_dist2_) {
max_dist2_ = dist2;
worst_ = insert_idx;
}
} else {
update_worst_unrolled();
}
}
}
private:
inline void update_worst_unrolled() {
float d0 = dist2_[0], d1 = dist2_[1], d2 = dist2_[2], d3 = dist2_[3], d4 = dist2_[4];
uint8_t idx01 = d0 > d1 ? 0 : 1;
float max01 = d0 > d1 ? d0 : d1;
uint8_t idx23 = d2 > d3 ? 2 : 3;
float max23 = d2 > d3 ? d2 : d3;
uint8_t idx0123 = max01 > max23 ? idx01 : idx23;
float max0123 = max01 > max23 ? max01 : max23;
worst_ = max0123 > d4 ? idx0123 : 4;
max_dist2_ = max0123 > d4 ? max0123 : d4;
}
public:
inline float max_dist2() const { return max_dist2_; }
};
template<typename Point>
class OctVox{
public:
OctVox(const Point& pt, uint8_t local_idx)
{
counts_.fill(UNINIT_MASK);
points_[local_idx] = pt;
counts_[local_idx] = 1;
}
~OctVox() {}
void AddPoint(const Point& pt, uint8_t local_idx) {
uint8_t& count = counts_[local_idx];
Point& stored_point = points_[local_idx];
if(count == UNINIT_MASK) {
stored_point = pt;
count = 1;
return;
}
if(count >= MAX_POINTS_PER_SUBVOXEL) return;
if ((pt - stored_point).squaredNorm() > DISTANCE_THRESHOLD_SQ) return;
stored_point = (stored_point * count + pt) / (count + 1);
++count;
}
bool getPoint(const uint8_t local_idx, Point& pt) const {
if (counts_[local_idx] == UNINIT_MASK) return false;
pt = points_[local_idx];
return true;
}
const Point* getPointPtr(const uint8_t local_idx) const {
if (counts_[local_idx] == UNINIT_MASK) return nullptr;
return &points_[local_idx];
}
static constexpr uint8_t UNINIT_MASK = 0x00;
static constexpr uint8_t MAX_POINTS_PER_SUBVOXEL = 20;
static constexpr double DISTANCE_THRESHOLD_SQ = 0.1 * 0.1;
std::array<uint8_t, 8> counts_;
std::array<Point, 8> points_;
};
template<typename Point, typename Scalar>
class OctVoxMap {
public:
using Ptr = std::shared_ptr<OctVoxMap>;
using KEY = Eigen::Vector3i;
using Points = std::vector<Point, Eigen::aligned_allocator<Point>>;
using KNNHeapType = KNNHeap<5, Point>;
using OctVoxType = OctVox<Point>;
struct Options {
float resolution = 0.5;
std::size_t capacity = 1000000;
Options(float __resolution, std::size_t __capacity) {
resolution = __resolution;
capacity = __capacity;
}
};
OctVoxMap() {
flat_search_ptrs_.reserve(flat_search_order_offsets.size());
for(std::size_t i = 0; i < flat_search_order_offsets.size(); i++){
uint16_t start = flat_search_order_offsets[i];
flat_search_ptrs_.push_back(const_cast<uint8_t*>(flat_search_order.data() + start));
}
group_idx_max_ = flat_search_order_offsets.size() - 1;
}
~OctVoxMap() {
grids_.clear();
data_.clear();
}
OctVoxMap(Options options){
SetOptions(options);
std::cout << " ---> OctVoxMap init. Resolution: " << resolution_
<< " Capacity: " << capacity_ << std::endl;
flat_search_ptrs_.reserve(flat_search_order_offsets.size());
for(std::size_t i = 0; i < flat_search_order_offsets.size(); i++){
uint16_t start = flat_search_order_offsets[i];
flat_search_ptrs_.push_back(const_cast<uint8_t*>(flat_search_order.data() + start));
}
group_idx_max_ = flat_search_order_offsets.size() - 1;
}
void SetOptions(const Options& options)
{
resolution_ = options.resolution;
capacity_ = options.capacity;
inv_resolution_ = 1.0 / resolution_;
sub_resolution_ = resolution_ / 2.0;
sub_inv_resolution_ = 1.0 / sub_resolution_;
}
void insert(const Points& cloud_world);
void printInfo() const;
void getMap(std::vector<float>&) const;
void saveMap() const; // TODO:
void resetMap(const std::vector<float>&);
void clear();
void getTopK(const Point& point, KNNHeapType& top_K) const;
struct SearchCache {
KEY last_fine_key = KEY::Zero();
bool valid = false;
std::array<OctVoxType*, 60> voxel_ptrs;
uint64_t lookup_mask = 0;
void reset() {
valid = false;
lookup_mask = 0;
}
};
void getTopK(const Point& point, KNNHeapType& top_K, SearchCache* cache) const;
void getTopK_VN(const Point& point, KNNHeapType& top_K) const;
void reset_max_group(){
group_idx_max_ = flat_search_order_offsets.size() - 1;
}
void decrease_max_group(){
if(group_idx_max_ > 4) group_idx_max_--;
}
// size_t getMemoryUsageBytes() const {
// size_t bytes = 0;
// bytes += sizeof(*this);
// bytes += data_.size() * (sizeof(KEY) + sizeof(OctVoxType)
// + sizeof(void*) * 2); // list node pointers
// bytes += grids_.size() * (sizeof(KEY) + sizeof(DATA_ITER)
// + sizeof(size_t)); // hash & pair overhead
// bytes += grids_.bucket_count() * sizeof(void*); // bucket array
// bytes += flat_search_ptrs_.capacity() * sizeof(uint8_t*);
// return bytes;
// }
private:
float resolution_ = 0.5;
float inv_resolution_ = 1.0;
float sub_resolution_ = 0.25;
float sub_inv_resolution_ = 4.0;
std::size_t capacity_ = 1000000;
bool reset_map_ = false;
int reset_map_count_ = 0;
const KEY nearby_grids_[19] = {
KEY(0, 0, 0),
KEY(-1, -1, 0), KEY(-1, 0, 0), KEY(-1, 1, 0),
KEY(0, -1, 0), KEY(0, 1, 0),
KEY(1, -1, 0), KEY(1, 0, 0), KEY(1, 1, 0),
KEY(0, 0, -1), KEY(1, 0, -1), KEY(-1, 0, -1),
KEY(0, 1, -1), KEY(0, -1, -1),
KEY(0, 0, 1), KEY(1, 0, 1), KEY(-1, 0, 1),
KEY(0, 1, 1), KEY(0, -1, 1)
};
/// HashShiftMix
struct HASH_VEC {
std::size_t operator()(const KEY &v) const {
size_t h = static_cast<size_t>(v[0]);
h ^= v[1] * 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= v[2] * 0x85ebca6b + (h << 6) + (h >> 2);
return h;
}
};
using DATA_LIST = std::list<std::pair<KEY, OctVoxType>>;
using DATA_ITER = typename DATA_LIST::iterator;
DATA_LIST data_;
tsl::robin_map<KEY, DATA_ITER, HASH_VEC> grids_;
std::vector<uint8_t*> flat_search_ptrs_;
int group_idx_max_;
};
template<typename Point, typename Scalar>
void OctVoxMap<Point, Scalar>::insert(const Points& cloud_world){
if(reset_map_){
reset_map_count_--;
if(reset_map_count_ > 0){
std::cout << "OctVoxMap::insert skip: reset_map_count_ = " << reset_map_count_ << std::endl;
return;
}
reset_map_ = false;
}
for(auto& pt : cloud_world){
KEY fine_key = (pt * sub_inv_resolution_).array().floor().template cast<int>();
KEY key;
key[0] = fine_key[0] >> 1;
key[1] = fine_key[1] >> 1;
key[2] = fine_key[2] >> 1;
uint8_t dx = fine_key[0] & 1;
uint8_t dy = fine_key[1] & 1;
uint8_t dz = fine_key[2] & 1;
uint8_t local_idx = (dz << 2) | (dy << 1) | dx;
auto iter = grids_.find(key);
if (iter == grids_.end()) {
data_.emplace_front(std::piecewise_construct,
std::forward_as_tuple(key),
std::forward_as_tuple(pt, local_idx));
grids_.insert(std::make_pair(key, data_.begin()));
if (data_.size() >= capacity_) {
grids_.erase(data_.back().first);
data_.pop_back();
}
} else {
iter->second->second.AddPoint(pt, local_idx);
data_.splice(data_.begin(), data_, iter->second);
grids_.erase(iter);
grids_.insert(std::make_pair(key, data_.begin()));
}
}
}
template<typename Point, typename Scalar>
void OctVoxMap<Point, Scalar>::getTopK(const Point& point, KNNHeapType& top_K) const {
getTopK(point, top_K, nullptr);
}
template<typename Point, typename Scalar>
void OctVoxMap<Point, Scalar>::getTopK(const Point& point, KNNHeapType& top_K, SearchCache* cache) const {
const KEY fine_key = (point * sub_inv_resolution_).array().floor().template cast<int>();
bool use_cache = false;
if (cache) {
if (cache->valid && cache->last_fine_key == fine_key) {
use_cache = true;
} else {
cache->valid = true;
cache->last_fine_key = fine_key;
cache->lookup_mask = 0;
}
}
KEY key;
key[0] = fine_key[0] >> 1;
key[1] = fine_key[1] >> 1;
key[2] = fine_key[2] >> 1;
const int dx = fine_key[0] & 1;
const int dy = fine_key[1] & 1;
const int dz = fine_key[2] & 1;
const int local_idx = (dz << 2) | (dy << 1) | dx;
const KEY mirror_axis = KEY(1 - (dx << 1), 1 - (dy << 1), 1 - (dz << 1));
const int pre_voxel_ptr_size = 8;
OctVoxType* top_voxels_2_search[pre_voxel_ptr_size];
for(uint8_t i = 0; i < pre_voxel_ptr_size; ++i)
{
if (use_cache && (cache->lookup_mask & (1ULL << i))) {
top_voxels_2_search[i] = cache->voxel_ptrs[i];
} else {
KEY delta_key = mirror_axis.cwiseProduct(HKNN_neighbor_voxel[i]);
KEY n_key = key + delta_key;
if (auto iter = grids_.find(n_key); iter != grids_.end()) {
top_voxels_2_search[i] = &iter->second->second;
} else {
top_voxels_2_search[i] = nullptr;
}
if (cache) {
cache->voxel_ptrs[i] = top_voxels_2_search[i];
cache->lookup_mask |= (1ULL << i);
}
}
}
Point __sub_point;
for (int group_idx = 0; group_idx < group_idx_max_; ++group_idx) {
const uint8_t* group_it = flat_search_ptrs_[group_idx];
const uint8_t* group_end = flat_search_ptrs_[group_idx + 1];
while(group_it < group_end){
const uint8_t neighbor_idx = *group_it++;
uint8_t data_size = *group_it++;
OctVoxType* voxel_ptr = nullptr;
if(neighbor_idx < pre_voxel_ptr_size)
{
voxel_ptr = top_voxels_2_search[neighbor_idx];
}
else
{
if (use_cache && (cache->lookup_mask & (1ULL << neighbor_idx))) {
voxel_ptr = cache->voxel_ptrs[neighbor_idx];
} else {
KEY delta_key = mirror_axis.cwiseProduct(HKNN_neighbor_voxel[neighbor_idx]);
const KEY n_key = key + delta_key;
if (auto iter = grids_.find(n_key); iter != grids_.end()){
voxel_ptr = &iter->second->second;
}
if (cache) {
cache->voxel_ptrs[neighbor_idx] = voxel_ptr;
cache->lookup_mask |= (1ULL << neighbor_idx);
}
}
}
if (voxel_ptr) {
const auto& counts = voxel_ptr->counts_;
const auto& points = voxel_ptr->points_;
while (data_size--) {
uint8_t _local_idx = (*group_it++)^local_idx;
if (counts[_local_idx] != OctVoxType::UNINIT_MASK) {
const Point& pt = points[_local_idx];
const float dist2 = (pt - point).squaredNorm();
if (top_K.count < 5 || dist2 < top_K.max_dist2_) {
top_K.try_insert(dist2, pt);
}
}
}
}
else group_it+=data_size;
}
if (top_K.count == 5)
if (top_K.max_dist2_ < orders_min_dis2[group_idx]){
break;
}
}
}
template<typename Point, typename Scalar>
void OctVoxMap<Point, Scalar>::getTopK_VN(const Point& point, KNNHeapType& top_K) const{
KEY key = (point * inv_resolution_).array().floor().template cast<int>();
std::vector<OctVoxType*> voxels_2_search;
voxels_2_search.reserve(19);
for(std::size_t i = 0; i < 19; ++i) {
KEY n_key = key + nearby_grids_[i];
if (auto iter = grids_.find(n_key); iter != grids_.end()) {
voxels_2_search.emplace_back(&iter->second->second);
}
}
for(auto& voxel : voxels_2_search) {
for(uint8_t _i = 0; _i < 8; ++_i) {
if(const Point* pt_ptr = voxel->getPointPtr(_i)) {
float dist2 = (*pt_ptr - point).squaredNorm();
top_K.try_insert(dist2, *pt_ptr);
}
}
}
}
template<typename Point, typename Scalar>
void OctVoxMap<Point, Scalar>::getMap(std::vector<float>& output) const{
size_t total_points = 0;
output.clear();
output.reserve(total_points * 3);
Point point;
float pcl_point[3];
for (const auto& voxel_pair : data_) {
const OctVoxType& voxel = voxel_pair.second;
for(uint8_t i = 0; i < 8; ++i) {
if (!voxel.getPoint(i, point)) continue;
pcl_point[0] = static_cast<float>(point.x());
pcl_point[1] = static_cast<float>(point.y());
pcl_point[2] = static_cast<float>(point.z());
output.push_back(pcl_point[0]);
output.push_back(pcl_point[1]);
output.push_back(pcl_point[2]);
}
}
}
template<typename Point, typename Scalar>
void OctVoxMap<Point, Scalar>::resetMap(const std::vector<float>& input){
if (input.empty()) return;
clear();
size_t num_points = input.size() / 3;
Points cloud_world;
cloud_world.reserve(num_points);
for (size_t i = 0; i < num_points; ++i) {
Point point(input[i * 3], input[i * 3 + 1], input[i * 3 + 2]);
cloud_world.push_back(point);
}
insert(cloud_world);
reset_map_ = true;
reset_map_count_ = 10;
}
template<typename Point, typename Scalar>
void OctVoxMap<Point, Scalar>::saveMap() const {
const std::string g_root_dir = std::string(ROOT_DIR);
std::string filename = g_root_dir + "map/OctVoxMap.pcd";
if (std::filesystem::exists(filename)) {
std::filesystem::remove(filename);
std::cout << "Removed existing file: " << filename << std::endl;
}
pcl::PointCloud<pcl::PointXYZ>::Ptr cloud(new pcl::PointCloud<pcl::PointXYZ>);
size_t total_points = data_.size() * 8;
cloud->points.reserve(total_points);
for (const auto& voxel_pair : data_) {
const OctVoxType& voxel = voxel_pair.second;
for(uint8_t i = 0; i < 8; ++i) {
pcl::PointXYZ pcl_point;
Point point;
if (!voxel.getPoint(i, point)) continue;
pcl_point.x = static_cast<float>(point.x());
pcl_point.y = static_cast<float>(point.y());
pcl_point.z = static_cast<float>(point.z());
cloud->points.push_back(pcl_point);
}
}
cloud->width = cloud->points.size();
cloud->height = 1;
cloud->is_dense = true;
int result = pcl::io::savePCDFileBinary(filename, *cloud);
if (result == 0) {
std::cout << "Successfully saved " << cloud->points.size()
<< " points to " << filename << " (binary format)" << std::endl;
} else {
std::cerr << "Error saving point cloud to " << filename << std::endl;
throw std::runtime_error("Failed to save PCD file: " + filename);
}
}
template<typename Point, typename Scalar>
void OctVoxMap<Point, Scalar>::clear() {
grids_.clear();
data_.clear();
}
template<typename Point, typename Scalar>
void OctVoxMap<Point, Scalar>::printInfo() const {
std::cout << " ---> OctVoxMap info. Size: " << data_.size()
<< " Capacity: " << capacity_ << std::endl;
}
}
#endif

View File

@@ -0,0 +1,117 @@
#ifndef OCTVOXMAP_ADAPTER_HPP
#define OCTVOXMAP_ADAPTER_HPP
#include "OctVoxMap.hpp"
#include <vector>
#include <memory>
#include <pcl/point_types.h>
#include <Eigen/Core>
// Adapter to make OctVoxMap look like IVox
template <int dim = 3, int node_type = 0, typename PointType = pcl::PointXYZ>
class OctVoxMapAdapter {
public:
using KeyType = Eigen::Matrix<int, dim, 1>;
using PtType = Eigen::Matrix<float, dim, 1>;
using PointVector = std::vector<PointType, Eigen::aligned_allocator<PointType>>;
// Use Eigen::Vector3f internally for OctVoxMap to avoid operator overloading issues with PCL types
using InternalPointType = Eigen::Matrix<float, dim, 1>;
using OctVoxMapType = LI2Sup::OctVoxMap<InternalPointType, float>;
using KNNHeapType = typename OctVoxMapType::KNNHeapType;
using SearchCache = typename OctVoxMapType::SearchCache;
enum class NearbyType {
CENTER,
NEARBY6,
NEARBY18,
NEARBY26
};
struct Options {
float resolution_ = 0.2;
float inv_resolution_ = 10.0;
NearbyType nearby_type_ = NearbyType::NEARBY6; // Not used by OctVoxMap
std::size_t capacity_ = 1000000;
};
explicit OctVoxMapAdapter(Options options) {
typename OctVoxMapType::Options oct_options(options.resolution_, options.capacity_);
oct_vox_map_.reset(new OctVoxMapType(oct_options));
}
void AddPoints(const PointVector& points_to_add) {
// Convert PCL points to Eigen points
std::vector<InternalPointType, Eigen::aligned_allocator<InternalPointType>> internal_points;
internal_points.reserve(points_to_add.size());
for (const auto& pt : points_to_add) {
internal_points.emplace_back(pt.x, pt.y, pt.z);
}
// OctVoxMap handles downsampling internally
oct_vox_map_->insert(internal_points);
}
bool GetClosestPoint(const PointType& pt, PointVector& closest_pt, int max_num = 5, double max_range = 5.0, SearchCache* cache = nullptr) {
// OctVoxMap hardcodes K=5 in KNNHeap<5, Point>
if (max_num > 5) max_num = 5;
InternalPointType internal_pt(pt.x, pt.y, pt.z);
KNNHeapType top_K;
// OctVoxMap::getTopK takes (Point, KNNHeap&)
oct_vox_map_->getTopK(internal_pt, top_K, cache);
closest_pt.clear();
closest_pt.reserve(top_K.count);
float max_range_sq = max_range * max_range;
for (int i = 0; i < top_K.count; ++i) {
if (top_K.dist2_[i] <= max_range_sq) {
PointType p;
p.x = top_K.points_[i].x();
p.y = top_K.points_[i].y();
p.z = top_K.points_[i].z();
// Intensity/Curvature are lost, but not needed for geometric ICP
closest_pt.emplace_back(p);
}
}
return !closest_pt.empty();
}
// Overload for single point NN if needed
bool GetClosestPoint(const PointType& pt, PointType& closest_pt) {
InternalPointType internal_pt(pt.x, pt.y, pt.z);
KNNHeapType top_K;
oct_vox_map_->getTopK(internal_pt, top_K);
if (top_K.count > 0) {
// Find the closest one in the heap
int best_idx = -1;
float min_dist = std::numeric_limits<float>::max();
for(int i=0; i<top_K.count; ++i) {
if(top_K.dist2_[i] < min_dist) {
min_dist = top_K.dist2_[i];
best_idx = i;
}
}
if (best_idx != -1) {
closest_pt.x = top_K.points_[best_idx].x();
closest_pt.y = top_K.points_[best_idx].y();
closest_pt.z = top_K.points_[best_idx].z();
return true;
}
}
return false;
}
size_t NumValidGrids() const { return 0; }
size_t NumPoints() const { return 0; }
std::vector<float> StatGridPoints() const { return {}; }
private:
std::shared_ptr<OctVoxMapType> oct_vox_map_;
};
#endif // OCTVOXMAP_ADAPTER_HPP

View File

@@ -0,0 +1,83 @@
#ifndef VOXEL_GRID_CLOSEST_H
#define VOXEL_GRID_CLOSEST_H
#include <pcl/point_cloud.h>
#include <Eigen/Core>
#include "tsl/robin_hood.h"
namespace LI2Sup {
template<typename PointType>
class VoxelGridClosest {
private:
using Point = PointType;
using PointCloud = pcl::PointCloud<Point>;
using CloudPtr = typename PointCloud::Ptr;
CloudPtr cloud_;
float voxel_size_ = 0.5f;
float inv_voxel_size_ = 2.0f;
robin_hood::unordered_flat_map<std::size_t, std::size_t> voxel_map_;
std::vector<Point, Eigen::aligned_allocator<Point>> points_;
std::vector<float> dist2_;
const Eigen::Vector3i offset_ = Eigen::Vector3i(1000, 1000, 1000);
public:
VoxelGridClosest() {
dist2_.reserve(10000);
points_.reserve(10000);
voxel_map_.reserve(10000);
}
void setLeafSize(float lx) {
voxel_size_ = lx;
inv_voxel_size_ = 1.0f / lx;
}
void setInputCloud(const CloudPtr& cloud) {
cloud_ = cloud;
}
void filter(CloudPtr& output) {
voxel_map_.clear();
dist2_.clear();
points_.clear();
for (const auto& pt : cloud_->points) {
Eigen::Vector3f pf = pt.getVector3fMap();
Eigen::Vector3i idx = (pf * inv_voxel_size_).array().round().cast<int>();
Eigen::Vector3f center = voxel_size_ * idx.cast<float>();
float d2 = (pf - center).squaredNorm();
idx += offset_; // Avoid negative indices
const std::size_t key = ((std::size_t(idx[2])) << 30) |
((std::size_t(idx[1])) << 15) |
( std::size_t(idx[0]));
auto it = voxel_map_.find(key);
if (it == voxel_map_.end()) {
voxel_map_.emplace(key, points_.size());
points_.push_back(pt);
dist2_.push_back(d2);
} else if (d2 < dist2_[it->second]) {
points_[it->second] = pt;
dist2_[it->second] = d2;
}
}
output->points.swap(points_);
output->width = output->points.size();
output->height = 1;
output->is_dense = true;
output->header = cloud_->header;
}
};
}
#endif // VOXEL_GRID_CLOSEST_H

View File

@@ -0,0 +1,415 @@
/**
* MIT License
*
* Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_GROWTH_POLICY_H
#define TSL_ROBIN_GROWTH_POLICY_H
#include <algorithm>
#include <array>
#include <climits>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <limits>
#include <ratio>
#include <stdexcept>
// A change of the major version indicates an API and/or ABI break (change of
// in-memory layout of the data structure)
#define TSL_RH_VERSION_MAJOR 1
// A change of the minor version indicates the addition of a feature without
// impact on the API/ABI
#define TSL_RH_VERSION_MINOR 4
// A change of the patch version indicates a bugfix without additional
// functionality
#define TSL_RH_VERSION_PATCH 0
#ifdef TSL_DEBUG
#define tsl_rh_assert(expr) assert(expr)
#else
#define tsl_rh_assert(expr) (static_cast<void>(0))
#endif
/**
* If exceptions are enabled, throw the exception passed in parameter, otherwise
* call std::terminate.
*/
#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || \
(defined(_MSC_VER) && defined(_CPPUNWIND))) && \
!defined(TSL_NO_EXCEPTIONS)
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) throw ex(msg)
#else
#define TSL_RH_NO_EXCEPTIONS
#ifdef TSL_DEBUG
#include <iostream>
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) \
do { \
std::cerr << msg << std::endl; \
std::terminate(); \
} while (0)
#else
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) std::terminate()
#endif
#endif
#if defined(__GNUC__) || defined(__clang__)
#define TSL_RH_LIKELY(exp) (__builtin_expect(!!(exp), true))
#else
#define TSL_RH_LIKELY(exp) (exp)
#endif
#define TSL_RH_UNUSED(x) static_cast<void>(x)
namespace tsl {
namespace rh {
/**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a
* power of two. It allows the table to use a mask operation instead of a modulo
* operation to map a hash to a bucket.
*
* GrowthFactor must be a power of two >= 2.
*/
template <std::size_t GrowthFactor>
class power_of_two_growth_policy {
public:
/**
* Called on the hash table creation and on rehash. The number of buckets for
* the table is passed in parameter. This number is a minimum, the policy may
* update this value with a higher value if needed (but not lower).
*
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy
* creation and bucket_for_hash must always return 0 in this case.
*/
explicit power_of_two_growth_policy(std::size_t& min_bucket_count_in_out) {
if (min_bucket_count_in_out > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maximum size.");
}
if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out =
round_up_to_power_of_two(min_bucket_count_in_out);
m_mask = min_bucket_count_in_out - 1;
} else {
m_mask = 0;
}
}
/**
* Return the bucket [0, bucket_count()) to which the hash belongs.
* If bucket_count() is 0, it must always return 0.
*/
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash & m_mask;
}
/**
* Return the number of buckets that should be used on next growth.
*/
std::size_t next_bucket_count() const {
if ((m_mask + 1) > max_bucket_count() / GrowthFactor) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maximum size.");
}
return (m_mask + 1) * GrowthFactor;
}
/**
* Return the maximum number of buckets supported by the policy.
*/
std::size_t max_bucket_count() const {
// Largest power of two.
return (std::numeric_limits<std::size_t>::max() / 2) + 1;
}
/**
* Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is
* called.
*/
void clear() noexcept { m_mask = 0; }
private:
static std::size_t round_up_to_power_of_two(std::size_t value) {
if (is_power_of_two(value)) {
return value;
}
if (value == 0) {
return 1;
}
--value;
for (std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
value |= value >> i;
}
return value + 1;
}
static constexpr bool is_power_of_two(std::size_t value) {
return value != 0 && (value & (value - 1)) == 0;
}
protected:
static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2,
"GrowthFactor must be a power of two >= 2.");
std::size_t m_mask;
};
/**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo
* to map a hash to a bucket. Slower but it can be useful if you want a slower
* growth.
*/
template <class GrowthFactor = std::ratio<3, 2>>
class mod_growth_policy {
public:
explicit mod_growth_policy(std::size_t& min_bucket_count_in_out) {
if (min_bucket_count_in_out > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maximum size.");
}
if (min_bucket_count_in_out > 0) {
m_mod = min_bucket_count_in_out;
} else {
m_mod = 1;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash % m_mod;
}
std::size_t next_bucket_count() const {
if (m_mod == max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maximum size.");
}
const double next_bucket_count =
std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
if (!std::isnormal(next_bucket_count)) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maximum size.");
}
if (next_bucket_count > double(max_bucket_count())) {
return max_bucket_count();
} else {
return std::size_t(next_bucket_count);
}
}
std::size_t max_bucket_count() const { return MAX_BUCKET_COUNT; }
void clear() noexcept { m_mod = 1; }
private:
static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR =
1.0 * GrowthFactor::num / GrowthFactor::den;
static const std::size_t MAX_BUCKET_COUNT =
std::size_t(double(std::numeric_limits<std::size_t>::max() /
REHASH_SIZE_MULTIPLICATION_FACTOR));
static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1,
"Growth factor should be >= 1.1.");
std::size_t m_mod;
};
namespace detail {
#if SIZE_MAX >= ULLONG_MAX
#define TSL_RH_NB_PRIMES 51
#elif SIZE_MAX >= ULONG_MAX
#define TSL_RH_NB_PRIMES 40
#else
#define TSL_RH_NB_PRIMES 23
#endif
inline constexpr std::array<std::size_t, TSL_RH_NB_PRIMES> PRIMES = {{
1u,
5u,
17u,
29u,
37u,
53u,
67u,
79u,
97u,
131u,
193u,
257u,
389u,
521u,
769u,
1031u,
1543u,
2053u,
3079u,
6151u,
12289u,
24593u,
49157u,
#if SIZE_MAX >= ULONG_MAX
98317ul,
196613ul,
393241ul,
786433ul,
1572869ul,
3145739ul,
6291469ul,
12582917ul,
25165843ul,
50331653ul,
100663319ul,
201326611ul,
402653189ul,
805306457ul,
1610612741ul,
3221225473ul,
4294967291ul,
#endif
#if SIZE_MAX >= ULLONG_MAX
6442450939ull,
12884901893ull,
25769803751ull,
51539607551ull,
103079215111ull,
206158430209ull,
412316860441ull,
824633720831ull,
1649267441651ull,
3298534883309ull,
6597069766657ull,
#endif
}};
template <unsigned int IPrime>
static constexpr std::size_t mod(std::size_t hash) {
return hash % PRIMES[IPrime];
}
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for
// faster modulo as the compiler can optimize the modulo code better with a
// constant known at the compilation.
inline constexpr std::array<std::size_t (*)(std::size_t), TSL_RH_NB_PRIMES>
MOD_PRIME = {{
&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>,
&mod<6>, &mod<7>, &mod<8>, &mod<9>, &mod<10>, &mod<11>,
&mod<12>, &mod<13>, &mod<14>, &mod<15>, &mod<16>, &mod<17>,
&mod<18>, &mod<19>, &mod<20>, &mod<21>, &mod<22>,
#if SIZE_MAX >= ULONG_MAX
&mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>, &mod<28>,
&mod<29>, &mod<30>, &mod<31>, &mod<32>, &mod<33>, &mod<34>,
&mod<35>, &mod<36>, &mod<37>, &mod<38>, &mod<39>,
#endif
#if SIZE_MAX >= ULLONG_MAX
&mod<40>, &mod<41>, &mod<42>, &mod<43>, &mod<44>, &mod<45>,
&mod<46>, &mod<47>, &mod<48>, &mod<49>, &mod<50>,
#endif
}};
} // namespace detail
/**
* Grow the hash table by using prime numbers as bucket count. Slower than
* tsl::rh::power_of_two_growth_policy in general but will probably distribute
* the values around better in the buckets with a poor hash function.
*
* To allow the compiler to optimize the modulo operation, a lookup table is
* used with constant primes numbers.
*
* With a switch the code would look like:
* \code
* switch(iprime) { // iprime is the current prime of the hash table
* case 0: hash % 5ul;
* break;
* case 1: hash % 17ul;
* break;
* case 2: hash % 29ul;
* break;
* ...
* }
* \endcode
*
* Due to the constant variable in the modulo the compiler is able to optimize
* the operation by a series of multiplications, substractions and shifts.
*
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34)
* * 5' in a 64 bits environment.
*/
class prime_growth_policy {
public:
explicit prime_growth_policy(std::size_t& min_bucket_count_in_out) {
auto it_prime = std::lower_bound(
detail::PRIMES.begin(), detail::PRIMES.end(), min_bucket_count_in_out);
if (it_prime == detail::PRIMES.end()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maximum size.");
}
m_iprime = static_cast<unsigned int>(
std::distance(detail::PRIMES.begin(), it_prime));
if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out = *it_prime;
} else {
min_bucket_count_in_out = 0;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return detail::MOD_PRIME[m_iprime](hash);
}
std::size_t next_bucket_count() const {
if (m_iprime + 1 >= detail::PRIMES.size()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maximum size.");
}
return detail::PRIMES[m_iprime + 1];
}
std::size_t max_bucket_count() const { return detail::PRIMES.back(); }
void clear() noexcept { m_iprime = 0; }
private:
unsigned int m_iprime;
static_assert(std::numeric_limits<decltype(m_iprime)>::max() >=
detail::PRIMES.size(),
"The type of m_iprime is not big enough.");
};
} // namespace rh
} // namespace tsl
#endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,815 @@
/**
* MIT License
*
* Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_MAP_H
#define TSL_ROBIN_MAP_H
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <memory>
#include <type_traits>
#include <utility>
#include "robin_hash.h"
namespace tsl {
/**
* Implementation of a hash map using open-addressing and the robin hood hashing
* algorithm with backward shift deletion.
*
* For operations modifying the hash map (insert, erase, rehash, ...), the
* strong exception guarantee is only guaranteed when the expression
* `std::is_nothrow_swappable<std::pair<Key, T>>::value &&
* std::is_nothrow_move_constructible<std::pair<Key, T>>::value` is true,
* otherwise if an exception is thrown during the swap or the move, the hash map
* may end up in a undefined state. Per the standard a `Key` or `T` with a
* noexcept copy constructor and no move constructor also satisfies the
* `std::is_nothrow_move_constructible<std::pair<Key, T>>::value` criterion (and
* will thus guarantee the strong exception for the map).
*
* When `StoreHash` is true, 32 bits of the hash are stored alongside the
* values. It can improve the performance during lookups if the `KeyEqual`
* function takes time (if it engenders a cache-miss for example) as we then
* compare the stored hashes before comparing the keys. When
* `tsl::rh::power_of_two_growth_policy` is used as `GrowthPolicy`, it may also
* speed-up the rehash process as we can avoid to recalculate the hash. When it
* is detected that storing the hash will not incur any memory penalty due to
* alignment (i.e. `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType,
* true>) == sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`)
* and `tsl::rh::power_of_two_growth_policy` is used, the hash will be stored
* even if `StoreHash` is false so that we can speed-up the rehash (but it will
* not be used on lookups unless `StoreHash` is true).
*
* `GrowthPolicy` defines how the map grows and consequently how a hash value is
* mapped to a bucket. By default the map uses
* `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of
* buckets to a power of two and uses a mask to map the hash to a bucket instead
* of the slow modulo. Other growth policies are available and you may define
* your own growth policy, check `tsl::rh::power_of_two_growth_policy` for the
* interface.
*
* `std::pair<Key, T>` must be swappable.
*
* `Key` and `T` must be copy and/or move constructible.
*
* If the destructor of `Key` or `T` throws an exception, the behaviour of the
* class is undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective
* insert, invalidate the iterators.
* - erase: always invalidate the iterators.
*/
template <class Key, class T, class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<std::pair<Key, T>>,
bool StoreHash = false,
class GrowthPolicy = tsl::rh::power_of_two_growth_policy<2>>
class robin_map {
private:
template <typename U>
using has_is_transparent = tsl::detail_robin_hash::has_is_transparent<U>;
class KeySelect {
public:
using key_type = Key;
const key_type& operator()(
const std::pair<Key, T>& key_value) const noexcept {
return key_value.first;
}
key_type& operator()(std::pair<Key, T>& key_value) noexcept {
return key_value.first;
}
};
class ValueSelect {
public:
using value_type = T;
const value_type& operator()(
const std::pair<Key, T>& key_value) const noexcept {
return key_value.second;
}
value_type& operator()(std::pair<Key, T>& key_value) noexcept {
return key_value.second;
}
};
using ht = detail_robin_hash::robin_hash<std::pair<Key, T>, KeySelect,
ValueSelect, Hash, KeyEqual,
Allocator, StoreHash, GrowthPolicy>;
public:
using key_type = typename ht::key_type;
using mapped_type = T;
using value_type = typename ht::value_type;
using size_type = typename ht::size_type;
using difference_type = typename ht::difference_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using allocator_type = typename ht::allocator_type;
using reference = typename ht::reference;
using const_reference = typename ht::const_reference;
using pointer = typename ht::pointer;
using const_pointer = typename ht::const_pointer;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
public:
/*
* Constructors
*/
robin_map() : robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE) {}
explicit robin_map(size_type bucket_count, const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator())
: m_ht(bucket_count, hash, equal, alloc) {}
robin_map(size_type bucket_count, const Allocator& alloc)
: robin_map(bucket_count, Hash(), KeyEqual(), alloc) {}
robin_map(size_type bucket_count, const Hash& hash, const Allocator& alloc)
: robin_map(bucket_count, hash, KeyEqual(), alloc) {}
explicit robin_map(const Allocator& alloc)
: robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {}
template <class InputIt>
robin_map(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(), const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator())
: robin_map(bucket_count, hash, equal, alloc) {
insert(first, last);
}
template <class InputIt>
robin_map(InputIt first, InputIt last, size_type bucket_count,
const Allocator& alloc)
: robin_map(first, last, bucket_count, Hash(), KeyEqual(), alloc) {}
template <class InputIt>
robin_map(InputIt first, InputIt last, size_type bucket_count,
const Hash& hash, const Allocator& alloc)
: robin_map(first, last, bucket_count, hash, KeyEqual(), alloc) {}
robin_map(std::initializer_list<value_type> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(), const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator())
: robin_map(init.begin(), init.end(), bucket_count, hash, equal, alloc) {}
robin_map(std::initializer_list<value_type> init, size_type bucket_count,
const Allocator& alloc)
: robin_map(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(),
alloc) {}
robin_map(std::initializer_list<value_type> init, size_type bucket_count,
const Hash& hash, const Allocator& alloc)
: robin_map(init.begin(), init.end(), bucket_count, hash, KeyEqual(),
alloc) {}
robin_map& operator=(std::initializer_list<value_type> ilist) {
m_ht.clear();
m_ht.reserve(ilist.size());
m_ht.insert(ilist.begin(), ilist.end());
return *this;
}
allocator_type get_allocator() const { return m_ht.get_allocator(); }
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert(const value_type& value) {
return m_ht.insert(value);
}
template <class P, typename std::enable_if<std::is_constructible<
value_type, P&&>::value>::type* = nullptr>
std::pair<iterator, bool> insert(P&& value) {
return m_ht.emplace(std::forward<P>(value));
}
std::pair<iterator, bool> insert(value_type&& value) {
return m_ht.insert(std::move(value));
}
iterator insert(const_iterator hint, const value_type& value) {
return m_ht.insert_hint(hint, value);
}
template <class P, typename std::enable_if<std::is_constructible<
value_type, P&&>::value>::type* = nullptr>
iterator insert(const_iterator hint, P&& value) {
return m_ht.emplace_hint(hint, std::forward<P>(value));
}
iterator insert(const_iterator hint, value_type&& value) {
return m_ht.insert_hint(hint, std::move(value));
}
template <class InputIt>
void insert(InputIt first, InputIt last) {
m_ht.insert(first, last);
}
void insert(std::initializer_list<value_type> ilist) {
m_ht.insert(ilist.begin(), ilist.end());
}
template <class M>
std::pair<iterator, bool> insert_or_assign(const key_type& k, M&& obj) {
return m_ht.insert_or_assign(k, std::forward<M>(obj));
}
template <class M>
std::pair<iterator, bool> insert_or_assign(key_type&& k, M&& obj) {
return m_ht.insert_or_assign(std::move(k), std::forward<M>(obj));
}
template <class M>
iterator insert_or_assign(const_iterator hint, const key_type& k, M&& obj) {
return m_ht.insert_or_assign(hint, k, std::forward<M>(obj));
}
template <class M>
iterator insert_or_assign(const_iterator hint, key_type&& k, M&& obj) {
return m_ht.insert_or_assign(hint, std::move(k), std::forward<M>(obj));
}
/**
* Due to the way elements are stored, emplace will need to move or copy the
* key-value once. The method is equivalent to
* insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template <class... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
return m_ht.emplace(std::forward<Args>(args)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy
* the key-value once. The method is equivalent to insert(hint,
* value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template <class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args) {
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
}
template <class... Args>
std::pair<iterator, bool> try_emplace(const key_type& k, Args&&... args) {
return m_ht.try_emplace(k, std::forward<Args>(args)...);
}
template <class... Args>
std::pair<iterator, bool> try_emplace(key_type&& k, Args&&... args) {
return m_ht.try_emplace(std::move(k), std::forward<Args>(args)...);
}
template <class... Args>
iterator try_emplace(const_iterator hint, const key_type& k, Args&&... args) {
return m_ht.try_emplace_hint(hint, k, std::forward<Args>(args)...);
}
template <class... Args>
iterator try_emplace(const_iterator hint, key_type&& k, Args&&... args) {
return m_ht.try_emplace_hint(hint, std::move(k),
std::forward<Args>(args)...);
}
iterator erase(iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) {
return m_ht.erase(first, last);
}
size_type erase(const key_type& key) { return m_ht.erase(key); }
/**
* Erase the element at position 'pos'. In contrast to the regular erase()
* function, erase_fast() does not return an iterator. This allows it to be
* faster especially in hash tables with a low load factor, where finding the
* next nonempty bucket would be costly.
*/
void erase_fast(iterator pos) { return m_ht.erase_fast(pos); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
size_type erase(const key_type& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key) {
return m_ht.erase(key);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
void swap(robin_map& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
T& at(const Key& key) { return m_ht.at(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
T& at(const Key& key, std::size_t precalculated_hash) {
return m_ht.at(key, precalculated_hash);
}
const T& at(const Key& key) const { return m_ht.at(key); }
/**
* @copydoc at(const Key& key, std::size_t precalculated_hash)
*/
const T& at(const Key& key, std::size_t precalculated_hash) const {
return m_ht.at(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
T& at(const K& key) {
return m_ht.at(key);
}
/**
* @copydoc at(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
T& at(const K& key, std::size_t precalculated_hash) {
return m_ht.at(key, precalculated_hash);
}
/**
* @copydoc at(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const T& at(const K& key) const {
return m_ht.at(key);
}
/**
* @copydoc at(const K& key, std::size_t precalculated_hash)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const T& at(const K& key, std::size_t precalculated_hash) const {
return m_ht.at(key, precalculated_hash);
}
T& operator[](const Key& key) { return m_ht[key]; }
T& operator[](Key&& key) { return m_ht[std::move(key)]; }
size_type count(const Key& key) const { return m_ht.count(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
size_type count(const Key& key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key) const {
return m_ht.count(key);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
iterator find(const Key& key) { return m_ht.find(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
iterator find(const Key& key, std::size_t precalculated_hash) {
return m_ht.find(key, precalculated_hash);
}
const_iterator find(const Key& key) const { return m_ht.find(key); }
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator find(const Key& key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key) {
return m_ht.find(key);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key, std::size_t precalculated_hash) {
return m_ht.find(key, precalculated_hash);
}
/**
* @copydoc find(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key) const {
return m_ht.find(key);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
bool contains(const Key& key) const { return m_ht.contains(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
bool contains(const Key& key, std::size_t precalculated_hash) const {
return m_ht.contains(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
bool contains(const K& key) const {
return m_ht.contains(key);
}
/**
* @copydoc contains(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
bool contains(const K& key, std::size_t precalculated_hash) const {
return m_ht.contains(key, precalculated_hash);
}
std::pair<iterator, iterator> equal_range(const Key& key) {
return m_ht.equal_range(key);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
std::pair<iterator, iterator> equal_range(const Key& key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(
const Key& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key) {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* @copydoc equal_range(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key) const {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(
const K& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float min_load_factor() const { return m_ht.min_load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
/**
* Set the `min_load_factor` to `ml`. When the `load_factor` of the map goes
* below `min_load_factor` after some erase operations, the map will be
* shrunk when an insertion occurs. The erase method itself never shrinks
* the map.
*
* The default value of `min_load_factor` is 0.0f, the map never shrinks by
* default.
*/
void min_load_factor(float ml) { m_ht.min_load_factor(ml); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count_) { m_ht.rehash(count_); }
void reserve(size_type count_) { m_ht.reserve(count_); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator mutable_iterator(const_iterator pos) {
return m_ht.mutable_iterator(pos);
}
/**
* Serialize the map through the `serializer` parameter.
*
* The `serializer` parameter must be a function object that supports the
* following call:
* - `template<typename U> void operator()(const U& value);` where the types
* `std::int16_t`, `std::uint32_t`, `std::uint64_t`, `float` and
* `std::pair<Key, T>` must be supported for U.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, ...) of the types it serializes in the hands of the `Serializer`
* function object if compatibility is required.
*/
template <class Serializer>
void serialize(Serializer& serializer) const {
m_ht.serialize(serializer);
}
/**
* Deserialize a previously serialized map through the `deserializer`
* parameter.
*
* The `deserializer` parameter must be a function object that supports the
* following call:
* - `template<typename U> U operator()();` where the types `std::int16_t`,
* `std::uint32_t`, `std::uint64_t`, `float` and `std::pair<Key, T>` must be
* supported for U.
*
* If the deserialized hash map type is hash compatible with the serialized
* map, the deserialization process can be sped up by setting
* `hash_compatible` to true. To be hash compatible, the Hash, KeyEqual and
* GrowthPolicy must behave the same way than the ones used on the serialized
* map and the StoreHash must have the same value. The `std::size_t` must also
* be of the same size as the one on the platform used to serialize the map.
* If these criteria are not met, the behaviour is undefined with
* `hash_compatible` sets to true.
*
* The behaviour is undefined if the type `Key` and `T` of the `robin_map` are
* not the same as the types used during serialization.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, size of int, ...) of the types it deserializes in the hands of the
* `Deserializer` function object if compatibility is required.
*/
template <class Deserializer>
static robin_map deserialize(Deserializer& deserializer,
bool hash_compatible = false) {
robin_map map(0);
map.m_ht.deserialize(deserializer, hash_compatible);
return map;
}
friend bool operator==(const robin_map& lhs, const robin_map& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto& element_lhs : lhs) {
const auto it_element_rhs = rhs.find(element_lhs.first);
if (it_element_rhs == rhs.cend() ||
element_lhs.second != it_element_rhs->second) {
return false;
}
}
return true;
}
friend bool operator!=(const robin_map& lhs, const robin_map& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(robin_map& lhs, robin_map& rhs) { lhs.swap(rhs); }
private:
ht m_ht;
};
/**
* Same as `tsl::robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash,
* tsl::rh::prime_growth_policy>`.
*/
template <class Key, class T, class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<std::pair<Key, T>>,
bool StoreHash = false>
using robin_pg_map = robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash,
tsl::rh::prime_growth_policy>;
} // end namespace tsl
#endif

View File

@@ -0,0 +1,668 @@
/**
* MIT License
*
* Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_SET_H
#define TSL_ROBIN_SET_H
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <memory>
#include <type_traits>
#include <utility>
#include "robin_hash.h"
namespace tsl {
/**
* Implementation of a hash set using open-addressing and the robin hood hashing
* algorithm with backward shift deletion.
*
* For operations modifying the hash set (insert, erase, rehash, ...), the
* strong exception guarantee is only guaranteed when the expression
* `std::is_nothrow_swappable<Key>::value &&
* std::is_nothrow_move_constructible<Key>::value` is true, otherwise if an
* exception is thrown during the swap or the move, the hash set may end up in a
* undefined state. Per the standard a `Key` with a noexcept copy constructor
* and no move constructor also satisfies the
* `std::is_nothrow_move_constructible<Key>::value` criterion (and will thus
* guarantee the strong exception for the set).
*
* When `StoreHash` is true, 32 bits of the hash are stored alongside the
* values. It can improve the performance during lookups if the `KeyEqual`
* function takes time (or engenders a cache-miss for example) as we then
* compare the stored hashes before comparing the keys. When
* `tsl::rh::power_of_two_growth_policy` is used as `GrowthPolicy`, it may also
* speed-up the rehash process as we can avoid to recalculate the hash. When it
* is detected that storing the hash will not incur any memory penalty due to
* alignment (i.e. `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType,
* true>) == sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`)
* and `tsl::rh::power_of_two_growth_policy` is used, the hash will be stored
* even if `StoreHash` is false so that we can speed-up the rehash (but it will
* not be used on lookups unless `StoreHash` is true).
*
* `GrowthPolicy` defines how the set grows and consequently how a hash value is
* mapped to a bucket. By default the set uses
* `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of
* buckets to a power of two and uses a mask to set the hash to a bucket instead
* of the slow modulo. Other growth policies are available and you may define
* your own growth policy, check `tsl::rh::power_of_two_growth_policy` for the
* interface.
*
* `Key` must be swappable.
*
* `Key` must be copy and/or move constructible.
*
* If the destructor of `Key` throws an exception, the behaviour of the class is
* undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective
* insert, invalidate the iterators.
* - erase: always invalidate the iterators.
*/
template <class Key, class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<Key>, bool StoreHash = false,
class GrowthPolicy = tsl::rh::power_of_two_growth_policy<2>>
class robin_set {
private:
template <typename U>
using has_is_transparent = tsl::detail_robin_hash::has_is_transparent<U>;
class KeySelect {
public:
using key_type = Key;
const key_type& operator()(const Key& key) const noexcept { return key; }
key_type& operator()(Key& key) noexcept { return key; }
};
using ht = detail_robin_hash::robin_hash<Key, KeySelect, void, Hash, KeyEqual,
Allocator, StoreHash, GrowthPolicy>;
public:
using key_type = typename ht::key_type;
using value_type = typename ht::value_type;
using size_type = typename ht::size_type;
using difference_type = typename ht::difference_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using allocator_type = typename ht::allocator_type;
using reference = typename ht::reference;
using const_reference = typename ht::const_reference;
using pointer = typename ht::pointer;
using const_pointer = typename ht::const_pointer;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
/*
* Constructors
*/
robin_set() : robin_set(ht::DEFAULT_INIT_BUCKETS_SIZE) {}
explicit robin_set(size_type bucket_count, const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator())
: m_ht(bucket_count, hash, equal, alloc) {}
robin_set(size_type bucket_count, const Allocator& alloc)
: robin_set(bucket_count, Hash(), KeyEqual(), alloc) {}
robin_set(size_type bucket_count, const Hash& hash, const Allocator& alloc)
: robin_set(bucket_count, hash, KeyEqual(), alloc) {}
explicit robin_set(const Allocator& alloc)
: robin_set(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {}
template <class InputIt>
robin_set(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(), const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator())
: robin_set(bucket_count, hash, equal, alloc) {
insert(first, last);
}
template <class InputIt>
robin_set(InputIt first, InputIt last, size_type bucket_count,
const Allocator& alloc)
: robin_set(first, last, bucket_count, Hash(), KeyEqual(), alloc) {}
template <class InputIt>
robin_set(InputIt first, InputIt last, size_type bucket_count,
const Hash& hash, const Allocator& alloc)
: robin_set(first, last, bucket_count, hash, KeyEqual(), alloc) {}
robin_set(std::initializer_list<value_type> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(), const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator())
: robin_set(init.begin(), init.end(), bucket_count, hash, equal, alloc) {}
robin_set(std::initializer_list<value_type> init, size_type bucket_count,
const Allocator& alloc)
: robin_set(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(),
alloc) {}
robin_set(std::initializer_list<value_type> init, size_type bucket_count,
const Hash& hash, const Allocator& alloc)
: robin_set(init.begin(), init.end(), bucket_count, hash, KeyEqual(),
alloc) {}
robin_set& operator=(std::initializer_list<value_type> ilist) {
m_ht.clear();
m_ht.reserve(ilist.size());
m_ht.insert(ilist.begin(), ilist.end());
return *this;
}
allocator_type get_allocator() const { return m_ht.get_allocator(); }
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert(const value_type& value) {
return m_ht.insert(value);
}
std::pair<iterator, bool> insert(value_type&& value) {
return m_ht.insert(std::move(value));
}
iterator insert(const_iterator hint, const value_type& value) {
return m_ht.insert_hint(hint, value);
}
iterator insert(const_iterator hint, value_type&& value) {
return m_ht.insert_hint(hint, std::move(value));
}
template <class InputIt>
void insert(InputIt first, InputIt last) {
m_ht.insert(first, last);
}
void insert(std::initializer_list<value_type> ilist) {
m_ht.insert(ilist.begin(), ilist.end());
}
/**
* Due to the way elements are stored, emplace will need to move or copy the
* key-value once. The method is equivalent to
* insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template <class... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
return m_ht.emplace(std::forward<Args>(args)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy
* the key-value once. The method is equivalent to insert(hint,
* value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template <class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args) {
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
}
iterator erase(iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) {
return m_ht.erase(first, last);
}
size_type erase(const key_type& key) { return m_ht.erase(key); }
/**
* Erase the element at position 'pos'. In contrast to the regular erase()
* function, erase_fast() does not return an iterator. This allows it to be
* faster especially in hash sets with a low load factor, where finding the
* next nonempty bucket would be costly.
*/
void erase_fast(iterator pos) { return m_ht.erase_fast(pos); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
size_type erase(const key_type& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key) {
return m_ht.erase(key);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
void swap(robin_set& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
size_type count(const Key& key) const { return m_ht.count(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
size_type count(const Key& key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key) const {
return m_ht.count(key);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
iterator find(const Key& key) { return m_ht.find(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
iterator find(const Key& key, std::size_t precalculated_hash) {
return m_ht.find(key, precalculated_hash);
}
const_iterator find(const Key& key) const { return m_ht.find(key); }
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator find(const Key& key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key) {
return m_ht.find(key);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key, std::size_t precalculated_hash) {
return m_ht.find(key, precalculated_hash);
}
/**
* @copydoc find(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key) const {
return m_ht.find(key);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
bool contains(const Key& key) const { return m_ht.contains(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
bool contains(const Key& key, std::size_t precalculated_hash) const {
return m_ht.contains(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
bool contains(const K& key) const {
return m_ht.contains(key);
}
/**
* @copydoc contains(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
bool contains(const K& key, std::size_t precalculated_hash) const {
return m_ht.contains(key, precalculated_hash);
}
std::pair<iterator, iterator> equal_range(const Key& key) {
return m_ht.equal_range(key);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
std::pair<iterator, iterator> equal_range(const Key& key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(
const Key& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key) {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* @copydoc equal_range(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key) const {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(
const K& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float min_load_factor() const { return m_ht.min_load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
/**
* Set the `min_load_factor` to `ml`. When the `load_factor` of the set goes
* below `min_load_factor` after some erase operations, the set will be
* shrunk when an insertion occurs. The erase method itself never shrinks
* the set.
*
* The default value of `min_load_factor` is 0.0f, the set never shrinks by
* default.
*/
void min_load_factor(float ml) { m_ht.min_load_factor(ml); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count_) { m_ht.rehash(count_); }
void reserve(size_type count_) { m_ht.reserve(count_); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator mutable_iterator(const_iterator pos) {
return m_ht.mutable_iterator(pos);
}
friend bool operator==(const robin_set& lhs, const robin_set& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto& element_lhs : lhs) {
const auto it_element_rhs = rhs.find(element_lhs);
if (it_element_rhs == rhs.cend()) {
return false;
}
}
return true;
}
/**
* Serialize the set through the `serializer` parameter.
*
* The `serializer` parameter must be a function object that supports the
* following call:
* - `template<typename U> void operator()(const U& value);` where the types
* `std::int16_t`, `std::uint32_t`, `std::uint64_t`, `float` and `Key` must be
* supported for U.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, ...) of the types it serializes in the hands of the `Serializer`
* function object if compatibility is required.
*/
template <class Serializer>
void serialize(Serializer& serializer) const {
m_ht.serialize(serializer);
}
/**
* Deserialize a previously serialized set through the `deserializer`
* parameter.
*
* The `deserializer` parameter must be a function object that supports the
* following call:
* - `template<typename U> U operator()();` where the types `std::int16_t`,
* `std::uint32_t`, `std::uint64_t`, `float` and `Key` must be supported for
* U.
*
* If the deserialized hash set type is hash compatible with the serialized
* set, the deserialization process can be sped up by setting
* `hash_compatible` to true. To be hash compatible, the Hash, KeyEqual and
* GrowthPolicy must behave the same way than the ones used on the serialized
* set and the StoreHash must have the same value. The `std::size_t` must also
* be of the same size as the one on the platform used to serialize the set.
* If these criteria are not met, the behaviour is undefined with
* `hash_compatible` sets to true.
*
* The behaviour is undefined if the type `Key` of the `robin_set` is not the
* same as the type used during serialization.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, size of int, ...) of the types it deserializes in the hands of the
* `Deserializer` function object if compatibility is required.
*/
template <class Deserializer>
static robin_set deserialize(Deserializer& deserializer,
bool hash_compatible = false) {
robin_set set(0);
set.m_ht.deserialize(deserializer, hash_compatible);
return set;
}
friend bool operator!=(const robin_set& lhs, const robin_set& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(robin_set& lhs, robin_set& rhs) { lhs.swap(rhs); }
private:
ht m_ht;
};
/**
* Same as `tsl::robin_set<Key, Hash, KeyEqual, Allocator, StoreHash,
* tsl::rh::prime_growth_policy>`.
*/
template <class Key, class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<Key>, bool StoreHash = false>
using robin_pg_set = robin_set<Key, Hash, KeyEqual, Allocator, StoreHash,
tsl::rh::prime_growth_policy>;
} // end namespace tsl
#endif

View File

@@ -11,10 +11,14 @@
#include <tf2_eigen/tf2_eigen.hpp>
#include <../include/IKFoM/IKFoM_toolkit/esekfom/esekfom.hpp>
#include <queue>
#include <chrono>
using namespace std;
using namespace Eigen;
inline double get_wtime() {
return std::chrono::duration<double>(std::chrono::high_resolution_clock::now().time_since_epoch()).count();
}
typedef MTK::vect<3, double> vect3;
typedef MTK::SO3<double> SO3;

View File

@@ -29,6 +29,7 @@
<depend>pcl_ros</depend> <!-- For PCL support -->
<depend>pcl_conversions</depend> <!-- For PCL support -->
<depend>livox_ros_driver2</depend>
<depend>libtbb-dev</depend>
<!-- test_depend remains the same, but make sure the testing tools you use are compatible with ROS2 -->
<test_depend>ament_lint_auto</test_depend>

View File

@@ -1,5 +1,8 @@
// #include <../include/IKFoM/IKFoM_toolkit/esekfom/esekfom.hpp>
#include "Estimator.h"
#include <tbb/parallel_for.h>
#include <tbb/blocked_range.h>
#include <tbb/enumerable_thread_specific.h>
PointCloudXYZI::Ptr normvec(new PointCloudXYZI(100000, 1));
std::vector<int> time_seq;
@@ -112,12 +115,19 @@ Eigen::Matrix<double, 30, 30> df_dx_output(state_output &s, const input_ikfom &i
void h_model_input(state_input &s, Eigen::Matrix3d cov_p, Eigen::Matrix3d cov_R, esekfom::dyn_share_modified<double> &ekfom_data)
{
bool match_in_map = false;
normvec->resize(time_seq[k]);
std::atomic<int> effect_num_k(0);
// Grainsize set to 256 to reduce scheduler overhead
tbb::parallel_for(tbb::blocked_range<int>(0, time_seq[k], 256), [&](const tbb::blocked_range<int>& r) {
int local_effect_num = 0; // Local accumulator
#ifdef IVOX_NODE_TYPE_OCTVOXMAP
IVoxType::SearchCache cache;
#endif
for (int j = r.begin(); j != r.end(); ++j)
{
VF(4) pabcd;
pabcd.setZero();
normvec->resize(time_seq[k]);
int effect_num_k = 0;
for (int j = 0; j < time_seq[k]; j++)
{
PointType &point_body_j = feats_down_body->points[idx+j+1];
PointType &point_world_j = feats_down_world->points[idx+j+1];
pointBodyToWorld(&point_body_j, &point_world_j);
@@ -127,7 +137,11 @@ void h_model_input(state_input &s, Eigen::Matrix3d cov_p, Eigen::Matrix3d cov_R,
p_world << point_world_j.x, point_world_j.y, point_world_j.z;
{
auto &points_near = Nearest_Points[idx+j+1];
ivox_->GetClosestPoint(point_world_j, points_near, NUM_MATCH_POINTS); //
#ifdef IVOX_NODE_TYPE_OCTVOXMAP
ivox_->GetClosestPoint(point_world_j, points_near, NUM_MATCH_POINTS, 5.0, &cache);
#else
ivox_->GetClosestPoint(point_world_j, points_near, NUM_MATCH_POINTS);
#endif
if ((points_near.size() < NUM_MATCH_POINTS)) // || pointSearchSqDis[NUM_MATCH_POINTS - 1] > 5) // 5)
{
point_selected_surf[idx+j+1] = false;
@@ -138,25 +152,6 @@ void h_model_input(state_input &s, Eigen::Matrix3d cov_p, Eigen::Matrix3d cov_R,
if (esti_plane(pabcd, points_near, plane_thr)) //(planeValid)
{
float pd2 = fabs(pabcd(0) * point_world_j.x + pabcd(1) * point_world_j.y + pabcd(2) * point_world_j.z + pabcd(3));
// V3D norm_vec;
// M3D Rpf, pf;
// pf = crossmat_list[idx+j+1];
// // pf << SKEW_SYM_MATRX(p_body);
// Rpf = s.rot * pf;
// norm_vec << pabcd(0), pabcd(1), pabcd(2);
// double noise_state = norm_vec.transpose() * (cov_p+Rpf*cov_R*Rpf.transpose()) * norm_vec + sqrt(p_norm) * 0.001;
// // if (p_norm > match_s * pd2 * pd2)
// double epsilon = pd2 / sqrt(noise_state);
// // cout << "check epsilon:" << epsilon << endl;
// double weight = 1.0; // epsilon / sqrt(epsilon * epsilon+1);
// if (epsilon > 1.0)
// {
// weight = sqrt(2 * epsilon - 1) / epsilon;
// pabcd(0) = weight * pabcd(0);
// pabcd(1) = weight * pabcd(1);
// pabcd(2) = weight * pabcd(2);
// pabcd(3) = weight * pabcd(3);
// }
if (p_norm > match_s * pd2 * pd2)
{
point_selected_surf[idx+j+1] = true;
@@ -164,12 +159,15 @@ void h_model_input(state_input &s, Eigen::Matrix3d cov_p, Eigen::Matrix3d cov_R,
normvec->points[j].y = pabcd(1);
normvec->points[j].z = pabcd(2);
normvec->points[j].intensity = pabcd(3);
effect_num_k ++;
local_effect_num++;
}
}
}
}
}
effect_num_k += local_effect_num; // Atomic update once per chunk
});
if (effect_num_k == 0)
{
ekfom_data.valid = false;
@@ -218,12 +216,19 @@ void h_model_input(state_input &s, Eigen::Matrix3d cov_p, Eigen::Matrix3d cov_R,
void h_model_output(state_output &s, Eigen::Matrix3d cov_p, Eigen::Matrix3d cov_R, esekfom::dyn_share_modified<double> &ekfom_data)
{
bool match_in_map = false;
normvec->resize(time_seq[k]);
std::atomic<int> effect_num_k(0);
// Grainsize set to 256 to reduce scheduler overhead
tbb::parallel_for(tbb::blocked_range<int>(0, time_seq[k], 256), [&](const tbb::blocked_range<int>& r) {
int local_effect_num = 0; // Local accumulator
#ifdef IVOX_NODE_TYPE_OCTVOXMAP
IVoxType::SearchCache cache;
#endif
for (int j = r.begin(); j != r.end(); ++j)
{
VF(4) pabcd;
pabcd.setZero();
normvec->resize(time_seq[k]);
int effect_num_k = 0;
for (int j = 0; j < time_seq[k]; j++)
{
PointType &point_body_j = feats_down_body->points[idx+j+1];
PointType &point_world_j = feats_down_world->points[idx+j+1];
pointBodyToWorld(&point_body_j, &point_world_j);
@@ -234,7 +239,11 @@ void h_model_output(state_output &s, Eigen::Matrix3d cov_p, Eigen::Matrix3d cov_
{
auto &points_near = Nearest_Points[idx+j+1];
ivox_->GetClosestPoint(point_world_j, points_near, NUM_MATCH_POINTS); //
#ifdef IVOX_NODE_TYPE_OCTVOXMAP
ivox_->GetClosestPoint(point_world_j, points_near, NUM_MATCH_POINTS, 5.0, &cache);
#else
ivox_->GetClosestPoint(point_world_j, points_near, NUM_MATCH_POINTS);
#endif
if ((points_near.size() < NUM_MATCH_POINTS)) // || pointSearchSqDis[NUM_MATCH_POINTS - 1] > 5)
{
@@ -246,24 +255,6 @@ void h_model_output(state_output &s, Eigen::Matrix3d cov_p, Eigen::Matrix3d cov_
if (esti_plane(pabcd, points_near, plane_thr)) //(planeValid)
{
float pd2 = fabs(pabcd(0) * point_world_j.x + pabcd(1) * point_world_j.y + pabcd(2) * point_world_j.z + pabcd(3));
// V3D norm_vec;
// M3D Rpf, pf;
// pf = crossmat_list[idx+j+1];
// // pf << SKEW_SYM_MATRX(p_body);
// Rpf = s.rot * pf;
// norm_vec << pabcd(0), pabcd(1), pabcd(2);
// double noise_state = norm_vec.transpose() * (cov_p+Rpf*cov_R*Rpf.transpose()) * norm_vec + sqrt(p_norm) * 0.001;
// // if (p_norm > match_s * pd2 * pd2)
// double epsilon = pd2 / sqrt(noise_state);
// double weight = 1.0; // epsilon / sqrt(epsilon * epsilon+1);
// if (epsilon > 1.0)
// {
// weight = sqrt(2 * epsilon - 1) / epsilon;
// pabcd(0) = weight * pabcd(0);
// pabcd(1) = weight * pabcd(1);
// pabcd(2) = weight * pabcd(2);
// pabcd(3) = weight * pabcd(3);
// }
if (p_norm > match_s * pd2 * pd2)
{
// point_selected_surf[i] = true;
@@ -272,12 +263,15 @@ void h_model_output(state_output &s, Eigen::Matrix3d cov_p, Eigen::Matrix3d cov_
normvec->points[j].y = pabcd(1);
normvec->points[j].z = pabcd(2);
normvec->points[j].intensity = pabcd(3);
effect_num_k ++;
local_effect_num++;
}
}
}
}
}
effect_num_k += local_effect_num; // Atomic update once per chunk
});
if (effect_num_k == 0)
{
ekfom_data.valid = false;

View File

@@ -12,6 +12,11 @@
#include <tf2_ros/transform_broadcaster.h>
#include "li_initialization.h"
#include <malloc.h>
#include <tbb/parallel_for.h>
#include <tbb/blocked_range.h>
#include <tbb/enumerable_thread_specific.h>
#include <chrono>
// #include <cv_bridge/cv_bridge.h>
// #include "matplotlibcpp.h"
@@ -119,11 +124,13 @@ void pointBodyLidarToIMU(PointType const * const pi, PointType * const po)
}
void MapIncremental() {
PointVector points_to_add;
int cur_pts = feats_down_world->size();
points_to_add.reserve(cur_pts);
tbb::enumerable_thread_specific<PointVector> points_to_add_local;
for (size_t i = 0; i < cur_pts; ++i) {
// Grainsize set to 256
tbb::parallel_for(tbb::blocked_range<int>(0, cur_pts, 256), [&](const tbb::blocked_range<int>& r) {
PointVector& local_points = points_to_add_local.local();
for (int i = r.begin(); i != r.end(); ++i) {
/* decide if need add to map */
PointType &point_world = feats_down_world->points[i];
if (!Nearest_Points[i].empty()) {
@@ -142,12 +149,19 @@ void MapIncremental() {
}
}
if (need_add) {
points_to_add.emplace_back(point_world);
local_points.emplace_back(point_world);
}
} else {
points_to_add.emplace_back(point_world);
local_points.emplace_back(point_world);
}
}
});
PointVector points_to_add;
points_to_add.reserve(cur_pts);
for (const auto& local_vec : points_to_add_local) {
points_to_add.insert(points_to_add.end(), local_vec.begin(), local_vec.end());
}
ivox_->AddPoints(points_to_add);
}
@@ -229,11 +243,14 @@ void publish_frame_body(const rclcpp::Publisher<sensor_msgs::msg::PointCloud2>::
int size = feats_undistort->points.size();
PointCloudXYZI::Ptr laserCloudIMUBody(new PointCloudXYZI(size, 1));
for (int i = 0; i < size; i++)
// Grainsize set to 256
tbb::parallel_for(tbb::blocked_range<int>(0, size, 256), [&](const tbb::blocked_range<int>& r) {
for (int i = r.begin(); i != r.end(); ++i)
{
pointBodyLidarToIMU(&feats_undistort->points[i], \
&laserCloudIMUBody->points[i]);
}
});
sensor_msgs::msg::PointCloud2 laserCloudmsg;
pcl::toROSMsg(*laserCloudIMUBody, laserCloudmsg);
@@ -392,12 +409,13 @@ int main(int argc, char** argv)
signal(SIGINT, SigHandle);
rclcpp::Rate rate(500);
rclcpp::executors::SingleThreadedExecutor executor;
executor.add_node(nh);
while (rclcpp::ok())
{
if (flg_exit) break;
rclcpp::executors::SingleThreadedExecutor executor;
executor.add_node(nh);
executor.spin_some();
if(sync_packages(Measures))
@@ -478,10 +496,9 @@ int main(int argc, char** argv)
solve_time = 0;
propag_time = 0;
update_time = 0;
t0 = omp_get_wtime();
t0 = get_wtime();
/*** downsample the feature points in a scan ***/
t1 = omp_get_wtime();
p_imu->Process(Measures, feats_undistort);
if(space_down_sample)
{
@@ -498,7 +515,7 @@ int main(int argc, char** argv)
time_seq = time_compressing<int>(feats_down_body);
feats_down_size = feats_down_body->points.size();
}
t1 = get_wtime();
if (!p_imu->after_imu_init_) // !p_imu->UseLIInit &&
{
if (!p_imu->imu_need_init_)
@@ -555,14 +572,16 @@ int main(int argc, char** argv)
Nearest_Points.resize(feats_down_size);
t2 = omp_get_wtime();
t2 = get_wtime();
/*** iterated state estimation ***/
crossmat_list.reserve(feats_down_size);
pbody_list.reserve(feats_down_size);
crossmat_list.resize(feats_down_size);
pbody_list.resize(feats_down_size);
// pbody_ext_list.reserve(feats_down_size);
for (size_t i = 0; i < feats_down_body->size(); i++)
// Grainsize set to 256
tbb::parallel_for(tbb::blocked_range<size_t>(0, feats_down_body->size(), 256), [&](const tbb::blocked_range<size_t>& r) {
for (size_t i = r.begin(); i != r.end(); ++i)
{
V3D point_this(feats_down_body->points[i].x,
feats_down_body->points[i].y,
@@ -587,6 +606,7 @@ int main(int argc, char** argv)
crossmat_list[i]=point_crossmat;
}
}
});
if (!use_imu_as_input)
{
bool imu_upda_cov = false;
@@ -658,14 +678,14 @@ int main(int argc, char** argv)
if (dt_cov > 0.0)
{
time_update_last = get_time_sec(imu_next.header.stamp);
double propag_imu_start = omp_get_wtime();
double propag_imu_start = get_wtime();
kf_output.predict(dt_cov, Q_output, input_in, false, true);
propag_time += omp_get_wtime() - propag_imu_start;
double solve_imu_start = omp_get_wtime();
propag_time += get_wtime() - propag_imu_start;
double solve_imu_start = get_wtime();
kf_output.update_iterated_dyn_share_IMU();
solve_time += omp_get_wtime() - solve_imu_start;
solve_time += get_wtime() - solve_imu_start;
}
}
imu_deque.pop_front();
@@ -681,7 +701,7 @@ int main(int argc, char** argv)
}
double dt = time_current - time_predict_last_const;
double propag_state_start = omp_get_wtime();
double propag_state_start = get_wtime();
if(!prop_at_freq_of_imu)
{
double dt_cov = time_current - time_update_last;
@@ -692,9 +712,9 @@ int main(int argc, char** argv)
}
}
kf_output.predict(dt, Q_output, input_in, true, false);
propag_time += omp_get_wtime() - propag_state_start;
propag_time += get_wtime() - propag_state_start;
time_predict_last_const = time_current;
double t_update_start = omp_get_wtime();
double t_update_start = get_wtime();
if (feats_down_size < 1)
{
@@ -707,7 +727,7 @@ int main(int argc, char** argv)
idx = idx+time_seq[k];
continue;
}
solve_start = omp_get_wtime();
solve_start = get_wtime();
if (publish_odometry_without_downsample)
{
@@ -729,9 +749,9 @@ int main(int argc, char** argv)
pointBodyToWorld(&point_body_j, &point_world_j);
}
solve_time += omp_get_wtime() - solve_start;
solve_time += get_wtime() - solve_start;
update_time += omp_get_wtime() - t_update_start;
update_time += get_wtime() - t_update_start;
idx += time_seq[k];
// cout << "pbp output effect feat num:" << effct_feat_num << endl;
}
@@ -870,7 +890,7 @@ int main(int argc, char** argv)
}
double dt = time_current - t_last;
t_last = time_current;
double propag_start = omp_get_wtime();
double propag_start = get_wtime();
if(!prop_at_freq_of_imu)
{
@@ -883,9 +903,9 @@ int main(int argc, char** argv)
}
kf_input.predict(dt, Q_input, input_in, true, false);
propag_time += omp_get_wtime() - propag_start;
propag_time += get_wtime() - propag_start;
double t_update_start = omp_get_wtime();
double t_update_start = get_wtime();
if (feats_down_size < 1)
{
@@ -900,7 +920,7 @@ int main(int argc, char** argv)
continue;
}
solve_start = omp_get_wtime();
solve_start = get_wtime();
if (publish_odometry_without_downsample)
{
@@ -921,9 +941,9 @@ int main(int argc, char** argv)
PointType &point_world_j = feats_down_world->points[idx+j+1];
pointBodyToWorld(&point_body_j, &point_world_j);
}
solve_time += omp_get_wtime() - solve_start;
solve_time += get_wtime() - solve_start;
update_time += omp_get_wtime() - t_update_start;
update_time += get_wtime() - t_update_start;
idx = idx + time_seq[k];
}
}
@@ -1010,14 +1030,14 @@ int main(int argc, char** argv)
}
/*** add the feature points to map ***/
t3 = omp_get_wtime();
t3 = get_wtime();
if(feats_down_size > 4)
{
MapIncremental();
}
t5 = omp_get_wtime();
t5 = get_wtime();
/******* Publish points *******/
if (path_en) publish_path(pubPath);
if (scan_pub_en || pcd_save_en) publish_frame_world(pubLaserCloudFullRes);

View File

@@ -26,7 +26,7 @@ void standard_pcl_cbk(const sensor_msgs::msg::PointCloud2::ConstSharedPtr &msg)
{
// mtx_buffer.lock();
scan_count ++;
double preprocess_start_time = omp_get_wtime();
double preprocess_start_time = get_wtime();
double msg_time = get_time_sec(msg->header.stamp);
if (msg_time < last_timestamp_lidar)
{
@@ -97,7 +97,7 @@ void standard_pcl_cbk(const sensor_msgs::msg::PointCloud2::ConstSharedPtr &msg)
}
}
}
s_plot11[scan_count] = omp_get_wtime() - preprocess_start_time;
s_plot11[scan_count] = get_wtime() - preprocess_start_time;
// mtx_buffer.unlock();
// sig_buffer.notify_all();
}
@@ -105,7 +105,7 @@ void standard_pcl_cbk(const sensor_msgs::msg::PointCloud2::ConstSharedPtr &msg)
void livox_pcl_cbk(const livox_ros_driver2::msg::CustomMsg::ConstSharedPtr &msg)
{
// mtx_buffer.lock();
double preprocess_start_time = omp_get_wtime();
double preprocess_start_time = get_wtime();
scan_count ++;
double msg_time = get_time_sec(msg->header.stamp);
if (msg_time < last_timestamp_lidar)
@@ -177,7 +177,7 @@ void livox_pcl_cbk(const livox_ros_driver2::msg::CustomMsg::ConstSharedPtr &msg)
}
}
}
s_plot11[scan_count] = omp_get_wtime() - preprocess_start_time;
s_plot11[scan_count] = get_wtime() - preprocess_start_time;
// mtx_buffer.unlock();
// sig_buffer.notify_all();
}

View File

@@ -23,14 +23,20 @@
#include <sensor_msgs/msg/imu.hpp>
#include <pcl/common/transforms.h>
#include <geometry_msgs/msg/vector3.hpp>
#include <OctVoxMap/OctVoxMapAdapter.hpp>
// #define IVOX_NODE_TYPE_PHC
#define IVOX_NODE_TYPE_OCTVOXMAP
#ifdef IVOX_NODE_TYPE_OCTVOXMAP
using IVoxType = OctVoxMapAdapter<3, 0, PointType>;
#else
#ifdef IVOX_NODE_TYPE_PHC
using IVoxType = faster_lio::IVox<3, faster_lio::IVoxNodeType::PHC, PointType>;
#else
using IVoxType = faster_lio::IVox<3, faster_lio::IVoxNodeType::DEFAULT, PointType>;
#endif
#endif
extern bool is_first_frame;
extern double lidar_end_time, first_lidar_time, time_con;

View File

@@ -319,12 +319,16 @@ void Preprocess::avia_handler(const livox_ros_driver2::msg::CustomMsg::ConstShar
pl_surf.clear();
pl_corn.clear();
pl_full.clear();
double t1 = omp_get_wtime();
double t1 = get_wtime();
int plsize = msg->point_num;
pl_corn.reserve(plsize);
pl_surf.reserve(plsize);
// Avoid resizing pl_full if possible, or just use it as a buffer
if (pl_full.size() < plsize) {
pl_full.resize(plsize);
}
for(int i=0; i<N_SCANS; i++)
{
@@ -333,26 +337,35 @@ void Preprocess::avia_handler(const livox_ros_driver2::msg::CustomMsg::ConstShar
}
uint valid_num = 0;
// Pre-calculate constants
const float time_scale = 1.0f / 1000000.0f;
const double blind_sq = blind * blind;
const double det_range_sq = det_range * det_range;
for(uint i=1; i<plsize; i++)
{
if((msg->points[i].line < N_SCANS) && ((msg->points[i].tag & 0x30) == 0x10 || (msg->points[i].tag & 0x30) == 0x00))
const auto& pt = msg->points[i];
if((pt.line < N_SCANS) && ((pt.tag & 0x30) == 0x10 || (pt.tag & 0x30) == 0x00))
{
valid_num ++;
if (valid_num % point_filter_num == 0)
{
pl_full[i].x = msg->points[i].x;
pl_full[i].y = msg->points[i].y;
pl_full[i].z = msg->points[i].z;
pl_full[i].intensity = msg->points[i].reflectivity; // z; //
pl_full[i].curvature = msg->points[i].offset_time / float(1000000); // use curvature as time of each laser points, curvature unit: ms
double dist = pl_full[i].x * pl_full[i].x + pl_full[i].y * pl_full[i].y + pl_full[i].z * pl_full[i].z;
if (dist < blind * blind || dist > det_range * det_range) continue;
if(((abs(pl_full[i].x - pl_full[i-1].x) > 1e-7)
|| (abs(pl_full[i].y - pl_full[i-1].y) > 1e-7)
|| (abs(pl_full[i].z - pl_full[i-1].z) > 1e-7)))
auto& full_pt = pl_full[i];
full_pt.x = pt.x;
full_pt.y = pt.y;
full_pt.z = pt.z;
full_pt.intensity = pt.reflectivity;
full_pt.curvature = pt.offset_time * time_scale; // use curvature as time of each laser points, curvature unit: ms
double dist = full_pt.x * full_pt.x + full_pt.y * full_pt.y + full_pt.z * full_pt.z;
if (dist < blind_sq || dist > det_range_sq) continue;
const auto& prev_pt = pl_full[i-1];
if(((std::abs(full_pt.x - prev_pt.x) > 1e-7)
|| (std::abs(full_pt.y - prev_pt.y) > 1e-7)
|| (std::abs(full_pt.z - prev_pt.z) > 1e-7)))
{
pl_surf.push_back(pl_full[i]);
pl_surf.push_back(full_pt);
}
}
}