internal/gaby: various search improvements
- Add search-by-ID for documents in the database
- Add basic input validation (and display an error for invalid inputs)
- Add tests
For golang/oscar#32
Change-Id: Ib9f756696d4d4e5259d5a08d93bd0e57addb5b46
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/616856
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/gaby/search.go b/internal/gaby/search.go
index aedb5c3..d44d016 100644
--- a/internal/gaby/search.go
+++ b/internal/gaby/search.go
@@ -6,7 +6,9 @@
import (
"bytes"
+ "context"
"encoding/json"
+ "fmt"
"io"
"net/http"
"strconv"
@@ -16,80 +18,146 @@
"golang.org/x/oscar/internal/search"
)
+// a searchPage holds the fields needed to display the results
+// of a search.
type searchPage struct {
- Query string
- search.Options
- // allowlist and denylist as comma-separated strings (for display)
- AllowStr, DenyStr string
- Results []search.Result
+ searchForm // the raw query and options
+ Results []search.Result // the search results to display
+ SearchError string // if non-empty, the error to display instead of results
}
func (g *Gaby) handleSearch(w http.ResponseWriter, r *http.Request) {
- data, err := g.doSearch(r)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- } else {
- _, _ = w.Write(data)
- }
-}
-
-// doSearch returns the contents of the vector search page.
-func (g *Gaby) doSearch(r *http.Request) ([]byte, error) {
- page := populatePage(r)
- if page.Query != "" {
- var err error
- page.Results, err = search.Query(r.Context(), g.vector, g.docs, g.embed,
- &search.QueryRequest{
- EmbedDoc: llm.EmbedDoc{Text: page.Query},
- Options: page.Options,
- })
- if err != nil {
- return nil, err
- }
- for i := range page.Results {
- page.Results[i].Round()
- }
- }
+ page := g.populatePage(r)
var buf bytes.Buffer
if err := searchPageTmpl.Execute(&buf, page); err != nil {
- return nil, err
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
}
- return buf.Bytes(), nil
+ _, _ = w.Write(buf.Bytes())
}
-// populatePage parses the form values into a searchPage.
-// TODO(tatianabradley): Add error handling for malformed
-// filters and trim spaces in inputs.
-func populatePage(r *http.Request) searchPage {
- threshold, err := strconv.ParseFloat(r.FormValue("threshold"), 64)
+// populatePage returns the contents of the vector search page.
+func (g *Gaby) populatePage(r *http.Request) searchPage {
+ form := searchForm{
+ Query: r.FormValue("q"),
+ Threshold: r.FormValue("threshold"),
+ Limit: r.FormValue("limit"),
+ Allow: r.FormValue("allow_kind"),
+ Deny: r.FormValue("deny_kind"),
+ }
+ opts, err := form.toOptions()
if err != nil {
- threshold = 0
+ return searchPage{
+ searchForm: form,
+ SearchError: fmt.Errorf("invalid form value: %w", err).Error(),
+ }
}
- limit, err := strconv.Atoi(r.FormValue("limit"))
+ q := strings.TrimSpace(form.Query)
+ results, err := g.search(r.Context(), q, *opts)
if err != nil {
- limit = 20
- }
- var allow, deny []string
- var allowStr, denyStr string
- if allowStr = r.FormValue("allow_kind"); allowStr != "" {
- allow = strings.Split(allowStr, ",")
- }
- if denyStr = r.FormValue("deny_kind"); denyStr != "" {
- deny = strings.Split(denyStr, ",")
+ return searchPage{
+ searchForm: form,
+ SearchError: fmt.Errorf("search: %w", err).Error(),
+ }
}
return searchPage{
- Query: r.FormValue("q"),
- Options: search.Options{
- Limit: limit,
- Threshold: threshold,
- AllowKind: allow,
- DenyKind: deny,
- },
- AllowStr: allowStr,
- DenyStr: denyStr,
+ searchForm: form,
+ Results: results,
}
}
+// search performs a search on the query and options.
+//
+// If the query is an exact match for an ID in the vector database,
+// it looks up the vector for that ID and performs a search for the
+// nearest neighbors of that vector.
+// Otherwise, it embeds the query and performs a nearest neighbor
+// search for the embedding.
+//
+// It returns an error if search fails.
+func (g *Gaby) search(ctx context.Context, q string, opts search.Options) (results []search.Result, err error) {
+ if q == "" {
+ return nil, nil
+ }
+
+ if vec, ok := g.vector.Get(q); ok {
+ results = search.Vector(g.vector, g.docs,
+ &search.VectorRequest{
+ Options: opts,
+ Vector: vec,
+ })
+ } else {
+ if results, err = search.Query(ctx, g.vector, g.docs, g.embed,
+ &search.QueryRequest{
+ EmbedDoc: llm.EmbedDoc{Text: q},
+ Options: opts,
+ }); err != nil {
+ return nil, err
+ }
+ }
+
+ for i := range results {
+ results[i].Round()
+ }
+
+ return results, nil
+}
+
+// searchForm holds the raw inputs to the search form.
+type searchForm struct {
+ Query string // a text query, or an ID of a document in the database
+
+ // String representations of the fields of [search.Options]
+ Threshold string
+ Limit string
+ Allow, Deny string // comma separated lists
+}
+
+// toOptions converts a searchForm into a [search.Options],
+// trimming any leading/trailing spaces.
+//
+// It returns an error if any of the form inputs is invalid.
+func (f *searchForm) toOptions() (_ *search.Options, err error) {
+ opts := &search.Options{}
+
+ trim := strings.TrimSpace
+ splitAndTrim := func(s string) []string {
+ vs := strings.Split(s, ",")
+ for i, v := range vs {
+ vs[i] = trim(v)
+ }
+ return vs
+ }
+
+ if l := trim(f.Limit); l != "" {
+ opts.Limit, err = strconv.Atoi(l)
+ if err != nil {
+ return nil, fmt.Errorf("limit: %w", err)
+ }
+ }
+
+ if t := trim(f.Threshold); t != "" {
+ opts.Threshold, err = strconv.ParseFloat(t, 64)
+ if err != nil {
+ return nil, fmt.Errorf("threshold: %w", err)
+ }
+ }
+
+ if a := trim(f.Allow); a != "" {
+ opts.AllowKind = splitAndTrim(a)
+ }
+
+ if d := trim(f.Deny); d != "" {
+ opts.DenyKind = splitAndTrim(d)
+ }
+
+ if err := opts.Validate(); err != nil {
+ return nil, err
+ }
+
+ return opts, nil
+}
+
var searchPageTmpl = newTemplate(searchPageTmplFile, nil)
func (g *Gaby) handleSearchAPI(w http.ResponseWriter, r *http.Request) {
diff --git a/internal/gaby/search_test.go b/internal/gaby/search_test.go
index 24480de..c30dd2a 100644
--- a/internal/gaby/search_test.go
+++ b/internal/gaby/search_test.go
@@ -6,48 +6,290 @@
import (
"bytes"
+ "context"
+ "net/http"
+ "reflect"
"strings"
"testing"
+ "golang.org/x/oscar/internal/docs"
+ "golang.org/x/oscar/internal/embeddocs"
+ "golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/search"
"golang.org/x/oscar/internal/storage"
+ "golang.org/x/oscar/internal/testutil"
)
func TestSearchPageTemplate(t *testing.T) {
- page := searchPage{
- Query: "some query",
- Results: []search.Result{
- {
- Kind: "Example",
- Title: "t1",
- VectorResult: storage.VectorResult{
- ID: "https://example.com/x",
- Score: 0.987654321,
+ for _, tc := range []struct {
+ name string
+ page searchPage
+ }{
+ {
+ name: "results",
+ page: searchPage{
+ searchForm: searchForm{
+ Query: "some query",
},
- },
- {
- Kind: "",
- VectorResult: storage.VectorResult{
- ID: "https://example.com/y",
- Score: 0.876,
+ Results: []search.Result{
+ {
+ Kind: "Example",
+ Title: "t1",
+ VectorResult: storage.VectorResult{
+ ID: "https://example.com/x",
+ Score: 0.987654321,
+ },
+ },
+ {
+ Kind: "",
+ VectorResult: storage.VectorResult{
+ ID: "https://example.com/y",
+ Score: 0.876,
+ },
+ },
},
},
},
+ {
+ name: "error",
+ page: searchPage{
+ searchForm: searchForm{
+ Query: "some query",
+ },
+ SearchError: "some error",
+ },
+ },
+ {
+ name: "no results",
+ page: searchPage{
+ searchForm: searchForm{
+ Query: "some query",
+ },
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ var buf bytes.Buffer
+ if err := searchPageTmpl.Execute(&buf, tc.page); err != nil {
+ t.Fatal(err)
+ }
+ got := buf.String()
+
+ if len(tc.page.Results) != 0 {
+ wants := []string{tc.page.Query}
+ for _, sr := range tc.page.Results {
+ wants = append(wants, sr.VectorResult.ID)
+ }
+ t.Logf("%s", got)
+ for _, w := range wants {
+ if !strings.Contains(got, w) {
+ t.Errorf("did not find %q in HTML", w)
+ }
+ }
+ } else if e := tc.page.SearchError; e != "" {
+ if !strings.Contains(got, e) {
+ t.Errorf("did not find error %q in HTML", e)
+ }
+ } else {
+ want := "No results"
+ if !strings.Contains(got, want) {
+ t.Errorf("did not find %q in HTML", want)
+ }
+ }
+ })
+ }
+}
+
+func TestToOptions(t *testing.T) {
+ tests := []struct {
+ name string
+ form searchForm
+ want *search.Options
+ wantErr bool
+ }{
+ {
+ name: "basic",
+ form: searchForm{
+ Threshold: ".55",
+ Limit: "10",
+ Allow: "GoBlog,GoDevPage,GitHubIssue",
+ Deny: "GoDevPage,GoWiki",
+ },
+ want: &search.Options{
+ Threshold: .55,
+ Limit: 10,
+ AllowKind: []string{search.KindGoBlog, search.KindGoDevPage, search.KindGitHubIssue},
+ DenyKind: []string{search.KindGoDevPage, search.KindGoWiki},
+ },
+ },
+ {
+ name: "empty",
+ form: searchForm{},
+ // this will cause search to use defaults
+ want: &search.Options{},
+ },
+ {
+ name: "trim spaces",
+ form: searchForm{
+ Threshold: " .55 ",
+ Limit: " 10 ",
+ Allow: " GoBlog, GoDevPage,GitHubIssue ",
+ Deny: " GoDevPage, GoWiki ",
+ },
+ want: &search.Options{
+ Threshold: .55,
+ Limit: 10,
+ AllowKind: []string{search.KindGoBlog, search.KindGoDevPage, search.KindGitHubIssue},
+ DenyKind: []string{search.KindGoDevPage, search.KindGoWiki},
+ },
+ },
+ {
+ name: "unparseable limit",
+ form: searchForm{
+ Limit: "1.xx",
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid limit",
+ form: searchForm{
+ Limit: "1.33",
+ },
+ wantErr: true,
+ },
+ {
+ name: "unparseable threshold",
+ form: searchForm{
+ Threshold: "1x",
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid threshold",
+ form: searchForm{
+ Threshold: "-10",
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid allow",
+ form: searchForm{
+ Allow: "NotAKind, also not a kind",
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid deny",
+ form: searchForm{
+ Deny: "NotAKind, also not a kind",
+ },
+ wantErr: true,
+ },
+ }
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got, err := tc.form.toOptions()
+ if (err != nil) != tc.wantErr {
+ t.Fatalf("searchForm.toOptions() error = %v, wantErr %v", err, tc.wantErr)
+ }
+ if !reflect.DeepEqual(got, tc.want) {
+ t.Errorf("searchForm.toOptions() = %v, want %v", got, tc.want)
+ }
+ })
+ }
+}
+
+func TestPopulatePage(t *testing.T) {
+ ctx := context.Background()
+ lg := testutil.Slogger(t)
+ db := storage.MemDB()
+ dc := docs.New(lg, db)
+ vector := storage.MemVectorDB(db, lg, "vector")
+ dc.Add("id1", "hello", "hello world")
+ embedder := llm.QuoteEmbedder()
+ embeddocs.Sync(ctx, lg, vector, embedder, dc)
+ g := &Gaby{
+ slog: lg,
+ db: db,
+ vector: vector,
+ docs: dc,
+ embed: embedder,
}
- var buf bytes.Buffer
- if err := searchPageTmpl.Execute(&buf, page); err != nil {
- t.Fatal(err)
- }
- wants := []string{page.Query}
- for _, sr := range page.Results {
- wants = append(wants, sr.VectorResult.ID)
- }
- got := buf.String()
- t.Logf("%s", got)
- for _, w := range wants {
- if !strings.Contains(got, w) {
- t.Errorf("did not find %q in HTML", w)
- }
+ for _, tc := range []struct {
+ name string
+ url string
+ want searchPage
+ }{
+ {
+ name: "query",
+ url: "test/search?q=hello",
+ want: searchPage{
+ searchForm: searchForm{
+ Query: "hello",
+ },
+ Results: []search.Result{
+ {
+ Kind: search.KindUnknown,
+ Title: "hello",
+ VectorResult: storage.VectorResult{
+ ID: "id1",
+ Score: 0.526,
+ },
+ },
+ }},
+ },
+ {
+ name: "id lookup",
+ url: "test/search?q=id1",
+ want: searchPage{
+ searchForm: searchForm{
+ Query: "id1",
+ },
+ Results: []search.Result{{
+ Kind: search.KindUnknown,
+ Title: "hello",
+ VectorResult: storage.VectorResult{
+ ID: "id1",
+ Score: 1, // exact same
+ },
+ }}},
+ },
+ {
+ name: "options",
+ url: "test/search?q=id1&threshold=.5&limit=10&allow_kind=&deny_kind=Unknown,GoBlog",
+ want: searchPage{
+ searchForm: searchForm{
+ Query: "id1",
+ Threshold: ".5",
+ Limit: "10",
+ Allow: "",
+ Deny: "Unknown,GoBlog",
+ },
+ // No results (blocked by DenyKind)
+ },
+ },
+ {
+ name: "error",
+ url: "test/search?q=id1&deny_kind=Invalid",
+ want: searchPage{
+ searchForm: searchForm{
+ Query: "id1",
+ Deny: "Invalid",
+ },
+ SearchError: `invalid form value: unrecognized deny kind "Invalid" (case-sensitive)`,
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ r, err := http.NewRequest(http.MethodGet, tc.url, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := g.populatePage(r)
+ if !reflect.DeepEqual(got, tc.want) {
+ t.Errorf("Gaby.search() = %v, want %v", got, tc.want)
+ }
+ })
}
}
diff --git a/internal/gaby/tmpl/searchpage.tmpl b/internal/gaby/tmpl/searchpage.tmpl
index 8973b34..6a80b97 100644
--- a/internal/gaby/tmpl/searchpage.tmpl
+++ b/internal/gaby/tmpl/searchpage.tmpl
@@ -15,8 +15,9 @@
<h1>Gaby search</h1>
<p>Search Gaby's database of GitHub issues and Go documentation.</p>
<div class="filter-tips-box">
- <div class="toggle" onclick="toggle()">[show/hide filter tips]</div>
+ <div class="toggle" onclick="toggle()">[show/hide input tips]</div>
<ul id="filter-tips">
+ <li><b>query</b> (string): the text to search for neigbors of OR the ID (usually a URL) of a document in the vector database</li>
<li><b>min similarity</b> (<code>float64</code> between 0 and 1): similarity cutoff (default: 0, allow all)</li>
<li><b>max results</b> (<code>int</code>): maximum number of results to display (default: 20)</li>
<li><b>include types</b> (comma-separated list): document types to include, e.g <code>GitHubIssue,GoBlog</code> (default: empty, include all)</li>
@@ -30,19 +31,19 @@
</span>
<span>
<label for="threshold">min similarity</label>
- <input id="threshold" type="text" name="threshold" value="{{.Threshold}}" required autofocus />
+ <input id="threshold" type="text" name="threshold" value="{{.Threshold}}" optional autofocus />
</span>
<span>
<label for="limit">max results</label>
- <input id="limit" type="text" name="limit" value="{{.Limit}}" required autofocus />
+ <input id="limit" type="text" name="limit" value="{{.Limit}}" optional autofocus />
</span>
<span>
- <label for="allow_kind">allow types</label>
- <input id="allow_kind" type="text" name="allow_kind" value="{{.AllowStr}}" optional autofocus />
+ <label for="allow_kind">include types</label>
+ <input id="allow_kind" type="text" name="allow_kind" value="{{.Allow}}" optional autofocus />
</span>
<span>
<label for="deny_kind">exclude types</code></label>
- <input id="deny_kind" type="text" name="deny_kind" value="{{.DenyStr}}" optional autofocus />
+ <input id="deny_kind" type="text" name="deny_kind" value="{{.Deny}}" optional autofocus />
</span>
<span class="submit">
<input type="submit" value="search"/>
@@ -67,7 +68,9 @@
<div class="section">
<div id="working"></div>
- {{with .Results -}}
+ {{- with .SearchError -}}
+ <p>Error: {{.}}</p>
+ {{- else with .Results -}}
{{- range . -}}
<div class="result">
{{if .IDIsURL -}}
diff --git a/internal/search/search.go b/internal/search/search.go
index ff6dccd..d76da58 100644
--- a/internal/search/search.go
+++ b/internal/search/search.go
@@ -32,6 +32,8 @@
// Options are the results filters that can be passed to the search
// functions as part of a [QueryRequest] or [VectorRequest].
+//
+// TODO(tatianabradley): Make kinds case insensitive.
type Options struct {
Threshold float64 // lowest score to keep; default 0. Max is 1.
Limit int // max results (fewer if Threshold is set); 0 means use a fixed default
@@ -82,6 +84,27 @@
return vector(vdb, dc, req.Vector, &req.Options)
}
+// Validate returns an error if any of the options is invalid.
+func (o *Options) Validate() error {
+ if o.Limit < 0 {
+ return fmt.Errorf("limit must be >= 0 (got: %d)", o.Limit)
+ }
+ if o.Threshold < 0 || o.Threshold > 1 {
+ return fmt.Errorf("threshold must be >= 0 and <= 1 (got: %.3f)", o.Threshold)
+ }
+ for _, allow := range o.AllowKind {
+ if _, ok := kinds[allow]; !ok {
+ return fmt.Errorf("unrecognized allow kind %q (case-sensitive)", allow)
+ }
+ }
+ for _, deny := range o.DenyKind {
+ if _, ok := kinds[deny]; !ok {
+ return fmt.Errorf("unrecognized deny kind %q (case-sensitive)", deny)
+ }
+ }
+ return nil
+}
+
func vector(vdb storage.VectorDB, dc *docs.Corpus, vec llm.Vector, opts *Options) []Result {
limit := defaultLimit
if opts.Limit > 0 {
@@ -159,6 +182,16 @@
KindUnknown = "Unknown"
)
+// Set of recognized document kinds.
+var kinds = map[string]bool{
+ KindGitHubIssue: true,
+ KindGoWiki: true,
+ KindGoDocumentation: true,
+ KindGoBlog: true,
+ KindGoDevPage: true,
+ KindUnknown: true,
+}
+
// docIDKind determines the kind of document from its ID.
// It returns the empty string if it cannot do so.
//