| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package storage |
| |
| import ( |
| "math" |
| "reflect" |
| "slices" |
| "testing" |
| |
| "golang.org/x/oscar/internal/llm" |
| "golang.org/x/oscar/internal/testutil" |
| ) |
| |
| // TestVectorDB verifies that implementations of [VectorDB] |
| // conform to its specification. |
| // The opendb function should create a new connection to the same underlying storage. |
| func TestVectorDB(t *testing.T, opendb func() VectorDB) { |
| vdb := opendb() |
| |
| vdb.Set("orange2", embed("orange2")) |
| vdb.Set("orange1", embed("orange1")) |
| vdb.Set("orange1alias", embed("orange1")) |
| vdb.Delete("orange1alias") |
| vdb.Delete("orange1") |
| vdb.Set("orange1", embed("orange1")) |
| |
| testutil.StopPanic(func() { |
| vdb.Set("", llm.Vector{}) |
| t.Fatalf("Set with empty key did not panic") |
| }) |
| |
| haveAll := allIDs(vdb) |
| wantAll := []string{"orange1", "orange2"} |
| if !reflect.DeepEqual(haveAll, wantAll) { |
| t.Fatalf("All(): have %v;\nwant %v", haveAll, wantAll) |
| } |
| |
| vdb.Set("apple1", embed("apple1")) |
| b := vdb.Batch() |
| b.Delete("apple1") |
| b.Set("apple3", embed("apple3")) |
| b.Set("apple4", embed("apple4")) |
| b.Set("apple4alias", embed("apple4")) |
| b.Delete("apple4alias") |
| b.Set("ignore", embed("bad")[:4]) |
| b.Set("orange3", embed("orange3")) |
| b.Delete("orange3") |
| b.Delete("orange4") |
| b.Set("orange4", embed("orange4")) |
| b.Apply() |
| |
| testutil.StopPanic(func() { |
| b.Set("", llm.Vector{}) |
| t.Fatalf("Batch.Set with empty key did not panic") |
| }) |
| |
| // Check that batch.Apply clears the batch. |
| b = vdb.Batch() |
| b.Set("apple5", embed("apple5")) |
| b.Apply() |
| vdb.Delete("apple5") |
| b.Apply() // should be a no-op |
| if _, ok := vdb.Get("apple5"); ok { |
| t.Fatalf("empty Apply should be no-op, but got previous value") |
| } |
| |
| haveAll = allIDs(vdb) |
| wantAll = []string{"apple3", "apple4", "ignore", "orange1", "orange2", "orange4"} |
| if !reflect.DeepEqual(haveAll, wantAll) { |
| t.Fatalf("All(): have %v;\nwant %v", haveAll, wantAll) |
| } |
| |
| v, ok := vdb.Get("apple3") |
| if !ok || !slices.Equal(v, embed("apple3")) { |
| // unreachable except bad vectordb |
| t.Errorf("Get(apple3) = %v, %v, want %v, true", v, ok, embed("apple3")) |
| } |
| |
| want := []VectorResult{ |
| {"apple4", 0.9999961187341375}, |
| {"apple3", 0.9999843342970269}, |
| {"orange1", 0.38062230442542155}, |
| {"orange2", 0.3785152783773009}, |
| {"orange4", 0.37429777504303363}, |
| } |
| have := vdb.Search(embed("apple5"), 5) |
| if !reflect.DeepEqual(have, want) { |
| // unreachable except bad vectordb |
| t.Fatalf("Search(apple5, 5):\nhave %v\nwant %v", have, want) |
| } |
| |
| vdb.Flush() |
| |
| vdb = opendb() |
| have = vdb.Search(embed("apple5"), 3) |
| want = want[:3] |
| if !reflect.DeepEqual(have, want) { |
| // unreachable except bad vectordb |
| t.Errorf("Search(apple5, 3) in fresh database:\nhave %v\nwant %v", have, want) |
| } |
| } |
| |
| func allIDs(vdb VectorDB) []string { |
| var all []string |
| for k := range vdb.All() { |
| all = append(all, k) |
| } |
| return all |
| } |
| |
| func embed(text string) llm.Vector { |
| const vectorLen = 16 |
| v := make(llm.Vector, vectorLen) |
| d := float32(0) |
| for i := range len(text) { |
| v[i] = float32(byte(text[i])) / 256 |
| d += float32(v[i] * v[i]) // float32() to avoid FMA |
| } |
| if len(text) < len(v) { |
| v[len(text)] = -1 |
| d += 1 |
| } |
| d = float32(1 / math.Sqrt(float64(d))) |
| for i, x := range v { |
| v[i] = x * d |
| } |
| return v |
| } |