blob: a92def8b629d4fa804a6d674d3ce665745f2b79c [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"slices"
"strings"
"testing"
"golang.org/x/oscar/internal/docs"
"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/testutil"
)
func TestSearchPageTemplate(t *testing.T) {
page := searchPage{
Query: "some query",
Results: []searchResult{
{
Kind: "Example",
Title: "t1",
VectorResult: storage.VectorResult{
ID: "https://example.com/x",
Score: 0.987654321,
},
},
{
Kind: "",
VectorResult: storage.VectorResult{
ID: "https://example.com/y",
Score: 0.876,
},
},
},
}
var buf bytes.Buffer
if err := searchPageTmpl.Execute(&buf, page); err != nil {
t.Fatal(err)
}
wants := []string{page.Query}
for _, sr := range page.Results {
wants = append(wants, sr.VectorResult.ID)
}
got := buf.String()
t.Logf("%s", got)
for _, w := range wants {
if !strings.Contains(got, w) {
t.Errorf("did not find %q in HTML", w)
}
}
}
func TestKind(t *testing.T) {
for _, test := range []struct {
id, want string
}{
{"something", ""},
{"https://go.dev/x", "GoDevPage"},
{"https://go.dev/blog/xxx", "GoBlog"},
{"https://go.dev/doc/x", "GoDocumentation"},
{"https://go.dev/ref/x", "GoReference"},
{"https://go.dev/wiki/x", "GoWiki"},
{"https://github.com/golang/go/issues/123", "GitHubIssue"},
} {
got := docIDKind(test.id)
if got != test.want {
t.Errorf("%q: got %q, want %q", test.id, got, test.want)
}
}
}
func TestSearch(t *testing.T) {
ctx := context.Background()
lg := testutil.Slogger(t)
embedder := llm.QuoteEmbedder()
db := storage.MemDB()
vdb := storage.MemVectorDB(db, lg, "")
corpus := docs.New(db)
for i := 0; i < 10; i++ {
id := fmt.Sprintf("id%d", i)
doc := llm.EmbedDoc{Title: fmt.Sprintf("title%d", i), Text: fmt.Sprintf("text-%s", strings.Repeat("x", i))}
corpus.Add(id, doc.Title, doc.Text)
vec, err := embedder.EmbedDocs(ctx, []llm.EmbedDoc{doc})
if err != nil {
t.Fatal(err)
}
vdb.Set(id, vec[0])
}
g := &Gaby{
embed: embedder,
db: db,
vector: vdb,
docs: corpus,
}
sreq := &searchRequest{
Threshold: 0,
Limit: 2,
EmbedDoc: llm.EmbedDoc{Title: "title3", Text: "text-xxx"},
}
sres, err := g.search(ctx, sreq)
if err != nil {
t.Fatal(err)
}
for i := range sres {
sres[i].round()
}
want := []searchResult{
{
Kind: "",
Title: "title3",
VectorResult: storage.VectorResult{ID: "id3", Score: 1.0},
},
{
Kind: "",
Title: "title4",
VectorResult: storage.VectorResult{ID: "id4", Score: 0.56},
},
}
if !slices.Equal(sres, want) {
t.Errorf("got %v\nwant %v", sres, want)
}
sreq.Threshold = 0.9
sres, err = g.search(ctx, sreq)
if err != nil {
t.Fatal(err)
}
if len(sres) != 1 {
t.Errorf("got %d results, want 1", len(sres))
}
}
func TestSearchJSON(t *testing.T) {
// Confirm that we can unmarshal a search request, and marshal a response.
postBody := `{"Limit": 10, "Threshold": 0.8, "Title": "t", "Text": "text"}`
var gotReq searchRequest
if err := json.Unmarshal([]byte(postBody), &gotReq); err != nil {
t.Fatal(err)
}
wantReq := searchRequest{Limit: 10, Threshold: 0.8, EmbedDoc: llm.EmbedDoc{Title: "t", Text: "text"}}
if gotReq != wantReq {
t.Errorf("got %+v, want %+v", gotReq, wantReq)
}
res := []searchResult{
{Kind: "K", Title: "t", VectorResult: storage.VectorResult{ID: "id", Score: .5}},
}
bytes, err := json.Marshal(res)
if err != nil {
t.Fatal(err)
}
got := string(bytes)
want := `[{"Kind":"K","Title":"t","ID":"id","Score":0.5}]`
if got != want {
t.Errorf("\ngot %s\nwant %s", got, want)
}
}