internal/storage: storage abstractions

Package storage defines DB, a key-value database,
and VectorDB, a vector database.

Copied from https://github.com/rsc/gaby/commit/3f1bdd4.

Change-Id: I057ad7086c7cd53826c8022f674ea3943c9bcba5
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/597138
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
diff --git a/internal/storage/db.go b/internal/storage/db.go
new file mode 100644
index 0000000..980f5c5
--- /dev/null
+++ b/internal/storage/db.go
@@ -0,0 +1,188 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package storage defines the storage abstractions needed for Oscar:
+// [DB], a basic key-value store, and [VectorDB], a vector database.
+// The storage needs are intentionally minimal (avoiding, for example,
+// a requirement on SQL), to admit as many implementations as possible.
+package storage
+
+import (
+	"encoding/json"
+	"fmt"
+	"iter"
+	"log/slog"
+	"strconv"
+	"strings"
+
+	"rsc.io/ordered"
+)
+
+// A DB is a key-value database.
+//
+// DB operations are assumed not to fail.
+// They panic, intending to take down the program,
+// if there is an error accessing the database.
+// The assumption is that the program cannot possibly
+// continue without the database, since that's where all the state is stored.
+// Similarly, clients of DB conventionally panic
+// using [DB.Panic] if the database returns corrupted data.
+// Code using multiple parallel database operations can recover
+// at the outermost calls.
+type DB interface {
+	// Lock acquires a lock on the given name, which need not exist in the database.
+	// After a successful Lock(name),
+	// any other call to Lock(name) from any other client of the database
+	// (including in another process, for shared databases)
+	// must block until Unlock(name) has been called.
+	// In a shared database, a lock may also unlock
+	// when the client disconnects or times out.
+	Lock(name string)
+
+	// Unlock releases the lock with the given name,
+	// which the caller must have locked.
+	Unlock(name string)
+
+	// Set sets the value associated with key to val.
+	Set(key, val []byte)
+
+	// Get looks up the value associated with key.
+	// If there is no entry for key in the database, Get returns nil, false.
+	// Otherwise it returns val, true.
+	Get(key []byte) (val []byte, ok bool)
+
+	// Scan returns an iterator over all key-value pairs with start ≤ key ≤ end.
+	// The second value in each iteration pair is a function returning the value,
+	// not the value itself:
+	//
+	//	for key, getVal := range db.Scan([]byte("aaa"), []byte("zzz")) {
+	//		val := getVal()
+	//		fmt.Printf("%q: %q\n", key, val)
+	//	}
+	//
+	// In iterations that only need the keys or only need the values for a subset of keys,
+	// some DB implementations may avoid work when the value function is not called.
+	//
+	// A Scan may or may not observe concurrent modifications made
+	// using Set, Delete, and DeleteRange.
+	Scan(start, end []byte) iter.Seq2[[]byte, func() []byte]
+
+	// Delete deletes any value associated with key.
+	// Delete of an unset key is a no-op.
+	Delete(key []byte)
+
+	// DeleteRange deletes all key-value pairs with start ≤ key ≤ end.
+	DeleteRange(start, end []byte)
+
+	// Batch returns a new [Batch] that accumulates database mutations
+	// to apply in an atomic operation. In addition to the atomicity, using a
+	// Batch for bulk operations is more efficient than making each
+	// change using repeated calls to DB's Set, Delete, and DeleteRange methods.
+	Batch() Batch
+
+	// Flush flushes DB changes to permanent storage.
+	// Flush must be called before the process crashes or exits,
+	// or else any changes since the previous Flush may be lost.
+	Flush()
+
+	// Close flushes and then closes the database.
+	// Like the other routines, it panics if an error happens,
+	// so there is no error result.
+	Close()
+
+	// Panic logs the error message and args using the database's slog.Logger
+	// and then panics with the text formatting of its arguments.
+	// It is meant to be called when database corruption or other
+	// database-related “can't happen” conditions have been detected.
+	Panic(msg string, args ...any)
+}
+
+// A Batch accumulates database mutations that are applied to a [DB]
+// as a single atomic operation. Applying bulk operations in a batch
+// is also more efficient than making individual [DB] method calls.
+// The batched operations apply in the order they are made.
+// For example, Set("a", "b") followed by Delete("a") is the same as
+// Delete("a"), while Delete("a") followed by Set("a", "b") is the same
+// as Set("a", "b").
+type Batch interface {
+	// Delete deletes any value associated with key.
+	// Delete of an unset key is a no-op.
+	//
+	// Delete does not retain any reference to key after returning.
+	Delete(key []byte)
+
+	// DeleteRange deletes all key-value pairs with start ≤ key ≤ end.
+	//
+	// DeleteRange does not retain any reference to start or end after returning.
+	DeleteRange(start, end []byte)
+
+	// Set sets the value associated with key to val.
+	//
+	// Set does not retain any reference to key or val after returning.
+	Set(key, val []byte)
+
+	// MaybeApply calls Apply if the batch is getting close to full.
+	// Every Batch has a limit to how many operations can be batched,
+	// so in a bulk operation where atomicity of the entire batch is not a concern,
+	// calling MaybeApply gives the Batch implementation
+	// permission to flush the batch at specific “safe points”.
+	// A typical limit for a batch is about 100MB worth of logged operations.
+	// MaybeApply reports whether it called Apply.
+	MaybeApply() bool
+
+	// Apply applies all the batched operations to the underlying DB
+	// as a single atomic unit.
+	// When Apply returns, the Batch is an empty batch ready for
+	// more operations.
+	Apply()
+}
+
+// Panic panics with the text formatting of its arguments.
+// It is meant to be called for database errors or corruption,
+// which have been defined to be impossible.
+// (See the [DB] documentation.)
+//
+// Panic is expected to be used by DB implementations.
+// DB clients should use the [DB.Panic] method instead.
+func Panic(msg string, args ...any) {
+	var b strings.Builder
+	slog.New(slog.NewTextHandler(&b, nil)).Error(msg, args...)
+	s := b.String()
+	if _, rest, ok := strings.Cut(s, " level=ERROR msg="); ok {
+		s = rest
+	}
+	panic(strings.TrimSpace(s))
+}
+
+// JSON converts x to JSON and returns the result.
+// It panics if there is any error converting x to JSON.
+// Since whether x can be converted to JSON depends
+// almost entirely on its type, a marshaling error indicates a
+// bug at the call site.
+//
+// (The exception is certain malformed UTF-8 and floating-point
+// infinity and NaN. Code must be careful not to use JSON with those.)
+func JSON(x any) []byte {
+	js, err := json.Marshal(x)
+	if err != nil {
+		panic(fmt.Sprintf("json.Marshal: %v", err))
+	}
+	return js
+}
+
+// Fmt formats data for printing,
+// first trying [ordered.DecodeFmt] in case data is an [ordered encoding],
+// then trying a backquoted string if possible
+// (handling simple JSON data),
+// and finally resorting to [strconv.QuoteToASCII].
+func Fmt(data []byte) string {
+	if s, err := ordered.DecodeFmt(data); err == nil {
+		return s
+	}
+	s := string(data)
+	if strconv.CanBackquote(s) {
+		return "`" + s + "`"
+	}
+	return strconv.QuoteToASCII(s)
+}
diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go
new file mode 100644
index 0000000..d0df4df
--- /dev/null
+++ b/internal/storage/db_test.go
@@ -0,0 +1,67 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+	"math"
+	"testing"
+
+	"rsc.io/ordered"
+)
+
+func TestPanic(t *testing.T) {
+	func() {
+		defer func() {
+			r := recover()
+			if r.(string) != "msg key=val" {
+				t.Errorf("panic value is not msg key=val:\n%s", r)
+			}
+		}()
+		Panic("msg", "key", "val")
+		t.Fatalf("did not panic")
+	}()
+
+}
+
+func TestJSON(t *testing.T) {
+	x := map[string]string{"a": "b"}
+	js := JSON(x)
+	want := `{"a":"b"}`
+	if string(js) != want {
+		t.Errorf("JSON(%v) = %#q, want %#q", x, js, want)
+	}
+
+	func() {
+		defer func() {
+			recover()
+		}()
+		JSON(math.NaN())
+		t.Errorf("JSON(NaN) did not panic")
+	}()
+}
+
+var fmtTests = []struct {
+	data []byte
+	out  string
+}{
+	{ordered.Encode(1, 2, 3), "(1, 2, 3)"},
+	{[]byte(`"hello"`), "`\"hello\"`"},
+	{[]byte("`hello`"), "\"`hello`\""},
+}
+
+func TestFmt(t *testing.T) {
+	for _, tt := range fmtTests {
+		out := Fmt(tt.data)
+		if out != tt.out {
+			t.Errorf("Fmt(%q) = %q, want %q", tt.data, out, tt.out)
+		}
+	}
+}
+
+func TestMemLocker(t *testing.T) {
+	m := new(MemLocker)
+
+	testDBLock(t, m)
+}
diff --git a/internal/storage/mem.go b/internal/storage/mem.go
new file mode 100644
index 0000000..ce9a8f4
--- /dev/null
+++ b/internal/storage/mem.go
@@ -0,0 +1,342 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+	"bytes"
+	"fmt"
+	"iter"
+	"log/slog"
+	"slices"
+	"sync"
+
+	"golang.org/x/oscar/internal/llm"
+	"rsc.io/omap"
+	"rsc.io/ordered"
+	"rsc.io/top"
+)
+
+// A MemLocker is a single-process implementation
+// of the database Lock and Unlock methods,
+// suitable if there is only one process accessing the
+// database at a time.
+//
+// The zero value for a MemLocker
+// is a valid MemLocker with no locks held.
+// It must not be copied after first use.
+type MemLocker struct {
+	mu    sync.Mutex
+	locks map[string]*sync.Mutex
+}
+
+// Lock locks the mutex with the given name.
+func (l *MemLocker) Lock(name string) {
+	l.mu.Lock()
+	if l.locks == nil {
+		l.locks = make(map[string]*sync.Mutex)
+	}
+	mu := l.locks[name]
+	if mu == nil {
+		mu = new(sync.Mutex)
+		l.locks[name] = mu
+	}
+	l.mu.Unlock()
+
+	mu.Lock()
+}
+
+// Unlock unlocks the mutex with the given name.
+func (l *MemLocker) Unlock(name string) {
+	l.mu.Lock()
+	mu := l.locks[name]
+	l.mu.Unlock()
+	if mu == nil {
+		panic("Unlock of never locked key")
+	}
+	mu.Unlock()
+}
+
+// MemDB returns an in-memory DB implementation.
+func MemDB() DB {
+	return new(memDB)
+}
+
+// A memDB is an in-memory DB implementation,.
+type memDB struct {
+	MemLocker
+	mu   sync.RWMutex
+	data omap.Map[string, []byte]
+}
+
+func (*memDB) Close() {}
+
+func (*memDB) Panic(msg string, args ...any) {
+	Panic(msg, args...)
+}
+
+// Get returns the value associated with the key.
+func (db *memDB) Get(key []byte) (val []byte, ok bool) {
+	db.mu.RLock()
+	v, ok := db.data.Get(string(key))
+	db.mu.RUnlock()
+	if ok {
+		v = bytes.Clone(v)
+	}
+	return v, ok
+}
+
+// Scan returns an iterator overall key-value pairs
+// in the range start ≤ key ≤ end.
+func (db *memDB) Scan(start, end []byte) iter.Seq2[[]byte, func() []byte] {
+	lo := string(start)
+	hi := string(end)
+	return func(yield func(key []byte, val func() []byte) bool) {
+		db.mu.RLock()
+		locked := true
+		defer func() {
+			if locked {
+				db.mu.RUnlock()
+			}
+		}()
+		for k, v := range db.data.Scan(lo, hi) {
+			key := []byte(k)
+			val := func() []byte { return bytes.Clone(v) }
+			db.mu.RUnlock()
+			locked = false
+			if !yield(key, val) {
+				return
+			}
+			db.mu.RLock()
+			locked = true
+		}
+	}
+}
+
+// Delete deletes any entry with the given key.
+func (db *memDB) Delete(key []byte) {
+	db.mu.Lock()
+	defer db.mu.Unlock()
+
+	db.data.Delete(string(key))
+}
+
+// DeleteRange deletes all entries with start ≤ key ≤ end.
+func (db *memDB) DeleteRange(start, end []byte) {
+	db.mu.Lock()
+	defer db.mu.Unlock()
+
+	db.data.DeleteRange(string(start), string(end))
+}
+
+// Set sets the value associated with key to val.
+func (db *memDB) Set(key, val []byte) {
+	db.mu.Lock()
+	defer db.mu.Unlock()
+
+	db.data.Set(string(key), bytes.Clone(val))
+}
+
+// Batch returns a new batch.
+func (db *memDB) Batch() Batch {
+	return &memBatch{db: db}
+}
+
+// Flush flushes everything to persistent storage.
+// Since this is an in-memory database, the memory is as persistent as it gets.
+func (db *memDB) Flush() {
+}
+
+// A memBatch is a Batch for a memDB.
+type memBatch struct {
+	db  *memDB   // underlying database
+	ops []func() // operations to apply
+}
+
+func (b *memBatch) Set(key, val []byte) {
+	k := string(key)
+	v := bytes.Clone(val)
+	b.ops = append(b.ops, func() { b.db.data.Set(k, v) })
+}
+
+func (b *memBatch) Delete(key []byte) {
+	k := string(key)
+	b.ops = append(b.ops, func() { b.db.data.Delete(k) })
+}
+
+func (b *memBatch) DeleteRange(start, end []byte) {
+	s := string(start)
+	e := string(end)
+	b.ops = append(b.ops, func() { b.db.data.DeleteRange(s, e) })
+}
+
+func (b *memBatch) MaybeApply() bool {
+	return false
+}
+
+func (b *memBatch) Apply() {
+	b.db.mu.Lock()
+	defer b.db.mu.Unlock()
+
+	for _, op := range b.ops {
+		op()
+	}
+}
+
+// A memVectorDB is a VectorDB implementing in-memory search
+// but storing its vectors in an underlying DB.
+type memVectorDB struct {
+	storage   DB
+	slog      *slog.Logger
+	namespace string
+
+	mu    sync.RWMutex
+	cache map[string][]float32 // in-memory cache of all vectors, indexed by id
+}
+
+// MemVectorDB returns a VectorDB that stores its vectors in db
+// but uses a cached, in-memory copy to implement Search using
+// a brute-force scan.
+//
+// The namespace is incorporated into the keys used in the underlying db,
+// to allow multiple vector databases to be stored in a single [DB].
+//
+// When MemVectorDB is called, it reads all previously stored vectors
+// from db; after that, changes must be made using the MemVectorDB
+// Set method.
+//
+// A MemVectorDB requires approximately 3kB of memory per stored vector.
+//
+// The db keys used by a MemVectorDB have the form
+//
+//	ordered.Encode("llm.Vector", namespace, id)
+//
+// where id is the document ID passed to Set.
+func MemVectorDB(db DB, lg *slog.Logger, namespace string) VectorDB {
+	// NOTE: We could cut the memory per stored vector in half by quantizing to int16.
+	//
+	// The worst case score error in a dot product over 768 entries
+	// caused by quantization error of e is approximately 55.4 e:
+	//
+	// For a unit vector v of length N, the way to maximize Σ v[i] is to make
+	// all the vector entries the same value x, such that sqrt(N x²) = 1,
+	// so x = 1/sqrt(N). The maximum of Σ v[i] is therefore N/sqrt(N).
+	//
+	// Looking at the dot product error for v₁ · v₂ caused by adding
+	// quantization error vectors e₁ and e₂:
+	//
+	// 	|Σ v₁[i]*v₂[i] - Σ (v₁[i]+e₁[i])*(v₂[i]+e₂[i])| =
+	//	|Σ v₁[i]*v₂[i] - Σ (v₁[i]*v₂[i] + e₁[i]*v₂[i] + e₂[i]*v₁[i] + e₁[i]*e₂[i])| =
+	//	|Σ (e₁[i]*v₂[i] + e₂[i]*v₁[i] + e₁[i]*e₂[i])| ≤
+	//	Σ |e₁[i]*v₂[i]| + Σ |e₂[i]*v₁[i]| + Σ |e₁[i]*e₂[i]| ≤
+	//	e × (Σ v₁[i] + Σ v₂[i]) + N e² ≤
+	//	e × 2 × N/sqrt(N) + N e² =
+	//	e × (2 × N/sqrt(N) + e) ~= 55.4 e for N=768.
+	//
+	// Quantizing the float32 range [-1,+1] to int16 range [-32768,32767]
+	// would introduce a maximum quantization error e of
+	// ½ × (+1 - -1) / (32767 - -32768) = 1/65535 = 0.000015259,
+	// resulting in a maximum dot product error of approximately 0.00846,
+	// which would not change the result order significantly.
+
+	vdb := &memVectorDB{
+		storage:   db,
+		slog:      lg,
+		namespace: namespace,
+		cache:     make(map[string][]float32),
+	}
+
+	// Load all the previously-stored vectors.
+	vdb.cache = make(map[string][]float32)
+	for key, getVal := range vdb.storage.Scan(
+		ordered.Encode("llm.Vector", namespace),
+		ordered.Encode("llm.Vector", namespace, ordered.Inf)) {
+
+		var id string
+		if err := ordered.Decode(key, nil, nil, &id); err != nil {
+			// unreachable except data corruption
+			panic(fmt.Errorf("MemVectorDB decode key=%v: %v", Fmt(key), err))
+		}
+		val := getVal()
+		if len(val)%4 != 0 {
+			// unreachable except data corruption
+			panic(fmt.Errorf("MemVectorDB decode key=%v bad len(val)=%d", Fmt(key), len(val)))
+		}
+		var vec llm.Vector
+		vec.Decode(val)
+		vdb.cache[id] = vec
+	}
+
+	vdb.slog.Info("loaded vectordb", "n", len(vdb.cache), "namespace", namespace)
+	return vdb
+}
+
+func (db *memVectorDB) Set(id string, vec llm.Vector) {
+	db.storage.Set(ordered.Encode("llm.Vector", db.namespace, id), vec.Encode())
+
+	db.mu.Lock()
+	db.cache[id] = slices.Clone(vec)
+	db.mu.Unlock()
+}
+
+func (db *memVectorDB) Get(name string) (llm.Vector, bool) {
+	db.mu.RLock()
+	vec, ok := db.cache[name]
+	db.mu.RUnlock()
+	return vec, ok
+}
+
+func (db *memVectorDB) Search(target llm.Vector, n int) []VectorResult {
+	db.mu.RLock()
+	defer db.mu.RUnlock()
+	best := top.New(n, VectorResult.cmp)
+	for name, vec := range db.cache {
+		if len(vec) != len(target) {
+			continue
+		}
+		best.Add(VectorResult{name, target.Dot(vec)})
+	}
+	return best.Take()
+}
+
+func (db *memVectorDB) Flush() {
+	db.storage.Flush()
+}
+
+// memVectorBatch implements VectorBatch for a memVectorDB.
+type memVectorBatch struct {
+	db *memVectorDB          // underlying memVectorDB
+	sb Batch                 // batch for underlying DB
+	w  map[string]llm.Vector // vectors to write
+}
+
+func (db *memVectorDB) Batch() VectorBatch {
+	return &memVectorBatch{db, db.storage.Batch(), make(map[string]llm.Vector)}
+}
+
+func (b *memVectorBatch) Set(name string, vec llm.Vector) {
+	b.sb.Set(ordered.Encode("llm.Vector", b.db.namespace, name), vec.Encode())
+
+	b.w[name] = slices.Clone(vec)
+}
+
+func (b *memVectorBatch) MaybeApply() bool {
+	if !b.sb.MaybeApply() {
+		return false
+	}
+	b.Apply()
+	return true
+}
+
+func (b *memVectorBatch) Apply() {
+	b.sb.Apply()
+
+	b.db.mu.Lock()
+	defer b.db.mu.Unlock()
+
+	for name, vec := range b.w {
+		b.db.cache[name] = vec
+	}
+	clear(b.w)
+}
diff --git a/internal/storage/mem_test.go b/internal/storage/mem_test.go
new file mode 100644
index 0000000..b276051
--- /dev/null
+++ b/internal/storage/mem_test.go
@@ -0,0 +1,59 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+	"testing"
+
+	"golang.org/x/oscar/internal/testutil"
+)
+
+func TestMemDB(t *testing.T) {
+	TestDB(t, MemDB())
+}
+
+func TestMemVectorDB(t *testing.T) {
+	db := MemDB()
+	TestVectorDB(t, func() VectorDB { return MemVectorDB(db, testutil.Slogger(t), "") })
+}
+
+type maybeDB struct {
+	DB
+	maybe bool
+}
+
+type maybeBatch struct {
+	db *maybeDB
+	Batch
+}
+
+func (db *maybeDB) Batch() Batch {
+	return &maybeBatch{db: db, Batch: db.DB.Batch()}
+}
+
+func (b *maybeBatch) MaybeApply() bool {
+	return b.db.maybe
+}
+
+// Test that when db.Batch.MaybeApply returns true,
+// the memvector Batch MaybeApply applies the memvector ops.
+func TestMemVectorBatchMaybeApply(t *testing.T) {
+	db := &maybeDB{DB: MemDB()}
+	vdb := MemVectorDB(db, testutil.Slogger(t), "")
+	b := vdb.Batch()
+	b.Set("apple3", embed("apple3"))
+	if _, ok := vdb.Get("apple3"); ok {
+		t.Errorf("Get(apple3) succeeded before batch apply")
+	}
+	b.MaybeApply() // should not apply because the db Batch does not apply
+	if _, ok := vdb.Get("apple3"); ok {
+		t.Errorf("Get(apple3) succeeded after MaybeApply that didn't apply")
+	}
+	db.maybe = true
+	b.MaybeApply() // now should apply
+	if _, ok := vdb.Get("apple3"); !ok {
+		t.Errorf("Get(apple3) failed after MaybeApply that did apply")
+	}
+}
diff --git a/internal/storage/test.go b/internal/storage/test.go
new file mode 100644
index 0000000..9a193a6
--- /dev/null
+++ b/internal/storage/test.go
@@ -0,0 +1,126 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+	"fmt"
+	"slices"
+	"sync"
+	"testing"
+
+	"rsc.io/ordered"
+)
+
+// TestDB runs basic tests on db.
+// It should be empty when TestDB is called.
+func TestDB(t *testing.T, db DB) {
+	db.Set([]byte("key"), []byte("value"))
+	if val, ok := db.Get([]byte("key")); string(val) != "value" || ok != true {
+		// unreachable except for bad db
+		t.Fatalf("Get(key) = %q, %v, want %q, true", val, ok, "value")
+	}
+	if val, ok := db.Get([]byte("missing")); val != nil || ok != false {
+		// unreachable except for bad db
+		t.Fatalf("Get(missing) = %v, %v, want nil, false", val, ok)
+	}
+
+	db.Delete([]byte("key"))
+	if val, ok := db.Get([]byte("key")); val != nil || ok != false {
+		// unreachable except for bad db
+		t.Fatalf("Get(key) after delete = %v, %v, want nil, false", val, ok)
+	}
+
+	b := db.Batch()
+	for i := range 10 {
+		b.Set(ordered.Encode(i), []byte(fmt.Sprint(i)))
+		b.MaybeApply()
+	}
+	b.Apply()
+
+	collect := func(min, max, stop int) []int {
+		t.Helper()
+		var list []int
+		for key, val := range db.Scan(ordered.Encode(min), ordered.Encode(max)) {
+			var i int
+			if err := ordered.Decode(key, &i); err != nil {
+				// unreachable except for bad db
+				t.Fatalf("db.Scan malformed key %v", Fmt(key))
+			}
+			if sv, want := string(val()), fmt.Sprint(i); sv != want {
+				// unreachable except for bad db
+				t.Fatalf("db.Scan key %v val=%q, want %q", i, sv, want)
+			}
+			list = append(list, i)
+			if i == stop {
+				break
+			}
+		}
+		return list
+	}
+
+	if scan, want := collect(3, 6, -1), []int{3, 4, 5, 6}; !slices.Equal(scan, want) {
+		// unreachable except for bad db
+		t.Fatalf("Scan(3, 6) = %v, want %v", scan, want)
+	}
+
+	if scan, want := collect(3, 6, 5), []int{3, 4, 5}; !slices.Equal(scan, want) {
+		// unreachable except for bad db
+		t.Fatalf("Scan(3, 6) with break at 5 = %v, want %v", scan, want)
+	}
+
+	db.DeleteRange(ordered.Encode(4), ordered.Encode(7))
+	if scan, want := collect(-1, 11, -1), []int{0, 1, 2, 3, 8, 9}; !slices.Equal(scan, want) {
+		// unreachable except for bad db
+		t.Fatalf("Scan(-1, 11) after Delete(4, 7) = %v, want %v", scan, want)
+	}
+
+	b = db.Batch()
+	for i := range 5 {
+		b.Delete(ordered.Encode(i))
+		b.Set(ordered.Encode(2*i), []byte(fmt.Sprint(2*i)))
+	}
+	b.DeleteRange(ordered.Encode(0), ordered.Encode(0))
+	b.Apply()
+	if scan, want := collect(-1, 11, -1), []int{6, 8, 9}; !slices.Equal(scan, want) {
+		// unreachable except for bad db
+		t.Fatalf("Scan(-1, 11) after batch Delete+Set = %v, want %v", scan, want)
+	}
+
+	// Can't test much, but check that it doesn't crash.
+	db.Flush()
+
+	testDBLock(t, db)
+}
+
+type locker interface {
+	Lock(string)
+	Unlock(string)
+}
+
+func testDBLock(t *testing.T, db locker) {
+	var x int
+	db.Lock("abc")
+	var wg sync.WaitGroup
+	wg.Add(1)
+	go func() {
+		db.Lock("abc")
+		x = 2 // cause race if not synchronized
+		db.Unlock("abc")
+		wg.Done()
+	}()
+	x = 1 // cause race if not synchronized
+	db.Unlock("abc")
+	wg.Wait()
+	_ = x
+
+	func() {
+		defer func() {
+			recover()
+		}()
+		db.Unlock("def")
+		t.Errorf("Unlock never-locked key did not panic")
+	}()
+
+}
diff --git a/internal/storage/vectordb.go b/internal/storage/vectordb.go
new file mode 100644
index 0000000..b7d2b7a
--- /dev/null
+++ b/internal/storage/vectordb.go
@@ -0,0 +1,83 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+	"cmp"
+
+	"golang.org/x/oscar/internal/llm"
+)
+
+// A VectorDB is a vector database that implements
+// nearest-neighbor search over embedding vectors
+// corresponding to documents.
+type VectorDB interface {
+	// Set sets the vector associated with the given document ID to vec.
+	Set(id string, vec llm.Vector)
+
+	// TODO: Add Delete.
+
+	// Get gets the vector associated with the given document ID.
+	// If no such document exists, Get returns nil, false.
+	// If a document exists, Get returns vec, true.
+	Get(id string) (llm.Vector, bool)
+
+	// Batch returns a new [VectorBatch] that accumulates
+	// vector database mutations to apply in an atomic operation.
+	// It is more efficient than repeated calls to Set.
+	Batch() VectorBatch
+
+	// Search searches the database for the n vectors
+	// most similar to vec, returning the document IDs
+	// and similarity scores.
+	//
+	// Normally a VectorDB is used entirely with vectors of a single length.
+	// Search ignores stored vectors with a different length than vec.
+	Search(vec llm.Vector, n int) []VectorResult
+
+	// Flush flushes storage to disk.
+	Flush()
+}
+
+// A VectorBatch accumulates vector database mutations
+// that are applied to a [VectorDB] in a single atomic operation.
+// Applying bulk operations in a batch is also more efficient than
+// making individual [VectorDB] method calls.
+// The batched operations apply in the order they are made.
+type VectorBatch interface {
+	// Set sets the vector associated with the given document ID to vec.
+	Set(id string, vec llm.Vector)
+
+	// TODO: Add Delete.
+
+	// MaybeApply calls Apply if the VectorBatch is getting close to full.
+	// Every VectorBatch has a limit to how many operations can be batched,
+	// so in a bulk operation where atomicity of the entire batch is not a concern,
+	// calling MaybeApply gives the VectorBatch implementation
+	// permission to flush the batch at specific “safe points”.
+	// A typical limit for a batch is about 100MB worth of logged operations.
+	//
+	// MaybeApply reports whether it called Apply.
+	MaybeApply() bool
+
+	// Apply applies all the batched operations to the underlying VectorDB
+	// as a single atomic unit.
+	// When Apply returns, the VectorBatch is an empty batch ready for
+	// more operations.
+	Apply()
+}
+
+// A VectorResult is a single document returned by a VectorDB search.
+type VectorResult struct {
+	ID    string  // document ID
+	Score float64 // similarity score in range [0, 1]; 1 is exact match
+}
+
+func (x VectorResult) cmp(y VectorResult) int {
+	if x.Score != y.Score {
+		return cmp.Compare(x.Score, y.Score)
+	}
+	return cmp.Compare(x.ID, y.ID)
+}
diff --git a/internal/storage/vectordb_test.go b/internal/storage/vectordb_test.go
new file mode 100644
index 0000000..2348120
--- /dev/null
+++ b/internal/storage/vectordb_test.go
@@ -0,0 +1,30 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import "testing"
+
+func TestVectorResultCompare(t *testing.T) {
+	type R = VectorResult
+	var tests = []struct {
+		x, y VectorResult
+		cmp  int
+	}{
+		{R{"b", 0.5}, R{"c", 0.5}, -1},
+		{R{"b", 0.4}, R{"a", 0.5}, -1},
+	}
+
+	try := func(x, y VectorResult, cmp int) {
+		if c := x.cmp(y); c != cmp {
+			t.Errorf("Compare(%v, %v) = %d, want %d", x, y, c, cmp)
+		}
+	}
+	for _, tt := range tests {
+		try(tt.x, tt.x, 0)
+		try(tt.y, tt.y, 0)
+		try(tt.x, tt.y, tt.cmp)
+		try(tt.y, tt.x, -tt.cmp)
+	}
+}
diff --git a/internal/storage/vtest.go b/internal/storage/vtest.go
new file mode 100644
index 0000000..a19f6ad
--- /dev/null
+++ b/internal/storage/vtest.go
@@ -0,0 +1,74 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+	"math"
+	"reflect"
+	"slices"
+	"testing"
+
+	"golang.org/x/oscar/internal/llm"
+)
+
+func TestVectorDB(t *testing.T, newdb func() VectorDB) {
+	vdb := newdb()
+
+	vdb.Set("orange2", embed("orange2"))
+	vdb.Set("orange1", embed("orange1"))
+	b := vdb.Batch()
+	b.Set("apple3", embed("apple3"))
+	b.Set("apple4", embed("apple4"))
+	b.Set("ignore", embed("bad")[:4])
+	b.Apply()
+
+	v, ok := vdb.Get("apple3")
+	if !ok || !slices.Equal(v, embed("apple3")) {
+		// unreachable except bad vectordb
+		t.Errorf("Get(apple3) = %v, %v, want %v, true", v, ok, embed("apple3"))
+	}
+
+	want := []VectorResult{
+		{"apple4", 0.9999961187341375},
+		{"apple3", 0.9999843342970269},
+		{"orange1", 0.38062230442542155},
+		{"orange2", 0.3785152783773009},
+	}
+	have := vdb.Search(embed("apple5"), 5)
+	if !reflect.DeepEqual(have, want) {
+		// unreachable except bad vectordb
+		t.Fatalf("Search(apple5, 5):\nhave %v\nwant %v", have, want)
+	}
+
+	vdb.Flush()
+
+	vdb = newdb()
+	have = vdb.Search(embed("apple5"), 3)
+	want = want[:3]
+	if !reflect.DeepEqual(have, want) {
+		// unreachable except bad vectordb
+		t.Errorf("Search(apple5, 3) in fresh database:\nhave %v\nwant %v", have, want)
+	}
+
+}
+
+func embed(text string) llm.Vector {
+	const vectorLen = 16
+	v := make(llm.Vector, vectorLen)
+	d := float32(0)
+	for i := range len(text) {
+		v[i] = float32(byte(text[i])) / 256
+		d += float32(v[i] * v[i]) // float32() to avoid FMA
+	}
+	if len(text) < len(v) {
+		v[len(text)] = -1
+		d += 1
+	}
+	d = float32(1 / math.Sqrt(float64(d)))
+	for i, x := range v {
+		v[i] = x * d
+	}
+	return v
+}