internal/storage: storage abstractions
Package storage defines DB, a key-value database,
and VectorDB, a vector database.
Copied from https://github.com/rsc/gaby/commit/3f1bdd4.
Change-Id: I057ad7086c7cd53826c8022f674ea3943c9bcba5
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/597138
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
diff --git a/internal/storage/db.go b/internal/storage/db.go
new file mode 100644
index 0000000..980f5c5
--- /dev/null
+++ b/internal/storage/db.go
@@ -0,0 +1,188 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package storage defines the storage abstractions needed for Oscar:
+// [DB], a basic key-value store, and [VectorDB], a vector database.
+// The storage needs are intentionally minimal (avoiding, for example,
+// a requirement on SQL), to admit as many implementations as possible.
+package storage
+
+import (
+ "encoding/json"
+ "fmt"
+ "iter"
+ "log/slog"
+ "strconv"
+ "strings"
+
+ "rsc.io/ordered"
+)
+
+// A DB is a key-value database.
+//
+// DB operations are assumed not to fail.
+// They panic, intending to take down the program,
+// if there is an error accessing the database.
+// The assumption is that the program cannot possibly
+// continue without the database, since that's where all the state is stored.
+// Similarly, clients of DB conventionally panic
+// using [DB.Panic] if the database returns corrupted data.
+// Code using multiple parallel database operations can recover
+// at the outermost calls.
+type DB interface {
+ // Lock acquires a lock on the given name, which need not exist in the database.
+ // After a successful Lock(name),
+ // any other call to Lock(name) from any other client of the database
+ // (including in another process, for shared databases)
+ // must block until Unlock(name) has been called.
+ // In a shared database, a lock may also unlock
+ // when the client disconnects or times out.
+ Lock(name string)
+
+ // Unlock releases the lock with the given name,
+ // which the caller must have locked.
+ Unlock(name string)
+
+ // Set sets the value associated with key to val.
+ Set(key, val []byte)
+
+ // Get looks up the value associated with key.
+ // If there is no entry for key in the database, Get returns nil, false.
+ // Otherwise it returns val, true.
+ Get(key []byte) (val []byte, ok bool)
+
+ // Scan returns an iterator over all key-value pairs with start ≤ key ≤ end.
+ // The second value in each iteration pair is a function returning the value,
+ // not the value itself:
+ //
+ // for key, getVal := range db.Scan([]byte("aaa"), []byte("zzz")) {
+ // val := getVal()
+ // fmt.Printf("%q: %q\n", key, val)
+ // }
+ //
+ // In iterations that only need the keys or only need the values for a subset of keys,
+ // some DB implementations may avoid work when the value function is not called.
+ //
+ // A Scan may or may not observe concurrent modifications made
+ // using Set, Delete, and DeleteRange.
+ Scan(start, end []byte) iter.Seq2[[]byte, func() []byte]
+
+ // Delete deletes any value associated with key.
+ // Delete of an unset key is a no-op.
+ Delete(key []byte)
+
+ // DeleteRange deletes all key-value pairs with start ≤ key ≤ end.
+ DeleteRange(start, end []byte)
+
+ // Batch returns a new [Batch] that accumulates database mutations
+ // to apply in an atomic operation. In addition to the atomicity, using a
+ // Batch for bulk operations is more efficient than making each
+ // change using repeated calls to DB's Set, Delete, and DeleteRange methods.
+ Batch() Batch
+
+ // Flush flushes DB changes to permanent storage.
+ // Flush must be called before the process crashes or exits,
+ // or else any changes since the previous Flush may be lost.
+ Flush()
+
+ // Close flushes and then closes the database.
+ // Like the other routines, it panics if an error happens,
+ // so there is no error result.
+ Close()
+
+ // Panic logs the error message and args using the database's slog.Logger
+ // and then panics with the text formatting of its arguments.
+ // It is meant to be called when database corruption or other
+ // database-related “can't happen” conditions have been detected.
+ Panic(msg string, args ...any)
+}
+
+// A Batch accumulates database mutations that are applied to a [DB]
+// as a single atomic operation. Applying bulk operations in a batch
+// is also more efficient than making individual [DB] method calls.
+// The batched operations apply in the order they are made.
+// For example, Set("a", "b") followed by Delete("a") is the same as
+// Delete("a"), while Delete("a") followed by Set("a", "b") is the same
+// as Set("a", "b").
+type Batch interface {
+ // Delete deletes any value associated with key.
+ // Delete of an unset key is a no-op.
+ //
+ // Delete does not retain any reference to key after returning.
+ Delete(key []byte)
+
+ // DeleteRange deletes all key-value pairs with start ≤ key ≤ end.
+ //
+ // DeleteRange does not retain any reference to start or end after returning.
+ DeleteRange(start, end []byte)
+
+ // Set sets the value associated with key to val.
+ //
+ // Set does not retain any reference to key or val after returning.
+ Set(key, val []byte)
+
+ // MaybeApply calls Apply if the batch is getting close to full.
+ // Every Batch has a limit to how many operations can be batched,
+ // so in a bulk operation where atomicity of the entire batch is not a concern,
+ // calling MaybeApply gives the Batch implementation
+ // permission to flush the batch at specific “safe points”.
+ // A typical limit for a batch is about 100MB worth of logged operations.
+ // MaybeApply reports whether it called Apply.
+ MaybeApply() bool
+
+ // Apply applies all the batched operations to the underlying DB
+ // as a single atomic unit.
+ // When Apply returns, the Batch is an empty batch ready for
+ // more operations.
+ Apply()
+}
+
+// Panic panics with the text formatting of its arguments.
+// It is meant to be called for database errors or corruption,
+// which have been defined to be impossible.
+// (See the [DB] documentation.)
+//
+// Panic is expected to be used by DB implementations.
+// DB clients should use the [DB.Panic] method instead.
+func Panic(msg string, args ...any) {
+ var b strings.Builder
+ slog.New(slog.NewTextHandler(&b, nil)).Error(msg, args...)
+ s := b.String()
+ if _, rest, ok := strings.Cut(s, " level=ERROR msg="); ok {
+ s = rest
+ }
+ panic(strings.TrimSpace(s))
+}
+
+// JSON converts x to JSON and returns the result.
+// It panics if there is any error converting x to JSON.
+// Since whether x can be converted to JSON depends
+// almost entirely on its type, a marshaling error indicates a
+// bug at the call site.
+//
+// (The exception is certain malformed UTF-8 and floating-point
+// infinity and NaN. Code must be careful not to use JSON with those.)
+func JSON(x any) []byte {
+ js, err := json.Marshal(x)
+ if err != nil {
+ panic(fmt.Sprintf("json.Marshal: %v", err))
+ }
+ return js
+}
+
+// Fmt formats data for printing,
+// first trying [ordered.DecodeFmt] in case data is an [ordered encoding],
+// then trying a backquoted string if possible
+// (handling simple JSON data),
+// and finally resorting to [strconv.QuoteToASCII].
+func Fmt(data []byte) string {
+ if s, err := ordered.DecodeFmt(data); err == nil {
+ return s
+ }
+ s := string(data)
+ if strconv.CanBackquote(s) {
+ return "`" + s + "`"
+ }
+ return strconv.QuoteToASCII(s)
+}
diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go
new file mode 100644
index 0000000..d0df4df
--- /dev/null
+++ b/internal/storage/db_test.go
@@ -0,0 +1,67 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+ "math"
+ "testing"
+
+ "rsc.io/ordered"
+)
+
+func TestPanic(t *testing.T) {
+ func() {
+ defer func() {
+ r := recover()
+ if r.(string) != "msg key=val" {
+ t.Errorf("panic value is not msg key=val:\n%s", r)
+ }
+ }()
+ Panic("msg", "key", "val")
+ t.Fatalf("did not panic")
+ }()
+
+}
+
+func TestJSON(t *testing.T) {
+ x := map[string]string{"a": "b"}
+ js := JSON(x)
+ want := `{"a":"b"}`
+ if string(js) != want {
+ t.Errorf("JSON(%v) = %#q, want %#q", x, js, want)
+ }
+
+ func() {
+ defer func() {
+ recover()
+ }()
+ JSON(math.NaN())
+ t.Errorf("JSON(NaN) did not panic")
+ }()
+}
+
+var fmtTests = []struct {
+ data []byte
+ out string
+}{
+ {ordered.Encode(1, 2, 3), "(1, 2, 3)"},
+ {[]byte(`"hello"`), "`\"hello\"`"},
+ {[]byte("`hello`"), "\"`hello`\""},
+}
+
+func TestFmt(t *testing.T) {
+ for _, tt := range fmtTests {
+ out := Fmt(tt.data)
+ if out != tt.out {
+ t.Errorf("Fmt(%q) = %q, want %q", tt.data, out, tt.out)
+ }
+ }
+}
+
+func TestMemLocker(t *testing.T) {
+ m := new(MemLocker)
+
+ testDBLock(t, m)
+}
diff --git a/internal/storage/mem.go b/internal/storage/mem.go
new file mode 100644
index 0000000..ce9a8f4
--- /dev/null
+++ b/internal/storage/mem.go
@@ -0,0 +1,342 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+ "bytes"
+ "fmt"
+ "iter"
+ "log/slog"
+ "slices"
+ "sync"
+
+ "golang.org/x/oscar/internal/llm"
+ "rsc.io/omap"
+ "rsc.io/ordered"
+ "rsc.io/top"
+)
+
+// A MemLocker is a single-process implementation
+// of the database Lock and Unlock methods,
+// suitable if there is only one process accessing the
+// database at a time.
+//
+// The zero value for a MemLocker
+// is a valid MemLocker with no locks held.
+// It must not be copied after first use.
+type MemLocker struct {
+ mu sync.Mutex
+ locks map[string]*sync.Mutex
+}
+
+// Lock locks the mutex with the given name.
+func (l *MemLocker) Lock(name string) {
+ l.mu.Lock()
+ if l.locks == nil {
+ l.locks = make(map[string]*sync.Mutex)
+ }
+ mu := l.locks[name]
+ if mu == nil {
+ mu = new(sync.Mutex)
+ l.locks[name] = mu
+ }
+ l.mu.Unlock()
+
+ mu.Lock()
+}
+
+// Unlock unlocks the mutex with the given name.
+func (l *MemLocker) Unlock(name string) {
+ l.mu.Lock()
+ mu := l.locks[name]
+ l.mu.Unlock()
+ if mu == nil {
+ panic("Unlock of never locked key")
+ }
+ mu.Unlock()
+}
+
+// MemDB returns an in-memory DB implementation.
+func MemDB() DB {
+ return new(memDB)
+}
+
+// A memDB is an in-memory DB implementation,.
+type memDB struct {
+ MemLocker
+ mu sync.RWMutex
+ data omap.Map[string, []byte]
+}
+
+func (*memDB) Close() {}
+
+func (*memDB) Panic(msg string, args ...any) {
+ Panic(msg, args...)
+}
+
+// Get returns the value associated with the key.
+func (db *memDB) Get(key []byte) (val []byte, ok bool) {
+ db.mu.RLock()
+ v, ok := db.data.Get(string(key))
+ db.mu.RUnlock()
+ if ok {
+ v = bytes.Clone(v)
+ }
+ return v, ok
+}
+
+// Scan returns an iterator overall key-value pairs
+// in the range start ≤ key ≤ end.
+func (db *memDB) Scan(start, end []byte) iter.Seq2[[]byte, func() []byte] {
+ lo := string(start)
+ hi := string(end)
+ return func(yield func(key []byte, val func() []byte) bool) {
+ db.mu.RLock()
+ locked := true
+ defer func() {
+ if locked {
+ db.mu.RUnlock()
+ }
+ }()
+ for k, v := range db.data.Scan(lo, hi) {
+ key := []byte(k)
+ val := func() []byte { return bytes.Clone(v) }
+ db.mu.RUnlock()
+ locked = false
+ if !yield(key, val) {
+ return
+ }
+ db.mu.RLock()
+ locked = true
+ }
+ }
+}
+
+// Delete deletes any entry with the given key.
+func (db *memDB) Delete(key []byte) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+
+ db.data.Delete(string(key))
+}
+
+// DeleteRange deletes all entries with start ≤ key ≤ end.
+func (db *memDB) DeleteRange(start, end []byte) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+
+ db.data.DeleteRange(string(start), string(end))
+}
+
+// Set sets the value associated with key to val.
+func (db *memDB) Set(key, val []byte) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+
+ db.data.Set(string(key), bytes.Clone(val))
+}
+
+// Batch returns a new batch.
+func (db *memDB) Batch() Batch {
+ return &memBatch{db: db}
+}
+
+// Flush flushes everything to persistent storage.
+// Since this is an in-memory database, the memory is as persistent as it gets.
+func (db *memDB) Flush() {
+}
+
+// A memBatch is a Batch for a memDB.
+type memBatch struct {
+ db *memDB // underlying database
+ ops []func() // operations to apply
+}
+
+func (b *memBatch) Set(key, val []byte) {
+ k := string(key)
+ v := bytes.Clone(val)
+ b.ops = append(b.ops, func() { b.db.data.Set(k, v) })
+}
+
+func (b *memBatch) Delete(key []byte) {
+ k := string(key)
+ b.ops = append(b.ops, func() { b.db.data.Delete(k) })
+}
+
+func (b *memBatch) DeleteRange(start, end []byte) {
+ s := string(start)
+ e := string(end)
+ b.ops = append(b.ops, func() { b.db.data.DeleteRange(s, e) })
+}
+
+func (b *memBatch) MaybeApply() bool {
+ return false
+}
+
+func (b *memBatch) Apply() {
+ b.db.mu.Lock()
+ defer b.db.mu.Unlock()
+
+ for _, op := range b.ops {
+ op()
+ }
+}
+
+// A memVectorDB is a VectorDB implementing in-memory search
+// but storing its vectors in an underlying DB.
+type memVectorDB struct {
+ storage DB
+ slog *slog.Logger
+ namespace string
+
+ mu sync.RWMutex
+ cache map[string][]float32 // in-memory cache of all vectors, indexed by id
+}
+
+// MemVectorDB returns a VectorDB that stores its vectors in db
+// but uses a cached, in-memory copy to implement Search using
+// a brute-force scan.
+//
+// The namespace is incorporated into the keys used in the underlying db,
+// to allow multiple vector databases to be stored in a single [DB].
+//
+// When MemVectorDB is called, it reads all previously stored vectors
+// from db; after that, changes must be made using the MemVectorDB
+// Set method.
+//
+// A MemVectorDB requires approximately 3kB of memory per stored vector.
+//
+// The db keys used by a MemVectorDB have the form
+//
+// ordered.Encode("llm.Vector", namespace, id)
+//
+// where id is the document ID passed to Set.
+func MemVectorDB(db DB, lg *slog.Logger, namespace string) VectorDB {
+ // NOTE: We could cut the memory per stored vector in half by quantizing to int16.
+ //
+ // The worst case score error in a dot product over 768 entries
+ // caused by quantization error of e is approximately 55.4 e:
+ //
+ // For a unit vector v of length N, the way to maximize Σ v[i] is to make
+ // all the vector entries the same value x, such that sqrt(N x²) = 1,
+ // so x = 1/sqrt(N). The maximum of Σ v[i] is therefore N/sqrt(N).
+ //
+ // Looking at the dot product error for v₁ · v₂ caused by adding
+ // quantization error vectors e₁ and e₂:
+ //
+ // |Σ v₁[i]*v₂[i] - Σ (v₁[i]+e₁[i])*(v₂[i]+e₂[i])| =
+ // |Σ v₁[i]*v₂[i] - Σ (v₁[i]*v₂[i] + e₁[i]*v₂[i] + e₂[i]*v₁[i] + e₁[i]*e₂[i])| =
+ // |Σ (e₁[i]*v₂[i] + e₂[i]*v₁[i] + e₁[i]*e₂[i])| ≤
+ // Σ |e₁[i]*v₂[i]| + Σ |e₂[i]*v₁[i]| + Σ |e₁[i]*e₂[i]| ≤
+ // e × (Σ v₁[i] + Σ v₂[i]) + N e² ≤
+ // e × 2 × N/sqrt(N) + N e² =
+ // e × (2 × N/sqrt(N) + e) ~= 55.4 e for N=768.
+ //
+ // Quantizing the float32 range [-1,+1] to int16 range [-32768,32767]
+ // would introduce a maximum quantization error e of
+ // ½ × (+1 - -1) / (32767 - -32768) = 1/65535 = 0.000015259,
+ // resulting in a maximum dot product error of approximately 0.00846,
+ // which would not change the result order significantly.
+
+ vdb := &memVectorDB{
+ storage: db,
+ slog: lg,
+ namespace: namespace,
+ cache: make(map[string][]float32),
+ }
+
+ // Load all the previously-stored vectors.
+ vdb.cache = make(map[string][]float32)
+ for key, getVal := range vdb.storage.Scan(
+ ordered.Encode("llm.Vector", namespace),
+ ordered.Encode("llm.Vector", namespace, ordered.Inf)) {
+
+ var id string
+ if err := ordered.Decode(key, nil, nil, &id); err != nil {
+ // unreachable except data corruption
+ panic(fmt.Errorf("MemVectorDB decode key=%v: %v", Fmt(key), err))
+ }
+ val := getVal()
+ if len(val)%4 != 0 {
+ // unreachable except data corruption
+ panic(fmt.Errorf("MemVectorDB decode key=%v bad len(val)=%d", Fmt(key), len(val)))
+ }
+ var vec llm.Vector
+ vec.Decode(val)
+ vdb.cache[id] = vec
+ }
+
+ vdb.slog.Info("loaded vectordb", "n", len(vdb.cache), "namespace", namespace)
+ return vdb
+}
+
+func (db *memVectorDB) Set(id string, vec llm.Vector) {
+ db.storage.Set(ordered.Encode("llm.Vector", db.namespace, id), vec.Encode())
+
+ db.mu.Lock()
+ db.cache[id] = slices.Clone(vec)
+ db.mu.Unlock()
+}
+
+func (db *memVectorDB) Get(name string) (llm.Vector, bool) {
+ db.mu.RLock()
+ vec, ok := db.cache[name]
+ db.mu.RUnlock()
+ return vec, ok
+}
+
+func (db *memVectorDB) Search(target llm.Vector, n int) []VectorResult {
+ db.mu.RLock()
+ defer db.mu.RUnlock()
+ best := top.New(n, VectorResult.cmp)
+ for name, vec := range db.cache {
+ if len(vec) != len(target) {
+ continue
+ }
+ best.Add(VectorResult{name, target.Dot(vec)})
+ }
+ return best.Take()
+}
+
+func (db *memVectorDB) Flush() {
+ db.storage.Flush()
+}
+
+// memVectorBatch implements VectorBatch for a memVectorDB.
+type memVectorBatch struct {
+ db *memVectorDB // underlying memVectorDB
+ sb Batch // batch for underlying DB
+ w map[string]llm.Vector // vectors to write
+}
+
+func (db *memVectorDB) Batch() VectorBatch {
+ return &memVectorBatch{db, db.storage.Batch(), make(map[string]llm.Vector)}
+}
+
+func (b *memVectorBatch) Set(name string, vec llm.Vector) {
+ b.sb.Set(ordered.Encode("llm.Vector", b.db.namespace, name), vec.Encode())
+
+ b.w[name] = slices.Clone(vec)
+}
+
+func (b *memVectorBatch) MaybeApply() bool {
+ if !b.sb.MaybeApply() {
+ return false
+ }
+ b.Apply()
+ return true
+}
+
+func (b *memVectorBatch) Apply() {
+ b.sb.Apply()
+
+ b.db.mu.Lock()
+ defer b.db.mu.Unlock()
+
+ for name, vec := range b.w {
+ b.db.cache[name] = vec
+ }
+ clear(b.w)
+}
diff --git a/internal/storage/mem_test.go b/internal/storage/mem_test.go
new file mode 100644
index 0000000..b276051
--- /dev/null
+++ b/internal/storage/mem_test.go
@@ -0,0 +1,59 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+ "testing"
+
+ "golang.org/x/oscar/internal/testutil"
+)
+
+func TestMemDB(t *testing.T) {
+ TestDB(t, MemDB())
+}
+
+func TestMemVectorDB(t *testing.T) {
+ db := MemDB()
+ TestVectorDB(t, func() VectorDB { return MemVectorDB(db, testutil.Slogger(t), "") })
+}
+
+type maybeDB struct {
+ DB
+ maybe bool
+}
+
+type maybeBatch struct {
+ db *maybeDB
+ Batch
+}
+
+func (db *maybeDB) Batch() Batch {
+ return &maybeBatch{db: db, Batch: db.DB.Batch()}
+}
+
+func (b *maybeBatch) MaybeApply() bool {
+ return b.db.maybe
+}
+
+// Test that when db.Batch.MaybeApply returns true,
+// the memvector Batch MaybeApply applies the memvector ops.
+func TestMemVectorBatchMaybeApply(t *testing.T) {
+ db := &maybeDB{DB: MemDB()}
+ vdb := MemVectorDB(db, testutil.Slogger(t), "")
+ b := vdb.Batch()
+ b.Set("apple3", embed("apple3"))
+ if _, ok := vdb.Get("apple3"); ok {
+ t.Errorf("Get(apple3) succeeded before batch apply")
+ }
+ b.MaybeApply() // should not apply because the db Batch does not apply
+ if _, ok := vdb.Get("apple3"); ok {
+ t.Errorf("Get(apple3) succeeded after MaybeApply that didn't apply")
+ }
+ db.maybe = true
+ b.MaybeApply() // now should apply
+ if _, ok := vdb.Get("apple3"); !ok {
+ t.Errorf("Get(apple3) failed after MaybeApply that did apply")
+ }
+}
diff --git a/internal/storage/test.go b/internal/storage/test.go
new file mode 100644
index 0000000..9a193a6
--- /dev/null
+++ b/internal/storage/test.go
@@ -0,0 +1,126 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+ "fmt"
+ "slices"
+ "sync"
+ "testing"
+
+ "rsc.io/ordered"
+)
+
+// TestDB runs basic tests on db.
+// It should be empty when TestDB is called.
+func TestDB(t *testing.T, db DB) {
+ db.Set([]byte("key"), []byte("value"))
+ if val, ok := db.Get([]byte("key")); string(val) != "value" || ok != true {
+ // unreachable except for bad db
+ t.Fatalf("Get(key) = %q, %v, want %q, true", val, ok, "value")
+ }
+ if val, ok := db.Get([]byte("missing")); val != nil || ok != false {
+ // unreachable except for bad db
+ t.Fatalf("Get(missing) = %v, %v, want nil, false", val, ok)
+ }
+
+ db.Delete([]byte("key"))
+ if val, ok := db.Get([]byte("key")); val != nil || ok != false {
+ // unreachable except for bad db
+ t.Fatalf("Get(key) after delete = %v, %v, want nil, false", val, ok)
+ }
+
+ b := db.Batch()
+ for i := range 10 {
+ b.Set(ordered.Encode(i), []byte(fmt.Sprint(i)))
+ b.MaybeApply()
+ }
+ b.Apply()
+
+ collect := func(min, max, stop int) []int {
+ t.Helper()
+ var list []int
+ for key, val := range db.Scan(ordered.Encode(min), ordered.Encode(max)) {
+ var i int
+ if err := ordered.Decode(key, &i); err != nil {
+ // unreachable except for bad db
+ t.Fatalf("db.Scan malformed key %v", Fmt(key))
+ }
+ if sv, want := string(val()), fmt.Sprint(i); sv != want {
+ // unreachable except for bad db
+ t.Fatalf("db.Scan key %v val=%q, want %q", i, sv, want)
+ }
+ list = append(list, i)
+ if i == stop {
+ break
+ }
+ }
+ return list
+ }
+
+ if scan, want := collect(3, 6, -1), []int{3, 4, 5, 6}; !slices.Equal(scan, want) {
+ // unreachable except for bad db
+ t.Fatalf("Scan(3, 6) = %v, want %v", scan, want)
+ }
+
+ if scan, want := collect(3, 6, 5), []int{3, 4, 5}; !slices.Equal(scan, want) {
+ // unreachable except for bad db
+ t.Fatalf("Scan(3, 6) with break at 5 = %v, want %v", scan, want)
+ }
+
+ db.DeleteRange(ordered.Encode(4), ordered.Encode(7))
+ if scan, want := collect(-1, 11, -1), []int{0, 1, 2, 3, 8, 9}; !slices.Equal(scan, want) {
+ // unreachable except for bad db
+ t.Fatalf("Scan(-1, 11) after Delete(4, 7) = %v, want %v", scan, want)
+ }
+
+ b = db.Batch()
+ for i := range 5 {
+ b.Delete(ordered.Encode(i))
+ b.Set(ordered.Encode(2*i), []byte(fmt.Sprint(2*i)))
+ }
+ b.DeleteRange(ordered.Encode(0), ordered.Encode(0))
+ b.Apply()
+ if scan, want := collect(-1, 11, -1), []int{6, 8, 9}; !slices.Equal(scan, want) {
+ // unreachable except for bad db
+ t.Fatalf("Scan(-1, 11) after batch Delete+Set = %v, want %v", scan, want)
+ }
+
+ // Can't test much, but check that it doesn't crash.
+ db.Flush()
+
+ testDBLock(t, db)
+}
+
+type locker interface {
+ Lock(string)
+ Unlock(string)
+}
+
+func testDBLock(t *testing.T, db locker) {
+ var x int
+ db.Lock("abc")
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ db.Lock("abc")
+ x = 2 // cause race if not synchronized
+ db.Unlock("abc")
+ wg.Done()
+ }()
+ x = 1 // cause race if not synchronized
+ db.Unlock("abc")
+ wg.Wait()
+ _ = x
+
+ func() {
+ defer func() {
+ recover()
+ }()
+ db.Unlock("def")
+ t.Errorf("Unlock never-locked key did not panic")
+ }()
+
+}
diff --git a/internal/storage/vectordb.go b/internal/storage/vectordb.go
new file mode 100644
index 0000000..b7d2b7a
--- /dev/null
+++ b/internal/storage/vectordb.go
@@ -0,0 +1,83 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+ "cmp"
+
+ "golang.org/x/oscar/internal/llm"
+)
+
+// A VectorDB is a vector database that implements
+// nearest-neighbor search over embedding vectors
+// corresponding to documents.
+type VectorDB interface {
+ // Set sets the vector associated with the given document ID to vec.
+ Set(id string, vec llm.Vector)
+
+ // TODO: Add Delete.
+
+ // Get gets the vector associated with the given document ID.
+ // If no such document exists, Get returns nil, false.
+ // If a document exists, Get returns vec, true.
+ Get(id string) (llm.Vector, bool)
+
+ // Batch returns a new [VectorBatch] that accumulates
+ // vector database mutations to apply in an atomic operation.
+ // It is more efficient than repeated calls to Set.
+ Batch() VectorBatch
+
+ // Search searches the database for the n vectors
+ // most similar to vec, returning the document IDs
+ // and similarity scores.
+ //
+ // Normally a VectorDB is used entirely with vectors of a single length.
+ // Search ignores stored vectors with a different length than vec.
+ Search(vec llm.Vector, n int) []VectorResult
+
+ // Flush flushes storage to disk.
+ Flush()
+}
+
+// A VectorBatch accumulates vector database mutations
+// that are applied to a [VectorDB] in a single atomic operation.
+// Applying bulk operations in a batch is also more efficient than
+// making individual [VectorDB] method calls.
+// The batched operations apply in the order they are made.
+type VectorBatch interface {
+ // Set sets the vector associated with the given document ID to vec.
+ Set(id string, vec llm.Vector)
+
+ // TODO: Add Delete.
+
+ // MaybeApply calls Apply if the VectorBatch is getting close to full.
+ // Every VectorBatch has a limit to how many operations can be batched,
+ // so in a bulk operation where atomicity of the entire batch is not a concern,
+ // calling MaybeApply gives the VectorBatch implementation
+ // permission to flush the batch at specific “safe points”.
+ // A typical limit for a batch is about 100MB worth of logged operations.
+ //
+ // MaybeApply reports whether it called Apply.
+ MaybeApply() bool
+
+ // Apply applies all the batched operations to the underlying VectorDB
+ // as a single atomic unit.
+ // When Apply returns, the VectorBatch is an empty batch ready for
+ // more operations.
+ Apply()
+}
+
+// A VectorResult is a single document returned by a VectorDB search.
+type VectorResult struct {
+ ID string // document ID
+ Score float64 // similarity score in range [0, 1]; 1 is exact match
+}
+
+func (x VectorResult) cmp(y VectorResult) int {
+ if x.Score != y.Score {
+ return cmp.Compare(x.Score, y.Score)
+ }
+ return cmp.Compare(x.ID, y.ID)
+}
diff --git a/internal/storage/vectordb_test.go b/internal/storage/vectordb_test.go
new file mode 100644
index 0000000..2348120
--- /dev/null
+++ b/internal/storage/vectordb_test.go
@@ -0,0 +1,30 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import "testing"
+
+func TestVectorResultCompare(t *testing.T) {
+ type R = VectorResult
+ var tests = []struct {
+ x, y VectorResult
+ cmp int
+ }{
+ {R{"b", 0.5}, R{"c", 0.5}, -1},
+ {R{"b", 0.4}, R{"a", 0.5}, -1},
+ }
+
+ try := func(x, y VectorResult, cmp int) {
+ if c := x.cmp(y); c != cmp {
+ t.Errorf("Compare(%v, %v) = %d, want %d", x, y, c, cmp)
+ }
+ }
+ for _, tt := range tests {
+ try(tt.x, tt.x, 0)
+ try(tt.y, tt.y, 0)
+ try(tt.x, tt.y, tt.cmp)
+ try(tt.y, tt.x, -tt.cmp)
+ }
+}
diff --git a/internal/storage/vtest.go b/internal/storage/vtest.go
new file mode 100644
index 0000000..a19f6ad
--- /dev/null
+++ b/internal/storage/vtest.go
@@ -0,0 +1,74 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+ "math"
+ "reflect"
+ "slices"
+ "testing"
+
+ "golang.org/x/oscar/internal/llm"
+)
+
+func TestVectorDB(t *testing.T, newdb func() VectorDB) {
+ vdb := newdb()
+
+ vdb.Set("orange2", embed("orange2"))
+ vdb.Set("orange1", embed("orange1"))
+ b := vdb.Batch()
+ b.Set("apple3", embed("apple3"))
+ b.Set("apple4", embed("apple4"))
+ b.Set("ignore", embed("bad")[:4])
+ b.Apply()
+
+ v, ok := vdb.Get("apple3")
+ if !ok || !slices.Equal(v, embed("apple3")) {
+ // unreachable except bad vectordb
+ t.Errorf("Get(apple3) = %v, %v, want %v, true", v, ok, embed("apple3"))
+ }
+
+ want := []VectorResult{
+ {"apple4", 0.9999961187341375},
+ {"apple3", 0.9999843342970269},
+ {"orange1", 0.38062230442542155},
+ {"orange2", 0.3785152783773009},
+ }
+ have := vdb.Search(embed("apple5"), 5)
+ if !reflect.DeepEqual(have, want) {
+ // unreachable except bad vectordb
+ t.Fatalf("Search(apple5, 5):\nhave %v\nwant %v", have, want)
+ }
+
+ vdb.Flush()
+
+ vdb = newdb()
+ have = vdb.Search(embed("apple5"), 3)
+ want = want[:3]
+ if !reflect.DeepEqual(have, want) {
+ // unreachable except bad vectordb
+ t.Errorf("Search(apple5, 3) in fresh database:\nhave %v\nwant %v", have, want)
+ }
+
+}
+
+func embed(text string) llm.Vector {
+ const vectorLen = 16
+ v := make(llm.Vector, vectorLen)
+ d := float32(0)
+ for i := range len(text) {
+ v[i] = float32(byte(text[i])) / 256
+ d += float32(v[i] * v[i]) // float32() to avoid FMA
+ }
+ if len(text) < len(v) {
+ v[len(text)] = -1
+ d += 1
+ }
+ d = float32(1 / math.Sqrt(float64(d)))
+ for i, x := range v {
+ v[i] = x * d
+ }
+ return v
+}