internal/persistent: use generics

Now that we're on 1.18+, make internal/persistent.Map generic.

Change-Id: I3403241fe22e28f969d7feb09a752b52f0d2ee4d
Reviewed-on: https://go-review.googlesource.com/c/tools/+/524759
gopls-CI: kokoro <noreply+kokoro@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
diff --git a/gopls/internal/lsp/cache/check.go b/gopls/internal/lsp/cache/check.go
index 74404af..b7267e9 100644
--- a/gopls/internal/lsp/cache/check.go
+++ b/gopls/internal/lsp/cache/check.go
@@ -849,7 +849,7 @@
 				unfinishedSuccs: int32(len(m.DepsByPkgPath)),
 			}
 			if entry, hit := b.s.packages.Get(m.ID); hit {
-				n.ph = entry.(*packageHandle)
+				n.ph = entry
 			}
 			if n.unfinishedSuccs == 0 {
 				leaves = append(leaves, n)
@@ -1118,12 +1118,11 @@
 	}
 
 	// Check the packages map again in case another goroutine got there first.
-	if alt, ok := b.s.packages.Get(n.m.ID); ok && alt.(*packageHandle).validated {
-		altPH := alt.(*packageHandle)
-		if altPH.m != n.m {
+	if alt, ok := b.s.packages.Get(n.m.ID); ok && alt.validated {
+		if alt.m != n.m {
 			bug.Reportf("existing package handle does not match for %s", n.m.ID)
 		}
-		n.ph = altPH
+		n.ph = alt
 	} else {
 		b.s.packages.Set(n.m.ID, n.ph, nil)
 	}
diff --git a/gopls/internal/lsp/cache/load.go b/gopls/internal/lsp/cache/load.go
index 05d4432..03db2a3 100644
--- a/gopls/internal/lsp/cache/load.go
+++ b/gopls/internal/lsp/cache/load.go
@@ -217,8 +217,7 @@
 	s.mu.Lock()
 
 	// Assert the invariant s.packages.Get(id).m == s.meta.metadata[id].
-	s.packages.Range(func(k, v interface{}) {
-		id, ph := k.(PackageID), v.(*packageHandle)
+	s.packages.Range(func(id PackageID, ph *packageHandle) {
 		if s.meta.metadata[id] != ph.m {
 			panic("inconsistent metadata")
 		}
diff --git a/gopls/internal/lsp/cache/maps.go b/gopls/internal/lsp/cache/maps.go
index de6187d..3fa866c 100644
--- a/gopls/internal/lsp/cache/maps.go
+++ b/gopls/internal/lsp/cache/maps.go
@@ -10,21 +10,14 @@
 	"golang.org/x/tools/internal/persistent"
 )
 
-// TODO(euroelessar): Use generics once support for go1.17 is dropped.
-
 type filesMap struct {
-	impl       *persistent.Map
+	impl       *persistent.Map[span.URI, source.FileHandle]
 	overlayMap map[span.URI]*Overlay // the subset that are overlays
 }
 
-// uriLessInterface is the < relation for "any" values containing span.URIs.
-func uriLessInterface(a, b interface{}) bool {
-	return a.(span.URI) < b.(span.URI)
-}
-
 func newFilesMap() filesMap {
 	return filesMap{
-		impl:       persistent.NewMap(uriLessInterface),
+		impl:       new(persistent.Map[span.URI, source.FileHandle]),
 		overlayMap: make(map[span.URI]*Overlay),
 	}
 }
@@ -53,9 +46,7 @@
 }
 
 func (m filesMap) Range(do func(key span.URI, value source.FileHandle)) {
-	m.impl.Range(func(key, value interface{}) {
-		do(key.(span.URI), value.(source.FileHandle))
-	})
+	m.impl.Range(do)
 }
 
 func (m filesMap) Set(key span.URI, value source.FileHandle) {
@@ -86,19 +77,13 @@
 	return overlays
 }
 
-func packageIDLessInterface(x, y interface{}) bool {
-	return x.(PackageID) < y.(PackageID)
-}
-
 type knownDirsSet struct {
-	impl *persistent.Map
+	impl *persistent.Map[span.URI, struct{}]
 }
 
 func newKnownDirsSet() knownDirsSet {
 	return knownDirsSet{
-		impl: persistent.NewMap(func(a, b interface{}) bool {
-			return a.(span.URI) < b.(span.URI)
-		}),
+		impl: new(persistent.Map[span.URI, struct{}]),
 	}
 }
 
@@ -118,8 +103,8 @@
 }
 
 func (s knownDirsSet) Range(do func(key span.URI)) {
-	s.impl.Range(func(key, value interface{}) {
-		do(key.(span.URI))
+	s.impl.Range(func(key span.URI, value struct{}) {
+		do(key)
 	})
 }
 
@@ -128,7 +113,7 @@
 }
 
 func (s knownDirsSet) Insert(key span.URI) {
-	s.impl.Set(key, nil, nil)
+	s.impl.Set(key, struct{}{}, nil)
 }
 
 func (s knownDirsSet) Remove(key span.URI) {
diff --git a/gopls/internal/lsp/cache/mod.go b/gopls/internal/lsp/cache/mod.go
index db0ab0a..8a452ab 100644
--- a/gopls/internal/lsp/cache/mod.go
+++ b/gopls/internal/lsp/cache/mod.go
@@ -52,7 +52,7 @@
 	}
 
 	// Await result.
-	v, err := s.awaitPromise(ctx, entry.(*memoize.Promise))
+	v, err := s.awaitPromise(ctx, entry)
 	if err != nil {
 		return nil, err
 	}
@@ -130,7 +130,7 @@
 	}
 
 	// Await result.
-	v, err := s.awaitPromise(ctx, entry.(*memoize.Promise))
+	v, err := s.awaitPromise(ctx, entry)
 	if err != nil {
 		return nil, err
 	}
@@ -240,7 +240,7 @@
 	}
 
 	// Await result.
-	v, err := s.awaitPromise(ctx, entry.(*memoize.Promise))
+	v, err := s.awaitPromise(ctx, entry)
 	if err != nil {
 		return nil, err
 	}
diff --git a/gopls/internal/lsp/cache/mod_tidy.go b/gopls/internal/lsp/cache/mod_tidy.go
index 64e02d1..b806edb 100644
--- a/gopls/internal/lsp/cache/mod_tidy.go
+++ b/gopls/internal/lsp/cache/mod_tidy.go
@@ -85,7 +85,7 @@
 	}
 
 	// Await result.
-	v, err := s.awaitPromise(ctx, entry.(*memoize.Promise))
+	v, err := s.awaitPromise(ctx, entry)
 	if err != nil {
 		return nil, err
 	}
diff --git a/gopls/internal/lsp/cache/mod_vuln.go b/gopls/internal/lsp/cache/mod_vuln.go
index 942ca52..dcd58bf 100644
--- a/gopls/internal/lsp/cache/mod_vuln.go
+++ b/gopls/internal/lsp/cache/mod_vuln.go
@@ -55,7 +55,7 @@
 	}
 
 	// Await result.
-	v, err := s.awaitPromise(ctx, entry.(*memoize.Promise))
+	v, err := s.awaitPromise(ctx, entry)
 	if err != nil {
 		return nil, err
 	}
diff --git a/gopls/internal/lsp/cache/session.go b/gopls/internal/lsp/cache/session.go
index 6b75f10..cd51e6d 100644
--- a/gopls/internal/lsp/cache/session.go
+++ b/gopls/internal/lsp/cache/session.go
@@ -20,6 +20,7 @@
 	"golang.org/x/tools/internal/event"
 	"golang.org/x/tools/internal/gocommand"
 	"golang.org/x/tools/internal/imports"
+	"golang.org/x/tools/internal/memoize"
 	"golang.org/x/tools/internal/persistent"
 	"golang.org/x/tools/internal/xcontext"
 )
@@ -169,18 +170,18 @@
 		backgroundCtx:        backgroundCtx,
 		cancel:               cancel,
 		store:                s.cache.store,
-		packages:             persistent.NewMap(packageIDLessInterface),
+		packages:             new(persistent.Map[PackageID, *packageHandle]),
 		meta:                 new(metadataGraph),
 		files:                newFilesMap(),
-		activePackages:       persistent.NewMap(packageIDLessInterface),
-		symbolizeHandles:     persistent.NewMap(uriLessInterface),
+		activePackages:       new(persistent.Map[PackageID, *Package]),
+		symbolizeHandles:     new(persistent.Map[span.URI, *memoize.Promise]),
 		workspacePackages:    make(map[PackageID]PackagePath),
 		unloadableFiles:      make(map[span.URI]struct{}),
-		parseModHandles:      persistent.NewMap(uriLessInterface),
-		parseWorkHandles:     persistent.NewMap(uriLessInterface),
-		modTidyHandles:       persistent.NewMap(uriLessInterface),
-		modVulnHandles:       persistent.NewMap(uriLessInterface),
-		modWhyHandles:        persistent.NewMap(uriLessInterface),
+		parseModHandles:      new(persistent.Map[span.URI, *memoize.Promise]),
+		parseWorkHandles:     new(persistent.Map[span.URI, *memoize.Promise]),
+		modTidyHandles:       new(persistent.Map[span.URI, *memoize.Promise]),
+		modVulnHandles:       new(persistent.Map[span.URI, *memoize.Promise]),
+		modWhyHandles:        new(persistent.Map[span.URI, *memoize.Promise]),
 		knownSubdirs:         newKnownDirsSet(),
 		workspaceModFiles:    wsModFiles,
 		workspaceModFilesErr: wsModFilesErr,
diff --git a/gopls/internal/lsp/cache/snapshot.go b/gopls/internal/lsp/cache/snapshot.go
index a1fe475..a914880 100644
--- a/gopls/internal/lsp/cache/snapshot.go
+++ b/gopls/internal/lsp/cache/snapshot.go
@@ -101,7 +101,7 @@
 
 	// symbolizeHandles maps each file URI to a handle for the future
 	// result of computing the symbols declared in that file.
-	symbolizeHandles *persistent.Map // from span.URI to *memoize.Promise[symbolizeResult]
+	symbolizeHandles *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[symbolizeResult]
 
 	// packages maps a packageKey to a *packageHandle.
 	// It may be invalidated when a file's content changes.
@@ -110,13 +110,13 @@
 	//  - packages.Get(id).meta == meta.metadata[id] for all ids
 	//  - if a package is in packages, then all of its dependencies should also
 	//    be in packages, unless there is a missing import
-	packages *persistent.Map // from packageID to *packageHandle
+	packages *persistent.Map[PackageID, *packageHandle]
 
 	// activePackages maps a package ID to a memoized active package, or nil if
 	// the package is known not to be open.
 	//
 	// IDs not contained in the map are not known to be open or not open.
-	activePackages *persistent.Map // from packageID to *Package
+	activePackages *persistent.Map[PackageID, *Package]
 
 	// workspacePackages contains the workspace's packages, which are loaded
 	// when the view is created. It contains no intermediate test variants.
@@ -137,18 +137,18 @@
 
 	// parseModHandles keeps track of any parseModHandles for the snapshot.
 	// The handles need not refer to only the view's go.mod file.
-	parseModHandles *persistent.Map // from span.URI to *memoize.Promise[parseModResult]
+	parseModHandles *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[parseModResult]
 
 	// parseWorkHandles keeps track of any parseWorkHandles for the snapshot.
 	// The handles need not refer to only the view's go.work file.
-	parseWorkHandles *persistent.Map // from span.URI to *memoize.Promise[parseWorkResult]
+	parseWorkHandles *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[parseWorkResult]
 
 	// Preserve go.mod-related handles to avoid garbage-collecting the results
 	// of various calls to the go command. The handles need not refer to only
 	// the view's go.mod file.
-	modTidyHandles *persistent.Map // from span.URI to *memoize.Promise[modTidyResult]
-	modWhyHandles  *persistent.Map // from span.URI to *memoize.Promise[modWhyResult]
-	modVulnHandles *persistent.Map // from span.URI to *memoize.Promise[modVulnResult]
+	modTidyHandles *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[modTidyResult]
+	modWhyHandles  *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[modWhyResult]
+	modVulnHandles *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[modVulnResult]
 
 	// knownSubdirs is the set of subdirectory URIs in the workspace,
 	// used to create glob patterns for file watching.
@@ -871,7 +871,7 @@
 	defer s.mu.Unlock()
 
 	if value, ok := s.activePackages.Get(id); ok {
-		return value.(*Package) // possibly nil, if we have already checked this id.
+		return value
 	}
 	return nil
 }
@@ -895,7 +895,7 @@
 
 func (s *snapshot) resetActivePackagesLocked() {
 	s.activePackages.Destroy()
-	s.activePackages = persistent.NewMap(packageIDLessInterface)
+	s.activePackages = new(persistent.Map[PackageID, *Package])
 }
 
 const fileExtensions = "go,mod,sum,work"
@@ -2189,7 +2189,7 @@
 			result.packages.Delete(id)
 		} else {
 			if entry, hit := result.packages.Get(id); hit {
-				ph := entry.(*packageHandle).clone(false)
+				ph := entry.clone(false)
 				result.packages.Set(id, ph, nil)
 			}
 		}
@@ -2291,12 +2291,11 @@
 // changed that happens not to be present in the map, but that's OK: the goal
 // of this function is to guarantee that IF the nearest mod file is present in
 // the map, it is invalidated.
-func deleteMostRelevantModFile(m *persistent.Map, changed span.URI) {
+func deleteMostRelevantModFile(m *persistent.Map[span.URI, *memoize.Promise], changed span.URI) {
 	var mostRelevant span.URI
 	changedFile := changed.Filename()
 
-	m.Range(func(key, value interface{}) {
-		modURI := key.(span.URI)
+	m.Range(func(modURI span.URI, _ *memoize.Promise) {
 		if len(modURI) > len(mostRelevant) {
 			if source.InDir(filepath.Dir(modURI.Filename()), changedFile) {
 				mostRelevant = modURI
diff --git a/gopls/internal/lsp/cache/symbols.go b/gopls/internal/lsp/cache/symbols.go
index 466d9dc..3ecd794 100644
--- a/gopls/internal/lsp/cache/symbols.go
+++ b/gopls/internal/lsp/cache/symbols.go
@@ -15,7 +15,6 @@
 	"golang.org/x/tools/gopls/internal/lsp/protocol"
 	"golang.org/x/tools/gopls/internal/lsp/source"
 	"golang.org/x/tools/gopls/internal/span"
-	"golang.org/x/tools/internal/memoize"
 )
 
 // symbolize returns the result of symbolizing the file identified by uri, using a cache.
@@ -51,7 +50,7 @@
 	}
 
 	// Await result.
-	v, err := s.awaitPromise(ctx, entry.(*memoize.Promise))
+	v, err := s.awaitPromise(ctx, entry)
 	if err != nil {
 		return nil, err
 	}
diff --git a/internal/constraints/constraint.go b/internal/constraints/constraint.go
new file mode 100644
index 0000000..4e6ab61
--- /dev/null
+++ b/internal/constraints/constraint.go
@@ -0,0 +1,52 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package constraints defines a set of useful constraints to be used
+// with type parameters.
+package constraints
+
+// Copied from x/exp/constraints.
+
+// Signed is a constraint that permits any signed integer type.
+// If future releases of Go add new predeclared signed integer types,
+// this constraint will be modified to include them.
+type Signed interface {
+	~int | ~int8 | ~int16 | ~int32 | ~int64
+}
+
+// Unsigned is a constraint that permits any unsigned integer type.
+// If future releases of Go add new predeclared unsigned integer types,
+// this constraint will be modified to include them.
+type Unsigned interface {
+	~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr
+}
+
+// Integer is a constraint that permits any integer type.
+// If future releases of Go add new predeclared integer types,
+// this constraint will be modified to include them.
+type Integer interface {
+	Signed | Unsigned
+}
+
+// Float is a constraint that permits any floating-point type.
+// If future releases of Go add new predeclared floating-point types,
+// this constraint will be modified to include them.
+type Float interface {
+	~float32 | ~float64
+}
+
+// Complex is a constraint that permits any complex numeric type.
+// If future releases of Go add new predeclared complex numeric types,
+// this constraint will be modified to include them.
+type Complex interface {
+	~complex64 | ~complex128
+}
+
+// Ordered is a constraint that permits any ordered type: any type
+// that supports the operators < <= >= >.
+// If future releases of Go add new ordered types,
+// this constraint will be modified to include them.
+type Ordered interface {
+	Integer | Float | ~string
+}
diff --git a/internal/persistent/map.go b/internal/persistent/map.go
index a9d878f..02389f8 100644
--- a/internal/persistent/map.go
+++ b/internal/persistent/map.go
@@ -12,6 +12,8 @@
 	"math/rand"
 	"strings"
 	"sync/atomic"
+
+	"golang.org/x/tools/internal/constraints"
 )
 
 // Implementation details:
@@ -25,9 +27,7 @@
 // Each argument is followed by a delta change to its reference counter.
 // In case if no change is expected, the delta will be `-0`.
 
-// Map is an associative mapping from keys to values, both represented as
-// interface{}. Key comparison and iteration order is defined by a
-// client-provided function that implements a strict weak order.
+// Map is an associative mapping from keys to values.
 //
 // Maps can be Cloned in constant time.
 // Get, Store, and Delete operations are done on average in logarithmic time.
@@ -38,16 +38,23 @@
 //
 // Internally the implementation is based on a randomized persistent treap:
 // https://en.wikipedia.org/wiki/Treap.
-type Map struct {
-	less func(a, b interface{}) bool
+//
+// The zero value is ready to use.
+type Map[K constraints.Ordered, V any] struct {
+	// Map is a generic wrapper around a non-generic implementation to avoid a
+	// significant increase in the size of the executable.
 	root *mapNode
 }
 
-func (m *Map) String() string {
+func (*Map[K, V]) less(l, r any) bool {
+	return l.(K) < r.(K)
+}
+
+func (m *Map[K, V]) String() string {
 	var buf strings.Builder
 	buf.WriteByte('{')
 	var sep string
-	m.Range(func(k, v interface{}) {
+	m.Range(func(k K, v V) {
 		fmt.Fprintf(&buf, "%s%v: %v", sep, k, v)
 		sep = ", "
 	})
@@ -56,7 +63,7 @@
 }
 
 type mapNode struct {
-	key         interface{}
+	key         any
 	value       *refValue
 	weight      uint64
 	refCount    int32
@@ -65,11 +72,11 @@
 
 type refValue struct {
 	refCount int32
-	value    interface{}
-	release  func(key, value interface{})
+	value    any
+	release  func(key, value any)
 }
 
-func newNodeWithRef(key, value interface{}, release func(key, value interface{})) *mapNode {
+func newNodeWithRef[K constraints.Ordered, V any](key K, value V, release func(key, value any)) *mapNode {
 	return &mapNode{
 		key: key,
 		value: &refValue{
@@ -116,20 +123,10 @@
 	}
 }
 
-// NewMap returns a new map whose keys are ordered by the given comparison
-// function (a strict weak order). It is the responsibility of the caller to
-// Destroy it at later time.
-func NewMap(less func(a, b interface{}) bool) *Map {
-	return &Map{
-		less: less,
-	}
-}
-
 // Clone returns a copy of the given map. It is a responsibility of the caller
 // to Destroy it at later time.
-func (pm *Map) Clone() *Map {
-	return &Map{
-		less: pm.less,
+func (pm *Map[K, V]) Clone() *Map[K, V] {
+	return &Map[K, V]{
 		root: pm.root.incref(),
 	}
 }
@@ -137,24 +134,26 @@
 // Destroy destroys the map.
 //
 // After Destroy, the Map should not be used again.
-func (pm *Map) Destroy() {
+func (pm *Map[K, V]) Destroy() {
 	// The implementation of these two functions is the same,
 	// but their intent is different.
 	pm.Clear()
 }
 
 // Clear removes all entries from the map.
-func (pm *Map) Clear() {
+func (pm *Map[K, V]) Clear() {
 	pm.root.decref()
 	pm.root = nil
 }
 
 // Range calls f sequentially in ascending key order for all entries in the map.
-func (pm *Map) Range(f func(key, value interface{})) {
-	pm.root.forEach(f)
+func (pm *Map[K, V]) Range(f func(key K, value V)) {
+	pm.root.forEach(func(k, v any) {
+		f(k.(K), v.(V))
+	})
 }
 
-func (node *mapNode) forEach(f func(key, value interface{})) {
+func (node *mapNode) forEach(f func(key, value any)) {
 	if node == nil {
 		return
 	}
@@ -163,26 +162,26 @@
 	node.right.forEach(f)
 }
 
-// Get returns the map value associated with the specified key, or nil if no entry
-// is present. The ok result indicates whether an entry was found in the map.
-func (pm *Map) Get(key interface{}) (interface{}, bool) {
+// Get returns the map value associated with the specified key.
+// The ok result indicates whether an entry was found in the map.
+func (pm *Map[K, V]) Get(key K) (V, bool) {
 	node := pm.root
 	for node != nil {
-		if pm.less(key, node.key) {
+		if key < node.key.(K) {
 			node = node.left
-		} else if pm.less(node.key, key) {
+		} else if node.key.(K) < key {
 			node = node.right
 		} else {
-			return node.value.value, true
+			return node.value.value.(V), true
 		}
 	}
-	return nil, false
+	var zero V
+	return zero, false
 }
 
 // SetAll updates the map with key/value pairs from the other map, overwriting existing keys.
 // It is equivalent to calling Set for each entry in the other map but is more efficient.
-// Both maps must have the same comparison function, otherwise behavior is undefined.
-func (pm *Map) SetAll(other *Map) {
+func (pm *Map[K, V]) SetAll(other *Map[K, V]) {
 	root := pm.root
 	pm.root = union(root, other.root, pm.less, true)
 	root.decref()
@@ -191,7 +190,7 @@
 // Set updates the value associated with the specified key.
 // If release is non-nil, it will be called with entry's key and value once the
 // key is no longer contained in the map or any clone.
-func (pm *Map) Set(key, value interface{}, release func(key, value interface{})) {
+func (pm *Map[K, V]) Set(key K, value V, release func(key, value any)) {
 	first := pm.root
 	second := newNodeWithRef(key, value, release)
 	pm.root = union(first, second, pm.less, true)
@@ -205,7 +204,7 @@
 // union(first:-0, second:-0) (result:+1)
 // Union borrows both subtrees without affecting their refcount and returns a
 // new reference that the caller is expected to call decref.
-func union(first, second *mapNode, less func(a, b interface{}) bool, overwrite bool) *mapNode {
+func union(first, second *mapNode, less func(any, any) bool, overwrite bool) *mapNode {
 	if first == nil {
 		return second.incref()
 	}
@@ -243,7 +242,7 @@
 // split(n:-0) (left:+1, mid:+1, right:+1)
 // Split borrows n without affecting its refcount, and returns three
 // new references that the caller is expected to call decref.
-func split(n *mapNode, key interface{}, less func(a, b interface{}) bool, requireMid bool) (left, mid, right *mapNode) {
+func split(n *mapNode, key any, less func(any, any) bool, requireMid bool) (left, mid, right *mapNode) {
 	if n == nil {
 		return nil, nil, nil
 	}
@@ -272,7 +271,7 @@
 }
 
 // Delete deletes the value for a key.
-func (pm *Map) Delete(key interface{}) {
+func (pm *Map[K, V]) Delete(key K) {
 	root := pm.root
 	left, mid, right := split(root, key, pm.less, true)
 	if mid == nil {
diff --git a/internal/persistent/map_test.go b/internal/persistent/map_test.go
index 9f89a1d..c73e566 100644
--- a/internal/persistent/map_test.go
+++ b/internal/persistent/map_test.go
@@ -18,7 +18,7 @@
 }
 
 type validatedMap struct {
-	impl     *Map
+	impl     *Map[int, int]
 	expected map[int]int      // current key-value mapping.
 	deleted  map[mapEntry]int // maps deleted entries to their clock time of last deletion
 	seen     map[mapEntry]int // maps seen entries to their clock time of last insertion
@@ -30,9 +30,7 @@
 	seenEntries := make(map[mapEntry]int)
 
 	m1 := &validatedMap{
-		impl: NewMap(func(a, b interface{}) bool {
-			return a.(int) < b.(int)
-		}),
+		impl:     new(Map[int, int]),
 		expected: make(map[int]int),
 		deleted:  deletedEntries,
 		seen:     seenEntries,
@@ -123,9 +121,7 @@
 	seenEntries := make(map[mapEntry]int)
 
 	m := &validatedMap{
-		impl: NewMap(func(a, b interface{}) bool {
-			return a.(int) < b.(int)
-		}),
+		impl:     new(Map[int, int]),
 		expected: make(map[int]int),
 		deleted:  deletedEntries,
 		seen:     seenEntries,
@@ -165,9 +161,7 @@
 	seenEntries := make(map[mapEntry]int)
 
 	m1 := &validatedMap{
-		impl: NewMap(func(a, b interface{}) bool {
-			return a.(int) < b.(int)
-		}),
+		impl:     new(Map[int, int]),
 		expected: make(map[int]int),
 		deleted:  deletedEntries,
 		seen:     seenEntries,
@@ -233,7 +227,7 @@
 func (vm *validatedMap) validate(t *testing.T) {
 	t.Helper()
 
-	validateNode(t, vm.impl.root, vm.impl.less)
+	validateNode(t, vm.impl.root)
 
 	// Note: this validation may not make sense if maps were constructed using
 	// SetAll operations. If this proves to be problematic, remove the clock,
@@ -246,23 +240,23 @@
 	}
 
 	actualMap := make(map[int]int, len(vm.expected))
-	vm.impl.Range(func(key, value interface{}) {
-		if other, ok := actualMap[key.(int)]; ok {
+	vm.impl.Range(func(key, value int) {
+		if other, ok := actualMap[key]; ok {
 			t.Fatalf("key is present twice, key: %d, first value: %d, second value: %d", key, value, other)
 		}
-		actualMap[key.(int)] = value.(int)
+		actualMap[key] = value
 	})
 
 	assertSameMap(t, actualMap, vm.expected)
 }
 
-func validateNode(t *testing.T, node *mapNode, less func(a, b interface{}) bool) {
+func validateNode(t *testing.T, node *mapNode) {
 	if node == nil {
 		return
 	}
 
 	if node.left != nil {
-		if less(node.key, node.left.key) {
+		if node.key.(int) < node.left.key.(int) {
 			t.Fatalf("left child has larger key: %v vs %v", node.left.key, node.key)
 		}
 		if node.left.weight > node.weight {
@@ -271,7 +265,7 @@
 	}
 
 	if node.right != nil {
-		if less(node.right.key, node.key) {
+		if node.right.key.(int) < node.key.(int) {
 			t.Fatalf("right child has smaller key: %v vs %v", node.right.key, node.key)
 		}
 		if node.right.weight > node.weight {
@@ -279,8 +273,8 @@
 		}
 	}
 
-	validateNode(t, node.left, less)
-	validateNode(t, node.right, less)
+	validateNode(t, node.left)
+	validateNode(t, node.right)
 }
 
 func (vm *validatedMap) setAll(t *testing.T, other *validatedMap) {
@@ -300,7 +294,7 @@
 	vm.clock++
 	vm.seen[entry] = vm.clock
 
-	vm.impl.Set(key, value, func(deletedKey, deletedValue interface{}) {
+	vm.impl.Set(key, value, func(deletedKey, deletedValue any) {
 		if deletedKey != key || deletedValue != value {
 			t.Fatalf("unexpected passed in deleted entry: %v/%v, expected: %v/%v", deletedKey, deletedValue, key, value)
 		}
@@ -346,7 +340,7 @@
 	vm.impl.Destroy()
 }
 
-func assertSameMap(t *testing.T, map1, map2 interface{}) {
+func assertSameMap(t *testing.T, map1, map2 any) {
 	t.Helper()
 
 	if !reflect.DeepEqual(map1, map2) {