internal/genai: add support for Gemini
Adds a wrapper client for the Gemini API.
Change-Id: I2a48ec83fd003f258eb3f835eccc94a090f2baad
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/552038
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/genai/gemini.go b/internal/genai/gemini.go
new file mode 100644
index 0000000..2a0bdf7
--- /dev/null
+++ b/internal/genai/gemini.go
@@ -0,0 +1,68 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package genai
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+
+ gemini "github.com/google/generative-ai-go/genai"
+ "google.golang.org/api/option"
+)
+
+type GeminiClient struct {
+ model
+ closer
+}
+
+type model interface {
+ GenerateContent(ctx context.Context, parts ...gemini.Part) (*gemini.GenerateContentResponse, error)
+}
+
+type closer interface {
+ Close() error
+}
+
+const (
+ geminiAPIKeyEnv = "GEMINI_API_KEY"
+ geminiModel = "gemini-pro"
+)
+
+func NewGeminiClient(ctx context.Context) (*GeminiClient, error) {
+ key := os.Getenv(geminiAPIKeyEnv)
+ if key == "" {
+ return nil, fmt.Errorf("%s must be set", geminiAPIKeyEnv)
+ }
+ client, err := gemini.NewClient(ctx, option.WithAPIKey(key))
+ if err != nil {
+ return nil, err
+ }
+ return &GeminiClient{
+ model: client.GenerativeModel(geminiModel),
+ closer: client,
+ }, nil
+}
+
+func (c *GeminiClient) GenerateText(ctx context.Context, prompt string) ([]string, error) {
+ response, err := c.model.GenerateContent(ctx, gemini.Text(prompt))
+ if err != nil {
+ return nil, err
+ }
+ b, err := json.Marshal(response)
+ if err == nil {
+ fmt.Println(string(b))
+ }
+ var candidates []string
+ for _, c := range response.Candidates {
+ if c.Content != nil {
+ for _, p := range c.Content.Parts {
+ candidates = append(candidates, fmt.Sprintf("%s", p))
+ }
+ }
+ }
+ return candidates, nil
+}
diff --git a/internal/genai/gemini_test.go b/internal/genai/gemini_test.go
new file mode 100644
index 0000000..5f3f40d
--- /dev/null
+++ b/internal/genai/gemini_test.go
@@ -0,0 +1,52 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package genai
+
+import (
+ "context"
+ "testing"
+
+ gemini "github.com/google/generative-ai-go/genai"
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestGemini(t *testing.T) {
+ c := testGeminiClient()
+
+ got, err := c.GenerateText(context.Background(), "say hello")
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := []string{"Hello there! How can I assist you today?"}
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("GenerateText mismatch (-want, +got):\n%s", diff)
+ }
+}
+
+func testGeminiClient() *GeminiClient {
+ return &GeminiClient{
+ model: testModel{},
+ closer: testCloser{},
+ }
+}
+
+type testModel struct{}
+
+func (_ testModel) GenerateContent(ctx context.Context, parts ...gemini.Part) (*gemini.GenerateContentResponse, error) {
+ // TODO(tatianabradley): Improve testing by replaying a real API response.
+ return &gemini.GenerateContentResponse{
+ Candidates: []*gemini.Candidate{{
+ Content: &gemini.Content{
+ Parts: []gemini.Part{
+ gemini.Text("Hello there! How can I assist you today?"),
+ },
+ },
+ }},
+ }, nil
+}
+
+type testCloser struct{}
+
+func (_ testCloser) Close() error { return nil }