internal/{llmapp,search,gaby}: add related entities overview
Add a feature to the Gaby overviews page that allows the user
to search for, and generate a summary of, documents related
to the query.
To support this, add functionality to the llmapp and search packages
that make the underlying queries.
Change-Id: I885e27c1943c8b656ac89bdc5dbbad46db5a8115
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/623356
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Hyang-Ah Hana Kim <hyangah@gmail.com>
diff --git a/internal/gaby/overview.go b/internal/gaby/overview.go
index e0c6fb0..aa6be0e 100644
--- a/internal/gaby/overview.go
+++ b/internal/gaby/overview.go
@@ -5,34 +5,65 @@
package main
import (
+ "context"
"fmt"
"net/http"
"strconv"
"strings"
"time"
- "github.com/google/safehtml"
"github.com/google/safehtml/template"
"golang.org/x/oscar/internal/github"
"golang.org/x/oscar/internal/htmlutil"
+ "golang.org/x/oscar/internal/search"
)
// overviewPage holds the fields needed to display the results
// of a search.
type overviewPage struct {
- overviewForm // the raw form inputs
- Result *overviewResult
- Error string // if non-empty, the error to display instead of the result
+ Form overviewForm // the raw form inputs
+ Result *overviewResult
+ Error error // if non-nil, the error to display instead of the result
}
type overviewResult struct {
- github.IssueOverviewResult // the raw result
- OverviewHTML safehtml.HTML // the overview as HTML
+ github.IssueOverviewResult // the raw result
+ Type string // the type of overview
}
// overviewForm holds the raw inputs to the overview form.
type overviewForm struct {
- Query string // the issue ID to lookup
+ Query string // the issue ID to lookup
+ OverviewType string // the type of overview to generate
+}
+
+// the possible overview types
+const (
+ issueOverviewType = "issue"
+ relatedOverviewType = "related"
+)
+
+// IsIssueOverview reports whether this overview result
+// is of type [issueOverviewType].
+func (r *overviewResult) IsIssueOverview() bool {
+ return r.Type == issueOverviewType
+}
+
+// CheckRadio reports whether radio button with the given id
+// should be checked.
+func (p overviewPage) CheckRadio(id string) bool {
+ // checked returns the id of the radio button that should be checked.
+ checked := func() string {
+ // If there is no result yet, the default option
+ // (issue overview) should be checked.
+ if p.Result == nil {
+ return issueOverviewType
+ }
+ // Otherwise, the button corresponding to the result
+ // type should be checked.
+ return p.Result.Type
+ }
+ return id == checked()
}
func (g *Gaby) handleOverview(w http.ResponseWriter, r *http.Request) {
@@ -40,7 +71,8 @@
}
var overviewPageTmpl = newTemplate(overviewPageTmplFile, template.FuncMap{
- "fmttime": fmtTimeString,
+ "fmttime": fmtTimeString,
+ "safehtml": htmlutil.MarkdownToSafeHTML,
})
// fmtTimeString formats an [time.RFC3339]-encoded time string
@@ -58,35 +90,76 @@
// populateOverviewPage returns the contents of the overview page.
func (g *Gaby) populateOverviewPage(r *http.Request) overviewPage {
- form := overviewForm{
- Query: r.FormValue("q"),
- }
- if form.Query == "" {
- return overviewPage{
- overviewForm: form,
- }
- }
- issue, err := strconv.Atoi(strings.TrimSpace(form.Query))
- if err != nil {
- return overviewPage{
- overviewForm: form,
- Error: fmt.Errorf("invalid form value %q: %w", form.Query, err).Error(),
- }
- }
- overview, err := github.IssueOverview(r.Context(), g.llm, g.db, g.githubProject, int64(issue))
- if err != nil {
- return overviewPage{
- overviewForm: form,
- Error: fmt.Errorf("overview: %w", err).Error(),
- }
- }
- return overviewPage{
- overviewForm: form,
- Result: &overviewResult{
- IssueOverviewResult: *overview,
- OverviewHTML: htmlutil.MarkdownToSafeHTML(overview.Overview.Overview),
+ p := overviewPage{
+ Form: overviewForm{
+ Query: r.FormValue("q"),
+ OverviewType: r.FormValue("t"),
},
}
+ q := strings.TrimSpace(p.Form.Query)
+ if q == "" {
+ return p
+ }
+ issue, err := strconv.ParseInt(q, 10, 64)
+ if err != nil {
+ p.Error = fmt.Errorf("invalid form value %q: %w", q, err)
+ return p
+ }
+ if issue < 0 {
+ p.Error = fmt.Errorf("invalid form value %q", q)
+ return p
+ }
+ overview, err := g.overview(r.Context(), issue, p.Form.OverviewType)
+ if err != nil {
+ p.Error = err
+ return p
+ }
+ p.Result = overview
+ return p
+}
+
+// overview generates an overview of the issue of the given type.
+func (g *Gaby) overview(ctx context.Context, issue int64, overviewType string) (*overviewResult, error) {
+ switch overviewType {
+ case "", issueOverviewType:
+ return g.issueOverview(ctx, issue)
+ case relatedOverviewType:
+ return g.relatedOverview(ctx, issue)
+ default:
+ return nil, fmt.Errorf("unknown overview type %q", overviewType)
+ }
+}
+
+// issueOverview generates an overview of the issue and its comments.
+func (g *Gaby) issueOverview(ctx context.Context, issue int64) (*overviewResult, error) {
+ overview, err := github.IssueOverview(ctx, g.llm, g.db, g.githubProject, issue)
+ if err != nil {
+ return nil, err
+ }
+ return &overviewResult{
+ IssueOverviewResult: *overview,
+ Type: issueOverviewType,
+ }, nil
+}
+
+// relatedOverview generates an overview of the issue and its related documents.
+func (g *Gaby) relatedOverview(ctx context.Context, issue int64) (*overviewResult, error) {
+ iss, err := github.LookupIssue(g.db, g.githubProject, issue)
+ if err != nil {
+ return nil, err
+ }
+ overview, err := search.Overview(ctx, g.llm, g.vector, g.docs, iss.DocID())
+ if err != nil {
+ return nil, err
+ }
+ return &overviewResult{
+ IssueOverviewResult: github.IssueOverviewResult{
+ Issue: iss,
+ // number of comments not displayed for related type
+ Overview: overview.OverviewResult,
+ },
+ Type: relatedOverviewType,
+ }, nil
}
// Related returns the relative URL of the related-entity search
diff --git a/internal/gaby/overview_test.go b/internal/gaby/overview_test.go
new file mode 100644
index 0000000..5e74b9a
--- /dev/null
+++ b/internal/gaby/overview_test.go
@@ -0,0 +1,190 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "context"
+ "net/http"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "golang.org/x/oscar/internal/docs"
+ "golang.org/x/oscar/internal/embeddocs"
+ "golang.org/x/oscar/internal/github"
+ "golang.org/x/oscar/internal/llmapp"
+)
+
+func TestPopulateOverviewPage(t *testing.T) {
+ g := newTestGaby(t)
+
+ // Add test data relevant to this test.
+ project := "hello/world"
+ g.githubProject = project
+ g.github.Add(project)
+
+ iss1 := &github.Issue{
+ URL: "https://api.github.com/repos/hello/world/issues/1",
+ HTMLURL: "https://github.com/hello/world/issues/1",
+ Number: 1,
+ Title: "hello",
+ Body: "hello world",
+ }
+ iss2 := &github.Issue{
+ URL: "https://api.github.com/repos/hello/world/issues/2",
+ HTMLURL: "https://github.com/hello/world/issues/2",
+ Number: 2,
+ Title: "hello 2",
+ Body: "hello world 2",
+ }
+ g.github.Testing().AddIssue(project, iss1)
+ comment := &github.IssueComment{
+ Body: "a comment",
+ }
+ g.github.Testing().AddIssueComment(project, 1, comment)
+ g.github.Testing().AddIssue(project, iss2)
+
+ ctx := context.Background()
+ docs.Sync(g.docs, g.github)
+ embeddocs.Sync(ctx, g.slog, g.vector, g.embed, g.docs)
+
+ // Generate expected overviews.
+ wantIssueOverview, err := g.llm.PostOverview(ctx, &llmapp.Doc{
+ Type: "issue",
+ URL: iss1.HTMLURL,
+ Title: iss1.Title,
+ Text: iss1.Body,
+ }, []*llmapp.Doc{
+ {
+ Type: "issue comment",
+ URL: comment.HTMLURL,
+ Text: comment.Body,
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantRelatedOverview, err := g.llm.RelatedOverview(ctx, &llmapp.Doc{
+ Type: "main",
+ URL: iss1.HTMLURL,
+ Title: iss1.Title,
+ Text: iss1.Body,
+ }, []*llmapp.Doc{
+ {
+ Type: "related",
+ URL: iss2.HTMLURL,
+ Title: iss2.Title,
+ Text: iss2.Body,
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for _, tc := range []struct {
+ name string
+ r *http.Request
+ want overviewPage
+ }{
+ {
+ name: "empty",
+ r: &http.Request{},
+ want: overviewPage{},
+ },
+ {
+ name: "issue overview (default)",
+ r: &http.Request{
+ Form: map[string][]string{
+ "q": {"1"},
+ },
+ },
+ want: overviewPage{
+ Form: overviewForm{
+ Query: "1",
+ OverviewType: "",
+ },
+ Result: &overviewResult{
+ IssueOverviewResult: github.IssueOverviewResult{
+ Issue: iss1,
+ NumComments: 1,
+ Overview: wantIssueOverview,
+ },
+ Type: issueOverviewType,
+ },
+ },
+ },
+ {
+ name: "issue overview (explicit)",
+ r: &http.Request{
+ Form: map[string][]string{
+ "q": {"1"},
+ "t": {issueOverviewType},
+ },
+ },
+ want: overviewPage{
+ Form: overviewForm{
+ Query: "1",
+ OverviewType: issueOverviewType,
+ },
+ Result: &overviewResult{
+ IssueOverviewResult: github.IssueOverviewResult{
+ Issue: iss1,
+ NumComments: 1,
+ Overview: wantIssueOverview,
+ },
+ Type: issueOverviewType,
+ },
+ },
+ },
+ {
+ name: "related overview",
+ r: &http.Request{
+ Form: map[string][]string{
+ "q": {"1"},
+ "t": {relatedOverviewType},
+ },
+ },
+ want: overviewPage{
+ Form: overviewForm{
+ Query: "1",
+ OverviewType: relatedOverviewType,
+ },
+ Result: &overviewResult{
+ IssueOverviewResult: github.IssueOverviewResult{
+ Issue: iss1,
+ Overview: wantRelatedOverview,
+ },
+ Type: relatedOverviewType,
+ },
+ },
+ },
+ {
+ name: "error",
+ r: &http.Request{
+ Form: map[string][]string{
+ "q": {"3"}, // not in DB
+ "t": {relatedOverviewType},
+ },
+ },
+ want: overviewPage{
+ Form: overviewForm{
+ Query: "3",
+ OverviewType: relatedOverviewType,
+ },
+ Error: cmpopts.AnyError,
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ got := g.populateOverviewPage(tc.r)
+ if diff := cmp.Diff(got, tc.want,
+ cmpopts.IgnoreFields(llmapp.OverviewResult{}, "Cached"),
+ cmpopts.EquateErrors()); diff != "" {
+ t.Errorf("Gaby.populateOverviewPage() mismatch (-got +want):\n%s", diff)
+ }
+ })
+ }
+
+}
diff --git a/internal/gaby/search_test.go b/internal/gaby/search_test.go
index c30dd2a..6abf88e 100644
--- a/internal/gaby/search_test.go
+++ b/internal/gaby/search_test.go
@@ -13,9 +13,11 @@
"testing"
"golang.org/x/oscar/internal/docs"
- "golang.org/x/oscar/internal/embeddocs"
+ "golang.org/x/oscar/internal/github"
"golang.org/x/oscar/internal/llm"
+ "golang.org/x/oscar/internal/llmapp"
"golang.org/x/oscar/internal/search"
+ "golang.org/x/oscar/internal/secret"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/testutil"
)
@@ -200,21 +202,11 @@
}
func TestPopulatePage(t *testing.T) {
- ctx := context.Background()
- lg := testutil.Slogger(t)
- db := storage.MemDB()
- dc := docs.New(lg, db)
- vector := storage.MemVectorDB(db, lg, "vector")
- dc.Add("id1", "hello", "hello world")
- embedder := llm.QuoteEmbedder()
- embeddocs.Sync(ctx, lg, vector, embedder, dc)
- g := &Gaby{
- slog: lg,
- db: db,
- vector: vector,
- docs: dc,
- embed: embedder,
- }
+ g := newTestGaby(t)
+
+ // Add test data relevant for this test.
+ g.docs.Add("id1", "hello", "hello world")
+ g.embedAll(context.Background())
for _, tc := range []struct {
name string
@@ -293,3 +285,22 @@
})
}
}
+
+func newTestGaby(t *testing.T) *Gaby {
+ t.Helper()
+
+ lg := testutil.Slogger(t)
+ db := storage.MemDB()
+
+ g := &Gaby{
+ slog: lg,
+ db: db,
+ vector: storage.MemVectorDB(db, lg, "vector"),
+ github: github.New(lg, db, secret.Empty(), nil),
+ llm: llmapp.New(lg, llm.EchoTextGenerator(), db),
+ docs: docs.New(lg, db),
+ embed: llm.QuoteEmbedder(),
+ }
+
+ return g
+}
diff --git a/internal/gaby/static/overview.css b/internal/gaby/static/overview.css
new file mode 100644
index 0000000..c12b568
--- /dev/null
+++ b/internal/gaby/static/overview.css
@@ -0,0 +1,28 @@
+/*
+Copyright 2024 The Go Authors. All rights reserved.
+Use of this source code is governed by a BSD-style
+license that can be found in the LICENSE file.
+*/
+label {
+ min-width: fit-content;
+ width: 10em;
+}
+input {
+ min-width: fit-content;
+ width: 1em;
+}
+#prompt {
+ display: none;
+}
+#prompt ul {
+ list-style-type: none;
+ margin: 0;
+ padding: 0;
+}
+#overview {
+ padding: 0em 1em 1em 1em;
+ width: 75%;
+ margin: 1em;
+ border: solid black .1em;
+ border-radius: .25em;
+}
\ No newline at end of file
diff --git a/internal/gaby/static/search.css b/internal/gaby/static/search.css
index fde5273..0b23562 100644
--- a/internal/gaby/static/search.css
+++ b/internal/gaby/static/search.css
@@ -10,14 +10,12 @@
display: block;
padding-bottom: .2em
}
-label {
+label,input {
display: inline-block;
width: 20%;
- margin-right: .1em;
}
-input {
- display: inline-block;
- width: 20%;
+label {
+ margin-right: .1em;
}
input.submit {
width: 10%;
@@ -59,21 +57,6 @@
#filter-tips {
display: none;
}
-#prompt {
- display: none;
-}
-#prompt ul {
- list-style-type: none;
- margin: 0;
- padding: 0;
-}
-#overview {
- padding: 0em 1em 1em 1em;
- width: 75%;
- margin: 1em;
- border: solid black .1em;
- border-radius: .25em;
-}
.submit {
padding-top: .5em;
}
\ No newline at end of file
diff --git a/internal/gaby/templates_test.go b/internal/gaby/templates_test.go
index 3b4236d..3500ed9 100644
--- a/internal/gaby/templates_test.go
+++ b/internal/gaby/templates_test.go
@@ -13,7 +13,6 @@
"strings"
"testing"
- "github.com/google/safehtml"
"github.com/google/safehtml/template"
"golang.org/x/net/html"
"golang.org/x/net/html/atom"
@@ -36,7 +35,7 @@
}},
{"overview-initial", overviewPageTmpl, overviewPage{}},
{"overview", overviewPageTmpl, overviewPage{
- overviewForm: overviewForm{Query: "12"},
+ Form: overviewForm{Query: "12"},
Result: &overviewResult{
IssueOverviewResult: github.IssueOverviewResult{
Issue: &github.Issue{
@@ -51,8 +50,12 @@
Prompt: []string{"a prompt"},
},
},
- OverviewHTML: safehtml.HTMLEscaped("an overview"),
+ Type: issueOverviewType,
}}},
+ {"overview-error", overviewPageTmpl, overviewPage{
+ Form: overviewForm{Query: "12"},
+ Error: fmt.Errorf("an error"),
+ }},
} {
t.Run(test.name, func(t *testing.T) {
var buf bytes.Buffer
diff --git a/internal/gaby/tmpl/overviewpage.tmpl b/internal/gaby/tmpl/overviewpage.tmpl
index 07470e8..e8406ac 100644
--- a/internal/gaby/tmpl/overviewpage.tmpl
+++ b/internal/gaby/tmpl/overviewpage.tmpl
@@ -9,21 +9,37 @@
<title>Oscar Overviews</title>
<link rel="stylesheet" href="static/style.css"/>
<link rel="stylesheet" href="static/search.css"/>
+ <link rel="stylesheet" href="static/overview.css"/>
</head>
<body>
<div class="section" class="header">
<h1>Oscar Overviews</h1>
- <p>Generate summaries of posts and their comments. This is a first draft (and currently limited to golang/go GitHub issues). Feedback welcome!</p>
+ <p>This tool can:</p>
+ <ul>
+ <li>Summarize a golang/go issue and its comments</li>
+ <li>Summarize the relationship between a golang/go issue and its "related documents"</li>
+ </ul>
+ <p>This is a first draft. Feedback welcome!</p>
<div class="filter-tips-box">
<div class="toggle" onclick="toggleTips()">[show/hide input tips]</div>
<ul id="filter-tips">
<li><b>issue</b> (<code>int</code>): the issue ID (in the github.com/golang/go repo) of the issue to summarize</li>
+ <li><b>overview type</b> (choice): "issue and comments" generates an overview of the issue and its comments; "related documents" searches for related documents and summarizes them</li>
</ul>
</div>
<form id="form" action="/overview" method="GET">
<span>
<label for="query"><b>issue</b></label>
- <input id="query" type="text" name="q" value="{{.Query}}" required autofocus />
+ <input id="query" type="text" name="q" value="{{.Form.Query}}" required autofocus />
+ </span>
+ <span><label><b>overview type:</b></label></span>
+ <span>
+ <label for="issue-overview">issue and comments</label>
+ <input type="radio" id="issue" name="t" value="issue" {{if .CheckRadio "issue"}}checked="checked"{{end}} required autofocus />
+ </span>
+ <span>
+ <label for="related">related documents</label>
+ <input type="radio" id="related" name="t" value="related" {{if .CheckRadio "related"}}checked="checked"{{end}} required autofocus />
</span>
<span class="submit">
<input type="submit" value="generate"/>
@@ -56,19 +72,15 @@
<div class="section">
<div id="working"></div>
{{- with .Error -}}
- <p>Error: {{.}}</p>
+ <p>Error: {{.Error}}</p>
{{- else with .Result -}}
<div class="result">
<p><a href="{{.HTMLURL}}" target="_blank">{{.HTMLURL}}</a></p>
<p><strong>{{.Title}}</strong></p>
- <p>Author: {{.User.Login}}</p>
- <p>State: {{.State}}</p>
- <p>Created: {{fmttime .CreatedAt}}</p>
- <p>Updated: {{fmttime .UpdatedAt}}</p>
- <p>Number of comments: {{.NumComments}}</p>
+ <p>author: {{.User.Login}} | state: {{.State}} | created: {{fmttime .CreatedAt}} | updated: {{fmttime .UpdatedAt}}{{if .IsIssueOverview}} | comments: {{.NumComments}}{{end}}</p>
<p><a href="{{.Related}}" target="_blank">[Search for related issues]</a></p>
<p>AI-generated overview{{if .Overview.Cached}} (cached){{end}}:</p>
- <div id="overview">{{.OverviewHTML}}</div>
+ <div id="overview">{{safehtml .Overview.Overview}}</div>
</div>
<div class="toggle" onclick="togglePrompt()">[show prompt]</div>
<div id="prompt">
@@ -81,7 +93,7 @@
</ul>
</div>
{{- else }}
- {{if .Query}}<p>No result.</p>{{end}}
+ {{if .Form.Query}}<p>No result.</p>{{end}}
{{- end}}
</div>
</body>
diff --git a/internal/github/data.go b/internal/github/data.go
index 2f6a699..89f05a7 100644
--- a/internal/github/data.go
+++ b/internal/github/data.go
@@ -45,12 +45,19 @@
return bad()
}
- for e := range c.Events(proj, n, n) {
+ return LookupIssue(c.db, proj, n)
+}
+
+// LookupIssue looks up an issue by project and issue number
+// (for example "golang/go", 12345), only consulting the database
+// (not actual GitHub).
+func LookupIssue(db storage.DB, project string, issue int64) (*Issue, error) {
+ for e := range events(db, project, issue, issue) {
if e.API == "/issues" {
return e.Typed.(*Issue), nil
}
}
- return nil, fmt.Errorf("%s#%d not in database", proj, n)
+ return nil, fmt.Errorf("github.LookupIssue: issue %s#%d not in database", project, issue)
}
// An Event is a single GitHub issue event stored in the database.
@@ -329,6 +336,12 @@
return urlToProject(x.URL)
}
+// DocID returns the ID of this issue for storage in a docs.Corpus
+// or a storage.VectorDB.
+func (i *Issue) DocID() string {
+ return i.HTMLURL
+}
+
// Methods implementing model.Post.
func (x *Issue) ID() string { return x.URL }
func (x *Issue) Title_() string { return x.Title }
diff --git a/internal/github/overview.go b/internal/github/overview.go
index ee4716d..427df87 100644
--- a/internal/github/overview.go
+++ b/internal/github/overview.go
@@ -6,6 +6,7 @@
import (
"context"
+ "fmt"
"golang.org/x/oscar/internal/llmapp"
"golang.org/x/oscar/internal/storage"
@@ -40,6 +41,9 @@
}
comments = append(comments, doc)
}
+ if post == nil {
+ return nil, fmt.Errorf("github.IssueOverview: issue %d not in db", issue)
+ }
overview, err := lc.PostOverview(ctx, post, comments)
if err != nil {
return nil, err
diff --git a/internal/github/sync.go b/internal/github/sync.go
index 1a4fefb..395bf8a 100644
--- a/internal/github/sync.go
+++ b/internal/github/sync.go
@@ -125,7 +125,7 @@
}
return slices.Values([]*docs.Doc{
{
- ID: fmt.Sprintf("https://github.com/%s/issues/%d", e.Project, e.Issue),
+ ID: issue.DocID(),
Title: CleanTitle(issue.Title),
Text: CleanBody(issue.Body),
},
diff --git a/internal/llmapp/overview.go b/internal/llmapp/overview.go
index 0f03629..7251f4c 100644
--- a/internal/llmapp/overview.go
+++ b/internal/llmapp/overview.go
@@ -85,6 +85,20 @@
return c.overview(ctx, postAndComments, append([]*Doc{post}, comments...))
}
+// RelatedOverview returns an LLM-generated overview of the given document and
+// related documents, styled with markdown.
+// RelatedOverview returns an error if no initial document is provided, no related docs are
+// provided, or the LLM is unable to generate a response.
+func (c *Client) RelatedOverview(ctx context.Context, doc *Doc, related []*Doc) (*OverviewResult, error) {
+ if doc == nil {
+ return nil, errors.New("llmapp RelatedOverview: no doc")
+ }
+ if len(related) == 0 {
+ return nil, errors.New("llmapp RelatedOverview: no related docs")
+ }
+ return c.overview(ctx, docAndRelated, append([]*Doc{doc}, related...))
+}
+
// overview returns an LLM-generated overview of the given documents,
// styled with markdown.
// The kind argument is a descriptor for the given documents, used to add
@@ -127,6 +141,9 @@
// The documents represent a post and comments/replies
// on that post. For example, a GitHub issue and its comments.
postAndComments docsKind = "post_and_comments"
+ // The documents represent a document followed by documents
+ // that are related to it in some way.
+ docAndRelated docsKind = "doc_and_related"
)
//go:embed prompts/*.tmpl
diff --git a/internal/llmapp/overview_test.go b/internal/llmapp/overview_test.go
index 9a1f180..ce85176 100644
--- a/internal/llmapp/overview_test.go
+++ b/internal/llmapp/overview_test.go
@@ -51,6 +51,21 @@
t.Errorf("PostOverview() mismatch (-got +want):\n%s", diff)
}
})
+
+ t.Run("RelatedOverview", func(t *testing.T) {
+ got, err := c.RelatedOverview(ctx, doc1, []*Doc{doc2})
+ if err != nil {
+ t.Fatal(err)
+ }
+ promptParts := []string{raw1, raw2, docAndRelated.instructions()}
+ want := &OverviewResult{
+ Overview: llm.EchoResponse(promptParts...),
+ Prompt: promptParts,
+ }
+ if diff := cmp.Diff(got, want); diff != "" {
+ t.Errorf("RelatedOverview() mismatch (-got +want):\n%s", diff)
+ }
+ })
}
var (
@@ -149,26 +164,37 @@
}
func TestInstructions(t *testing.T) {
- wantAll := "markdown" // in all instructions
- wantPost := "post" // only in PostAndComments
+ wantAll := "markdown" // in all instructions
+ wantPost := "post" // only in postAndComments
+ wantRelated := "related" // only in docAndRelated
- t.Run("Documents", func(t *testing.T) {
+ t.Run("documents", func(t *testing.T) {
di := documents.instructions()
if !strings.Contains(di, wantAll) {
- t.Errorf("Documents.instructions(): does not contain %q", wantAll)
+ t.Errorf("documents.instructions(): does not contain %q", wantAll)
}
if strings.Contains(di, wantPost) {
- t.Errorf("Documents.instructions(): incorrectly contains %q", wantPost)
+ t.Errorf("documents.instructions(): incorrectly contains %q", wantPost)
}
})
- t.Run("PostAndComments", func(t *testing.T) {
+ t.Run("postAndComments", func(t *testing.T) {
pi := postAndComments.instructions()
if !strings.Contains(pi, wantAll) {
- t.Fatalf("PostAndComments.instructions(): does not contain %q", wantAll)
+ t.Fatalf("postAndComments.instructions(): does not contain %q", wantAll)
}
if !strings.Contains(pi, wantPost) {
- t.Fatalf("PostAndComments.instructions(): does not contain %q", wantPost)
+ t.Fatalf("postAndComments.instructions(): does not contain %q", wantPost)
+ }
+ })
+
+ t.Run("DocAndRelated", func(t *testing.T) {
+ pi := docAndRelated.instructions()
+ if !strings.Contains(pi, wantAll) {
+ t.Fatalf("docAndRelated.instructions(): does not contain %q", wantAll)
+ }
+ if !strings.Contains(pi, wantRelated) {
+ t.Fatalf("docAndRelated.instructions(): does not contain %q", wantPost)
}
})
}
diff --git a/internal/llmapp/prompts/common.tmpl b/internal/llmapp/prompts/common.tmpl
index 54f4653..6bfa7b7 100644
--- a/internal/llmapp/prompts/common.tmpl
+++ b/internal/llmapp/prompts/common.tmpl
@@ -9,5 +9,6 @@
Citation Requirements:
Every summary point, whether paraphrased or quoted, MUST be cited appropriately.
Cite sources using this format: (author, [Type](URL)). For example: (oscar, [issue](github.com/issue/19)).
+If no author, use this citation format: ([Type](URL)).
Do not fabricate any information or citations. If no comments are present, state that explicitly.
{{end}}
\ No newline at end of file
diff --git a/internal/llmapp/prompts/doc_and_related.tmpl b/internal/llmapp/prompts/doc_and_related.tmpl
new file mode 100644
index 0000000..4685ea0
--- /dev/null
+++ b/internal/llmapp/prompts/doc_and_related.tmpl
@@ -0,0 +1,13 @@
+{{define "doc_and_related"}}
+The documents represent a document followed by related documents.
+
+Steps:
+
+1. (Heading ## Original Document) Summarize the main points of the original document.
+2. (Heading ## Related Documents) For each related document use this format:
+ ### Title ([URL](URL)) for example: ### the document title ([example.com](example.com))
+ * **Summary**: Summarize the document in one sentence.
+ * **Relationship**: Explain how the document is related to the original document.
+
+{{template "requirements"}}
+{{end}}
\ No newline at end of file
diff --git a/internal/search/overview.go b/internal/search/overview.go
new file mode 100644
index 0000000..4c7d90a
--- /dev/null
+++ b/internal/search/overview.go
@@ -0,0 +1,90 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package search
+
+import (
+ "context"
+ "fmt"
+
+ "golang.org/x/oscar/internal/docs"
+ "golang.org/x/oscar/internal/llmapp"
+ "golang.org/x/oscar/internal/storage"
+)
+
+// OverviewResult is the result of [Overview].
+type OverviewResult struct {
+ *llmapp.OverviewResult // the LLM-generated overview
+}
+
+// Overview returns an LLM-generated overview of a document and its related documents.
+// id is the ID of the main document, which must be present in both the docs corpus and the vector db.
+// Overview finds related documents using vector search (see [Vector]) with fixed options.
+func Overview(ctx context.Context, lc *llmapp.Client, vdb storage.VectorDB, dc *docs.Corpus, id string) (*OverviewResult, error) {
+ doc, ok := llmDoc(dc, "main", id)
+ if !ok {
+ return nil, fmt.Errorf("search.Overview: main doc %q not in docs corpus", id)
+ }
+ rs, err := searchRelated(vdb, dc, id)
+ if err != nil {
+ return nil, err
+ }
+ var related []*llmapp.Doc
+ for _, r := range rs {
+ d, ok := llmDoc(dc, "related", r.ID)
+ if !ok {
+ return nil, fmt.Errorf("search.Overview: related doc %s not in docs corpus", id)
+ }
+ related = append(related, d)
+ }
+ overview, err := lc.RelatedOverview(ctx, doc, related)
+ if err != nil {
+ return nil, err
+ }
+ return &OverviewResult{overview}, nil
+}
+
+var maxResults = 5
+
+// searchRelated finds up to 5 documents related to the document
+// identified by id in vdb.
+func searchRelated(vdb storage.VectorDB, dc *docs.Corpus, id string) ([]Result, error) {
+ v, ok := vdb.Get(id)
+ if !ok {
+ return nil, fmt.Errorf("search: main doc %q not in vector db", id)
+ }
+ rs := Vector(vdb, dc, &VectorRequest{
+ Options: Options{
+ Limit: maxResults + 1, // buffer for self
+ },
+ Vector: v,
+ })
+ // Remove the query itself if present.
+ if len(rs) > 0 && rs[0].ID == id {
+ rs = rs[1:]
+ }
+ // Trim length.
+ if len(rs) > maxResults {
+ rs = rs[:maxResults]
+ }
+ return rs, nil
+}
+
+// llmDoc converts the document in dc identified by id into
+// an [*llmapp.Doc].
+func llmDoc(dc *docs.Corpus, t string, id string) (*llmapp.Doc, bool) {
+ d, ok := dc.Get(id)
+ if !ok {
+ return nil, false
+ }
+ doc := &llmapp.Doc{
+ Type: t,
+ Title: d.Title,
+ Text: d.Text,
+ }
+ if isURL(d.ID) {
+ doc.URL = d.ID
+ }
+ return doc, true
+}
diff --git a/internal/search/overview_test.go b/internal/search/overview_test.go
new file mode 100644
index 0000000..96c2932
--- /dev/null
+++ b/internal/search/overview_test.go
@@ -0,0 +1,79 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package search
+
+import (
+ "context"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "golang.org/x/oscar/internal/docs"
+ "golang.org/x/oscar/internal/embeddocs"
+ "golang.org/x/oscar/internal/llm"
+ "golang.org/x/oscar/internal/llmapp"
+ "golang.org/x/oscar/internal/storage"
+ "golang.org/x/oscar/internal/testutil"
+)
+
+func TestOverview(t *testing.T) {
+ ctx := context.Background()
+ lg := testutil.Slogger(t)
+ g := llm.EchoTextGenerator()
+ db := storage.MemDB()
+ lc := llmapp.New(lg, g, db)
+ vdb := storage.MemVectorDB(db, lg, "test")
+ dc := docs.New(lg, db)
+
+ mr := maxResults
+ maxResults = 1
+ t.Cleanup(func() {
+ maxResults = mr
+ })
+
+ id := "https://example.com/123"
+ dc.Add(id, "title", "text")
+ dc.Add("456", "title2", "text2")
+ dc.Add("3", "title3", "text3")
+
+ // Add the documents to vdb.
+ testutil.Check(t, embeddocs.Sync(ctx, lg, vdb, llm.QuoteEmbedder(), dc))
+
+ got, err := Overview(ctx, lc, vdb, dc, id)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ doc1 := &llmapp.Doc{
+ Type: "main",
+ URL: id,
+ Title: "title",
+ Text: "text",
+ }
+ doc2 := &llmapp.Doc{
+ Type: "related",
+ // id "456" is not a URL, so it is omitted
+ Title: "title2",
+ Text: "text2",
+ }
+
+ // This checks that the expected call to
+ // [llmapp.Client.RelatedOverview] is made by [Overview].
+ ro, err := lc.RelatedOverview(ctx, doc1, []*llmapp.Doc{doc2})
+ if err != nil {
+ t.Fatal(err)
+ }
+ prompt := ro.Prompt
+
+ want := &OverviewResult{
+ &llmapp.OverviewResult{
+ Overview: llm.EchoResponse(prompt...),
+ Prompt: prompt,
+ },
+ }
+
+ if cmp.Diff(got, want) != "" {
+ t.Errorf("Overview() mismatch (-got +want):\n%s", cmp.Diff(got, want))
+ }
+}
diff --git a/internal/search/search.go b/internal/search/search.go
index 7c6a3a4..dc35feb 100644
--- a/internal/search/search.go
+++ b/internal/search/search.go
@@ -161,9 +161,16 @@
r.Score = math.Round(r.Score*1e3) / 1e3
}
-// IDIsURL reports whether the Result's ID is a valid URL.
+// IDIsURL reports whether the Result's ID is a valid absolute URL.
func (r *Result) IDIsURL() bool {
- _, err := url.Parse(r.ID)
+ return isURL(r.ID)
+}
+
+// isURL reports whether the string is a valid absolute URL.
+func isURL(s string) bool {
+ // Use [url.ParseRequestURI] as it only accepts absolute URLs
+ // ([url.Parse] accepts relative URLs too).
+ _, err := url.ParseRequestURI(s)
return err == nil
}
diff --git a/internal/search/search_test.go b/internal/search/search_test.go
index 88a2b94..412b8a0 100644
--- a/internal/search/search_test.go
+++ b/internal/search/search_test.go
@@ -363,3 +363,25 @@
t.Errorf("\ngot %s\nwant %s", got, want)
}
}
+
+func TestIsURL(t *testing.T) {
+ for _, tc := range []struct {
+ s string
+ want bool
+ }{
+ {"", false},
+ {"435", false},
+ {"example.com/hello", false},
+ {"http://example.com", true},
+ {"https://example.com", true},
+ {"https://example.com/path", true},
+ {"https://example.com/path?query=string", true},
+ {"https://example.com/path#fragment", true},
+ {"https://example.com/path?query=string#fragment", true},
+ } {
+ got := isURL(tc.s)
+ if got != tc.want {
+ t.Errorf("%q: got %t, want %t", tc.s, got, tc.want)
+ }
+ }
+}