Adding RemoveAll method
diff --git a/set.go b/set.go
index 7ca6ca3..0bb8d7c 100644
--- a/set.go
+++ b/set.go
@@ -141,6 +141,9 @@
// Remove a single element from the set.
Remove(i interface{})
+ // RemoveAll removes multiple elements from the set.
+ RemoveAll(i ...interface{})
+
// Provides a convenient string representation
// of the current state of the set.
String() string
diff --git a/set_test.go b/set_test.go
index 6cdf583..ae242a6 100644
--- a/set_test.go
+++ b/set_test.go
@@ -131,6 +131,26 @@
}
}
+func Test_RemoveAllSet(t *testing.T) {
+ a := makeSet([]int{6, 3, 1, 8, 9})
+
+ a.RemoveAll(3, 1)
+
+ if a.Cardinality() != 3 {
+ t.Error("RemoveAll should only have 2 items in the set")
+ }
+
+ if !a.Contains(6, 8, 9) {
+ t.Error("RemoveAll should have only items (6,8,9) in the set")
+ }
+
+ a.RemoveAll(6, 8, 9)
+
+ if a.Cardinality() != 0 {
+ t.Error("RemoveSet should be an empty set after removing 6 and 1")
+ }
+}
+
func Test_RemoveUnsafeSet(t *testing.T) {
a := makeUnsafeSet([]int{6, 3, 1})
@@ -152,6 +172,26 @@
}
}
+func Test_RemoveAllUnsafeSet(t *testing.T) {
+ a := makeUnsafeSet([]int{6, 3, 1, 8, 9})
+
+ a.RemoveAll(3, 1)
+
+ if a.Cardinality() != 3 {
+ t.Error("RemoveAll should only have 2 items in the set")
+ }
+
+ if !a.Contains(6, 8, 9) {
+ t.Error("RemoveAll should have only items (6,8,9) in the set")
+ }
+
+ a.RemoveAll(6, 8, 9)
+
+ if a.Cardinality() != 0 {
+ t.Error("RemoveSet should be an empty set after removing 6 and 1")
+ }
+}
+
func Test_ContainsSet(t *testing.T) {
a := NewSet()
diff --git a/threadsafe.go b/threadsafe.go
index 8ef31f6..04bf181 100644
--- a/threadsafe.go
+++ b/threadsafe.go
@@ -149,6 +149,12 @@
set.Unlock()
}
+func (set *threadSafeSet) RemoveAll(i ...interface{}) {
+ set.Lock()
+ set.s.RemoveAll(i...)
+ set.Unlock()
+}
+
func (set *threadSafeSet) Cardinality() int {
set.RLock()
defer set.RUnlock()
diff --git a/threadunsafe.go b/threadunsafe.go
index 4d4e69f..2cd1f98 100644
--- a/threadunsafe.go
+++ b/threadunsafe.go
@@ -163,6 +163,12 @@
delete(*set, i)
}
+func (set *threadUnsafeSet) RemoveAll(i ...interface{}) {
+ for _, elem := range i {
+ delete(*set, elem)
+ }
+}
+
func (set *threadUnsafeSet) Cardinality() int {
return len(*set)
}