blob: 883b63c006ecd93b2c4d90fc98206db183ed83d0 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package persistent_test
import (
"fmt"
"strings"
"testing"
"golang.org/x/tools/gopls/internal/persistent"
"golang.org/x/tools/internal/constraints"
)
func TestSet(t *testing.T) {
const (
add = iota
remove
)
type op struct {
op int
v int
}
tests := []struct {
label string
ops []op
want []int
}{
{"empty", nil, nil},
{"singleton", []op{{add, 1}}, []int{1}},
{"add and remove", []op{
{add, 1},
{remove, 1},
}, nil},
{"interleaved and remove", []op{
{add, 1},
{add, 2},
{remove, 1},
{add, 3},
}, []int{2, 3}},
}
for _, test := range tests {
t.Run(test.label, func(t *testing.T) {
var s persistent.Set[int]
for _, op := range test.ops {
switch op.op {
case add:
s.Add(op.v)
case remove:
s.Remove(op.v)
}
}
if d := diff(&s, test.want); d != "" {
t.Errorf("unexpected diff:\n%s", d)
}
})
}
}
func TestSet_Clone(t *testing.T) {
s1 := new(persistent.Set[int])
s1.Add(1)
s1.Add(2)
s2 := s1.Clone()
s1.Add(3)
s2.Add(4)
if d := diff(s1, []int{1, 2, 3}); d != "" {
t.Errorf("s1: unexpected diff:\n%s", d)
}
if d := diff(s2, []int{1, 2, 4}); d != "" {
t.Errorf("s2: unexpected diff:\n%s", d)
}
}
func TestSet_AddAll(t *testing.T) {
s1 := new(persistent.Set[int])
s1.Add(1)
s1.Add(2)
s2 := new(persistent.Set[int])
s2.Add(2)
s2.Add(3)
s2.Add(4)
s3 := new(persistent.Set[int])
s := new(persistent.Set[int])
s.AddAll(s1)
s.AddAll(s2)
s.AddAll(s3)
if d := diff(s1, []int{1, 2}); d != "" {
t.Errorf("s1: unexpected diff:\n%s", d)
}
if d := diff(s2, []int{2, 3, 4}); d != "" {
t.Errorf("s2: unexpected diff:\n%s", d)
}
if d := diff(s3, nil); d != "" {
t.Errorf("s3: unexpected diff:\n%s", d)
}
if d := diff(s, []int{1, 2, 3, 4}); d != "" {
t.Errorf("s: unexpected diff:\n%s", d)
}
}
func diff[K constraints.Ordered](got *persistent.Set[K], want []K) string {
wantSet := make(map[K]struct{})
for _, w := range want {
wantSet[w] = struct{}{}
}
var diff []string
got.Range(func(key K) {
if _, ok := wantSet[key]; !ok {
diff = append(diff, fmt.Sprintf("+%v", key))
}
})
for key := range wantSet {
if !got.Contains(key) {
diff = append(diff, fmt.Sprintf("-%v", key))
}
}
if len(diff) > 0 {
d := new(strings.Builder)
for _, l := range diff {
fmt.Fprintln(d, l)
}
return d.String()
}
return ""
}