| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package gemini |
| |
| import ( |
| "bytes" |
| "context" |
| "fmt" |
| "net/http" |
| "testing" |
| |
| "golang.org/x/oscar/internal/httprr" |
| "golang.org/x/oscar/internal/llm" |
| "golang.org/x/oscar/internal/secret" |
| "golang.org/x/oscar/internal/testutil" |
| ) |
| |
| var docs = []llm.EmbedDoc{ |
| {Text: "for loops"}, |
| {Text: "for all time, always"}, |
| {Text: "break statements"}, |
| {Text: "breakdancing"}, |
| {Text: "forever could never be long enough for me"}, |
| {Text: "the macarena"}, |
| } |
| |
| var matches = map[string]string{ |
| "for loops": "break statements", |
| "for all time, always": "forever could never be long enough for me", |
| "breakdancing": "the macarena", |
| } |
| |
| func init() { |
| for k, v := range matches { |
| matches[v] = k |
| } |
| } |
| |
| var ctx = context.Background() |
| |
| func newTestClient(t *testing.T, rrfile string) *Client { |
| check := testutil.Checker(t) |
| lg := testutil.Slogger(t) |
| |
| rr, err := httprr.Open(rrfile, http.DefaultTransport) |
| check(err) |
| rr.ScrubReq(Scrub) |
| sdb := secret.ReadOnlyMap{"ai.google.dev": "nokey"} |
| if rr.Recording() { |
| sdb = secret.Netrc() |
| } |
| |
| c, err := NewClient(ctx, lg, sdb, rr.Client(), DefaultEmbeddingModel, DefaultGenerativeModel) |
| check(err) |
| |
| return c |
| } |
| |
| func TestEmbedBatch(t *testing.T) { |
| ctx := context.Background() |
| check := testutil.Checker(t) |
| c := newTestClient(t, "testdata/embedbatch.httprr") |
| vecs, err := c.EmbedDocs(ctx, docs) |
| check(err) |
| if len(vecs) != len(docs) { |
| t.Fatalf("len(vecs) = %d, but len(docs) = %d", len(vecs), len(docs)) |
| } |
| |
| var buf bytes.Buffer |
| for i := range docs { |
| for j := range docs { |
| fmt.Fprintf(&buf, " %.4f", vecs[i].Dot(vecs[j])) |
| } |
| fmt.Fprintf(&buf, "\n") |
| } |
| |
| for i, d := range docs { |
| best := "" |
| bestDot := 0.0 |
| for j := range docs { |
| if dot := vecs[i].Dot(vecs[j]); i != j && dot > bestDot { |
| best, bestDot = docs[j].Text, dot |
| } |
| } |
| if best != matches[d.Text] { |
| if buf.Len() > 0 { |
| t.Errorf("dot matrix:\n%s", buf.String()) |
| buf.Reset() |
| } |
| t.Errorf("%q: best=%q, want %q", d.Text, best, matches[d.Text]) |
| } |
| } |
| } |
| |
| func TestGenerateContentText(t *testing.T) { |
| ctx := context.Background() |
| check := testutil.Checker(t) |
| c := newTestClient(t, "testdata/generatetext.httprr") |
| responses, err := c.GenerateContent(ctx, nil, []llm.Part{llm.Text("CanonicalHeaderKey returns the canonical format of the header key s. The canonicalization converts the first letter and any letter following a hyphen to upper case; the rest are converted to lowercase. For example, the canonical key for 'accept-encoding' is 'Accept-Encoding'. If s contains a space or invalid header field bytes, it is returned without modifications."), llm.Text("When should I use CanonicalHeaderKey?")}) |
| check(err) |
| if len(responses) == 0 { |
| t.Fatal("no responses") |
| } |
| } |
| |
| func TestGenerateContentJSON(t *testing.T) { |
| ctx := context.Background() |
| check := testutil.Checker(t) |
| c := newTestClient(t, "testdata/generatejson.httprr") |
| responses, err := c.GenerateContent(ctx, |
| &llm.Schema{ |
| Type: llm.TypeObject, |
| Properties: map[string]*llm.Schema{ |
| "answer": { |
| Type: llm.TypeString, |
| }, |
| "confidence": { |
| Type: llm.TypeInteger, |
| }, |
| }, |
| }, |
| []llm.Part{ |
| llm.Text("(confidence is between 0 and 100)"), |
| llm.Text("What is the tallest mountain in the world?"), |
| }) |
| check(err) |
| if len(responses) == 0 { |
| t.Fatal("no responses") |
| } |
| } |
| |
| func TestBigBatch(t *testing.T) { |
| ctx := context.Background() |
| check := testutil.Checker(t) |
| c := newTestClient(t, "testdata/bigbatch.httprr") |
| var docs []llm.EmbedDoc |
| |
| for i := range 251 { |
| docs = append(docs, llm.EmbedDoc{Text: fmt.Sprintf("word%d", i)}) |
| } |
| docs = docs[:251] |
| vecs, err := c.EmbedDocs(ctx, docs) |
| check(err) |
| if len(vecs) != len(docs) { |
| t.Fatalf("len(vecs) = %d, but len(docs) = %d", len(vecs), len(docs)) |
| } |
| } |