internal/genai, cmd/vulnreport: remove support for legacy PaLM API
The Gemini API has fully replaced the PaLM API so there is no need for
us to maintain support.
Change-Id: I1e6581313b481a7ce9042d5ef82e99cb45ee48cf
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/559598
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/cmd/vulnreport/suggest.go b/cmd/vulnreport/suggest.go
index 04b7b02..718c699 100644
--- a/cmd/vulnreport/suggest.go
+++ b/cmd/vulnreport/suggest.go
@@ -17,7 +17,6 @@
var (
interactive = flag.Bool("i", false, "for suggest, interactive mode")
numSuggestions = flag.Int("n", 4, "for suggest, the number of suggestions to attempt to generate (max is 8)")
- palm = flag.Bool("palm", false, "use the legacy PaLM API instead of the Gemini API")
)
func suggestCmd(ctx context.Context, filename string) (err error) {
@@ -28,16 +27,10 @@
return err
}
- var c genai.Client
- if *palm {
- infolog.Print("contacting the PaLM API...")
- c = genai.NewDefaultPaLMClient()
- } else {
- infolog.Print("contacting the Gemini API... (set flag -palm to use legacy PaLM API instead)")
- c, err = genai.NewGeminiClient(ctx)
- if err != nil {
- return err
- }
+ infolog.Print("contacting the Gemini API...")
+ c, err := genai.NewGeminiClient(ctx)
+ if err != nil {
+ return err
}
suggestions, err := suggest(ctx, c, r, *numSuggestions)
diff --git a/internal/genai/gen_examples/main.go b/internal/genai/gen_examples/main.go
index 12ae903..2b5cba8 100644
--- a/internal/genai/gen_examples/main.go
+++ b/internal/genai/gen_examples/main.go
@@ -3,7 +3,8 @@
// license that can be found in the LICENSE file.
// Command gen_examples generates and stores examples
-// that can be used to create prompts / training inputs for the PaLM API.
+// that can be used to create prompts / training inputs for Google's
+// Generative AI APIs.
package main
import (
diff --git a/internal/genai/palmclient.go b/internal/genai/palmclient.go
deleted file mode 100644
index c6a53a0..0000000
--- a/internal/genai/palmclient.go
+++ /dev/null
@@ -1,210 +0,0 @@
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Package genai provides a client and utilities for interacting with
-// Google's generative AI libraries.
-package genai
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "os"
-)
-
-type PaLMClient struct {
- c *http.Client
- url string
- getAPIKey func() (string, error)
-}
-
-// NewDefaultPaLMClient returns a new default client for the PaLM API that reads
-// an API key from the environment variable "PALM_API_KEY".
-func NewDefaultPaLMClient() *PaLMClient {
- const (
- defaultURL = `https://generativelanguage.googleapis.com`
- apiKeyEnv = "PALM_API_KEY"
- )
- return NewClient(http.DefaultClient, defaultURL, func() (string, error) {
- key := os.Getenv(apiKeyEnv)
- if key == "" {
- return "", fmt.Errorf("PaLM API key (env var %s) not set. You can get an API key at https://makersuite.google.com/app/apikey", apiKeyEnv)
- }
- return key, nil
- })
-}
-
-func NewClient(httpClient *http.Client, url string, getAPIKey func() (string, error)) *PaLMClient {
- return &PaLMClient{
- c: httpClient,
- url: url,
- getAPIKey: getAPIKey}
-}
-
-const generateTextEndpoint = "generateText"
-const textBisonModel = "/v1beta3/models/text-bison-001"
-
-// generateText is a wrapper for the PaLM API "generateText" endpoint.
-// See https://developers.generativeai.google/api/rest/generativelanguage/models/generateText.
-func (c *PaLMClient) generateText(prompt string) (*GenerateTextResponse, error) {
- reqBody, err := toRequestBody(prompt)
- if err != nil {
- return nil, err
- }
- key, err := c.getAPIKey()
- if err != nil {
- return nil, err
- }
- resp, err := http.Post(fmt.Sprintf("%s%s:%s?key=%s", c.url, textBisonModel, generateTextEndpoint, key), "application/json", bytes.NewBuffer(reqBody))
- if err != nil {
- return nil, err
- }
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- err := fmt.Errorf("PaLM API returned non-OK status %s", resp.Status)
- if msg, err2 := getErrMsg(resp.Body); err2 == nil {
- return nil, fmt.Errorf("%w: %s", err, msg)
- }
- return nil, err
- }
- return parseGenerateTextResponse(resp.Body)
-}
-
-func (c *PaLMClient) GenerateText(_ context.Context, prompt string) ([]string, error) {
- response, err := c.generateText(prompt)
- if err != nil {
- return nil, err
- }
- candidates := make([]string, len(response.Candidates))
- for i, c := range response.Candidates {
- candidates[i] = c.Output
- }
- return candidates, nil
-}
-
-func getErrMsg(r io.Reader) (string, error) {
- b, err := io.ReadAll(r)
- if err != nil {
- return "", err
- }
- var errResponse struct {
- Err struct {
- Message string `json:"message"`
- } `json:"error"`
- }
- if err := json.Unmarshal(b, &errResponse); err != nil {
- return "", err
- }
- return errResponse.Err.Message, nil
-}
-
-// See https://developers.generativeai.google/api/rest/generativelanguage/GenerateTextResponse
-type GenerateTextResponse struct {
- Candidates []TextCompletion `json:"candidates"`
- // Fields "filters" and "safetyFeedback" omitted.
-}
-
-// See https://developers.generativeai.google/api/rest/generativelanguage/GenerateTextResponse#TextCompletion
-type TextCompletion struct {
- Output string `json:"output"`
- // Field "safetyRatings" omitted.
- Citations Citation `json:"citationMetadata,omitempty"`
-}
-
-// See https://developers.generativeai.google/api/rest/generativelanguage/CitationMetadata
-type Citation struct {
- Sources []Source `json:"citationSources,omitempty"`
-}
-
-// See https://developers.generativeai.google/api/rest/generativelanguage/CitationMetadata#CitationSource
-type Source struct {
- StartIndex int `json:"startIndex,omitempty"`
- EndIndex int `json:"endIndex,omitempty"`
- URI string `json:"uri,omitempty"`
- License string `json:"license,omitempty"`
-}
-
-func parseGenerateTextResponse(r io.Reader) (*GenerateTextResponse, error) {
- b, err := io.ReadAll(r)
- if err != nil {
- return nil, err
- }
- var response GenerateTextResponse
- if err := json.Unmarshal(b, &response); err != nil {
- return nil, err
- }
- return &response, nil
-}
-
-// See https://developers.generativeai.google/api/rest/generativelanguage/models/generateText#request-body
-type GenerateTextRequest struct {
- Prompt TextPrompt `json:"prompt"`
- Temperature float32 `json:"temperature,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- TopK int `json:"topK,omitempty"`
- TopP float32 `json:"topP,omitempty"`
- MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
- StopSequences []string `json:"stopSequences,omitempty"`
- SafetySettings []SafetySetting `json:"safetySettings,omitempty"`
-}
-
-// See https://developers.generativeai.google/api/rest/generativelanguage/TextPrompt
-type TextPrompt struct {
- Text string `json:"text"`
-}
-
-// See https://developers.generativeai.google/api/rest/generativelanguage/SafetySetting
-type SafetySetting struct {
- Category string `json:"category,omitempty"`
- Threshold int `json:"threshold,omitempty"`
-}
-
-func toRequestBody(promptText string) ([]byte, error) {
- req := GenerateTextRequest{
- Prompt: TextPrompt{
- Text: promptText,
- },
- CandidateCount: 8, // max
- // Use a low temperature (max is 1.0) to allow less creativity.
- Temperature: 0.35,
- SafetySettings: blockNone(),
- }
- b, err := json.Marshal(req)
- if err != nil {
- return nil, err
- }
- return b, nil
-}
-
-func blockNone() []SafetySetting {
- return []SafetySetting{
- {
- Category: "HARM_CATEGORY_DEROGATORY",
- Threshold: 4,
- },
- {
- Category: "HARM_CATEGORY_TOXICITY",
- Threshold: 4,
- },
- {
- Category: "HARM_CATEGORY_VIOLENCE",
- Threshold: 4,
- },
- {
- Category: "HARM_CATEGORY_SEXUAL",
- Threshold: 4,
- },
- {
- Category: "HARM_CATEGORY_MEDICAL",
- Threshold: 4,
- },
- {
- Category: "HARM_CATEGORY_DANGEROUS",
- Threshold: 4,
- },
- }
-}
diff --git a/internal/genai/palmclient_test.go b/internal/genai/palmclient_test.go
deleted file mode 100644
index 2219912..0000000
--- a/internal/genai/palmclient_test.go
+++ /dev/null
@@ -1,147 +0,0 @@
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package genai
-
-import (
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "reflect"
- "strings"
- "testing"
-)
-
-func TestGenerateText(t *testing.T) {
- tests := []struct {
- name string
- prompt string
- want *GenerateTextResponse
- }{
- {
- name: "no_response",
- prompt: "say hello",
- want: &GenerateTextResponse{},
- },
- {
- name: "response",
- prompt: "say hello",
- want: &GenerateTextResponse{
- Candidates: []TextCompletion{
- {
- Output: "hi!",
- },
- {
- Output: "hello there",
- Citations: Citation{
- Sources: []Source{
- {
- URI: "https://www.example.com",
- },
- },
- },
- },
- },
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- c, cleanup, err := testClient(generateTextEndpoint, tt.prompt, tt.want)
- if err != nil {
- t.Fatal(err)
- }
- t.Cleanup(cleanup)
- got, err := c.generateText(tt.prompt)
- if err != nil {
- t.Fatalf("GenerateText() error = %v", err)
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("GenerateText() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestGenerateTextError(t *testing.T) {
- tests := []struct {
- name string
- prompt string
- wantErr string
- }{
- {
- name: "error",
- prompt: "say hello",
- wantErr: "an error message",
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- c, cleanup, err := testClientErr(generateTextEndpoint, tt.prompt, tt.wantErr)
- if err != nil {
- t.Fatal(err)
- }
- t.Cleanup(cleanup)
- _, err = c.generateText(tt.prompt)
- if !strings.Contains(err.Error(), tt.wantErr) {
- t.Fatalf("GenerateText() error = %v; want error containing %q", err, tt.wantErr)
- }
- })
- }
-}
-
-func testClient(endpoint, prompt string, response *GenerateTextResponse) (c *PaLMClient, cleanup func(), err error) {
- rBytes, err := json.Marshal(response)
- if err != nil {
- return nil, nil, err
- }
- handler := func(w http.ResponseWriter, r *http.Request) {
- writeErr := func(err error) {
- w.WriteHeader(http.StatusBadRequest)
- errJSON := fmt.Sprintf(`{"error":{"message":"%s"}}`, err)
- _, _ = w.Write([]byte(errJSON))
- }
-
- body, err := io.ReadAll(r.Body)
- if err != nil {
- writeErr(err)
- return
- }
-
- var req GenerateTextRequest
- if err := json.Unmarshal(body, &req); err != nil {
- writeErr(err)
- return
- }
-
- if r.Method == http.MethodPost &&
- r.URL.Path == textBisonModel+":"+endpoint &&
- req.Prompt.Text == prompt {
- _, _ = w.Write(rBytes)
- return
- }
-
- writeErr(fmt.Errorf("Unrecognized endpoint (%s) or prompt (%s)", endpoint, prompt))
- }
- s := httptest.NewServer(http.HandlerFunc(handler))
- return NewClient(s.Client(), s.URL, getTestAPIKey), func() { s.Close() }, nil
-}
-
-func testClientErr(endpoint, prompt string, errMsg string) (c *PaLMClient, cleanup func(), err error) {
- handler := func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusBadRequest)
- errJSON := fmt.Sprintf(`{"error":{"message":"%s"}}`, errMsg)
- _, _ = w.Write([]byte(errJSON))
- }
- s := httptest.NewServer(http.HandlerFunc(handler))
- return NewClient(s.Client(), s.URL, getTestAPIKey), func() { s.Close() }, nil
-}
-
-const testAPIKey = "TEST-API-KEY"
-
-func getTestAPIKey() (string, error) {
- return testAPIKey, nil
-}