gopls/internal/util/lru: make lru.Cache generic

The lru.Cache type was written before we could use generics in gopls.
Make it generic, to benefit from a bit more ergonomic APIs.

Change-Id: I8475613580156c644b170eaa473f927f8bd37e67
Reviewed-on: https://go-review.googlesource.com/c/tools/+/608795
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
diff --git a/gopls/internal/filecache/filecache.go b/gopls/internal/filecache/filecache.go
index 31a76ef..243e954 100644
--- a/gopls/internal/filecache/filecache.go
+++ b/gopls/internal/filecache/filecache.go
@@ -53,7 +53,7 @@
 // As an optimization, use a 100MB in-memory LRU cache in front of filecache
 // operations. This reduces I/O for operations such as diagnostics or
 // implementations that repeatedly access the same cache entries.
-var memCache = lru.New(100 * 1e6)
+var memCache = lru.New[memKey, []byte](100 * 1e6)
 
 type memKey struct {
 	kind string
@@ -69,8 +69,8 @@
 	// First consult the read-through memory cache.
 	// Note that memory cache hits do not update the times
 	// used for LRU eviction of the file-based cache.
-	if value := memCache.Get(memKey{kind, key}); value != nil {
-		return value.([]byte), nil
+	if value, ok := memCache.Get(memKey{kind, key}); ok {
+		return value, nil
 	}
 
 	iolimit <- struct{}{}        // acquire a token
diff --git a/gopls/internal/util/lru/lru.go b/gopls/internal/util/lru/lru.go
index b75fc85..4ed8eaf 100644
--- a/gopls/internal/util/lru/lru.go
+++ b/gopls/internal/util/lru/lru.go
@@ -11,8 +11,61 @@
 	"sync"
 )
 
-// A Cache is a fixed-size in-memory LRU cache.
-type Cache struct {
+// A Cache is a fixed-size in-memory LRU cache, storing values of type V keyed
+// by keys of type K.
+type Cache[K comparable, V any] struct {
+	impl *cache
+}
+
+// Get retrieves the value for the specified key.
+// If the key is found, its access time is updated.
+//
+// The second result reports whether the key was found.
+func (c *Cache[K, V]) Get(key K) (V, bool) {
+	v, ok := c.impl.get(key)
+	if !ok {
+		var zero V
+		return zero, false
+	}
+	// Handle untyped nil explicitly to avoid a panic in the type assertion
+	// below.
+	if v == nil {
+		var zero V
+		return zero, true
+	}
+	return v.(V), true
+}
+
+// Set stores a value for the specified key, using its given size to update the
+// current cache size, evicting old entries as necessary to fit in the cache
+// capacity.
+//
+// Size must be a non-negative value. If size is larger than the cache
+// capacity, the value is not stored and the cache is not modified.
+func (c *Cache[K, V]) Set(key K, value V, size int) {
+	c.impl.set(key, value, size)
+}
+
+// New creates a new Cache with the given capacity, which must be positive.
+//
+// The cache capacity uses arbitrary units, which are specified during the Set
+// operation.
+func New[K comparable, V any](capacity int) *Cache[K, V] {
+	if capacity == 0 {
+		panic("zero capacity")
+	}
+
+	return &Cache[K, V]{&cache{
+		capacity: capacity,
+		m:        make(map[any]*entry),
+	}}
+}
+
+// cache is the non-generic implementation of [Cache].
+//
+// (Using a generic wrapper around a non-generic impl avoids unnecessary
+// "stenciling" or code duplication.)
+type cache struct {
 	capacity int
 
 	mu    sync.Mutex
@@ -30,26 +83,7 @@
 	index int   // index of entry in the heap slice
 }
 
-// New creates a new Cache with the given capacity, which must be positive.
-//
-// The cache capacity uses arbitrary units, which are specified during the Set
-// operation.
-func New(capacity int) *Cache {
-	if capacity == 0 {
-		panic("zero capacity")
-	}
-
-	return &Cache{
-		capacity: capacity,
-		m:        make(map[any]*entry),
-	}
-}
-
-// Get retrieves the value for the specified key, or nil if the key is not
-// found.
-//
-// If the key is found, its access time is updated.
-func (c *Cache) Get(key any) any {
+func (c *cache) get(key any) (any, bool) {
 	c.mu.Lock()
 	defer c.mu.Unlock()
 
@@ -58,19 +92,13 @@
 	if e, ok := c.m[key]; ok { // cache hit
 		e.atime = c.clock
 		heap.Fix(&c.lru, e.index)
-		return e.value
+		return e.value, true
 	}
 
-	return nil
+	return nil, false
 }
 
-// Set stores a value for the specified key, using its given size to update the
-// current cache size, evicting old entries as necessary to fit in the cache
-// capacity.
-//
-// Size must be a non-negative value. If size is larger than the cache
-// capacity, the value is not stored and the cache is not modified.
-func (c *Cache) Set(key, value any, size int) {
+func (c *cache) set(key, value any, size int) {
 	if size < 0 {
 		panic(fmt.Sprintf("size must be non-negative, got %d", size))
 	}
diff --git a/gopls/internal/util/lru/lru_fuzz_test.go b/gopls/internal/util/lru/lru_fuzz_test.go
index b82776b..2f5f43c 100644
--- a/gopls/internal/util/lru/lru_fuzz_test.go
+++ b/gopls/internal/util/lru/lru_fuzz_test.go
@@ -22,14 +22,14 @@
 			ops = append(ops, op{data[0]%2 == 0, data[1], data[2]})
 			data = data[3:]
 		}
-		cache := lru.New(100)
+		cache := lru.New[byte, byte](100)
 		var reference [256]byte
 		for _, op := range ops {
 			if op.set {
 				reference[op.key] = op.value
 				cache.Set(op.key, op.value, 1)
 			} else {
-				if v := cache.Get(op.key); v != nil && v != reference[op.key] {
+				if v, ok := cache.Get(op.key); ok && v != reference[op.key] {
 					t.Fatalf("cache.Get(%d) = %d, want %d", op.key, v, reference[op.key])
 				}
 			}
diff --git a/gopls/internal/util/lru/lru_nil_test.go b/gopls/internal/util/lru/lru_nil_test.go
new file mode 100644
index 0000000..08ce910
--- /dev/null
+++ b/gopls/internal/util/lru/lru_nil_test.go
@@ -0,0 +1,25 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lru_test
+
+// TODO(rfindley): uncomment once -lang is at least go1.20.
+// Prior to that language version, interfaces did not satisfy comparable.
+// Note that we can't simply use //go:build go1.20, because we need at least Go
+// 1.21 in the go.mod file for file language versions support!
+/*
+import (
+	"testing"
+
+	"golang.org/x/tools/gopls/internal/util/lru"
+)
+
+func TestSetUntypedNil(t *testing.T) {
+	cache := lru.New[any, any](100 * 1e6)
+	cache.Set(nil, nil, 1)
+	if got, ok := cache.Get(nil); !ok || got != nil {
+		t.Errorf("cache.Get(nil) = %v, %v, want nil, true", got, ok)
+	}
+}
+*/
diff --git a/gopls/internal/util/lru/lru_test.go b/gopls/internal/util/lru/lru_test.go
index 9ffe346..bf96e8d 100644
--- a/gopls/internal/util/lru/lru_test.go
+++ b/gopls/internal/util/lru/lru_test.go
@@ -20,7 +20,7 @@
 func TestCache(t *testing.T) {
 	type get struct {
 		key  string
-		want any
+		want string
 	}
 	type set struct {
 		key, value string
@@ -31,8 +31,8 @@
 		steps []any
 	}{
 		{"empty cache", []any{
-			get{"a", nil},
-			get{"b", nil},
+			get{"a", ""},
+			get{"b", ""},
 		}},
 		{"zero-length string", []any{
 			set{"a", ""},
@@ -48,7 +48,7 @@
 			set{"a", "123"},
 			set{"b", "456"},
 			set{"c", "78901"},
-			get{"a", nil},
+			get{"a", ""},
 			get{"b", "456"},
 			get{"c", "78901"},
 		}},
@@ -58,18 +58,18 @@
 			get{"a", "123"},
 			set{"c", "78901"},
 			get{"a", "123"},
-			get{"b", nil},
+			get{"b", ""},
 			get{"c", "78901"},
 		}},
 	}
 
 	for _, test := range tests {
 		t.Run(test.label, func(t *testing.T) {
-			c := lru.New(10)
+			c := lru.New[string, string](10)
 			for i, step := range test.steps {
 				switch step := step.(type) {
 				case get:
-					if got := c.Get(step.key); got != step.want {
+					if got, _ := c.Get(step.key); got != step.want {
 						t.Errorf("#%d: c.Get(%q) = %q, want %q", i, step.key, got, step.want)
 					}
 				case set:
@@ -96,21 +96,20 @@
 		}
 	}
 
-	cache := lru.New(100 * 1e6) // 100MB cache
+	cache := lru.New[[32]byte, []byte](100 * 1e6) // 100MB cache
 
 	// get calls Get and verifies that the cache entry
 	// matches one of the values passed to Set.
 	get := func(mustBeFound bool) error {
-		got := cache.Get(key)
-		if got == nil {
+		got, ok := cache.Get(key)
+		if !ok {
 			if !mustBeFound {
 				return nil
 			}
 			return fmt.Errorf("Get did not return a value")
 		}
-		gotBytes := got.([]byte)
 		for _, want := range values {
-			if bytes.Equal(want[:], gotBytes) {
+			if bytes.Equal(want[:], got) {
 				return nil // a match
 			}
 		}