blob: a944a5940ccff3502db6aa19ce373f0e5a745224 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package genai
import (
"context"
"fmt"
"strings"
"testing"
_ "embed"
"github.com/google/go-cmp/cmp"
)
func TestSuggest(t *testing.T) {
tests := []struct {
name string
response []string
want []*Suggestion
}{
{
name: "basic",
response: []string{
`{"Summary":"summary","Description":"new description"}`,
`{"Summary":"another summary","Description":"another description"}`,
},
want: []*Suggestion{{
Summary: "summary",
Description: "new description",
},
{
Summary: "another summary",
Description: "another description",
},
},
},
{
name: "ignore invalid",
response: []string{`{"Summary":"summary","Description":"new description"}`,
`invalid JSON ignored`,
},
want: []*Suggestion{
{
Summary: "summary",
Description: "new description",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// The input can be the same for each test because
// the response is hard-coded.
input := placeholderInput
// Make sure Suggest calls defaultPrompt (or its equivalent).
// A separate test checks that defaultPrompt is correct.
c := &testCli{
prompt: mustGetDefaultPrompt(input),
response: tt.response,
}
got, err := Suggest(context.Background(), c, input)
if err != nil {
t.Fatalf("Suggest() error = %v", err)
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("Suggest() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestGenSuggestionsError(t *testing.T) {
tests := []struct {
name string
response []string
wantErr string
}{
{
name: "no response",
response: nil,
wantErr: "no candidates",
},
{
name: "unmarshal error",
response: []string{`Summary:"invalid",`, `more invalid JSON`},
wantErr: `unmarshal`,
},
{
name: "missing data",
response: []string{
// Valid JSON, but description is missing.
`{"Summary":"summary"}`,
},
wantErr: `empty summary or description`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// The input can be the same for each test because
// the response is hard-coded.
input := placeholderInput
// Make sure Suggest calls defaultPrompt (or its equivalent).
// A separate test checks that defaultPrompt is correct.
c := &testCli{
prompt: mustGetDefaultPrompt(input),
response: tt.response,
}
_, gotErr := Suggest(context.Background(), c, input)
if gotErr == nil || !strings.Contains(gotErr.Error(), tt.wantErr) {
t.Fatalf("Suggest() error = %v, want err containing %s", gotErr, tt.wantErr)
}
})
}
}
func mustGetDefaultPrompt(in *Input) string {
prompt, err := defaultPrompt(in)
if err != nil {
panic(err)
}
return prompt
}
type testCli struct {
prompt string
response []string
}
func (c *testCli) GenerateText(_ context.Context, prompt string) ([]string, error) {
if diff := cmp.Diff(c.prompt, prompt); diff != "" {
return nil, fmt.Errorf("prompt mismatch (-want, +got):\n%s", diff)
}
return c.response, nil
}