Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python-like set operations to flat_set #2557

Merged
merged 5 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 226 additions & 7 deletions libmamba/include/mamba/util/flat_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace mamba::util
public:

using Base = std::vector<Key, Allocator>;
using Self = flat_set<Key, Compare, Allocator>;
Copy link
Member

@JohanMabille JohanMabille Nov 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicking: I think we can simply use flat_set in the declaration and implementation, the template parameters will be deduced.

using typename Base::allocator_type;
using typename Base::const_iterator;
using typename Base::const_reverse_iterator;
Expand Down Expand Up @@ -69,7 +70,6 @@ namespace mamba::util
flat_set& operator=(const flat_set&) = default;
flat_set& operator=(flat_set&&) = default;

bool contains(const value_type&) const;
const value_type& front() const noexcept;
const value_type& back() const noexcept;
const value_type& operator[](size_type pos) const;
Expand All @@ -93,6 +93,16 @@ namespace mamba::util
const_iterator erase(const_iterator first, const_iterator last);
size_type erase(const value_type& value);

bool contains(const value_type&) const;
bool is_disjoint_of(const Self& other) const;
bool is_subset_of(const Self& other) const;
bool is_superset_of(const Self& other) const;

static Self union_(const Self& lhs, const Self& rhs);
static Self intersection(const Self& lhs, const Self& rhs);
static Self difference(const Self& lhs, const Self& rhs);
static Self symetric_difference(const Self& lhs, const Self& rhs);

private:

key_compare m_compare;
Expand Down Expand Up @@ -133,6 +143,46 @@ namespace mamba::util
bool
operator!=(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs);

/** Return whether the first set is a subset of the second. */
template <typename Key, typename Compare, typename Allocator>
bool
operator<=(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs);

/** Return whether the first set is a strict subset of the second. */
template <typename Key, typename Compare, typename Allocator>
bool
operator<(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs);

/** Return whether the first set is a superset of the second. */
template <typename Key, typename Compare, typename Allocator>
bool
operator>=(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs);

/** Return whether the first set is a strict superset of the second. */
template <typename Key, typename Compare, typename Allocator>
bool
operator>(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs);

/** Compute the set union. */
template <typename Key, typename Compare, typename Allocator>
flat_set<Key, Compare, Allocator>
operator|(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs);

/** Compute the set intersection. */
template <typename Key, typename Compare, typename Allocator>
flat_set<Key, Compare, Allocator>
operator&(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs);

/** Compute the set difference. */
template <typename Key, typename Compare, typename Allocator>
flat_set<Key, Compare, Allocator>
operator-(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs);

/** Compute the set symetric difference. */
template <typename Key, typename Compare, typename Allocator>
flat_set<Key, Compare, Allocator>
operator^(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs);

/*******************************
* vector_set Implementation *
*******************************/
Expand Down Expand Up @@ -179,12 +229,6 @@ namespace mamba::util
sort_and_remove_duplicates();
}

template <typename K, typename C, typename A>
auto flat_set<K, C, A>::contains(const value_type& value) const -> bool
{
return std::binary_search(begin(), end(), value);
}

template <typename K, typename C, typename A>
auto flat_set<K, C, A>::front() const noexcept -> const value_type&
{
Expand Down Expand Up @@ -306,6 +350,51 @@ namespace mamba::util
return 1;
}

template <typename K, typename C, typename A>
auto flat_set<K, C, A>::contains(const value_type& value) const -> bool
{
return std::binary_search(begin(), end(), value);
}

namespace detail
{
/**
* Check if two sorted range have an empty intersection.
*
* Edited from https://en.cppreference.com/w/cpp/algorithm/set_intersection
* Distributed under the terms of the Copyright/CC-BY-SA License.
* The full license can be found at the address
* https://en.cppreference.com/w/Cppreference:Copyright/CC-BY-SA
*/
template <class InputIt1, class InputIt2, class Compare>
bool
set_disjoint(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2, Compare comp)
{
while (first1 != last1 && first2 != last2)
{
if (comp(*first1, *first2))
{
++first1;
}
else
{
if (!comp(*first2, *first1))
{
return false; // *first1 and *first2 are equivalent.
}
++first2;
}
}
return true;
}
}

template <typename K, typename C, typename A>
auto flat_set<K, C, A>::is_disjoint_of(const Self& other) const -> bool
{
return detail::set_disjoint(cbegin(), cend(), other.cbegin(), other.cend(), m_compare);
}

template <typename K, typename C, typename A>
bool operator==(const flat_set<K, C, A>& lhs, const flat_set<K, C, A>& rhs)
{
Expand All @@ -319,5 +408,135 @@ namespace mamba::util
return !(lhs == rhs);
}

template <typename K, typename C, typename A>
auto flat_set<K, C, A>::is_subset_of(const Self& other) const -> bool
{
return std::includes(other.cbegin(), other.cend(), cbegin(), cend(), m_compare);
}

template <typename Key, typename Compare, typename Allocator>
bool
operator<=(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs)
{
return lhs.is_subset_of(rhs);
}

template <typename Key, typename Compare, typename Allocator>
bool
operator<(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs)
{
return (lhs.size() < rhs.size()) && (lhs <= rhs);
}

template <typename K, typename C, typename A>
auto flat_set<K, C, A>::is_superset_of(const Self& other) const -> bool
{
return other.is_subset_of(*this);
}

template <typename Key, typename Compare, typename Allocator>
bool
operator>=(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs)
{
return lhs.is_superset_of(rhs);
}

template <typename Key, typename Compare, typename Allocator>
bool
operator>(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs)
{
return rhs < lhs;
}

template <typename Key, typename Compare, typename Allocator>
auto flat_set<Key, Compare, Allocator>::union_(const Self& lhs, const Self& rhs) -> Self
{
auto out = flat_set<Key, Compare, Allocator>();
out.reserve(std::max(lhs.size(), rhs.size())); // lower bound
std::set_union(
lhs.cbegin(),
lhs.cend(),
rhs.cbegin(),
rhs.cend(),
std::back_inserter(static_cast<Base&>(out)),
lhs.m_compare
);
return out;
}

template <typename Key, typename Compare, typename Allocator>
flat_set<Key, Compare, Allocator>
operator|(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs)
{
return flat_set<Key, Compare, Allocator>::union_(lhs, rhs);
}

template <typename Key, typename Compare, typename Allocator>
auto flat_set<Key, Compare, Allocator>::intersection(const Self& lhs, const Self& rhs) -> Self
{
auto out = flat_set<Key, Compare, Allocator>();
std::set_intersection(
lhs.cbegin(),
lhs.cend(),
rhs.cbegin(),
rhs.cend(),
std::back_inserter(static_cast<Base&>(out)),
lhs.m_compare
);
return out;
}

template <typename Key, typename Compare, typename Allocator>
flat_set<Key, Compare, Allocator>
operator&(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs)
{
return flat_set<Key, Compare, Allocator>::intersection(lhs, rhs);
}

template <typename Key, typename Compare, typename Allocator>
auto flat_set<Key, Compare, Allocator>::difference(const Self& lhs, const Self& rhs) -> Self
{
auto out = flat_set<Key, Compare, Allocator>();
std::set_difference(
lhs.cbegin(),
lhs.cend(),
rhs.cbegin(),
rhs.cend(),
std::back_inserter(static_cast<Base&>(out)),
lhs.m_compare
);
return out;
}

template <typename Key, typename Compare, typename Allocator>
flat_set<Key, Compare, Allocator>
operator-(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs)
{
return flat_set<Key, Compare, Allocator>::difference(lhs, rhs);
}

template <typename Key, typename Compare, typename Allocator>
auto flat_set<Key, Compare, Allocator>::symetric_difference(const Self& lhs, const Self& rhs)
-> Self
{
auto out = flat_set<Key, Compare, Allocator>();
std::set_symmetric_difference(
lhs.cbegin(),
lhs.cend(),
rhs.cbegin(),
rhs.cend(),
std::back_inserter(static_cast<Base&>(out)),
lhs.m_compare
);
return out;
}

template <typename Key, typename Compare, typename Allocator>
flat_set<Key, Compare, Allocator>
operator^(const flat_set<Key, Compare, Allocator>& lhs, const flat_set<Key, Compare, Allocator>& rhs)
{
return flat_set<Key, Compare, Allocator>::symetric_difference(lhs, rhs);
}

}
#endif
2 changes: 1 addition & 1 deletion libmamba/tests/src/doctest-printer/flat_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace doctest
{
static auto convert(const mamba::util::flat_set<K, C, A>& value) -> String
{
return { fmt::format("std::flat_set{{{}}}", fmt::join(value, ", ")).c_str() };
return { fmt::format("mamba::util::flat_set{{{}}}", fmt::join(value, ", ")).c_str() };
}
};
}
102 changes: 102 additions & 0 deletions libmamba/tests/src/util/test_flat_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,106 @@ TEST_SUITE("util::flat_set")
s.insert(6);
CHECK_EQ(s.front(), 6);
}

TEST_CASE("Set operations")
{
const auto s1 = flat_set<int>({ 1, 3, 4, 5 });
const auto s2 = flat_set<int>({ 3, 5 });
const auto s3 = flat_set<int>({ 4, 6 });

SUBCASE("Disjoint")
{
CHECK(s1.is_disjoint_of(flat_set<int>{}));
CHECK_FALSE(s1.is_disjoint_of(s1));
CHECK_FALSE(s1.is_disjoint_of(s2));
CHECK_FALSE(s1.is_disjoint_of(s3));
CHECK(s2.is_disjoint_of(s3));
CHECK(s3.is_disjoint_of(s2));
}

SUBCASE("Subset")
{
CHECK_LE(s1, s1);
CHECK_FALSE(s1 < s1);
CHECK_LE(flat_set<int>{}, s1);
CHECK_LT(flat_set<int>{}, s1);
CHECK_FALSE(s1 <= s2);
CHECK_FALSE(s1 <= flat_set<int>{});
CHECK_LE(flat_set<int>{ 1, 4 }, s1);
CHECK_LT(flat_set<int>{ 1, 4 }, s1);
CHECK_LE(s2, s1);
CHECK_LT(s2, s1);
}

SUBCASE("Superset")
{
CHECK_GE(s1, s1);
CHECK_FALSE(s1 > s1);
CHECK_GE(s1, flat_set<int>{});
CHECK_GT(s1, flat_set<int>{});
CHECK_FALSE(s2 >= s1);
CHECK_FALSE(flat_set<int>{} >= s1);
CHECK_GE(s1, flat_set<int>{ 1, 4 });
CHECK_GT(s1, flat_set<int>{ 1, 4 });
CHECK_GE(s1, s2);
CHECK_GT(s1, s2);
}

SUBCASE("Union")
{
CHECK_EQ(s1 | s1, s1);
CHECK_EQ(s1 | s2, s1);
CHECK_EQ(s2 | s1, s1 | s2);
CHECK_EQ(s1 | s3, flat_set<int>{ 1, 3, 4, 5, 6 });
CHECK_EQ(s3 | s1, s1 | s3);
CHECK_EQ(s2 | s3, flat_set<int>{ 3, 4, 5, 6 });
CHECK_EQ(s3 | s2, s2 | s3);
}

SUBCASE("Intersection")
{
CHECK_EQ(s1 & s1, s1);
CHECK_EQ(s1 & s2, s2);
CHECK_EQ(s2 & s1, s1 & s2);
CHECK_EQ(s1 & s3, flat_set<int>{ 4 });
CHECK_EQ(s3 & s1, s1 & s3);
CHECK_EQ(s2 & s3, flat_set<int>{});
CHECK_EQ(s3 & s2, s2 & s3);
}

SUBCASE("Difference")
{
CHECK_EQ(s1 - s1, flat_set<int>{});
CHECK_EQ(s1 - s2, flat_set<int>{ 1, 4 });
CHECK_EQ(s2 - s1, flat_set<int>{});
CHECK_EQ(s1 - s3, flat_set<int>{ 1, 3, 5 });
CHECK_EQ(s3 - s1, flat_set<int>{ 6 });
CHECK_EQ(s2 - s3, s2);
CHECK_EQ(s3 - s2, s3);
}

SUBCASE("Symetric difference")
{
CHECK_EQ(s1 ^ s1, flat_set<int>{});
CHECK_EQ(s1 ^ s2, flat_set<int>{ 1, 4 });
CHECK_EQ(s2 ^ s1, s1 ^ s2);
CHECK_EQ(s1 ^ s3, flat_set<int>{ 1, 3, 5, 6 });
CHECK_EQ(s3 ^ s1, s1 ^ s3);
CHECK_EQ(s2 ^ s3, flat_set<int>{ 3, 4, 5, 6 });
CHECK_EQ(s3 ^ s2, s2 ^ s3);
}

SUBCASE("Algebra")
{
for (const auto& u : { s1, s2, s3 })
{
for (const auto& v : { s1, s2, s3 })
{
CHECK_EQ((u - v) | (v - u) | (u & v), u | v);
CHECK_EQ((u ^ v) | (u & v), u | v);
CHECK_EQ((u | v) - (u & v), u ^ v);
}
}
}
}
}