Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++20 Spaceship operator #2401

Merged
merged 16 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test-matrix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
name: "CMake Matrix Test."
runs-on: ${{ matrix.config.os }}
strategy:
fail-fast: false
matrix:
config:
- {
Expand All @@ -37,7 +38,7 @@ jobs:
}
- {
name: "MacOS Min",
os: "macos-12",
os: "macos-13",
cc: "clang",
cxx: "clang++",
py: "3.9",
Expand Down
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ include("CheckCompilerXLC")

include("CompilerOptions")
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${CXXOPT_WALL}>")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)

Expand Down Expand Up @@ -320,7 +320,7 @@ if (ARB_WITH_PYTHON)
CPMAddPackage(NAME pybind11
GITHUB_REPOSITORY pybind/pybind11
VERSION 2.10.1
OPTIONS "PYBIND11_CPP_STANDARD -std=c++17")
OPTIONS "PYBIND11_CPP_STANDARD -std=c++20")
# required for find_python_module
include(FindPythonModule)
endif()
Expand Down
2 changes: 2 additions & 0 deletions arbor/backends/gpu/fine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ struct branch {
unsigned parent_idx; //
unsigned start_idx; // the index of the first node in the input parent index
unsigned length; // the number of nodes in the branch

bool operator==(const branch&) const = default;
};

// order branches by:
Expand Down
2 changes: 0 additions & 2 deletions arbor/backends/gpu/matrix_fine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

#include "fine.hpp"

#include <ostream>

namespace arb {
namespace gpu {
ARB_ARBOR_API void assemble_matrix_fine(
Expand Down
15 changes: 7 additions & 8 deletions arbor/connection.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#pragma once

#include <cstdint>

#include <arbor/common_types.hpp>
#include <arbor/spike.hpp>
#include <arbor/spike_event.hpp>
Expand All @@ -14,19 +12,20 @@ struct connection {
float weight = 0.0f;
float delay = 0.0f;
cell_size_type index_on_domain = cell_gid_type(-1);

bool operator==(const connection&) const = default;

// connections are sorted by source id
// these operators make for easy interopability with STL algorithms
auto operator<=>(const connection& rhs) const { return source <=> rhs.source; }
auto operator<=>(const cell_member_type& rhs) const { return source <=> rhs; }
};

inline
spike_event make_event(const connection& c, const spike& s) {
return {c.target, s.time + c.delay, c.weight};
}

// connections are sorted by source id
// these operators make for easy interopability with STL algorithms
static inline bool operator<(const connection& lhs, const connection& rhs) { return lhs.source < rhs.source; }
static inline bool operator<(const connection& lhs, cell_member_type rhs) { return lhs.source < rhs; }
static inline bool operator<(cell_member_type lhs, const connection& rhs) { return lhs < rhs.source; }

} // namespace arb

static inline std::ostream& operator<<(std::ostream& o, arb::connection const& con) {
Expand Down
22 changes: 9 additions & 13 deletions arbor/include/arbor/common_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <string>
#include <type_traits>

#include <arbor/util/lexcmp_def.hpp>
#include <arbor/util/hash_def.hpp>
#include <arbor/export.hpp>

Expand Down Expand Up @@ -58,6 +57,7 @@ using cell_local_size_type = std::make_unsigned_t<cell_lid_type>;
struct cell_member_type {
cell_gid_type gid;
cell_lid_type index;
auto operator<=>(const cell_member_type&) const = default;
};

// Pair of indexes that describe range of local indices.
Expand All @@ -66,8 +66,8 @@ struct lid_range {
cell_lid_type begin = 0;
cell_lid_type end = 0;
lid_range() = default;
lid_range(cell_lid_type b, cell_lid_type e):
begin(b), end(e) {}
lid_range(cell_lid_type b, cell_lid_type e): begin(b), end(e) {}
auto operator<=>(const lid_range&) const = default;
};

// Global range of indices with given step size.
Expand All @@ -77,10 +77,8 @@ struct gid_range {
cell_gid_type end = 0;
cell_gid_type step = 1;
gid_range() = default;
gid_range(cell_gid_type b, cell_gid_type e):
begin(b), end(e), step(1) {}
gid_range(cell_gid_type b, cell_gid_type e, cell_gid_type s):
begin(b), end(e), step(s) {}
gid_range(cell_gid_type b, cell_gid_type e): begin(b), end(e), step(1) {}
gid_range(cell_gid_type b, cell_gid_type e, cell_gid_type s): begin(b), end(e), step(s) {}
};

// Policy for selecting a cell_lid_type from a range of possible values.
Expand Down Expand Up @@ -132,18 +130,16 @@ struct cell_address_type {

cell_address_type& operator=(const cell_address_type&) = default;
cell_address_type& operator=(cell_address_type&&) = default;
};

ARB_DEFINE_LEXICOGRAPHIC_ORDERING(cell_address_type, (a.gid, a.tag), (b.gid, b.tag))
auto operator<=>(const cell_address_type&) const = default;
};

struct cell_remote_label_type {
cell_gid_type rid; // remote id
cell_lid_type index = 0; // index on remote id
};

ARB_DEFINE_LEXICOGRAPHIC_ORDERING(cell_remote_label_type,(a.rid,a.index),(b.rid,b.index))
ARB_DEFINE_LEXICOGRAPHIC_ORDERING(cell_member_type,(a.gid,a.index),(b.gid,b.index))
ARB_DEFINE_LEXICOGRAPHIC_ORDERING(lid_range,(a.begin, a.end),(b.begin,b.end))
auto operator<=>(const cell_remote_label_type&) const = default;
};

// For storing time values [ms]

Expand Down
6 changes: 5 additions & 1 deletion arbor/include/arbor/fvm_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ struct fvm_gap_junction {
arb_size_type local_cv; // CV index of the local gap junction site.
arb_size_type peer_cv; // CV index of the peer gap junction site.
arb_value_type weight; // unit-less local weight of the connection.

constexpr bool operator==(const fvm_gap_junction&) const = default;
constexpr auto operator<=>(const fvm_gap_junction& o) const {
return std::tie(local_cv, peer_cv, local_idx, weight) <=> std::tie(o.local_cv, o.peer_cv, o.local_idx, o.weight);
}
};
ARB_DEFINE_LEXICOGRAPHIC_ORDERING(fvm_gap_junction, (a.local_cv, a.peer_cv, a.local_idx, a.weight), (b.local_cv, b.peer_cv, b.local_idx, b.weight))

} // namespace arb
1 change: 0 additions & 1 deletion arbor/include/arbor/morph/morphology.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <arbor/export.hpp>
#include <arbor/morph/primitives.hpp>
#include <arbor/morph/segment_tree.hpp>
#include <arbor/util/lexcmp_def.hpp>

namespace arb {

Expand Down
27 changes: 7 additions & 20 deletions arbor/include/arbor/morph/primitives.hpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
#pragma once

#include <algorithm>
#include <cstdlib>
#include <cstdint>
#include <ostream>
#include <vector>

#include <arbor/export.hpp>
#include <arbor/util/hash_def.hpp>
#include <arbor/util/lexcmp_def.hpp>

//
// Types used to identify concrete locations.
//

namespace arb {

using msize_t = std::uint32_t;
Expand All @@ -24,10 +19,9 @@ struct ARB_SYMBOL_VISIBLE mpoint {
double x, y, z; // [µm]
double radius; // [μm]
friend std::ostream& operator<<(std::ostream&, const mpoint&);
auto operator<=>(const mpoint&) const = default;
};

ARB_DEFINE_LEXICOGRAPHIC_ORDERING(mpoint, (a.x,a.y,a.z,a.radius), (b.x,b.y,b.z,b.radius));

ARB_ARBOR_API mpoint lerp(const mpoint& a, const mpoint& b, double u);
ARB_ARBOR_API bool is_collocated(const mpoint& a, const mpoint& b);
ARB_ARBOR_API double distance(const mpoint& a, const mpoint& b);
Expand All @@ -46,25 +40,22 @@ struct ARB_SYMBOL_VISIBLE msegment {
mpoint prox;
mpoint dist;
int tag;

auto operator<=>(const msegment&) const = default;
friend std::ostream& operator<<(std::ostream&, const msegment&);
};

ARB_DEFINE_LEXICOGRAPHIC_ORDERING(msegment, (a.id,a.prox,a.dist,a.tag), (b.id,b.prox,b.dist,b.tag));

// Describe a specific location on a morpholology.
struct ARB_SYMBOL_VISIBLE mlocation {
// The id of the branch.
msize_t branch = 0;
// The relative position on the branch ∈ [0,1].
double pos = 0.0;

auto operator<=>(const mlocation&) const = default;
friend std::ostream& operator<<(std::ostream&, const mlocation&);
};

// branch ≠ npos and 0 ≤ pos ≤ 1
ARB_ARBOR_API bool test_invariants(const mlocation&);
ARB_DEFINE_LEXICOGRAPHIC_ORDERING(mlocation, (a.branch,a.pos), (b.branch,b.pos));

using mlocation_list = std::vector<mlocation>;
ARB_ARBOR_API std::ostream& operator<<(std::ostream& o, const mlocation_list& l);
Expand Down Expand Up @@ -94,20 +85,16 @@ struct ARB_SYMBOL_VISIBLE mcable {
double prox_pos; // ∈ [0,1]
double dist_pos; // ∈ [0,1]

friend mlocation prox_loc(const mcable& c) {
return {c.branch, c.prox_pos};
}
friend mlocation dist_loc(const mcable& c) {
return {c.branch, c.dist_pos};
}
auto operator<=>(const mcable&) const = default;

friend mlocation prox_loc(const mcable& c) { return {c.branch, c.prox_pos}; }
friend mlocation dist_loc(const mcable& c) { return {c.branch, c.dist_pos}; }

// branch ≠ npos, and 0 ≤ prox_pos ≤ dist_pos ≤ 1
friend bool test_invariants(const mcable&);
friend std::ostream& operator<<(std::ostream&, const mcable&);
};

ARB_DEFINE_LEXICOGRAPHIC_ORDERING(mcable, (a.branch,a.prox_pos,a.dist_pos), (b.branch,b.prox_pos,b.dist_pos));

using mcable_list = std::vector<mcable>;
ARB_ARBOR_API std::ostream& operator<<(std::ostream& o, const mcable_list& c);
// Tests whether each cable in the list satisfies the invariants for a cable,
Expand Down
19 changes: 3 additions & 16 deletions arbor/include/arbor/network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@
#include <arbor/common_types.hpp>
#include <arbor/export.hpp>
#include <arbor/morph/primitives.hpp>
#include <arbor/util/lexcmp_def.hpp>

#include <array>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <ostream>
#include <stdexcept>
#include <string>
#include <string_view>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
Expand All @@ -41,14 +37,10 @@ struct ARB_SYMBOL_VISIBLE network_site_info {
hash_type label;
mlocation location;
mpoint global_location;

auto operator<=>(const network_site_info&) const = default;
ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_site_info& s);
};

ARB_DEFINE_LEXICOGRAPHIC_ORDERING(network_site_info,
(a.gid, a.kind, a.label, a.location, a.global_location),
(b.gid, a.kind, b.label, b.location, b.global_location))

struct ARB_SYMBOL_VISIBLE network_connection_info {
network_site_info source, target;
double weight, delay;
Expand All @@ -62,14 +54,9 @@ struct ARB_SYMBOL_VISIBLE network_connection_info {
weight(weight),
delay(delay) {}

ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os,
const network_connection_info& s);
auto operator<=>(const network_connection_info&) const = default;
ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_connection_info& s);
};

ARB_DEFINE_LEXICOGRAPHIC_ORDERING(network_connection_info,
(a.source, a.target, a.weight, a.delay),
(b.source, b.target, b.weight, b.delay))

struct network_selection_impl;

struct network_value_impl;
Expand Down
4 changes: 1 addition & 3 deletions arbor/include/arbor/spike.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct basic_spike {
basic_spike(id_type s, time_type t):
source(std::move(s)), time(t)
{}

auto operator<=>(const basic_spike&) const = default;
ARB_SERDES_ENABLE(basic_spike<I>, source, time);
};

Expand All @@ -30,8 +30,6 @@ using spike = basic_spike<cell_member_type>;

using spike_predicate = std::function<bool(const spike&)>;

ARB_DEFINE_LEXICOGRAPHIC_ORDERING(spike, (a.source, a.time), (b.source, b.time));

// Custom stream operator for printing arb::spike<> values.
template <typename I>
std::ostream& operator<<(std::ostream& o, basic_spike<I> const& s) {
Expand Down
6 changes: 3 additions & 3 deletions arbor/include/arbor/spike_event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <arbor/export.hpp>
#include <arbor/serdes.hpp>
#include <arbor/common_types.hpp>
#include <arbor/util/lexcmp_def.hpp>

namespace arb {

Expand All @@ -22,11 +21,12 @@ struct spike_event {
spike_event() = default;
constexpr spike_event(cell_lid_type tgt, time_type t, arb_weight_type w) noexcept: target(tgt), weight(w), time(t) {}

bool operator==(const spike_event&) const = default;
constexpr auto operator<=>(const spike_event& o) const { return std::tie(time, target, weight) <=> std::tie(o.time, o.target, o.weight); }

ARB_SERDES_ENABLE(spike_event, target, time, weight);
};

ARB_DEFINE_LEXICOGRAPHIC_ORDERING(spike_event,(a.time,a.target,a.weight),(b.time,b.target,b.weight))

using pse_vector = std::vector<spike_event>;

ARB_ARBOR_API std::ostream& operator<<(std::ostream&, const spike_event&);
Expand Down
12 changes: 5 additions & 7 deletions arbor/include/arbor/util/any_ptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@

#include <arbor/export.hpp>
#include <arbor/util/any_cast.hpp>
#include <arbor/util/lexcmp_def.hpp>

namespace arb {
namespace util {

Expand Down Expand Up @@ -62,8 +60,8 @@ struct ARB_SYMBOL_VISIBLE any_ptr {
}

template <typename T, typename = std::enable_if_t<std::is_pointer<T>::value>>
T as() const noexcept {
if (std::is_same<T, void*>::value) {
constexpr T as() const noexcept {
if constexpr (std::is_same_v<T, void*>) {
return (T)ptr_;
}
else {
Expand All @@ -88,14 +86,14 @@ struct ARB_SYMBOL_VISIBLE any_ptr {
return *this;
}

constexpr auto operator<=>(const any_ptr& o) const { return this->as<void*>() <=> o.as<void*>(); }
constexpr auto operator==(const any_ptr& o) const { return this->as<void*>() == o.as<void*>(); }

private:
void* ptr_ = nullptr;
const std::type_info* type_ptr_ = &typeid(void);
};

// Order, compare by pointer value:
ARB_DEFINE_LEXICOGRAPHIC_ORDERING_BY_VALUE(any_ptr, (a.as<void*>()), (b.as<void*>()))

// Overload `util::any_cast` for these pointers.
template <typename T>
T any_cast(any_ptr p) noexcept { return p.as<T>(); }
Expand Down
Loading