Skip to content

Commit

Permalink
Add Python-like set operations to flat_set (#2557)
Browse files Browse the repository at this point in the history
* Add flat-set doctest printer

* Add Python-like set operations to flat_set

* Fix merge error

* Remove Self

* Refactor operator and member fuctions
  • Loading branch information
AntoinePrv authored Nov 17, 2023
1 parent 1fda529 commit 2c8ec3c
Show file tree
Hide file tree
Showing 3 changed files with 379 additions and 29 deletions.
292 changes: 264 additions & 28 deletions libmamba/include/mamba/util/flat_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
namespace mamba::util
{


struct sorted_unique_t
{
explicit sorted_unique_t() = default;
};

inline constexpr sorted_unique_t sorted_unique{};

/**
* A sorted vector behaving like a set.
*
Expand Down Expand Up @@ -61,49 +69,73 @@ namespace mamba::util
key_compare compare = key_compare(),
const allocator_type& alloc = Allocator()
);
template <typename InputIterator>
flat_set(
sorted_unique_t,
InputIterator first,
InputIterator last,
key_compare compare = key_compare(),
const allocator_type& alloc = Allocator()
);
flat_set(const flat_set&) = default;
flat_set(flat_set&&) = default;
explicit flat_set(std::vector<Key, Allocator>&& other, key_compare compare = key_compare());
explicit flat_set(const std::vector<Key, Allocator>& other, key_compare compare = key_compare());

flat_set& operator=(const flat_set&) = default;
flat_set& operator=(flat_set&&) = default;
auto operator=(const flat_set&) -> flat_set& = default;
auto operator=(flat_set&&) -> flat_set& = default;

auto key_comp() const -> const key_compare&;

bool contains(const value_type&) const;
const value_type& front() const noexcept;
const value_type& back() const noexcept;
const value_type& operator[](size_type pos) const;
const value_type& at(size_type pos) const;
auto front() const noexcept -> const value_type&;
auto back() const noexcept -> const value_type&;
auto operator[](size_type pos) const -> const value_type&;
auto at(size_type pos) const -> const value_type&;

const_iterator begin() const noexcept;
const_iterator end() const noexcept;
const_reverse_iterator rbegin() const noexcept;
const_reverse_iterator rend() const noexcept;
auto begin() const noexcept -> const_iterator;
auto end() const noexcept -> const_iterator;
auto rbegin() const noexcept -> const_reverse_iterator;
auto rend() const noexcept -> const_reverse_iterator;

/** Insert an element in the set.
*
* Like std::vector and unlike std::set, inserting an element invalidates iterators.
*/
std::pair<const_iterator, bool> insert(value_type&& value);
std::pair<const_iterator, bool> insert(const value_type& value);
auto insert(value_type&& value) -> std::pair<const_iterator, bool>;
auto insert(const value_type& value) -> std::pair<const_iterator, bool>;
template <typename InputIterator>
void insert(InputIterator first, InputIterator last);

const_iterator erase(const_iterator pos);
const_iterator erase(const_iterator first, const_iterator last);
size_type erase(const value_type& value);
auto erase(const_iterator pos) -> const_iterator;
auto erase(const_iterator first, const_iterator last) -> const_iterator;
auto erase(const value_type& value) -> size_type;

auto contains(const value_type&) const -> bool;

private:

key_compare m_compare;

bool key_eq(const value_type& a, const value_type& b) const;
auto key_eq(const value_type& a, const value_type& b) const -> bool;
template <typename U>
std::pair<const_iterator, bool> insert_impl(U&& value);
auto insert_impl(U&& value) -> std::pair<const_iterator, bool>;
void sort_and_remove_duplicates();

template <typename K, typename C, typename A>
friend bool operator==(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs);
friend auto operator==(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs) -> bool;

template <typename K, typename C, typename A>
friend auto set_union(const flat_set<K, C, A>&, const flat_set<K, C, A>&)
-> flat_set<K, C, A>;
template <typename K, typename C, typename A>
friend auto set_intersection(const flat_set<K, C, A>&, const flat_set<K, C, A>&)
-> flat_set<K, C, A>;
template <typename K, typename C, typename A>
friend auto set_difference(const flat_set<K, C, A>&, const flat_set<K, C, A>&)
-> flat_set<K, C, A>;
template <typename K, typename C, typename A>
friend auto set_symetric_difference(const flat_set<K, C, A>&, const flat_set<K, C, A>&)
-> flat_set<K, C, A>;
};

template <class Key, class Compare = std::less<Key>, class Allocator = std::allocator<Key>>
Expand All @@ -126,12 +158,68 @@ namespace mamba::util
-> flat_set<Key, Compare, Allocator>;

template <typename Key, typename Compare, typename Allocator>
bool
operator==(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs);
auto
operator==(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs)
-> bool;

template <typename Key, typename Compare, typename Allocator>
auto
operator!=(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs)
-> bool;

template <typename Key, typename Compare, typename Allocator>
auto set_is_disjoint_of(
const flat_set<Key, Compare, Allocator>& lhs,
const flat_set<Key, Compare, Allocator>& rhs
) -> bool;

template <typename Key, typename Compare, typename Allocator>
auto is_subset_of(
const flat_set<Key, Compare, Allocator>& lhs,
const flat_set<Key, Compare, Allocator>& rhs
) -> bool;

template <typename Key, typename Compare, typename Allocator>
auto is_strict_subset_of(
const flat_set<Key, Compare, Allocator>& lhs,
const flat_set<Key, Compare, Allocator>& rhs
) -> bool;

template <typename Key, typename Compare, typename Allocator>
auto is_superset_of(
const flat_set<Key, Compare, Allocator>& lhs,
const flat_set<Key, Compare, Allocator>& rhs
) -> bool;

template <typename Key, typename Compare, typename Allocator>
auto is_strict_superset_of(
const flat_set<Key, Compare, Allocator>& lhs,
const flat_set<Key, Compare, Allocator>& rhs
) -> bool;

template <typename Key, typename Compare, typename Allocator>
auto set_union( //
const flat_set<Key, Compare, Allocator>& lhs,
const flat_set<Key, Compare, Allocator>& rhs
) -> flat_set<Key, Compare, Allocator>;

template <typename Key, typename Compare, typename Allocator>
bool
operator!=(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs);
auto set_intersection(
const flat_set<Key, Compare, Allocator>& lhs,
const flat_set<Key, Compare, Allocator>& rhs
) -> flat_set<Key, Compare, Allocator>;

template <typename Key, typename Compare, typename Allocator>
auto set_difference(
const flat_set<Key, Compare, Allocator>& lhs,
const flat_set<Key, Compare, Allocator>& rhs
) -> flat_set<Key, Compare, Allocator>;

template <typename Key, typename Compare, typename Allocator>
auto set_symetric_difference(
const flat_set<Key, Compare, Allocator>& lhs,
const flat_set<Key, Compare, Allocator>& rhs
) -> flat_set<Key, Compare, Allocator>;

/*******************************
* vector_set Implementation *
Expand Down Expand Up @@ -163,6 +251,20 @@ namespace mamba::util
sort_and_remove_duplicates();
}

template <typename K, typename C, typename A>
template <typename InputIterator>
flat_set<K, C, A>::flat_set(
sorted_unique_t,
InputIterator first,
InputIterator last,
key_compare compare,
const allocator_type& alloc
)
: Base(first, last, alloc)
, m_compare(std::move(compare))
{
}

template <typename K, typename C, typename A>
flat_set<K, C, A>::flat_set(std::vector<K, A>&& other, C compare)
: Base(std::move(other))
Expand All @@ -180,9 +282,9 @@ namespace mamba::util
}

template <typename K, typename C, typename A>
auto flat_set<K, C, A>::contains(const value_type& value) const -> bool
auto flat_set<K, C, A>::key_comp() const -> const key_compare&
{
return std::binary_search(begin(), end(), value);
return m_compare;
}

template <typename K, typename C, typename A>
Expand Down Expand Up @@ -246,7 +348,7 @@ namespace mamba::util
}

template <typename K, typename C, typename A>
bool flat_set<K, C, A>::key_eq(const value_type& a, const value_type& b) const
auto flat_set<K, C, A>::key_eq(const value_type& a, const value_type& b) const -> bool
{
return !m_compare(a, b) && !m_compare(b, a);
}
Expand Down Expand Up @@ -307,17 +409,151 @@ namespace mamba::util
}

template <typename K, typename C, typename A>
bool operator==(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs)
auto flat_set<K, C, A>::contains(const value_type& value) const -> bool
{
return std::binary_search(begin(), end(), value);
}

namespace detail
{
/**
* Check if two sorted range have an empty intersection.
*
* Edited from https://en.cppreference.com/w/cpp/algorithm/set_intersection
* Distributed under the terms of the Copyright/CC-BY-SA License.
* The full license can be found at the address
* https://en.cppreference.com/w/Cppreference:Copyright/CC-BY-SA
*/
template <class InputIt1, class InputIt2, class Compare>
auto
set_disjoint(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2, Compare comp)
-> bool
{
while (first1 != last1 && first2 != last2)
{
if (comp(*first1, *first2))
{
++first1;
}
else
{
if (!comp(*first2, *first1))
{
return false; // *first1 and *first2 are equivalent.
}
++first2;
}
}
return true;
}
}

template <typename K, typename C, typename A>
auto operator==(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs) -> bool
{
auto is_eq = [&lhs](const auto& a, const auto& b) { return lhs.key_eq(a, b); };
return std::equal(lhs.cbegin(), lhs.cend(), rhs.cbegin(), rhs.cend(), is_eq);
}

template <typename K, typename C, typename A>
bool operator!=(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs)
auto operator!=(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs) -> bool
{
return !(lhs == rhs);
}

template <typename K, typename C, typename A>
auto set_is_disjoint_of(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs) -> bool
{
return detail::set_disjoint(lhs.cbegin(), lhs.cend(), rhs.cbegin(), rhs.cend(), lhs.key_comp());
}

template <typename K, typename C, typename A>
auto set_is_subset_of(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs) -> bool
{
return (lhs.size() <= rhs.size()) // For perf
&& std::includes(rhs.cbegin(), rhs.cend(), lhs.cbegin(), lhs.cend(), lhs.key_comp());
}

template <typename K, typename C, typename A>
auto set_is_strict_subset_of(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs) -> bool
{
return (lhs.size() < rhs.size()) && set_is_subset_of(lhs, rhs);
}

template <typename K, typename C, typename A>
auto set_is_superset_of(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs) -> bool
{
return set_is_subset_of(rhs, lhs);
}

template <typename K, typename C, typename A>
auto set_is_strict_superset_of(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs) -> bool
{
return set_is_strict_subset_of(rhs, lhs);
}

template <typename K, typename C, typename A>
auto set_union(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs) -> flat_set<K, C, A>
{
auto out = flat_set<K, C, A>();
out.reserve(std::max(lhs.size(), rhs.size())); // lower bound
std::set_union(
lhs.cbegin(),
lhs.cend(),
rhs.cbegin(),
rhs.cend(),
std::back_inserter(static_cast<typename flat_set<K, C, A>::Base&>(out)),
lhs.m_compare
);
return out;
}

template <typename K, typename C, typename A>
auto set_intersection(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs)
-> flat_set<K, C, A>
{
auto out = flat_set<K, C, A>();
std::set_intersection(
lhs.cbegin(),
lhs.cend(),
rhs.cbegin(),
rhs.cend(),
std::back_inserter(static_cast<typename flat_set<K, C, A>::Base&>(out)),
lhs.m_compare
);
return out;
}

template <typename K, typename C, typename A>
auto set_difference(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs)
-> flat_set<K, C, A>
{
auto out = flat_set<K, C, A>();
std::set_difference(
lhs.cbegin(),
lhs.cend(),
rhs.cbegin(),
rhs.cend(),
std::back_inserter(static_cast<typename flat_set<K, C, A>::Base&>(out)),
lhs.m_compare
);
return out;
}

template <typename K, typename C, typename A>
auto set_symetric_difference(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs)
-> flat_set<K, C, A>
{
auto out = flat_set<K, C, A>();
std::set_symmetric_difference(
lhs.cbegin(),
lhs.cend(),
rhs.cbegin(),
rhs.cend(),
std::back_inserter(static_cast<typename flat_set<K, C, A>::Base&>(out)),
lhs.m_compare
);
return out;
}
}
#endif
2 changes: 1 addition & 1 deletion libmamba/tests/src/doctest-printer/flat_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace doctest
{
static auto convert(const mamba::util::flat_set<K, C, A>& value) -> String
{
return { fmt::format("std::flat_set{{{}}}", fmt::join(value, ", ")).c_str() };
return { fmt::format("mamba::util::flat_set{{{}}}", fmt::join(value, ", ")).c_str() };
}
};
}
Loading

0 comments on commit 2c8ec3c

Please sign in to comment.