blob: 60b4e3d580dbf7ba105023715a730eb1d199eee5 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package palmapi
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestSuggest(t *testing.T) {
tests := []struct {
name string
response *GenerateTextResponse
want []*Suggestion
}{
{
name: "basic",
response: &GenerateTextResponse{
Candidates: []TextCompletion{
{
Output: `{"Summary":"summary","Description":"new description"}`,
},
{
Output: `{"Summary":"another summary","Description":"another description"}`,
},
},
},
want: []*Suggestion{{
Summary: "summary",
Description: "new description",
},
{
Summary: "another summary",
Description: "another description",
},
},
},
{
name: "ignore invalid",
response: &GenerateTextResponse{
Candidates: []TextCompletion{
{
Output: `{"Summary":"summary","Description":"new description"}`,
},
{
Output: `invalid JSON ignored`,
},
},
},
want: []*Suggestion{
{
Summary: "summary",
Description: "new description",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prompt := "a prompt" // prompt doesn't matter since the response is hard-coded
c, cleanup, err := testClient(generateTextEndpoint, prompt, tt.response)
if err != nil {
t.Fatalf("testClient() error = %v", err)
}
t.Cleanup(cleanup)
got, err := c.suggest(prompt)
if err != nil {
t.Fatalf("suggest() error = %v", err)
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("suggest() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestSuggestError(t *testing.T) {
tests := []struct {
name string
response *GenerateTextResponse
wantErr string
}{
{
name: "no response",
response: nil,
wantErr: "no candidates",
},
{
name: "unmarshal error",
response: &GenerateTextResponse{
Candidates: []TextCompletion{
{
Output: `Summary:"invalid",`,
},
{
Output: `more invalid JSON`,
},
},
},
wantErr: `unmarshal`,
},
{
name: "missing data",
response: &GenerateTextResponse{
Candidates: []TextCompletion{
{
// Valid JSON, but description is missing.
Output: `{"Summary":"summary"}`,
},
},
},
wantErr: `empty summary or description`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prompt := "a prompt" // prompt doesn't matter since the response is hard-coded
c, cleanup, gotErr := testClient(generateTextEndpoint, prompt, tt.response)
if gotErr != nil {
t.Fatalf("testClient() error = %v", gotErr)
}
t.Cleanup(cleanup)
_, gotErr = c.suggest(prompt)
if gotErr == nil || !strings.Contains(gotErr.Error(), tt.wantErr) {
t.Fatalf("suggest() error = %v, want err containing %s", gotErr, tt.wantErr)
}
})
}
}
func TestNewPrompt(t *testing.T) {
type args struct {
in *Input
promptContext string
examples Examples
maxExamples int
}
tests := []struct {
name string
args args
want string
}{
{
name: "basic",
args: args{
in: &Input{
Module: "input/module",
Description: "original description of input",
},
promptContext: "Context for the prompt.",
examples: Examples{
&Example{
Input: Input{
Module: "example/module",
Description: "original description of example",
},
Suggestion: Suggestion{
Summary: "summary",
Description: "new description",
},
},
},
maxExamples: 2, // no effect since there is only one example
},
want: `Context for the prompt.
input: {"Module":"example/module","Description":"original description of example"}
output: {"Summary":"summary","Description":"new description"}
input: {"Module":"input/module","Description":"original description of input"}
output:`,
},
{
name: "trim examples",
args: args{
in: &Input{
Module: "input/module",
Description: "original description of input",
},
promptContext: "Context",
examples: Examples{
&Example{
Input: Input{
Module: "example/module",
Description: "original description of example",
},
Suggestion: Suggestion{
Summary: "summary",
Description: "new description",
},
},
// This example will be ignored because maxExamples = 1.
&Example{
Input: Input{
Module: "another/example/module",
Description: "original description of example 2",
},
Suggestion: Suggestion{
Summary: "summary 2",
Description: "new description 2",
},
},
},
maxExamples: 1,
},
want: `Context
input: {"Module":"example/module","Description":"original description of example"}
output: {"Summary":"summary","Description":"new description"}
input: {"Module":"input/module","Description":"original description of input"}
output:`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newPrompt(tt.args.in, tt.args.promptContext, tt.args.examples, tt.args.maxExamples)
if err != nil {
t.Fatalf("newPrompt() error = %v", err)
}
if got != tt.want {
t.Errorf("newPrompt() = %v, want %v", got, tt.want)
}
})
}
}