From 9fa255d8c3ddd87d72b502d0c69d77329cb44884 Mon Sep 17 00:00:00 2001 From: thinkgos Date: Sat, 27 Apr 2024 11:06:01 +0800 Subject: [PATCH] feat: add sets --- sets/README.md | 3 + sets/set.go | 204 +++++++++++++++++++++++++++ sets/set_test.go | 358 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 565 insertions(+) create mode 100644 sets/README.md create mode 100644 sets/set.go create mode 100644 sets/set_test.go diff --git a/sets/README.md b/sets/README.md new file mode 100644 index 0000000..d7c1ca4 --- /dev/null +++ b/sets/README.md @@ -0,0 +1,3 @@ +# sets + +support`interface{}`/`any` which implement Comparator interface and builtin type. diff --git a/sets/set.go b/sets/set.go new file mode 100644 index 0000000..5a07f6d --- /dev/null +++ b/sets/set.go @@ -0,0 +1,204 @@ +package sets + +// Set sets.Set is a set of interface, +// implemented via map[T]struct{} for minimal memory consumption. +type Set[T comparable] map[T]struct{} + +// New creates a T from a list of values. +func New[T comparable](items ...T) Set[T] { + ss := Set[T]{} + return ss.Insert(items...) +} + +// NewFrom creates a T from a keys of a map[T](? extends any). +// If the value passed in is not actually a map, this will panic. +func NewFrom[T comparable, V any, M ~map[T]V](m M) Set[T] { + ret := Set[T]{} + for k := range m { + ret[k] = struct{}{} + } + return ret +} + +// Insert adds items to the set. +func (s Set[T]) Insert(items ...T) Set[T] { + for _, item := range items { + s[item] = struct{}{} + } + return s +} + +// Delete removes all items from the set. +func (s Set[T]) Delete(items ...T) Set[T] { + for _, item := range items { + delete(s, item) + } + return s +} + +// Contains returns true if and only if item is contained in the set. +func (s Set[T]) Contains(item T) bool { + _, contained := s[item] + return contained +} + +// ContainsAll returns true if and only if all items are contained in the set. +func (s Set[T]) ContainsAll(items ...T) bool { + for _, item := range items { + if !s.Contains(item) { + return false + } + } + return true +} + +// ContainsAny returns true if any items are contained in the set. +func (s Set[T]) ContainsAny(items ...T) bool { + for _, item := range items { + if s.Contains(item) { + return true + } + } + return false +} + +// Difference returns a set of objects that are not in s2 +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.Difference(s2) = {a3} +// s2.Difference(s1) = {a4, a5}. +func (s Set[T]) Difference(s2 Set[T]) Set[T] { + result := New[T]() + for key := range s { + if !s2.Contains(key) { + result[key] = struct{}{} + } + } + return result +} + +// Union returns a new set which includes items in either s1 or s2. +// For example: +// s1 = {a1, a2} +// s2 = {a3, a4} +// s1.Union(s2) = {a1, a2, a3, a4} +// s2.Union(s1) = {a1, a2, a3, a4}. +func (s Set[T]) Union(s2 Set[T]) Set[T] { + result := New[T]() + for key := range s { + result[key] = struct{}{} + } + for key := range s2 { + result[key] = struct{}{} + } + return result +} + +// Intersection returns a new set which includes the item in BOTH s1 and s2 +// For example: +// s1 = {a1, a2} +// s2 = {a2, a3} +// s1.Intersection(s2) = {a2}. +func (s Set[T]) Intersection(s2 Set[T]) Set[T] { + var walk, other Set[T] + result := New[T]() + if s.Len() < s2.Len() { + walk = s + other = s2 + } else { + walk = s2 + other = s + } + for key := range walk { + if other.Contains(key) { + result[key] = struct{}{} + } + } + return result +} + +// Merge is like Union, however it modifies the current set it's applied on +// with the given s2 set. +// For example: +// s1 = {a1, a2} +// s2 = {a3, a4} +// s1.Merge(s2), s1 = {a1, a2, a3, a4} +// s2.Merge(s1), s2 = {a1, a2, a3, a4}. +func (s Set[T]) Merge(s2 Set[T]) Set[T] { + for item := range s2 { + s[item] = struct{}{} + } + return s +} + +// IsSuperset returns true if and only if s1 is a superset of s2. +func (s Set[T]) IsSuperset(s2 Set[T]) bool { + for item := range s2 { + if !s.Contains(item) { + return false + } + } + return true +} + +// IsSubset returns true if and only if s1 is a superset of s2. +func (s Set[T]) IsSubset(s2 Set[T]) bool { + for item := range s { + if !s2.Contains(item) { + return false + } + } + return true +} + +// List returns the contents as a sorted slice. +func (s Set[T]) List() []T { + res := make([]T, 0, len(s)) + for key := range s { + res = append(res, key) + } + return res +} + +// Equal returns true if and only if s1 is equal (as a set) to s2. +// Two sets are equal if their membership is identical. +// (In practice, this means same elements, order doesn't matter). +func (s Set[T]) Equal(s2 Set[T]) bool { + return len(s) == len(s2) && s.IsSuperset(s2) +} + +// Pop Returns a single element from the set. +func (s Set[T]) Pop() (v T, ok bool) { + for key := range s { + delete(s, key) + return key, true + } + return +} + +// Len returns the size of the set. +func (s Set[T]) Len() int { + return len(s) +} + +// Each traverses the items in the Set, calling the provided function for each +// set member. Traversal will continue until all items in the Set have been +// visited, or if the closure returns false. +func (s Set[T]) Each(f func(item T) bool) { + for item := range s { + if !f(item) { + break + } + } +} + +// Clone returns a new Set with a copy of s. +func (s Set[T]) Clone() Set[T] { + ns := New[T]() + s.Each(func(item T) bool { + ns[item] = struct{}{} + return true + }) + return ns +} diff --git a/sets/set_test.go b/sets/set_test.go new file mode 100644 index 0000000..e65e203 --- /dev/null +++ b/sets/set_test.go @@ -0,0 +1,358 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sets + +import ( + "testing" +) + +func TestSet(t *testing.T) { + s := Set[int]{} + s2 := Set[int]{} + if len(s) != 0 { + t.Errorf("Expected len=0: %d", len(s)) + } + s.Insert(1, 2) + if len(s) != 2 { + t.Errorf("Expected len=2: %d", len(s)) + } + s.Insert(3) + if s.Contains(4) { + t.Errorf("Unexpected contents: %#v", s) + } + if !s.Contains(1) { + t.Errorf("Missing contents: %#v", s) + } + s.Delete(1) + if s.Contains(1) { + t.Errorf("Unexpected contents: %#v", s) + } + s.Insert(1) + if s.ContainsAll(1, 2, 4) { + t.Errorf("Unexpected contents: %#v", s) + } + if !s.ContainsAll(1, 2) { + t.Errorf("Missing contents: %#v", s) + } + s2.Insert(1, 2, 4) + if s.IsSuperset(s2) { + t.Errorf("Unexpected contents: %#v", s) + } + if s2.IsSubset(s) { + t.Errorf("Unexpected contents: %#v", s) + } + s2.Delete(4) + if !s.IsSuperset(s2) { + t.Errorf("Missing contents: %#v", s) + } + if !s2.IsSubset(s) { + t.Errorf("Missing contents: %#v", s) + } + _, ok := s2.Pop() + if !ok { + t.Errorf("Unexpected status: %#v", ok) + } + s2 = New[int]() + if s2.Len() != 0 { + t.Errorf("Expected len=0: %d", len(s2)) + } + v, ok := s2.Pop() + if ok { + t.Errorf("Unexpected status: %#v", ok) + } + if v != 0 { + t.Errorf("Unexpected value: %#v", v) + } +} + +func TestSetDeleteMultiples(t *testing.T) { + s := NewFrom(map[int]any{1: "1", 2: "2", 3: "3"}) + if len(s) != 3 { + t.Errorf("Expected len=3: %d", len(s)) + } + + s.Delete(1, 3) + if len(s) != 1 { + t.Errorf("Expected len=1: %d", len(s)) + } + if s.Contains(1) { + t.Errorf("Unexpected contents: %#v", s) + } + if s.Contains(3) { + t.Errorf("Unexpected contents: %#v", s) + } + if !s.Contains(2) { + t.Errorf("Missing contents: %#v", s) + } +} + +func TestNewSet(t *testing.T) { + s := New(1, 2, 3) + if len(s) != 3 { + t.Errorf("Expected len=3: %d", len(s)) + } + if !s.Contains(1) || !s.Contains(2) || !s.Contains(3) { + t.Errorf("Unexpected contents: %#v", s) + } +} + +func TestSetList(t *testing.T) { + s := New(13, 12, 1, 11) + v := s.List() + + for _, vv := range v { + if !s.Contains(vv) { + t.Errorf("List gave unexpected result: %#v", v) + } + } +} + +func TestSetDifference(t *testing.T) { + a := New(1, 2, 3) + b := New(1, 2, 4, 5) + c := a.Difference(b) + d := b.Difference(a) + if len(c) != 1 { + t.Errorf("Expected len=1: %d", len(c)) + } + if !c.Contains(3) { + t.Errorf("Unexpected contents: %#v", c.List()) + } + if len(d) != 2 { + t.Errorf("Expected len=2: %d", len(d)) + } + if !d.Contains(4) || !d.Contains(5) { + t.Errorf("Unexpected contents: %#v", d.List()) + } +} + +func TestSetHasAny(t *testing.T) { + a := New(1, 2, 3) + + if !a.ContainsAny(1, 4) { + t.Errorf("expected true, got false") + } + + if a.ContainsAny(10, 4) { + t.Errorf("expected false, got true") + } +} + +func TestSetEquals(t *testing.T) { + // Simple case (order doesn't matter) + a := New[int](1, 2) + b := New[int](2, 1) + if !a.Equal(b) { + t.Errorf("Expected to be equal: %v vs %v", a, b) + } + + // It is a set; duplicates are ignored + b = New[int](2, 2, 1) + if !a.Equal(b) { + t.Errorf("Expected to be equal: %v vs %v", a, b) + } + + // Edge cases around empty sets / empty strings + a = New[int]() + b = New[int]() + if !a.Equal(b) { + t.Errorf("Expected to be equal: %v vs %v", a, b) + } + + b = New(1, 2, 3) + if a.Equal(b) { + t.Errorf("Expected to be not-equal: %v vs %v", a, b) + } + + b = New(1, 2, 0) + if a.Equal(b) { + t.Errorf("Expected to be not-equal: %v vs %v", a, b) + } + + // Check for equality after mutation + a = New[int]() + a.Insert(1) + if a.Equal(b) { + t.Errorf("Expected to be not-equal: %v vs %v", a, b) + } + + a.Insert(2) + if a.Equal(b) { + t.Errorf("Expected to be not-equal: %v vs %v", a, b) + } + + a.Insert(0) + if !a.Equal(b) { + t.Errorf("Expected to be equal: %v vs %v", a, b) + } + + a.Delete(0) + if a.Equal(b) { + t.Errorf("Expected to be not-equal: %v vs %v", a, b) + } +} + +func TestUnion(t *testing.T) { + tests := []struct { + s1 Set[int] + s2 Set[int] + expected Set[int] + }{ + { + New[int](1, 2, 3, 4), + New[int](3, 4, 5, 6), + New[int](1, 2, 3, 4, 5, 6), + }, + { + New[int](1, 2, 3, 4), + New[int](), + New[int](1, 2, 3, 4), + }, + { + New[int](), + New[int](1, 2, 3, 4), + New[int](1, 2, 3, 4), + }, + { + New[int](), + New[int](), + New[int](), + }, + } + + for _, test := range tests { + union := test.s1.Union(test.s2) + if union.Len() != test.expected.Len() { + t.Errorf("Expected union.Len()=%d but got %d", test.expected.Len(), union.Len()) + } + + if !union.Equal(test.expected) { + t.Errorf("Expected union.Equal(expected) but not true. union:%v expected:%v", union.List(), test.expected.List()) + } + } +} + +func TestIntersection(t *testing.T) { + tests := []struct { + s1 Set[int] + s2 Set[int] + expected Set[int] + }{ + { + New[int](1, 2, 3, 4), + New[int](3, 4, 5, 6), + New[int](3, 4), + }, + { + New[int](1, 2, 3, 4), + New[int](1, 2, 3, 4), + New[int](1, 2, 3, 4), + }, + { + New[int](1, 2, 3, 4), + New[int](), + New[int](), + }, + { + New[int](), + New[int](1, 2, 3, 4), + New[int](), + }, + { + New[int](), + New[int](), + New[int](), + }, + } + + for _, test := range tests { + intersection := test.s1.Intersection(test.s2) + if intersection.Len() != test.expected.Len() { + t.Errorf("Expected intersection.Len()=%d but got %d", test.expected.Len(), intersection.Len()) + } + + if !intersection.Equal(test.expected) { + t.Errorf("Expected intersection.Equal(expected) but not true. intersection:%v expected:%v", + intersection.List(), test.expected.List()) + } + } +} + +func TestMerge(t *testing.T) { + tests := []struct { + s1 Set[int] + s2 Set[int] + expected Set[int] + }{ + { + New[int](1, 2, 3, 4), + New[int](3, 4, 5, 6), + New[int](1, 2, 3, 4, 5, 6), + }, + { + New[int](1, 2, 3, 4), + New[int](1, 2, 3, 4), + New[int](1, 2, 3, 4), + }, + { + New[int](1, 2, 3, 4), + New[int](), + New[int](1, 2, 3, 4), + }, + { + New[int](), + New[int](1, 2, 3, 4), + New[int](1, 2, 3, 4), + }, + { + New[int](), + New[int](), + New[int](), + }, + } + + for _, test := range tests { + intersection := test.s1.Merge(test.s2) + if intersection.Len() != test.expected.Len() { + t.Errorf("Expected merge.Len()=%d but got %d", test.expected.Len(), intersection.Len()) + } + + if !intersection.Equal(test.expected) { + t.Errorf("Expected merge.Equal(expected) but not true. merge:%v expected:%v", + intersection.List(), test.expected.List()) + } + } +} + +func Test_Each(t *testing.T) { + expect := New(1, 2, 3, 4) + s1 := New(1, 2, 3, 4) + s1.Each(func(item int) bool { + if got := expect.Contains(item); !got { + t.Errorf("Expected Equal()=%v but got %v", true, got) + } + return item != 3 + }) +} + +func Test_Clone(t *testing.T) { + s1 := New(1, 2, 3, 4) + s2 := s1.Clone() + if got := s1.Equal(s2); !got { + t.Errorf("Expected Equal()=%v but got %v", true, got) + } +}