internal/llm: add support for multiple prompt parts

Modify the TextGenerator interface and its implementations
to allow multiple prompt "parts". This matches the Gemini API
and allows us to more easily provide distinct pieces of information
without needing to decide how to concatenate them.

Change-Id: I2a14e9a0dcd22650aecc40c8d58d4cd74b188f4e
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/621857
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/gcp/gemini/gemini.go b/internal/gcp/gemini/gemini.go
index b69b296..3b609a6 100644
--- a/internal/gcp/gemini/gemini.go
+++ b/internal/gcp/gemini/gemini.go
@@ -11,6 +11,7 @@
 	"bytes"
 	"context"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"log/slog"
 	"maps"
@@ -119,6 +120,8 @@
 
 const maxBatch = 100 // empirical limit
 
+var _ llm.Embedder = (*Client)(nil)
+
 // EmbedDocs returns the vector embeddings for the docs,
 // implementing [llm.Embedder].
 func (c *Client) EmbedDocs(ctx context.Context, docs []llm.EmbedDoc) ([]llm.Vector, error) {
@@ -140,21 +143,33 @@
 	return vecs, nil
 }
 
-// GenerateText returns model's text responses for the prompt,
+var _ llm.TextGenerator = (*Client)(nil)
+
+// GenerateText returns model's text response for the prompt parts,
 // implementing [llm.TextGenerator].
-func (c *Client) GenerateText(ctx context.Context, prompt string) ([]string, error) {
+func (c *Client) GenerateText(ctx context.Context, promptParts ...string) (string, error) {
 	model := c.genai.GenerativeModel(c.generativeModel)
-	resp, err := model.GenerateContent(ctx, genai.Text(prompt))
-	if err != nil {
-		return nil, err
+	model.SetCandidateCount(1)
+
+	var parts = make([]genai.Part, len(promptParts))
+	for i, p := range promptParts {
+		parts[i] = genai.Text(p)
 	}
-	var candidates []string
+
+	resp, err := model.GenerateContent(ctx, parts...)
+	if err != nil {
+		return "", fmt.Errorf("gemini.GenerateText: %w", err)
+	}
+
 	for _, c := range resp.Candidates {
 		if c.Content != nil {
-			for _, p := range c.Content.Parts {
-				candidates = append(candidates, fmt.Sprintf("%s", p))
+			parts := make([]string, len(c.Content.Parts))
+			for i, p := range c.Content.Parts {
+				parts[i] = fmt.Sprintf("%s", p)
 			}
+			return strings.Join(parts, "\n"), nil
 		}
 	}
-	return candidates, nil
+
+	return "", errors.New("gemini.GenerateText: no content")
 }
diff --git a/internal/gcp/gemini/gemini_test.go b/internal/gcp/gemini/gemini_test.go
index 05211f0..18a367f 100644
--- a/internal/gcp/gemini/gemini_test.go
+++ b/internal/gcp/gemini/gemini_test.go
@@ -98,7 +98,7 @@
 	ctx := context.Background()
 	check := testutil.Checker(t)
 	c := newTestClient(t, "testdata/generatetext.httprr")
-	responses, err := c.GenerateText(ctx, "What is the Go programming language?")
+	responses, err := c.GenerateText(ctx, "CanonicalHeaderKey returns the canonical format of the header key s. The canonicalization converts the first letter and any letter following a hyphen to upper case; the rest are converted to lowercase. For example, the canonical key for 'accept-encoding' is 'Accept-Encoding'. If s contains a space or invalid header field bytes, it is returned without modifications.", "When should I use CanonicalHeaderKey?")
 	check(err)
 	if len(responses) == 0 {
 		t.Fatal("no responses")
diff --git a/internal/gcp/gemini/testdata/generatetext.httprr b/internal/gcp/gemini/testdata/generatetext.httprr
index b0353a9..18e4232 100644
--- a/internal/gcp/gemini/testdata/generatetext.httprr
+++ b/internal/gcp/gemini/testdata/generatetext.httprr
@@ -1,36 +1,36 @@
 httprr trace v1
-460 4712
+858 3622
 POST https://generativelanguage.googleapis.com/v1beta/models/gemini-1.0-pro:generateContent?%24alt=json%3Benum-encoding%3Dint HTTP/1.1

 Host: generativelanguage.googleapis.com

 User-Agent: Go-http-client/1.1

-Content-Length: 142

+Content-Length: 540

 Content-Type: application/json

 x-goog-request-params: model=models%2Fgemini-1.0-pro

 

-{"model":"models/gemini-1.0-pro","contents":[{"parts":[{"text":"What is the Go programming language?"}],"role":"user"}],"generationConfig":{}}HTTP/2.0 200 OK

+{"model":"models/gemini-1.0-pro","contents":[{"parts":[{"text":"CanonicalHeaderKey returns the canonical format of the header key s. The canonicalization converts the first letter and any letter following a hyphen to upper case; the rest are converted to lowercase. For example, the canonical key for 'accept-encoding' is 'Accept-Encoding'. If s contains a space or invalid header field bytes, it is returned without modifications."},{"text":"When should I use CanonicalHeaderKey?"}],"role":"user"}],"generationConfig":{"candidateCount":1}}HTTP/2.0 200 OK

 Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000

 Cache-Control: private

 Content-Type: application/json; charset=UTF-8

-Date: Tue, 22 Oct 2024 17:01:52 GMT

+Date: Thu, 24 Oct 2024 15:16:59 GMT

 Server: scaffolding on HTTPServer2

-Server-Timing: gfet4t7; dur=5262

+Server-Timing: gfet4t7; dur=2345

 Vary: Origin

 Vary: X-Origin

 Vary: Referer

 X-Content-Type-Options: nosniff

 X-Frame-Options: SAMEORIGIN

-X-Google-Backends: unix:/tmp/esfbackend.1729546380.45727.784683,/bns/ma/borg/ma/bns/genai-api/prod.genai-api/9,/bns/lclgaa/borg/lclgaa/bns/blue-layer1-gfe-prod-edge/prod.blue-layer1-gfe.lga34s37/54

+X-Google-Backends: unix:/tmp/esfbackend.1729722122.514958.2643309,/bns/ma/borg/ma/bns/genai-api/prod.genai-api/7,/bns/lclgaa/borg/lclgaa/bns/blue-layer1-gfe-prod-edge/prod.blue-layer1-gfe.lga34s38/40

 X-Google-Dos-Service-Trace: main:genai-api-api-prod,main:GLOBAL_all_non_cloud

 X-Google-Esf-Cloud-Client-Params: backend_service_name: "generativelanguage.googleapis.com" backend_fully_qualified_method: "google.ai.generativelanguage.v1beta.GenerativeService.GenerateContent"

-X-Google-Gfe-Handshake-Trace: GFE: /bns/lclgaa/borg/lclgaa/bns/blue-layer1-gfe-prod-edge/prod.blue-layer1-gfe.lga34s37/54,Mentat oracle: [2002:a05:6664:18d:b0:31:791e:bc06]:9801

-X-Google-Gfe-Request-Trace: aclgaee7:443,/bns/ma/borg/ma/bns/genai-api/prod.genai-api/9,aclgaee7:443

+X-Google-Gfe-Handshake-Trace: GFE: /bns/lclgaa/borg/lclgaa/bns/blue-layer1-gfe-prod-edge/prod.blue-layer1-gfe.lga34s38/40,Mentat oracle: [2002:a05:6692:824:b0:89:f617:8002]:9801

+X-Google-Gfe-Request-Trace: aclgaff4:443,/bns/ma/borg/ma/bns/genai-api/prod.genai-api/7,aclgaff4:443

 X-Google-Gfe-Response-Body-Transformations: chunked

 X-Google-Gfe-Response-Code-Details-Trace: response_code_set_by_backend

 X-Google-Gfe-Service-Trace: genai-api-api-prod/gfespec_googleapis-generativelanguage_generativelanguage-url-map-global_generativelanguage-genai-api-api-prod

-X-Google-Gfe-Version: 2.898.1

-X-Google-Netmon-Label: /bns/ma/borg/ma/bns/genai-api/prod.genai-api/9

+X-Google-Gfe-Version: 2.900.2

+X-Google-Netmon-Label: /bns/ma/borg/ma/bns/genai-api/prod.genai-api/7

 X-Google-Security-Signals: FRAMEWORK=ONE_PLATFORM,ENV=borg,ENV_DEBUG=borg_user:genai-api;borg_job:prod.genai-api

-X-Google-Security-Signals: FRAMEWORK=HTTPSERVER2,BUILD=GOOGLE3,BUILD_DEBUG=cl:687926139,ENV=borg,ENV_DEBUG=borg_user:genai-api;borg_job:prod.genai-api

+X-Google-Security-Signals: FRAMEWORK=HTTPSERVER2,BUILD=GOOGLE3,BUILD_DEBUG=cl:688683199,ENV=borg,ENV_DEBUG=borg_user:genai-api;borg_job:prod.genai-api

 X-Google-Service: genai-api-api-prod/gfespec_googleapis-generativelanguage_generativelanguage-url-map-global_generativelanguage-genai-api-api-prod

 X-Google-Session-Info: GgQYECgLIAE6IxIhZ2VuZXJhdGl2ZWxhbmd1YWdlLmdvb2dsZWFwaXMuY29t

 X-Google-Shellfish-Status: CA0gBEBG

@@ -42,7 +42,7 @@
       "content": {
         "parts": [
           {
-            "text": "Go is a high-level programming language developed at Google. It is designed to be fast, reliable, and easy to use. Go is a statically typed language with a simple syntax. It is similar to C in many ways, but it has a number of features that make it more modern and easier to use.\n\nSome of the key features of Go include:\n\n* **Concurrency:** Go has built-in support for concurrency, which makes it easy to write programs that can take advantage of multiple processors.\n* **Garbage collection:** Go has an automatic garbage collector, which frees the programmer from having to worry about memory management.\n* **Simplicity:** Go is a very simple language to learn and use. The syntax is straightforward and the standard library is well-documented.\n\nGo is used to develop a wide variety of applications, including web applications, mobile applications, and distributed systems. It is also popular for writing cloud-based applications.\n\nHere are some of the benefits of using Go:\n\n* **Fast:** Go is a very fast language, and it can be used to develop high-performance applications.\n* **Reliable:** Go is a very reliable language, and it is used to develop mission-critical applications.\n* **Easy to use:** Go is a very easy language to learn and use, and it is suitable for developers of all levels.\n* **Versatile:** Go can be used to develop a wide variety of applications, including web applications, mobile applications, and distributed systems.\n\nGo is a powerful and versatile programming language that is well-suited for developing a wide variety of applications. If you are looking for a language that is fast, reliable, and easy to use, then Go is a good option for you."
+            "text": "Use CanonicalHeaderKey when you need to canonicalize the header key s according to HTTP/2 requirements. Canonicalizing a header key means converting the first letter and any letter following a hyphen to upper case; the rest are converted to lowercase. For example, the canonical key for 'accept-encoding' is 'Accept-Encoding'. This function is useful in the context of HTTP/2, where header names are case-insensitive and must be canonicalized before being sent over the wire. By using CanonicalHeaderKey, you can ensure that header keys are properly formatted and can be easily compared and manipulated."
           }
         ],
         "role": "model"
@@ -70,9 +70,9 @@
       "citationMetadata": {
         "citationSources": [
           {
-            "startIndex": 344,
-            "endIndex": 465,
-            "uri": "https://miguelnorberto.com/coding/go/golang/an-introduction-to-go.html",
+            "startIndex": 149,
+            "endIndex": 319,
+            "uri": "https://tachingchen.com/blog/pitfall-of-golang-header-operation/",
             "license": ""
           }
         ]
@@ -80,9 +80,9 @@
     }
   ],
   "usageMetadata": {
-    "promptTokenCount": 8,
-    "candidatesTokenCount": 354,
-    "totalTokenCount": 362
+    "promptTokenCount": 81,
+    "candidatesTokenCount": 120,
+    "totalTokenCount": 201
   },
   "modelVersion": "gemini-1.0-pro"
 }
diff --git a/internal/llm/llm.go b/internal/llm/llm.go
index 632e479..cb46f9d 100644
--- a/internal/llm/llm.go
+++ b/internal/llm/llm.go
@@ -74,12 +74,13 @@
 	}
 }
 
-// A TextGenerator generates text responses given a text prompt.
+// A TextGenerator generates a text response given one or more text
+// prompts.
 //
 // See [EchoTextGenerator] for a generator, useful for testing, that
-// always responds with a deterministic message derived from the prompt.
+// always responds with a deterministic message derived from the prompts.
 //
 // See [golang.org/x/oscar/internal/gcp/gemini] for a real implementation.
 type TextGenerator interface {
-	GenerateText(ctx context.Context, prompt string) ([]string, error)
+	GenerateText(ctx context.Context, parts ...string) (string, error)
 }
diff --git a/internal/llm/testing.go b/internal/llm/testing.go
index fd3351f..86c7866 100644
--- a/internal/llm/testing.go
+++ b/internal/llm/testing.go
@@ -7,6 +7,7 @@
 import (
 	"context"
 	"math"
+	"strings"
 )
 
 const quoteLen = 123
@@ -90,8 +91,14 @@
 
 type echo struct{}
 
-// GenerateText echoes the prompt (for testing).
+// GenerateText echoes the prompts (for testing).
 // Implements [TextGenerator].
-func (echo) GenerateText(ctx context.Context, prompt string) ([]string, error) {
-	return []string{prompt}, nil
+func (echo) GenerateText(ctx context.Context, promptParts ...string) (string, error) {
+	return EchoResponse(promptParts...), nil
+}
+
+// EchoResponse returns the concatenation of the prompt parts.
+// For testing.
+func EchoResponse(promptParts ...string) string {
+	return strings.Join(promptParts, "")
 }
diff --git a/internal/llm/testing_test.go b/internal/llm/testing_test.go
index 5389d51..ec01541 100644
--- a/internal/llm/testing_test.go
+++ b/internal/llm/testing_test.go
@@ -43,14 +43,11 @@
 func TestEcho(t *testing.T) {
 	ctx := context.Background()
 	gen := EchoTextGenerator()
-	resp, err := gen.GenerateText(ctx, "abc")
+	resp, err := gen.GenerateText(ctx, "abc", "123")
 	if err != nil {
 		t.Fatal(err)
 	}
-	if len(resp) != 1 {
-		t.Fatalf("len(resp) = %v, want 1", len(resp))
-	}
-	if resp[0] != "abc" {
-		t.Errorf("resp[0] = %q, want %q", resp[0], "abc")
+	if resp != "abc123" {
+		t.Errorf("resp  = %q, want %q", resp[0], "abc123")
 	}
 }