blob: 5eb4f188b36c03dfedfdf7b37439d6afdd5228a3 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package search
import (
"context"
"encoding/json"
"fmt"
"reflect"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/oscar/internal/docs"
"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/testutil"
)
func TestKind(t *testing.T) {
for _, test := range []struct {
id, want string
}{
{"something", ""},
{"https://go.dev/x", "GoDevPage"},
{"https://go.dev/blog/xxx", "GoBlog"},
{"https://go.dev/doc/x", "GoDocumentation"},
{"https://go.dev/ref/x", "GoReference"},
{"https://go.dev/wiki/x", "GoWiki"},
{"https://github.com/golang/go/issues/123", "GitHubIssue"},
} {
got := docIDKind(test.id)
if got != test.want {
t.Errorf("%q: got %q, want %q", test.id, got, test.want)
}
}
}
func TestSearch(t *testing.T) {
ctx := context.Background()
lg := testutil.Slogger(t)
embedder := llm.QuoteEmbedder()
db := storage.MemDB()
vdb := storage.MemVectorDB(db, lg, "")
corpus := docs.New(lg, db)
for i := 0; i < 10; i++ {
id := fmt.Sprintf("id%d", i)
doc := llm.EmbedDoc{Title: fmt.Sprintf("title%d", i), Text: fmt.Sprintf("text-%s", strings.Repeat("x", i))}
corpus.Add(id, doc.Title, doc.Text)
vec := mustEmbed(t, embedder, doc)
vdb.Set(id, vec)
}
opts := Options{
Threshold: 0,
Limit: 2,
}
doc := llm.EmbedDoc{Title: "title3", Text: "text-xxx"}
qreq := &QueryRequest{
Options: opts,
EmbedDoc: doc,
}
gotQ, err := Query(ctx, vdb, corpus, embedder, qreq)
if err != nil {
t.Fatal(err)
}
round(gotQ)
vreq := &VectorRequest{
Options: opts,
Vector: mustEmbed(t, embedder, doc),
}
gotV := Vector(vdb, corpus, vreq)
round(gotV)
want := []Result{
{
Kind: "",
Title: "title3",
VectorResult: storage.VectorResult{ID: "id3", Score: 1.0},
},
{
Kind: "",
Title: "title4",
VectorResult: storage.VectorResult{ID: "id4", Score: 0.56},
},
}
if !slices.Equal(gotQ, want) {
t.Errorf("Query: got %v\nwant %v", gotQ, want)
}
if !slices.Equal(gotV, want) {
t.Errorf("Vector: got %v\nwant %v", gotQ, want)
}
qreq.Threshold = 0.9
gotQ, err = Query(ctx, vdb, corpus, embedder, qreq)
if err != nil {
t.Fatal(err)
}
if len(gotQ) != 1 {
t.Errorf("got %d results, want 1", len(gotQ))
}
vreq.Threshold = 0.9
gotV = Vector(vdb, corpus, vreq)
if len(gotV) != 1 {
t.Errorf("got %d results, want 1", len(gotQ))
}
}
func round(rs []Result) {
for i := range rs {
rs[i].Round()
}
}
func TestOptions(t *testing.T) {
ctx := context.Background()
lg := testutil.Slogger(t)
embedder := llm.QuoteEmbedder()
db := storage.MemDB()
vdb := storage.MemVectorDB(db, lg, "")
corpus := docs.New(lg, db)
ids := []string{
0: "https://go.dev/blog/topic",
1: "https://github.com/golang/go/issues/11",
2: "not-a-url",
3: "https://go.dev/doc/something",
4: "https://go.dev/ref/something",
5: "https://go.dev/page",
6: "https://go.dev/blog/another/topic",
7: "https://github.com/golang/go/issues/42",
8: "https://go.dev/wiki/something",
9: "https://pkg.go.dev/",
}
for i, id := range ids {
doc := llm.EmbedDoc{
Title: fmt.Sprintf("title%d", i),
Text: fmt.Sprintf("text-%s", strings.Repeat("x", i))}
corpus.Add(id, doc.Title, doc.Text)
vec := mustEmbed(t, embedder, doc)
vdb.Set(id, vec)
}
doc := llm.EmbedDoc{Title: "title3", Text: "text-xxx"}
results := []Result{
0: {
Kind: KindGoDocumentation,
Title: "title3",
VectorResult: storage.VectorResult{ID: ids[3], Score: 1.0},
},
1: {
Kind: KindGoReference,
Title: "title4",
VectorResult: storage.VectorResult{ID: ids[4], Score: 0.56},
},
2: {
Kind: KindGoDevPage,
Title: "title5",
VectorResult: storage.VectorResult{ID: ids[5], Score: 0.544},
},
3: {
Kind: "", // unknown
Title: "title2",
VectorResult: storage.VectorResult{ID: ids[2], Score: 0.531},
},
4: {
Kind: KindGoBlog,
Title: "title6",
VectorResult: storage.VectorResult{ID: ids[6], Score: 0.529},
},
5: {
Kind: KindGitHubIssue,
Title: "title7",
VectorResult: storage.VectorResult{ID: ids[7], Score: 0.516},
},
6: {
Kind: KindGoWiki,
Title: "title8",
VectorResult: storage.VectorResult{ID: ids[8], Score: 0.503},
},
7: {
Kind: "", // unknown
Title: "title9",
VectorResult: storage.VectorResult{ID: ids[9], Score: 0.492},
},
8: {
Kind: KindGitHubIssue,
Title: "title1",
VectorResult: storage.VectorResult{ID: ids[1], Score: 0.483},
},
9: {
Kind: KindGoBlog,
Title: "title0",
VectorResult: storage.VectorResult{ID: ids[0], Score: 0.431},
},
}
for _, tc := range []struct {
name string
options Options
want []Result
}{
{
name: "no options",
want: results,
},
{
name: "threshold",
options: Options{
Threshold: .5,
},
want: results[:7],
},
{
name: "limit",
options: Options{
Limit: 5,
},
want: results[:5],
},
{
// Limit wins.
name: "limit-threshold",
options: Options{
Threshold: .5,
Limit: 5,
},
want: results[:5],
},
{
// Threshold wins.
name: "threshold-limit",
options: Options{
Threshold: .5,
Limit: 10,
},
want: results[:7],
},
{
name: "allow",
options: Options{
AllowKind: []string{KindGoWiki, KindGitHubIssue},
},
want: []Result{
results[5], // issue
results[6], // wiki
results[8], // issue
},
},
{
name: "allow-limit",
options: Options{
AllowKind: []string{KindGoWiki, KindGitHubIssue},
Limit: 6,
},
want: []Result{results[5]},
},
{
name: "allow-threshold",
options: Options{
AllowKind: []string{KindGoWiki, KindGitHubIssue},
Threshold: .5,
},
want: []Result{results[5], results[6]},
},
{
name: "deny",
options: Options{
DenyKind: []string{KindGoWiki, KindGitHubIssue},
},
want: []Result{
results[0], results[1], results[2], results[3], results[4],
// skip 5 (issue) and 6 (wiki)
results[7],
// skip 8 (issue)
results[9]},
},
{
name: "allow-deny",
options: Options{
AllowKind: []string{KindGoWiki, KindGitHubIssue},
DenyKind: []string{KindGitHubIssue},
},
// Only wikis are allowed.
want: []Result{results[6]},
},
} {
t.Run(tc.name, func(t *testing.T) {
got, err := Query(ctx, vdb, corpus, embedder,
&QueryRequest{
Options: tc.options,
EmbedDoc: doc,
})
if err != nil {
t.Fatal(err)
}
round(got)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("Query() mismatch (-want +got):\n%s", diff)
}
})
}
}
func mustEmbed(t *testing.T, embedder llm.Embedder, doc llm.EmbedDoc) llm.Vector {
t.Helper()
vec, err := embedder.EmbedDocs(context.Background(), []llm.EmbedDoc{doc})
if err != nil {
t.Fatal(err)
}
return vec[0]
}
func TestSearchJSON(t *testing.T) {
// Confirm that we can unmarshal a search request, and marshal a response.
postBody := `{"Limit": 10, "Threshold": 0.8, "AllowKind": ["GoWiki"], "DenyKind": [""], "Title": "t", "Text": "text"}`
var gotReq QueryRequest
if err := json.Unmarshal([]byte(postBody), &gotReq); err != nil {
t.Fatal(err)
}
wantReq := QueryRequest{
Options: Options{
Limit: 10,
Threshold: 0.8,
AllowKind: []string{"GoWiki"},
DenyKind: []string{""},
},
EmbedDoc: llm.EmbedDoc{Title: "t", Text: "text"}}
if !reflect.DeepEqual(gotReq, wantReq) {
t.Errorf("got %+v, want %+v", gotReq, wantReq)
}
res := []Result{
{Kind: "K", Title: "t", VectorResult: storage.VectorResult{ID: "id", Score: .5}},
}
bytes, err := json.Marshal(res)
if err != nil {
t.Fatal(err)
}
got := string(bytes)
want := `[{"Kind":"K","Title":"t","ID":"id","Score":0.5}]`
if got != want {
t.Errorf("\ngot %s\nwant %s", got, want)
}
}