internal/persistent: add Set
Add a simple Set wrapper around persistent.Map, with a new test.
Follow-up CLs will replace ad-hoc sets in gopls with a persistent.Set.
Change-Id: Idd5fc5389719d3f59d658d8d9cb8fc0206e35797
Reviewed-on: https://go-review.googlesource.com/c/tools/+/524761
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
Auto-Submit: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
diff --git a/internal/persistent/set.go b/internal/persistent/set.go
new file mode 100644
index 0000000..348de5a
--- /dev/null
+++ b/internal/persistent/set.go
@@ -0,0 +1,78 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package persistent
+
+import "golang.org/x/tools/internal/constraints"
+
+// Set is a collection of elements of type K.
+//
+// It uses immutable data structures internally, so that sets can be cloned in
+// constant time.
+//
+// The zero value is a valid empty set.
+type Set[K constraints.Ordered] struct {
+ impl *Map[K, struct{}]
+}
+
+// Clone creates a copy of the receiver.
+func (s *Set[K]) Clone() *Set[K] {
+ clone := new(Set[K])
+ if s.impl != nil {
+ clone.impl = s.impl.Clone()
+ }
+ return clone
+}
+
+// Destroy destroys the set.
+//
+// After Destroy, the Set should not be used again.
+func (s *Set[K]) Destroy() {
+ if s.impl != nil {
+ s.impl.Destroy()
+ }
+}
+
+// Contains reports whether s contains the given key.
+func (s *Set[K]) Contains(key K) bool {
+ if s.impl == nil {
+ return false
+ }
+ _, ok := s.impl.Get(key)
+ return ok
+}
+
+// Range calls f sequentially in ascending key order for all entries in the set.
+func (s *Set[K]) Range(f func(key K)) {
+ if s.impl != nil {
+ s.impl.Range(func(key K, _ struct{}) {
+ f(key)
+ })
+ }
+}
+
+// AddAll adds all elements from other to the receiver set.
+func (s *Set[K]) AddAll(other *Set[K]) {
+ if other.impl != nil {
+ if s.impl == nil {
+ s.impl = new(Map[K, struct{}])
+ }
+ s.impl.SetAll(other.impl)
+ }
+}
+
+// Add adds an element to the set.
+func (s *Set[K]) Add(key K) {
+ if s.impl == nil {
+ s.impl = new(Map[K, struct{}])
+ }
+ s.impl.Set(key, struct{}{}, nil)
+}
+
+// Remove removes an element from the set.
+func (s *Set[K]) Remove(key K) {
+ if s.impl != nil {
+ s.impl.Delete(key)
+ }
+}
diff --git a/internal/persistent/set_test.go b/internal/persistent/set_test.go
new file mode 100644
index 0000000..5902514
--- /dev/null
+++ b/internal/persistent/set_test.go
@@ -0,0 +1,132 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package persistent_test
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+
+ "golang.org/x/tools/internal/constraints"
+ "golang.org/x/tools/internal/persistent"
+)
+
+func TestSet(t *testing.T) {
+ const (
+ add = iota
+ remove
+ )
+ type op struct {
+ op int
+ v int
+ }
+
+ tests := []struct {
+ label string
+ ops []op
+ want []int
+ }{
+ {"empty", nil, nil},
+ {"singleton", []op{{add, 1}}, []int{1}},
+ {"add and remove", []op{
+ {add, 1},
+ {remove, 1},
+ }, nil},
+ {"interleaved and remove", []op{
+ {add, 1},
+ {add, 2},
+ {remove, 1},
+ {add, 3},
+ }, []int{2, 3}},
+ }
+
+ for _, test := range tests {
+ t.Run(test.label, func(t *testing.T) {
+ var s persistent.Set[int]
+ for _, op := range test.ops {
+ switch op.op {
+ case add:
+ s.Add(op.v)
+ case remove:
+ s.Remove(op.v)
+ }
+ }
+
+ if d := diff(&s, test.want); d != "" {
+ t.Errorf("unexpected diff:\n%s", d)
+ }
+ })
+ }
+}
+
+func TestSet_Clone(t *testing.T) {
+ s1 := new(persistent.Set[int])
+ s1.Add(1)
+ s1.Add(2)
+ s2 := s1.Clone()
+ s1.Add(3)
+ s2.Add(4)
+ if d := diff(s1, []int{1, 2, 3}); d != "" {
+ t.Errorf("s1: unexpected diff:\n%s", d)
+ }
+ if d := diff(s2, []int{1, 2, 4}); d != "" {
+ t.Errorf("s2: unexpected diff:\n%s", d)
+ }
+}
+
+func TestSet_AddAll(t *testing.T) {
+ s1 := new(persistent.Set[int])
+ s1.Add(1)
+ s1.Add(2)
+ s2 := new(persistent.Set[int])
+ s2.Add(2)
+ s2.Add(3)
+ s2.Add(4)
+ s3 := new(persistent.Set[int])
+
+ s := new(persistent.Set[int])
+ s.AddAll(s1)
+ s.AddAll(s2)
+ s.AddAll(s3)
+
+ if d := diff(s1, []int{1, 2}); d != "" {
+ t.Errorf("s1: unexpected diff:\n%s", d)
+ }
+ if d := diff(s2, []int{2, 3, 4}); d != "" {
+ t.Errorf("s2: unexpected diff:\n%s", d)
+ }
+ if d := diff(s3, nil); d != "" {
+ t.Errorf("s3: unexpected diff:\n%s", d)
+ }
+ if d := diff(s, []int{1, 2, 3, 4}); d != "" {
+ t.Errorf("s: unexpected diff:\n%s", d)
+ }
+}
+
+func diff[K constraints.Ordered](got *persistent.Set[K], want []K) string {
+ wantSet := make(map[K]struct{})
+ for _, w := range want {
+ wantSet[w] = struct{}{}
+ }
+ var diff []string
+ got.Range(func(key K) {
+ if _, ok := wantSet[key]; !ok {
+ diff = append(diff, fmt.Sprintf("+%v", key))
+ }
+ })
+ for key := range wantSet {
+ if !got.Contains(key) {
+ diff = append(diff, fmt.Sprintf("-%v", key))
+ }
+ }
+ if len(diff) > 0 {
+ d := new(strings.Builder)
+ for _, l := range diff {
+ fmt.Fprintln(d, l)
+ }
+ return d.String()
+ }
+ return ""
+}