blob: 7d07dc5a43165f6551b29c24b0cf632daebd6977 [file] [log] [blame] [edit]
package llmapp
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/testutil"
)
func TestAnalyzeRelated(t *testing.T) {
ctx := context.Background()
lg := testutil.Slogger(t)
db := storage.MemDB()
t.Run("basic", func(t *testing.T) {
c := New(lg, RelatedTestGenerator(t, 1), db)
got, err := c.AnalyzeRelated(ctx, doc1, []*Doc{doc2})
if err != nil {
t.Fatal(err)
}
promptParts := []llm.Part{llm.Text("original"), raw1, llm.Text("related"), raw2, llm.Text(docAndRelated.instructions())}
rawOut, out := relatedTestOutput(t, 1)
want := &RelatedAnalysis{
Result: Result{
Response: rawOut,
Prompt: promptParts,
Schema: docAndRelated.schema(),
},
Output: out,
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("AnalyzeRelated() mismatch (-want +got):\n%s", diff)
}
})
}