diff --git a/libmamba/include/mamba/util/flat_set.hpp b/libmamba/include/mamba/util/flat_set.hpp index c899e82548..06dfd48c08 100644 --- a/libmamba/include/mamba/util/flat_set.hpp +++ b/libmamba/include/mamba/util/flat_set.hpp @@ -15,6 +15,14 @@ namespace mamba::util { + + struct sorted_unique_t + { + explicit sorted_unique_t() = default; + }; + + inline constexpr sorted_unique_t sorted_unique{}; + /** * A sorted vector behaving like a set. * @@ -61,49 +69,73 @@ namespace mamba::util key_compare compare = key_compare(), const allocator_type& alloc = Allocator() ); + template + flat_set( + sorted_unique_t, + InputIterator first, + InputIterator last, + key_compare compare = key_compare(), + const allocator_type& alloc = Allocator() + ); flat_set(const flat_set&) = default; flat_set(flat_set&&) = default; explicit flat_set(std::vector&& other, key_compare compare = key_compare()); explicit flat_set(const std::vector& other, key_compare compare = key_compare()); - flat_set& operator=(const flat_set&) = default; - flat_set& operator=(flat_set&&) = default; + auto operator=(const flat_set&) -> flat_set& = default; + auto operator=(flat_set&&) -> flat_set& = default; + + auto key_comp() const -> const key_compare&; - bool contains(const value_type&) const; - const value_type& front() const noexcept; - const value_type& back() const noexcept; - const value_type& operator[](size_type pos) const; - const value_type& at(size_type pos) const; + auto front() const noexcept -> const value_type&; + auto back() const noexcept -> const value_type&; + auto operator[](size_type pos) const -> const value_type&; + auto at(size_type pos) const -> const value_type&; - const_iterator begin() const noexcept; - const_iterator end() const noexcept; - const_reverse_iterator rbegin() const noexcept; - const_reverse_iterator rend() const noexcept; + auto begin() const noexcept -> const_iterator; + auto end() const noexcept -> const_iterator; + auto rbegin() const noexcept -> const_reverse_iterator; + auto rend() const noexcept -> const_reverse_iterator; /** Insert an element in the set. * * Like std::vector and unlike std::set, inserting an element invalidates iterators. */ - std::pair insert(value_type&& value); - std::pair insert(const value_type& value); + auto insert(value_type&& value) -> std::pair; + auto insert(const value_type& value) -> std::pair; template void insert(InputIterator first, InputIterator last); - const_iterator erase(const_iterator pos); - const_iterator erase(const_iterator first, const_iterator last); - size_type erase(const value_type& value); + auto erase(const_iterator pos) -> const_iterator; + auto erase(const_iterator first, const_iterator last) -> const_iterator; + auto erase(const value_type& value) -> size_type; + + auto contains(const value_type&) const -> bool; private: key_compare m_compare; - bool key_eq(const value_type& a, const value_type& b) const; + auto key_eq(const value_type& a, const value_type& b) const -> bool; template - std::pair insert_impl(U&& value); + auto insert_impl(U&& value) -> std::pair; void sort_and_remove_duplicates(); template - friend bool operator==(const flat_set& lhs, const flat_set& rhs); + friend auto operator==(const flat_set& lhs, const flat_set& rhs) -> bool; + + template + friend auto set_union(const flat_set&, const flat_set&) + -> flat_set; + template + friend auto set_intersection(const flat_set&, const flat_set&) + -> flat_set; + template + friend auto set_difference(const flat_set&, const flat_set&) + -> flat_set; + template + friend auto set_symetric_difference(const flat_set&, const flat_set&) + -> flat_set; }; template , class Allocator = std::allocator> @@ -126,12 +158,68 @@ namespace mamba::util -> flat_set; template - bool - operator==(const flat_set& lhs, const flat_set& rhs); + auto + operator==(const flat_set& lhs, const flat_set& rhs) + -> bool; + + template + auto + operator!=(const flat_set& lhs, const flat_set& rhs) + -> bool; + + template + auto set_is_disjoint_of( + const flat_set& lhs, + const flat_set& rhs + ) -> bool; + + template + auto is_subset_of( + const flat_set& lhs, + const flat_set& rhs + ) -> bool; + + template + auto is_strict_subset_of( + const flat_set& lhs, + const flat_set& rhs + ) -> bool; + + template + auto is_superset_of( + const flat_set& lhs, + const flat_set& rhs + ) -> bool; + + template + auto is_strict_superset_of( + const flat_set& lhs, + const flat_set& rhs + ) -> bool; + + template + auto set_union( // + const flat_set& lhs, + const flat_set& rhs + ) -> flat_set; template - bool - operator!=(const flat_set& lhs, const flat_set& rhs); + auto set_intersection( + const flat_set& lhs, + const flat_set& rhs + ) -> flat_set; + + template + auto set_difference( + const flat_set& lhs, + const flat_set& rhs + ) -> flat_set; + + template + auto set_symetric_difference( + const flat_set& lhs, + const flat_set& rhs + ) -> flat_set; /******************************* * vector_set Implementation * @@ -163,6 +251,20 @@ namespace mamba::util sort_and_remove_duplicates(); } + template + template + flat_set::flat_set( + sorted_unique_t, + InputIterator first, + InputIterator last, + key_compare compare, + const allocator_type& alloc + ) + : Base(first, last, alloc) + , m_compare(std::move(compare)) + { + } + template flat_set::flat_set(std::vector&& other, C compare) : Base(std::move(other)) @@ -180,9 +282,9 @@ namespace mamba::util } template - auto flat_set::contains(const value_type& value) const -> bool + auto flat_set::key_comp() const -> const key_compare& { - return std::binary_search(begin(), end(), value); + return m_compare; } template @@ -246,7 +348,7 @@ namespace mamba::util } template - bool flat_set::key_eq(const value_type& a, const value_type& b) const + auto flat_set::key_eq(const value_type& a, const value_type& b) const -> bool { return !m_compare(a, b) && !m_compare(b, a); } @@ -307,17 +409,151 @@ namespace mamba::util } template - bool operator==(const flat_set& lhs, const flat_set& rhs) + auto flat_set::contains(const value_type& value) const -> bool + { + return std::binary_search(begin(), end(), value); + } + + namespace detail + { + /** + * Check if two sorted range have an empty intersection. + * + * Edited from https://en.cppreference.com/w/cpp/algorithm/set_intersection + * Distributed under the terms of the Copyright/CC-BY-SA License. + * The full license can be found at the address + * https://en.cppreference.com/w/Cppreference:Copyright/CC-BY-SA + */ + template + auto + set_disjoint(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2, Compare comp) + -> bool + { + while (first1 != last1 && first2 != last2) + { + if (comp(*first1, *first2)) + { + ++first1; + } + else + { + if (!comp(*first2, *first1)) + { + return false; // *first1 and *first2 are equivalent. + } + ++first2; + } + } + return true; + } + } + + template + auto operator==(const flat_set& lhs, const flat_set& rhs) -> bool { auto is_eq = [&lhs](const auto& a, const auto& b) { return lhs.key_eq(a, b); }; return std::equal(lhs.cbegin(), lhs.cend(), rhs.cbegin(), rhs.cend(), is_eq); } template - bool operator!=(const flat_set& lhs, const flat_set& rhs) + auto operator!=(const flat_set& lhs, const flat_set& rhs) -> bool { return !(lhs == rhs); } + template + auto set_is_disjoint_of(const flat_set& lhs, const flat_set& rhs) -> bool + { + return detail::set_disjoint(lhs.cbegin(), lhs.cend(), rhs.cbegin(), rhs.cend(), lhs.key_comp()); + } + + template + auto set_is_subset_of(const flat_set& lhs, const flat_set& rhs) -> bool + { + return (lhs.size() <= rhs.size()) // For perf + && std::includes(rhs.cbegin(), rhs.cend(), lhs.cbegin(), lhs.cend(), lhs.key_comp()); + } + + template + auto set_is_strict_subset_of(const flat_set& lhs, const flat_set& rhs) -> bool + { + return (lhs.size() < rhs.size()) && set_is_subset_of(lhs, rhs); + } + + template + auto set_is_superset_of(const flat_set& lhs, const flat_set& rhs) -> bool + { + return set_is_subset_of(rhs, lhs); + } + + template + auto set_is_strict_superset_of(const flat_set& lhs, const flat_set& rhs) -> bool + { + return set_is_strict_subset_of(rhs, lhs); + } + + template + auto set_union(const flat_set& lhs, const flat_set& rhs) -> flat_set + { + auto out = flat_set(); + out.reserve(std::max(lhs.size(), rhs.size())); // lower bound + std::set_union( + lhs.cbegin(), + lhs.cend(), + rhs.cbegin(), + rhs.cend(), + std::back_inserter(static_cast::Base&>(out)), + lhs.m_compare + ); + return out; + } + + template + auto set_intersection(const flat_set& lhs, const flat_set& rhs) + -> flat_set + { + auto out = flat_set(); + std::set_intersection( + lhs.cbegin(), + lhs.cend(), + rhs.cbegin(), + rhs.cend(), + std::back_inserter(static_cast::Base&>(out)), + lhs.m_compare + ); + return out; + } + + template + auto set_difference(const flat_set& lhs, const flat_set& rhs) + -> flat_set + { + auto out = flat_set(); + std::set_difference( + lhs.cbegin(), + lhs.cend(), + rhs.cbegin(), + rhs.cend(), + std::back_inserter(static_cast::Base&>(out)), + lhs.m_compare + ); + return out; + } + + template + auto set_symetric_difference(const flat_set& lhs, const flat_set& rhs) + -> flat_set + { + auto out = flat_set(); + std::set_symmetric_difference( + lhs.cbegin(), + lhs.cend(), + rhs.cbegin(), + rhs.cend(), + std::back_inserter(static_cast::Base&>(out)), + lhs.m_compare + ); + return out; + } } #endif diff --git a/libmamba/tests/src/doctest-printer/flat_set.hpp b/libmamba/tests/src/doctest-printer/flat_set.hpp index 696c37f677..4d028df987 100644 --- a/libmamba/tests/src/doctest-printer/flat_set.hpp +++ b/libmamba/tests/src/doctest-printer/flat_set.hpp @@ -16,7 +16,7 @@ namespace doctest { static auto convert(const mamba::util::flat_set& value) -> String { - return { fmt::format("std::flat_set{{{}}}", fmt::join(value, ", ")).c_str() }; + return { fmt::format("mamba::util::flat_set{{{}}}", fmt::join(value, ", ")).c_str() }; } }; } diff --git a/libmamba/tests/src/util/test_flat_set.cpp b/libmamba/tests/src/util/test_flat_set.cpp index 3033010234..32725c26de 100644 --- a/libmamba/tests/src/util/test_flat_set.cpp +++ b/libmamba/tests/src/util/test_flat_set.cpp @@ -107,4 +107,118 @@ TEST_SUITE("util::flat_set") s.insert(6); CHECK_EQ(s.front(), 6); } + + TEST_CASE("Set operations") + { + const auto s1 = flat_set({ 1, 3, 4, 5 }); + const auto s2 = flat_set({ 3, 5 }); + const auto s3 = flat_set({ 4, 6 }); + + SUBCASE("Disjoint") + { + CHECK(set_is_disjoint_of(s1, flat_set{})); + CHECK_FALSE(set_is_disjoint_of(s1, s1)); + CHECK_FALSE(set_is_disjoint_of(s1, s2)); + CHECK_FALSE(set_is_disjoint_of(s1, s3)); + CHECK(set_is_disjoint_of(s2, s3)); + CHECK(set_is_disjoint_of(s3, s2)); + } + + SUBCASE("Subset") + { + CHECK(set_is_subset_of(s1, s1)); + CHECK_FALSE(set_is_strict_subset_of(s1, s1)); + CHECK(set_is_subset_of(flat_set{}, s1)); + CHECK(set_is_strict_subset_of(flat_set{}, s1)); + CHECK_FALSE(set_is_subset_of(s1, s2)); + CHECK_FALSE(set_is_subset_of(s1, flat_set{})); + CHECK(set_is_subset_of(flat_set{ 1, 4 }, s1)); + CHECK(set_is_strict_subset_of(flat_set{ 1, 4 }, s1)); + CHECK(set_is_subset_of(s2, s1)); + CHECK(set_is_strict_subset_of(s2, s1)); + } + + SUBCASE("Superset") + { + CHECK(set_is_superset_of(s1, s1)); + CHECK_FALSE(set_is_strict_superset_of(s1, s1)); + CHECK(set_is_superset_of(s1, flat_set{})); + CHECK(set_is_strict_superset_of(s1, flat_set{})); + CHECK_FALSE(set_is_superset_of(s2, s1)); + CHECK_FALSE(set_is_superset_of(flat_set{}, s1)); + CHECK(set_is_superset_of(s1, flat_set{ 1, 4 })); + CHECK(set_is_strict_superset_of(s1, flat_set{ 1, 4 })); + CHECK(set_is_superset_of(s1, s2)); + CHECK(set_is_strict_superset_of(s1, s2)); + } + + SUBCASE("Union") + { + CHECK_EQ(set_union(s1, s1), s1); + CHECK_EQ(set_union(s1, s2), s1); + CHECK_EQ(set_union(s2, s1), set_union(s1, s2)); + CHECK_EQ(set_union(s1, s3), flat_set{ 1, 3, 4, 5, 6 }); + CHECK_EQ(set_union(s3, s1), set_union(s1, s3)); + CHECK_EQ(set_union(s2, s3), flat_set{ 3, 4, 5, 6 }); + CHECK_EQ(set_union(s3, s2), set_union(s2, s3)); + } + + SUBCASE("Intersection") + { + CHECK_EQ(set_intersection(s1, s1), s1); + CHECK_EQ(set_intersection(s1, s2), s2); + CHECK_EQ(set_intersection(s2, s1), set_intersection(s1, s2)); + CHECK_EQ(set_intersection(s1, s3), flat_set{ 4 }); + CHECK_EQ(set_intersection(s3, s1), set_intersection(s1, s3)); + CHECK_EQ(set_intersection(s2, s3), flat_set{}); + CHECK_EQ(set_intersection(s3, s2), set_intersection(s2, s3)); + } + + SUBCASE("Difference") + { + CHECK_EQ(set_difference(s1, s1), flat_set{}); + CHECK_EQ(set_difference(s1, s2), flat_set{ 1, 4 }); + CHECK_EQ(set_difference(s2, s1), flat_set{}); + CHECK_EQ(set_difference(s1, s3), flat_set{ 1, 3, 5 }); + CHECK_EQ(set_difference(s3, s1), flat_set{ 6 }); + CHECK_EQ(set_difference(s2, s3), s2); + CHECK_EQ(set_difference(s3, s2), s3); + } + + SUBCASE("Symetric difference") + { + CHECK_EQ(set_symetric_difference(s1, s1), flat_set{}); + CHECK_EQ(set_symetric_difference(s1, s2), flat_set{ 1, 4 }); + CHECK_EQ(set_symetric_difference(s2, s1), set_symetric_difference(s1, s2)); + CHECK_EQ(set_symetric_difference(s1, s3), flat_set{ 1, 3, 5, 6 }); + CHECK_EQ(set_symetric_difference(s3, s1), set_symetric_difference(s1, s3)); + CHECK_EQ(set_symetric_difference(s2, s3), flat_set{ 3, 4, 5, 6 }); + CHECK_EQ(set_symetric_difference(s3, s2), set_symetric_difference(s2, s3)); + } + + SUBCASE("Algebra") + { + for (const auto& u : { s1, s2, s3 }) + { + for (const auto& v : { s1, s2, s3 }) + { + CHECK_EQ( + set_union( + set_difference(u, v), + set_union(set_difference(v, u), set_intersection(u, v)) + ), + set_union(u, v) + ); + CHECK_EQ( + set_union(set_symetric_difference(u, v), set_intersection(u, v)), + set_union(u, v) + ); + CHECK_EQ( + set_difference(set_union(u, v), set_intersection(u, v)), + set_symetric_difference(u, v) + ); + } + } + } + } }