blob: d38446454b09219f074b603cb9a5f68fb4863554 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package embeddocs
import (
"context"
"fmt"
"testing"
"golang.org/x/oscar/internal/docs"
"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/testutil"
)
var texts = []string{
"for loops",
"for all time, always",
"break statements",
"breakdancing",
"forever could never be long enough for me",
"the macarena",
}
func checker(t *testing.T) func(error) {
return func(err error) {
if err != nil {
t.Helper()
t.Fatal(err)
}
}
}
var ctx = context.Background()
func TestSync(t *testing.T) {
check := testutil.Checker(t)
lg := testutil.Slogger(t)
db := storage.MemDB()
vdb := storage.MemVectorDB(db, lg, "step1")
dc := docs.New(lg, db)
for i, text := range texts {
dc.Add(fmt.Sprintf("URL%d", i), "", text)
}
check(Sync(ctx, lg, vdb, llm.QuoteEmbedder(), dc))
for i, text := range texts {
vec, ok := vdb.Get(fmt.Sprintf("URL%d", i))
if !ok {
t.Errorf("URL%d missing from vdb", i)
continue
}
vtext := llm.UnquoteVector(vec)
if vtext != text {
t.Errorf("URL%d decoded to %q, want %q", i, vtext, text)
}
}
for i, text := range texts {
dc.Add(fmt.Sprintf("rot13%d", i), "", testutil.Rot13(text))
}
vdb2 := storage.MemVectorDB(db, lg, "step2")
check(Sync(ctx, lg, vdb2, llm.QuoteEmbedder(), dc))
for i, text := range texts {
vec, ok := vdb2.Get(fmt.Sprintf("URL%d", i))
if ok {
t.Errorf("URL%d written during second sync: %q", i, llm.UnquoteVector(vec))
continue
}
vec, ok = vdb2.Get(fmt.Sprintf("rot13%d", i))
vtext := llm.UnquoteVector(vec)
if vtext != testutil.Rot13(text) {
t.Errorf("rot13%d decoded to %q, want %q", i, vtext, testutil.Rot13(text))
}
}
}
func TestBigSync(t *testing.T) {
const N = 10000
check := testutil.Checker(t)
lg := testutil.Slogger(t)
db := storage.MemDB()
vdb := storage.MemVectorDB(db, lg, "vdb")
dc := docs.New(lg, db)
for i := range N {
dc.Add(fmt.Sprintf("URL%d", i), "", fmt.Sprintf("Text%d", i))
}
check(Sync(ctx, lg, vdb, llm.QuoteEmbedder(), dc))
for i := range N {
vec, ok := vdb.Get(fmt.Sprintf("URL%d", i))
if !ok {
t.Errorf("URL%d missing from vdb", i)
continue
}
text := fmt.Sprintf("Text%d", i)
vtext := llm.UnquoteVector(vec)
if vtext != text {
t.Errorf("URL%d decoded to %q, want %q", i, vtext, text)
}
}
}
func TestBadEmbedders(t *testing.T) {
const N = 150
lg := testutil.Slogger(t)
db := storage.MemDB()
dc := docs.New(lg, db)
for i := range N {
dc.Add(fmt.Sprintf("URL%03d", i), "", fmt.Sprintf("Text%d", i))
}
db = storage.MemDB()
vdb := storage.MemVectorDB(db, lg, "vdb")
if err := Sync(ctx, lg, vdb, tooManyEmbed{}, dc); err == nil {
t.Errorf("tooManyEmbed did not report error")
}
db = storage.MemDB()
vdb = storage.MemVectorDB(db, lg, "vdb")
if err := Sync(ctx, lg, vdb, embedErr{}, dc); err == nil {
t.Errorf("embedErr did not report error")
}
if _, ok := vdb.Get("URL001"); !ok {
t.Errorf("Sync did not write URL001 after embedErr")
}
db = storage.MemDB()
vdb = storage.MemVectorDB(db, lg, "vdb")
if err := Sync(ctx, lg, vdb, embedHalf{}, dc); err == nil {
t.Errorf("embedHalf did not report error")
}
if _, ok := vdb.Get("URL001"); !ok {
t.Errorf("Sync did not write URL001 after embedHalf")
}
}
type tooManyEmbed struct{}
func (tooManyEmbed) EmbedDocs(ctx context.Context, docs []llm.EmbedDoc) ([]llm.Vector, error) {
vec, _ := llm.QuoteEmbedder().EmbedDocs(ctx, docs)
vec = append(vec, vec...)
return vec, nil
}
type embedErr struct{}
func (embedErr) EmbedDocs(ctx context.Context, docs []llm.EmbedDoc) ([]llm.Vector, error) {
vec, _ := llm.QuoteEmbedder().EmbedDocs(ctx, docs)
return vec, fmt.Errorf("EMBED ERROR")
}
type embedHalf struct{}
func (embedHalf) EmbedDocs(ctx context.Context, docs []llm.EmbedDoc) ([]llm.Vector, error) {
vec, _ := llm.QuoteEmbedder().EmbedDocs(ctx, docs)
vec = vec[:len(vec)/2]
return vec, nil
}