internal/lsp/cache: use persistent map for storing gofiles in the snapshot

Use treap (https://en.wikipedia.org/wiki/Treap) as a persistent map to avoid copying s.goFiles across generations.
Maintain an additional s.parseKeysByURIMap to avoid scanning s.goFiles on individual file's content invalidation.

This on average reduces didChange latency on internal codebase from 160ms to 150ms.

In a followup the same approach can be used to avoid copying s.files, s.packages, and s.knownSubdirs.

Updates golang/go#45686

Change-Id: Ic4a9b3c8fb2b66256f224adf9896ddcaaa6865b1
GitHub-Last-Rev: 0abd2570ae9b20ea7126ff31bee69aa0dc3f40aa
GitHub-Pull-Request: golang/tools#382
Reviewed-on: https://go-review.googlesource.com/c/tools/+/411554
Reviewed-by: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/internal/lsp/cache/maps.go b/internal/lsp/cache/maps.go
new file mode 100644
index 0000000..70f8039
--- /dev/null
+++ b/internal/lsp/cache/maps.go
@@ -0,0 +1,112 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cache
+
+import (
+	"golang.org/x/tools/internal/persistent"
+	"golang.org/x/tools/internal/span"
+)
+
+// TODO(euroelessar): Use generics once support for go1.17 is dropped.
+
+type goFilesMap struct {
+	impl *persistent.Map
+}
+
+func newGoFilesMap() *goFilesMap {
+	return &goFilesMap{
+		impl: persistent.NewMap(func(a, b interface{}) bool {
+			return parseKeyLess(a.(parseKey), b.(parseKey))
+		}),
+	}
+}
+
+func parseKeyLess(a, b parseKey) bool {
+	if a.mode != b.mode {
+		return a.mode < b.mode
+	}
+	if a.file.Hash != b.file.Hash {
+		return a.file.Hash.Less(b.file.Hash)
+	}
+	return a.file.URI < b.file.URI
+}
+
+func (m *goFilesMap) Clone() *goFilesMap {
+	return &goFilesMap{
+		impl: m.impl.Clone(),
+	}
+}
+
+func (m *goFilesMap) Destroy() {
+	m.impl.Destroy()
+}
+
+func (m *goFilesMap) Load(key parseKey) (*parseGoHandle, bool) {
+	value, ok := m.impl.Load(key)
+	if !ok {
+		return nil, false
+	}
+	return value.(*parseGoHandle), true
+}
+
+func (m *goFilesMap) Range(do func(key parseKey, value *parseGoHandle)) {
+	m.impl.Range(func(key, value interface{}) {
+		do(key.(parseKey), value.(*parseGoHandle))
+	})
+}
+
+func (m *goFilesMap) Store(key parseKey, value *parseGoHandle, release func()) {
+	m.impl.Store(key, value, func(key, value interface{}) {
+		release()
+	})
+}
+
+func (m *goFilesMap) Delete(key parseKey) {
+	m.impl.Delete(key)
+}
+
+type parseKeysByURIMap struct {
+	impl *persistent.Map
+}
+
+func newParseKeysByURIMap() *parseKeysByURIMap {
+	return &parseKeysByURIMap{
+		impl: persistent.NewMap(func(a, b interface{}) bool {
+			return a.(span.URI) < b.(span.URI)
+		}),
+	}
+}
+
+func (m *parseKeysByURIMap) Clone() *parseKeysByURIMap {
+	return &parseKeysByURIMap{
+		impl: m.impl.Clone(),
+	}
+}
+
+func (m *parseKeysByURIMap) Destroy() {
+	m.impl.Destroy()
+}
+
+func (m *parseKeysByURIMap) Load(key span.URI) ([]parseKey, bool) {
+	value, ok := m.impl.Load(key)
+	if !ok {
+		return nil, false
+	}
+	return value.([]parseKey), true
+}
+
+func (m *parseKeysByURIMap) Range(do func(key span.URI, value []parseKey)) {
+	m.impl.Range(func(key, value interface{}) {
+		do(key.(span.URI), value.([]parseKey))
+	})
+}
+
+func (m *parseKeysByURIMap) Store(key span.URI, value []parseKey) {
+	m.impl.Store(key, value, nil)
+}
+
+func (m *parseKeysByURIMap) Delete(key span.URI) {
+	m.impl.Delete(key)
+}
diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go
index ab55743..376524b 100644
--- a/internal/lsp/cache/parse.go
+++ b/internal/lsp/cache/parse.go
@@ -58,7 +58,7 @@
 	if pgh := s.getGoFile(key); pgh != nil {
 		return pgh
 	}
-	parseHandle := s.generation.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} {
+	parseHandle, release := s.generation.GetHandle(key, func(ctx context.Context, arg memoize.Arg) interface{} {
 		snapshot := arg.(*snapshot)
 		return parseGo(ctx, snapshot.FileSet(), fh, mode)
 	}, nil)
@@ -68,7 +68,7 @@
 		file:   fh,
 		mode:   mode,
 	}
-	return s.addGoFile(key, pgh)
+	return s.addGoFile(key, pgh, release)
 }
 
 func (pgh *parseGoHandle) String() string {
diff --git a/internal/lsp/cache/session.go b/internal/lsp/cache/session.go
index 0d3e944..7dbccf7 100644
--- a/internal/lsp/cache/session.go
+++ b/internal/lsp/cache/session.go
@@ -234,7 +234,8 @@
 		packages:          make(map[packageKey]*packageHandle),
 		meta:              &metadataGraph{},
 		files:             make(map[span.URI]source.VersionedFileHandle),
-		goFiles:           newGoFileMap(),
+		goFiles:           newGoFilesMap(),
+		parseKeysByURI:    newParseKeysByURIMap(),
 		symbols:           make(map[span.URI]*symbolHandle),
 		actions:           make(map[actionKey]*actionHandle),
 		workspacePackages: make(map[PackageID]PackagePath),
diff --git a/internal/lsp/cache/snapshot.go b/internal/lsp/cache/snapshot.go
index 3c46648..b2ac782 100644
--- a/internal/lsp/cache/snapshot.go
+++ b/internal/lsp/cache/snapshot.go
@@ -77,7 +77,8 @@
 	files map[span.URI]source.VersionedFileHandle
 
 	// goFiles maps a parseKey to its parseGoHandle.
-	goFiles *goFileMap
+	goFiles        *goFilesMap
+	parseKeysByURI *parseKeysByURIMap
 
 	// TODO(rfindley): consider merging this with files to reduce burden on clone.
 	symbols map[span.URI]*symbolHandle
@@ -133,6 +134,12 @@
 	analyzer *analysis.Analyzer
 }
 
+func (s *snapshot) Destroy(destroyedBy string) {
+	s.generation.Destroy(destroyedBy)
+	s.goFiles.Destroy()
+	s.parseKeysByURI.Destroy()
+}
+
 func (s *snapshot) ID() uint64 {
 	return s.id
 }
@@ -665,17 +672,23 @@
 func (s *snapshot) getGoFile(key parseKey) *parseGoHandle {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	return s.goFiles.get(key)
+	if result, ok := s.goFiles.Load(key); ok {
+		return result
+	}
+	return nil
 }
 
-func (s *snapshot) addGoFile(key parseKey, pgh *parseGoHandle) *parseGoHandle {
+func (s *snapshot) addGoFile(key parseKey, pgh *parseGoHandle, release func()) *parseGoHandle {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-
-	if prev := s.goFiles.get(key); prev != nil {
-		return prev
+	if result, ok := s.goFiles.Load(key); ok {
+		release()
+		return result
 	}
-	s.goFiles.set(key, pgh)
+	s.goFiles.Store(key, pgh, release)
+	keys, _ := s.parseKeysByURI.Load(key.file.URI)
+	keys = append([]parseKey{key}, keys...)
+	s.parseKeysByURI.Store(key.file.URI, keys)
 	return pgh
 }
 
@@ -1663,6 +1676,9 @@
 }
 
 func (s *snapshot) clone(ctx, bgCtx context.Context, changes map[span.URI]*fileChange, forceReloadMetadata bool) *snapshot {
+	ctx, done := event.Start(ctx, "snapshot.clone")
+	defer done()
+
 	var vendorChanged bool
 	newWorkspace, workspaceChanged, workspaceReload := s.workspace.invalidate(ctx, changes, &unappliedChanges{
 		originalSnapshot: s,
@@ -1686,7 +1702,8 @@
 		packages:          make(map[packageKey]*packageHandle, len(s.packages)),
 		actions:           make(map[actionKey]*actionHandle, len(s.actions)),
 		files:             make(map[span.URI]source.VersionedFileHandle, len(s.files)),
-		goFiles:           s.goFiles.clone(),
+		goFiles:           s.goFiles.Clone(),
+		parseKeysByURI:    s.parseKeysByURI.Clone(),
 		symbols:           make(map[span.URI]*symbolHandle, len(s.symbols)),
 		workspacePackages: make(map[PackageID]PackagePath, len(s.workspacePackages)),
 		unloadableFiles:   make(map[span.URI]struct{}, len(s.unloadableFiles)),
@@ -1731,27 +1748,14 @@
 		result.parseWorkHandles[k] = v
 	}
 
-	// Copy the handles of all Go source files.
-	// There may be tens of thousands of files,
-	// but changes are typically few, so we
-	// use a striped map optimized for this case
-	// and visit its stripes in parallel.
-	var (
-		toDeleteMu sync.Mutex
-		toDelete   []parseKey
-	)
-	s.goFiles.forEachConcurrent(func(k parseKey, v *parseGoHandle) {
-		if changes[k.file.URI] == nil {
-			// no change (common case)
-			newGen.Inherit(v.handle)
-		} else {
-			toDeleteMu.Lock()
-			toDelete = append(toDelete, k)
-			toDeleteMu.Unlock()
+	for uri := range changes {
+		keys, ok := result.parseKeysByURI.Load(uri)
+		if ok {
+			for _, key := range keys {
+				result.goFiles.Delete(key)
+			}
+			result.parseKeysByURI.Delete(uri)
 		}
-	})
-	for _, k := range toDelete {
-		result.goFiles.delete(k)
 	}
 
 	// Copy all of the go.mod-related handles. They may be invalidated later,
@@ -2194,7 +2198,7 @@
 // lockedSnapshot must be locked.
 func peekOrParse(ctx context.Context, lockedSnapshot *snapshot, fh source.FileHandle, mode source.ParseMode) (*source.ParsedGoFile, error) {
 	key := parseKey{file: fh.FileIdentity(), mode: mode}
-	if pgh := lockedSnapshot.goFiles.get(key); pgh != nil {
+	if pgh, ok := lockedSnapshot.goFiles.Load(key); ok {
 		cached := pgh.handle.Cached(lockedSnapshot.generation)
 		if cached != nil {
 			cached := cached.(*parseGoData)
@@ -2482,89 +2486,3 @@
 	}
 	return nil
 }
-
-// -- goFileMap --
-
-// A goFileMap is conceptually a map[parseKey]*parseGoHandle,
-// optimized for cloning all or nearly all entries.
-type goFileMap struct {
-	// The map is represented as a map of 256 stripes, one per
-	// distinct value of the top 8 bits of key.file.Hash.
-	// Each stripe has an associated boolean indicating whether it
-	// is shared, and thus immutable, and thus must be copied before any update.
-	// (The bits could be packed but it hasn't been worth it yet.)
-	stripes   [256]map[parseKey]*parseGoHandle
-	exclusive [256]bool // exclusive[i] means stripe[i] is not shared and may be safely mutated
-}
-
-// newGoFileMap returns a new empty goFileMap.
-func newGoFileMap() *goFileMap {
-	return new(goFileMap) // all stripes are shared (non-exclusive) nil maps
-}
-
-// clone returns a copy of m.
-// For concurrency, it counts as an update to m.
-func (m *goFileMap) clone() *goFileMap {
-	m.exclusive = [256]bool{} // original and copy are now nonexclusive
-	copy := *m
-	return &copy
-}
-
-// get returns the value for key k.
-func (m *goFileMap) get(k parseKey) *parseGoHandle {
-	return m.stripes[m.hash(k)][k]
-}
-
-// set updates the value for key k to v.
-func (m *goFileMap) set(k parseKey, v *parseGoHandle) {
-	m.unshare(k)[k] = v
-}
-
-// delete deletes the value for key k, if any.
-func (m *goFileMap) delete(k parseKey) {
-	// TODO(adonovan): opt?: skip unshare if k isn't present.
-	delete(m.unshare(k), k)
-}
-
-// forEachConcurrent calls f for each entry in the map.
-// Calls may be concurrent.
-// f must not modify m.
-func (m *goFileMap) forEachConcurrent(f func(parseKey, *parseGoHandle)) {
-	// Visit stripes in parallel chunks.
-	const p = 16 // concurrency level
-	var wg sync.WaitGroup
-	wg.Add(p)
-	for i := 0; i < p; i++ {
-		chunk := m.stripes[i*p : (i+1)*p]
-		go func() {
-			for _, stripe := range chunk {
-				for k, v := range stripe {
-					f(k, v)
-				}
-			}
-			wg.Done()
-		}()
-	}
-	wg.Wait()
-}
-
-// -- internal--
-
-// hash returns 8 bits from the key's file digest.
-func (*goFileMap) hash(k parseKey) byte { return k.file.Hash[0] }
-
-// unshare makes k's stripe exclusive, allocating a copy if needed, and returns it.
-func (m *goFileMap) unshare(k parseKey) map[parseKey]*parseGoHandle {
-	i := m.hash(k)
-	if !m.exclusive[i] {
-		m.exclusive[i] = true
-
-		// Copy the map.
-		copy := make(map[parseKey]*parseGoHandle, len(m.stripes[i]))
-		for k, v := range m.stripes[i] {
-			copy[k] = v
-		}
-		m.stripes[i] = copy
-	}
-	return m.stripes[i]
-}
diff --git a/internal/lsp/cache/view.go b/internal/lsp/cache/view.go
index 620efd8..1810f6e 100644
--- a/internal/lsp/cache/view.go
+++ b/internal/lsp/cache/view.go
@@ -535,7 +535,7 @@
 	v.mu.Unlock()
 	v.snapshotMu.Lock()
 	if v.snapshot != nil {
-		go v.snapshot.generation.Destroy("View.shutdown")
+		go v.snapshot.Destroy("View.shutdown")
 		v.snapshot = nil
 	}
 	v.snapshotMu.Unlock()
@@ -732,7 +732,7 @@
 	oldSnapshot := v.snapshot
 
 	v.snapshot = oldSnapshot.clone(ctx, v.baseCtx, changes, forceReloadMetadata)
-	go oldSnapshot.generation.Destroy("View.invalidateContent")
+	go oldSnapshot.Destroy("View.invalidateContent")
 
 	return v.snapshot, v.snapshot.generation.Acquire()
 }
diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go
index 0d8d661..73e1b7f 100644
--- a/internal/lsp/source/view.go
+++ b/internal/lsp/source/view.go
@@ -551,6 +551,11 @@
 	return fmt.Sprintf("%64x", [sha256.Size]byte(h))
 }
 
+// Less returns true if the given hash is less than the other.
+func (h Hash) Less(other Hash) bool {
+	return bytes.Compare(h[:], other[:]) < 0
+}
+
 // FileIdentity uniquely identifies a file at a version from a FileSystem.
 type FileIdentity struct {
 	URI  span.URI
diff --git a/internal/memoize/memoize.go b/internal/memoize/memoize.go
index 480b87f..48a642c 100644
--- a/internal/memoize/memoize.go
+++ b/internal/memoize/memoize.go
@@ -83,16 +83,15 @@
 
 	g.store.mu.Lock()
 	defer g.store.mu.Unlock()
-	for k, e := range g.store.handles {
+	for _, e := range g.store.handles {
+		if !e.trackGenerations {
+			continue
+		}
 		e.mu.Lock()
 		if _, ok := e.generations[g]; ok {
 			delete(e.generations, g) // delete even if it's dead, in case of dangling references to the entry.
 			if len(e.generations) == 0 {
-				delete(g.store.handles, k)
-				e.state = stateDestroyed
-				if e.cleanup != nil && e.value != nil {
-					e.cleanup(e.value)
-				}
+				e.destroy(g.store)
 			}
 		}
 		e.mu.Unlock()
@@ -161,6 +160,12 @@
 	// cleanup, if non-nil, is used to perform any necessary clean-up on values
 	// produced by function.
 	cleanup func(interface{})
+
+	// If trackGenerations is set, this handle tracks generations in which it
+	// is valid, via the generations field. Otherwise, it is explicitly reference
+	// counted via the refCounter field.
+	trackGenerations bool
+	refCounter       int32
 }
 
 // Bind returns a handle for the given key and function.
@@ -173,7 +178,34 @@
 //
 // If cleanup is non-nil, it will be called on any non-nil values produced by
 // function when they are no longer referenced.
+//
+// It is responsibility of the caller to call Inherit on the handler whenever
+// it should still be accessible by a next generation.
 func (g *Generation) Bind(key interface{}, function Function, cleanup func(interface{})) *Handle {
+	return g.getHandle(key, function, cleanup, true)
+}
+
+// GetHandle returns a handle for the given key and function with similar
+// properties and behavior as Bind.
+//
+// As in opposite to Bind it returns a release callback which has to be called
+// once this reference to handle is not needed anymore.
+func (g *Generation) GetHandle(key interface{}, function Function, cleanup func(interface{})) (*Handle, func()) {
+	handle := g.getHandle(key, function, cleanup, false)
+	store := g.store
+	release := func() {
+		store.mu.Lock()
+		defer store.mu.Unlock()
+
+		handle.refCounter--
+		if handle.refCounter == 0 {
+			handle.destroy(store)
+		}
+	}
+	return handle, release
+}
+
+func (g *Generation) getHandle(key interface{}, function Function, cleanup func(interface{}), trackGenerations bool) *Handle {
 	// panic early if the function is nil
 	// it would panic later anyway, but in a way that was much harder to debug
 	if function == nil {
@@ -186,20 +218,19 @@
 	defer g.store.mu.Unlock()
 	h, ok := g.store.handles[key]
 	if !ok {
-		h := &Handle{
-			key:         key,
-			function:    function,
-			generations: map[*Generation]struct{}{g: {}},
-			cleanup:     cleanup,
+		h = &Handle{
+			key:              key,
+			function:         function,
+			cleanup:          cleanup,
+			trackGenerations: trackGenerations,
+		}
+		if trackGenerations {
+			h.generations = make(map[*Generation]struct{}, 1)
 		}
 		g.store.handles[key] = h
-		return h
 	}
-	h.mu.Lock()
-	defer h.mu.Unlock()
-	if _, ok := h.generations[g]; !ok {
-		h.generations[g] = struct{}{}
-	}
+
+	h.incrementRef(g)
 	return h
 }
 
@@ -240,13 +271,44 @@
 	if atomic.LoadUint32(&g.destroyed) != 0 {
 		panic("inherit on generation " + g.name + " destroyed by " + g.destroyedBy)
 	}
+	if !h.trackGenerations {
+		panic("called Inherit on handle not created by Generation.Bind")
+	}
 
+	h.incrementRef(g)
+}
+
+func (h *Handle) destroy(store *Store) {
+	h.state = stateDestroyed
+	if h.cleanup != nil && h.value != nil {
+		h.cleanup(h.value)
+	}
+	delete(store.handles, h.key)
+}
+
+func (h *Handle) incrementRef(g *Generation) {
 	h.mu.Lock()
+	defer h.mu.Unlock()
+
 	if h.state == stateDestroyed {
 		panic(fmt.Sprintf("inheriting destroyed handle %#v (type %T) into generation %v", h.key, h.key, g.name))
 	}
-	h.generations[g] = struct{}{}
-	h.mu.Unlock()
+
+	if h.trackGenerations {
+		h.generations[g] = struct{}{}
+	} else {
+		h.refCounter++
+	}
+}
+
+// hasRefLocked reports whether h is valid in generation g. h.mu must be held.
+func (h *Handle) hasRefLocked(g *Generation) bool {
+	if !h.trackGenerations {
+		return true
+	}
+
+	_, ok := h.generations[g]
+	return ok
 }
 
 // Cached returns the value associated with a handle.
@@ -256,7 +318,7 @@
 func (h *Handle) Cached(g *Generation) interface{} {
 	h.mu.Lock()
 	defer h.mu.Unlock()
-	if _, ok := h.generations[g]; !ok {
+	if !h.hasRefLocked(g) {
 		return nil
 	}
 	if h.state == stateCompleted {
@@ -277,7 +339,7 @@
 		return nil, ctx.Err()
 	}
 	h.mu.Lock()
-	if _, ok := h.generations[g]; !ok {
+	if !h.hasRefLocked(g) {
 		h.mu.Unlock()
 
 		err := fmt.Errorf("reading key %#v: generation %v is not known", h.key, g.name)
diff --git a/internal/memoize/memoize_test.go b/internal/memoize/memoize_test.go
index ee0fd23..bffbfc2 100644
--- a/internal/memoize/memoize_test.go
+++ b/internal/memoize/memoize_test.go
@@ -106,3 +106,58 @@
 		t.Error("after destroying g2, v2 is not cleaned up")
 	}
 }
+
+func TestHandleRefCounting(t *testing.T) {
+	s := &memoize.Store{}
+	g1 := s.Generation("g1")
+	v1 := false
+	v2 := false
+	cleanup := func(v interface{}) {
+		*(v.(*bool)) = true
+	}
+	h1, release1 := g1.GetHandle("key1", func(context.Context, memoize.Arg) interface{} {
+		return &v1
+	}, nil)
+	h2, release2 := g1.GetHandle("key2", func(context.Context, memoize.Arg) interface{} {
+		return &v2
+	}, cleanup)
+	expectGet(t, h1, g1, &v1)
+	expectGet(t, h2, g1, &v2)
+
+	g2 := s.Generation("g2")
+	expectGet(t, h1, g2, &v1)
+	g1.Destroy("by test")
+	expectGet(t, h2, g2, &v2)
+
+	h2Copy, release2Copy := g2.GetHandle("key2", func(context.Context, memoize.Arg) interface{} {
+		return &v1
+	}, nil)
+	if h2 != h2Copy {
+		t.Error("NewHandle returned a new value while old is not destroyed yet")
+	}
+	expectGet(t, h2Copy, g2, &v2)
+	g2.Destroy("by test")
+
+	release2()
+	if got, want := v2, false; got != want {
+		t.Error("after destroying first v2 ref, v2 is cleaned up")
+	}
+	release2Copy()
+	if got, want := v2, true; got != want {
+		t.Error("after destroying second v2 ref, v2 is not cleaned up")
+	}
+	if got, want := v1, false; got != want {
+		t.Error("after destroying v2, v1 is cleaned up")
+	}
+	release1()
+
+	g3 := s.Generation("g3")
+	h2Copy, release2Copy = g3.GetHandle("key2", func(context.Context, memoize.Arg) interface{} {
+		return &v2
+	}, cleanup)
+	if h2 == h2Copy {
+		t.Error("NewHandle returned previously destroyed value")
+	}
+	release2Copy()
+	g3.Destroy("by test")
+}
diff --git a/internal/persistent/map.go b/internal/persistent/map.go
new file mode 100644
index 0000000..bbcb72b
--- /dev/null
+++ b/internal/persistent/map.go
@@ -0,0 +1,268 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The persistent package defines various persistent data structures;
+// that is, data structures that can be efficiently copied and modified
+// in sublinear time.
+package persistent
+
+import (
+	"math/rand"
+	"sync/atomic"
+)
+
+// Implementation details:
+// * Each value is reference counted by nodes which hold it.
+// * Each node is reference counted by its parent nodes.
+// * Each map is considered a top-level parent node from reference counting perspective.
+// * Each change does always effectivelly produce a new top level node.
+//
+// Functions which operate directly with nodes do have a notation in form of
+// `foo(arg1:+n1, arg2:+n2) (ret1:+n3)`.
+// Each argument is followed by a delta change to its reference counter.
+// In case if no change is expected, the delta will be `-0`.
+
+// Map is an associative mapping from keys to values, both represented as
+// interface{}. Key comparison and iteration order is defined by a
+// client-provided function that implements a strict weak order.
+//
+// Maps can be Cloned in constant time.
+//
+// Values are reference counted, and a client-supplied release function
+// is called when a value is no longer referenced by a map or any clone.
+//
+// Internally the implementation is based on a randomized persistent treap:
+// https://en.wikipedia.org/wiki/Treap.
+type Map struct {
+	less func(a, b interface{}) bool
+	root *mapNode
+}
+
+type mapNode struct {
+	key         interface{}
+	value       *refValue
+	weight      uint64
+	refCount    int32
+	left, right *mapNode
+}
+
+type refValue struct {
+	refCount int32
+	value    interface{}
+	release  func(key, value interface{})
+}
+
+func newNodeWithRef(key, value interface{}, release func(key, value interface{})) *mapNode {
+	return &mapNode{
+		key: key,
+		value: &refValue{
+			value:    value,
+			release:  release,
+			refCount: 1,
+		},
+		refCount: 1,
+		weight:   rand.Uint64(),
+	}
+}
+
+func (node *mapNode) shallowCloneWithRef() *mapNode {
+	atomic.AddInt32(&node.value.refCount, 1)
+	return &mapNode{
+		key:      node.key,
+		value:    node.value,
+		weight:   node.weight,
+		refCount: 1,
+	}
+}
+
+func (node *mapNode) incref() *mapNode {
+	if node != nil {
+		atomic.AddInt32(&node.refCount, 1)
+	}
+	return node
+}
+
+func (node *mapNode) decref() {
+	if node == nil {
+		return
+	}
+	if atomic.AddInt32(&node.refCount, -1) == 0 {
+		if atomic.AddInt32(&node.value.refCount, -1) == 0 {
+			if node.value.release != nil {
+				node.value.release(node.key, node.value.value)
+			}
+			node.value.value = nil
+			node.value.release = nil
+		}
+		node.left.decref()
+		node.right.decref()
+	}
+}
+
+// NewMap returns a new map whose keys are ordered by the given comparison
+// function (a strict weak order). It is the responsibility of the caller to
+// Destroy it at later time.
+func NewMap(less func(a, b interface{}) bool) *Map {
+	return &Map{
+		less: less,
+	}
+}
+
+// Clone returns a copy of the given map. It is a responsibility of the caller
+// to Destroy it at later time.
+func (pm *Map) Clone() *Map {
+	return &Map{
+		less: pm.less,
+		root: pm.root.incref(),
+	}
+}
+
+// Destroy the persistent map.
+//
+// After Destroy, the Map should not be used again.
+func (pm *Map) Destroy() {
+	pm.root.decref()
+	pm.root = nil
+}
+
+// Range calls f sequentially in ascending key order for all entries in the map.
+func (pm *Map) Range(f func(key, value interface{})) {
+	pm.root.forEach(f)
+}
+
+func (node *mapNode) forEach(f func(key, value interface{})) {
+	if node == nil {
+		return
+	}
+	node.left.forEach(f)
+	f(node.key, node.value.value)
+	node.right.forEach(f)
+}
+
+// Load returns the value stored in the map for a key, or nil if no entry is
+// present. The ok result indicates whether an entry was found in the map.
+func (pm *Map) Load(key interface{}) (interface{}, bool) {
+	node := pm.root
+	for node != nil {
+		if pm.less(key, node.key) {
+			node = node.left
+		} else if pm.less(node.key, key) {
+			node = node.right
+		} else {
+			return node.value.value, true
+		}
+	}
+	return nil, false
+}
+
+// Store sets the value for a key.
+// If release is non-nil, it will be called with entry's key and value once the
+// key is no longer contained in the map or any clone.
+func (pm *Map) Store(key, value interface{}, release func(key, value interface{})) {
+	first := pm.root
+	second := newNodeWithRef(key, value, release)
+	pm.root = union(first, second, pm.less, true)
+	first.decref()
+	second.decref()
+}
+
+// union returns a new tree which is a union of first and second one.
+// If overwrite is set to true, second one would override a value for any duplicate keys.
+//
+// union(first:-0, second:-0) (result:+1)
+// Union borrows both subtrees without affecting their refcount and returns a
+// new reference that the caller is expected to call decref.
+func union(first, second *mapNode, less func(a, b interface{}) bool, overwrite bool) *mapNode {
+	if first == nil {
+		return second.incref()
+	}
+	if second == nil {
+		return first.incref()
+	}
+
+	if first.weight < second.weight {
+		second, first, overwrite = first, second, !overwrite
+	}
+
+	left, mid, right := split(second, first.key, less)
+	var result *mapNode
+	if overwrite && mid != nil {
+		result = mid.shallowCloneWithRef()
+	} else {
+		result = first.shallowCloneWithRef()
+	}
+	result.weight = first.weight
+	result.left = union(first.left, left, less, overwrite)
+	result.right = union(first.right, right, less, overwrite)
+	left.decref()
+	mid.decref()
+	right.decref()
+	return result
+}
+
+// split the tree midway by the key into three different ones.
+// Return three new trees: left with all nodes with smaller than key, mid with
+// the node matching the key, right with all nodes larger than key.
+// If there are no nodes in one of trees, return nil instead of it.
+//
+// split(n:-0) (left:+1, mid:+1, right:+1)
+// Split borrows n without affecting its refcount, and returns three
+// new references that that caller is expected to call decref.
+func split(n *mapNode, key interface{}, less func(a, b interface{}) bool) (left, mid, right *mapNode) {
+	if n == nil {
+		return nil, nil, nil
+	}
+
+	if less(n.key, key) {
+		left, mid, right := split(n.right, key, less)
+		newN := n.shallowCloneWithRef()
+		newN.left = n.left.incref()
+		newN.right = left
+		return newN, mid, right
+	} else if less(key, n.key) {
+		left, mid, right := split(n.left, key, less)
+		newN := n.shallowCloneWithRef()
+		newN.left = right
+		newN.right = n.right.incref()
+		return left, mid, newN
+	}
+	mid = n.shallowCloneWithRef()
+	return n.left.incref(), mid, n.right.incref()
+}
+
+// Delete deletes the value for a key.
+func (pm *Map) Delete(key interface{}) {
+	root := pm.root
+	left, mid, right := split(root, key, pm.less)
+	pm.root = merge(left, right)
+	left.decref()
+	mid.decref()
+	right.decref()
+	root.decref()
+}
+
+// merge two trees while preserving the weight invariant.
+// All nodes in left must have smaller keys than any node in right.
+//
+// merge(left:-0, right:-0) (result:+1)
+// Merge borrows its arguments without affecting their refcount
+// and returns a new reference that the caller is expected to call decref.
+func merge(left, right *mapNode) *mapNode {
+	switch {
+	case left == nil:
+		return right.incref()
+	case right == nil:
+		return left.incref()
+	case left.weight > right.weight:
+		root := left.shallowCloneWithRef()
+		root.left = left.left.incref()
+		root.right = merge(left.right, right)
+		return root
+	default:
+		root := right.shallowCloneWithRef()
+		root.left = merge(left, right.left)
+		root.right = right.right.incref()
+		return root
+	}
+}
diff --git a/internal/persistent/map_test.go b/internal/persistent/map_test.go
new file mode 100644
index 0000000..9585956
--- /dev/null
+++ b/internal/persistent/map_test.go
@@ -0,0 +1,316 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package persistent
+
+import (
+	"fmt"
+	"math/rand"
+	"reflect"
+	"sync/atomic"
+	"testing"
+)
+
+type mapEntry struct {
+	key   int
+	value int
+}
+
+type validatedMap struct {
+	impl     *Map
+	expected map[int]int
+	deleted  map[mapEntry]struct{}
+	seen     map[mapEntry]struct{}
+}
+
+func TestSimpleMap(t *testing.T) {
+	deletedEntries := make(map[mapEntry]struct{})
+	seenEntries := make(map[mapEntry]struct{})
+
+	m1 := &validatedMap{
+		impl: NewMap(func(a, b interface{}) bool {
+			return a.(int) < b.(int)
+		}),
+		expected: make(map[int]int),
+		deleted:  deletedEntries,
+		seen:     seenEntries,
+	}
+
+	m3 := m1.clone()
+	validateRef(t, m1, m3)
+	m3.insert(t, 8, 8)
+	validateRef(t, m1, m3)
+	m3.destroy()
+
+	assertSameMap(t, deletedEntries, map[mapEntry]struct{}{
+		{key: 8, value: 8}: {},
+	})
+
+	validateRef(t, m1)
+	m1.insert(t, 1, 1)
+	validateRef(t, m1)
+	m1.insert(t, 2, 2)
+	validateRef(t, m1)
+	m1.insert(t, 3, 3)
+	validateRef(t, m1)
+	m1.remove(t, 2)
+	validateRef(t, m1)
+	m1.insert(t, 6, 6)
+	validateRef(t, m1)
+
+	assertSameMap(t, deletedEntries, map[mapEntry]struct{}{
+		{key: 2, value: 2}: {},
+		{key: 8, value: 8}: {},
+	})
+
+	m2 := m1.clone()
+	validateRef(t, m1, m2)
+	m1.insert(t, 6, 60)
+	validateRef(t, m1, m2)
+	m1.remove(t, 1)
+	validateRef(t, m1, m2)
+
+	for i := 10; i < 14; i++ {
+		m1.insert(t, i, i)
+		validateRef(t, m1, m2)
+	}
+
+	m1.insert(t, 10, 100)
+	validateRef(t, m1, m2)
+
+	m1.remove(t, 12)
+	validateRef(t, m1, m2)
+
+	m2.insert(t, 4, 4)
+	validateRef(t, m1, m2)
+	m2.insert(t, 5, 5)
+	validateRef(t, m1, m2)
+
+	m1.destroy()
+
+	assertSameMap(t, deletedEntries, map[mapEntry]struct{}{
+		{key: 2, value: 2}:    {},
+		{key: 6, value: 60}:   {},
+		{key: 8, value: 8}:    {},
+		{key: 10, value: 10}:  {},
+		{key: 10, value: 100}: {},
+		{key: 11, value: 11}:  {},
+		{key: 12, value: 12}:  {},
+		{key: 13, value: 13}:  {},
+	})
+
+	m2.insert(t, 7, 7)
+	validateRef(t, m2)
+
+	m2.destroy()
+
+	assertSameMap(t, seenEntries, deletedEntries)
+}
+
+func TestRandomMap(t *testing.T) {
+	deletedEntries := make(map[mapEntry]struct{})
+	seenEntries := make(map[mapEntry]struct{})
+
+	m := &validatedMap{
+		impl: NewMap(func(a, b interface{}) bool {
+			return a.(int) < b.(int)
+		}),
+		expected: make(map[int]int),
+		deleted:  deletedEntries,
+		seen:     seenEntries,
+	}
+
+	keys := make([]int, 0, 1000)
+	for i := 0; i < 1000; i++ {
+		key := rand.Int()
+		m.insert(t, key, key)
+		keys = append(keys, key)
+
+		if i%10 == 1 {
+			index := rand.Intn(len(keys))
+			last := len(keys) - 1
+			key = keys[index]
+			keys[index], keys[last] = keys[last], keys[index]
+			keys = keys[:last]
+
+			m.remove(t, key)
+		}
+	}
+
+	m.destroy()
+	assertSameMap(t, seenEntries, deletedEntries)
+}
+
+func (vm *validatedMap) onDelete(t *testing.T, key, value int) {
+	entry := mapEntry{key: key, value: value}
+	if _, ok := vm.deleted[entry]; ok {
+		t.Fatalf("tried to delete entry twice, key: %d, value: %d", key, value)
+	}
+	vm.deleted[entry] = struct{}{}
+}
+
+func validateRef(t *testing.T, maps ...*validatedMap) {
+	t.Helper()
+
+	actualCountByEntry := make(map[mapEntry]int32)
+	nodesByEntry := make(map[mapEntry]map[*mapNode]struct{})
+	expectedCountByEntry := make(map[mapEntry]int32)
+	for i, m := range maps {
+		dfsRef(m.impl.root, actualCountByEntry, nodesByEntry)
+		dumpMap(t, fmt.Sprintf("%d:", i), m.impl.root)
+	}
+	for entry, nodes := range nodesByEntry {
+		expectedCountByEntry[entry] = int32(len(nodes))
+	}
+	assertSameMap(t, expectedCountByEntry, actualCountByEntry)
+}
+
+func dfsRef(node *mapNode, countByEntry map[mapEntry]int32, nodesByEntry map[mapEntry]map[*mapNode]struct{}) {
+	if node == nil {
+		return
+	}
+
+	entry := mapEntry{key: node.key.(int), value: node.value.value.(int)}
+	countByEntry[entry] = atomic.LoadInt32(&node.value.refCount)
+
+	nodes, ok := nodesByEntry[entry]
+	if !ok {
+		nodes = make(map[*mapNode]struct{})
+		nodesByEntry[entry] = nodes
+	}
+	nodes[node] = struct{}{}
+
+	dfsRef(node.left, countByEntry, nodesByEntry)
+	dfsRef(node.right, countByEntry, nodesByEntry)
+}
+
+func dumpMap(t *testing.T, prefix string, n *mapNode) {
+	if n == nil {
+		t.Logf("%s nil", prefix)
+		return
+	}
+	t.Logf("%s {key: %v, value: %v (ref: %v), ref: %v, weight: %v}", prefix, n.key, n.value.value, n.value.refCount, n.refCount, n.weight)
+	dumpMap(t, prefix+"l", n.left)
+	dumpMap(t, prefix+"r", n.right)
+}
+
+func (vm *validatedMap) validate(t *testing.T) {
+	t.Helper()
+
+	validateNode(t, vm.impl.root, vm.impl.less)
+
+	for key, value := range vm.expected {
+		entry := mapEntry{key: key, value: value}
+		if _, ok := vm.deleted[entry]; ok {
+			t.Fatalf("entry is deleted prematurely, key: %d, value: %d", key, value)
+		}
+	}
+
+	actualMap := make(map[int]int, len(vm.expected))
+	vm.impl.Range(func(key, value interface{}) {
+		if other, ok := actualMap[key.(int)]; ok {
+			t.Fatalf("key is present twice, key: %d, first value: %d, second value: %d", key, value, other)
+		}
+		actualMap[key.(int)] = value.(int)
+	})
+
+	assertSameMap(t, actualMap, vm.expected)
+}
+
+func validateNode(t *testing.T, node *mapNode, less func(a, b interface{}) bool) {
+	if node == nil {
+		return
+	}
+
+	if node.left != nil {
+		if less(node.key, node.left.key) {
+			t.Fatalf("left child has larger key: %v vs %v", node.left.key, node.key)
+		}
+		if node.left.weight > node.weight {
+			t.Fatalf("left child has larger weight: %v vs %v", node.left.weight, node.weight)
+		}
+	}
+
+	if node.right != nil {
+		if less(node.right.key, node.key) {
+			t.Fatalf("right child has smaller key: %v vs %v", node.right.key, node.key)
+		}
+		if node.right.weight > node.weight {
+			t.Fatalf("right child has larger weight: %v vs %v", node.right.weight, node.weight)
+		}
+	}
+
+	validateNode(t, node.left, less)
+	validateNode(t, node.right, less)
+}
+
+func (vm *validatedMap) insert(t *testing.T, key, value int) {
+	vm.seen[mapEntry{key: key, value: value}] = struct{}{}
+	vm.impl.Store(key, value, func(deletedKey, deletedValue interface{}) {
+		if deletedKey != key || deletedValue != value {
+			t.Fatalf("unexpected passed in deleted entry: %v/%v, expected: %v/%v", deletedKey, deletedValue, key, value)
+		}
+		vm.onDelete(t, key, value)
+	})
+	vm.expected[key] = value
+	vm.validate(t)
+
+	loadValue, ok := vm.impl.Load(key)
+	if !ok || loadValue != value {
+		t.Fatalf("unexpected load result after insertion, key: %v, expected: %v, got: %v (%v)", key, value, loadValue, ok)
+	}
+}
+
+func (vm *validatedMap) remove(t *testing.T, key int) {
+	vm.impl.Delete(key)
+	delete(vm.expected, key)
+	vm.validate(t)
+
+	loadValue, ok := vm.impl.Load(key)
+	if ok {
+		t.Fatalf("unexpected load result after removal, key: %v, got: %v", key, loadValue)
+	}
+}
+
+func (vm *validatedMap) clone() *validatedMap {
+	expected := make(map[int]int, len(vm.expected))
+	for key, value := range vm.expected {
+		expected[key] = value
+	}
+
+	return &validatedMap{
+		impl:     vm.impl.Clone(),
+		expected: expected,
+		deleted:  vm.deleted,
+		seen:     vm.seen,
+	}
+}
+
+func (vm *validatedMap) destroy() {
+	vm.impl.Destroy()
+}
+
+func assertSameMap(t *testing.T, map1, map2 interface{}) {
+	t.Helper()
+
+	if !reflect.DeepEqual(map1, map2) {
+		t.Fatalf("different maps:\n%v\nvs\n%v", map1, map2)
+	}
+}
+
+func isSameMap(map1, map2 reflect.Value) bool {
+	if map1.Len() != map2.Len() {
+		return false
+	}
+	iter := map1.MapRange()
+	for iter.Next() {
+		key := iter.Key()
+		value1 := iter.Value()
+		value2 := map2.MapIndex(key)
+		if value2.IsZero() || !reflect.DeepEqual(value1.Interface(), value2.Interface()) {
+			return false
+		}
+	}
+	return true
+}