blob: b0ff2ee1ccc96e2f363e6b18947f67044b96bd96 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package llmapp
import (
"context"
"encoding/json"
"math/rand/v2"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/testutil"
)
func TestOverview(t *testing.T) {
ctx := context.Background()
c := newTestClient(t)
t.Run("Overview", func(t *testing.T) {
got, err := c.Overview(ctx, doc1, doc2)
if err != nil {
t.Fatal(err)
}
promptParts := []string{raw1, raw2, documents.instructions()}
want := &OverviewResult{
Overview: llm.EchoResponse(promptParts...),
Prompt: promptParts,
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Overview() mismatch (-want +got):\n%s", diff)
}
})
t.Run("PostOverview", func(t *testing.T) {
got, err := c.PostOverview(ctx, doc1, []*Doc{doc2})
if err != nil {
t.Fatal(err)
}
promptParts := []string{raw1, raw2, postAndComments.instructions()}
want := &OverviewResult{
Overview: llm.EchoResponse(promptParts...),
Prompt: promptParts,
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("PostOverview() mismatch (-want +got):\n%s", diff)
}
})
t.Run("RelatedOverview", func(t *testing.T) {
got, err := c.RelatedOverview(ctx, doc1, []*Doc{doc2})
if err != nil {
t.Fatal(err)
}
promptParts := []string{raw1, raw2, docAndRelated.instructions()}
want := &OverviewResult{
Overview: llm.EchoResponse(promptParts...),
Prompt: promptParts,
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("RelatedOverview() mismatch (-want +got):\n%s", diff)
}
})
}
var (
doc1 = &Doc{URL: "https://example.com", Author: "rsc", Title: "title", Text: "some text"}
doc2 = &Doc{Text: "some text 2"}
raw1 = `{"url":"https://example.com","author":"rsc","title":"title","text":"some text"}`
raw2 = `{"text":"some text 2"}`
)
func newTestClient(t *testing.T) *Client {
t.Helper()
return New(testutil.Slogger(t), llm.EchoTextGenerator(), storage.MemDB())
}
func TestGenerateText(t *testing.T) {
ctx := context.Background()
lg := testutil.Slogger(t)
db := storage.MemDB()
t.Run("echo", func(t *testing.T) {
c := New(lg, llm.EchoTextGenerator(), db)
got, cached, err := c.generateText(ctx, []string{"a", "b", "c"})
if err != nil {
t.Fatal(err)
}
want := llm.EchoResponse("a", "b", "c")
if got != want {
t.Errorf("generateText() = %q, want %q", got, want)
}
if cached {
t.Error("generateText() = cached, want not cached")
}
// The result should be cached on the second call.
got, cached, err = c.generateText(ctx, []string{"a", "b", "c"})
if err != nil {
t.Fatal(err)
}
if got != want {
t.Errorf("generateText() = %q, want %q", got, want)
}
if !cached {
t.Error("generateText() = not cached, want cached")
}
})
// Test with a non-deterministic text generator to ensure
// caching actually works.
t.Run("random", func(t *testing.T) {
c := New(lg, random{}, db)
got1, cached, err := c.generateText(ctx, []string{"a", "b", "c"})
if err != nil {
t.Fatal(err)
}
if cached {
t.Error("generateText() = cached, want not cached")
}
got2, cached, err := c.generateText(ctx, []string{"a", "b", "c"})
if err != nil {
t.Fatal(err)
}
if got2 != got1 {
t.Errorf("generateText() = %s, want %s", got2, got1)
}
if !cached {
t.Error("generateText() = not cached, want cached")
}
})
}
// random is an [llm.TextGenerator] that ignores its prompt and
// returns a random integer.
type random struct{}
func (random) Model() string {
return "random"
}
func (random) GenerateText(_ context.Context, s ...string) (string, error) {
return strconv.Itoa(rand.IntN(1000)), nil
}
func TestResponseUnmarshal(t *testing.T) {
// Do not remove or edit this test case without a good reason.
// It ensures that no backwards incompatible changes are made to the [response] struct.
raw := `{"Model":"model","PromptHash":"Qb+qD8ZuYR26qktIqPIbbHTaWm0SaoBaWhwObKH8INg=","Response":"response"}`
var r response
if err := json.Unmarshal([]byte(raw), &r); err != nil {
t.Fatal(err)
}
if r.Response != "response" {
t.Errorf("r.Response = %s, want %s", r.Response, "response")
}
}
func TestInstructions(t *testing.T) {
wantAll := "markdown" // in all instructions
wantPost := "post" // only in postAndComments
wantRelated := "related" // only in docAndRelated
t.Run("documents", func(t *testing.T) {
di := documents.instructions()
if !strings.Contains(di, wantAll) {
t.Errorf("documents.instructions(): does not contain %q", wantAll)
}
if strings.Contains(di, wantPost) {
t.Errorf("documents.instructions(): incorrectly contains %q", wantPost)
}
})
t.Run("postAndComments", func(t *testing.T) {
pi := postAndComments.instructions()
if !strings.Contains(pi, wantAll) {
t.Fatalf("postAndComments.instructions(): does not contain %q", wantAll)
}
if !strings.Contains(pi, wantPost) {
t.Fatalf("postAndComments.instructions(): does not contain %q", wantPost)
}
})
t.Run("DocAndRelated", func(t *testing.T) {
pi := docAndRelated.instructions()
if !strings.Contains(pi, wantAll) {
t.Fatalf("docAndRelated.instructions(): does not contain %q", wantAll)
}
if !strings.Contains(pi, wantRelated) {
t.Fatalf("docAndRelated.instructions(): does not contain %q", wantPost)
}
})
}