blob: 84dc71097a9ff2b5b514b20be59de6a759e6f778 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package llmapp
import (
"context"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/testutil"
)
func TestWithChecker(t *testing.T) {
lg := testutil.Slogger(t)
g := llm.EchoContentGenerator()
db := storage.MemDB()
checker := badChecker{}
c := NewWithChecker(lg, g, checker, db)
// With violation.
doc1 := &Doc{URL: "https://example.com", Author: "rsc", Title: "title", Text: "some bad text"}
doc2 := &Doc{Text: "some good text 2"}
r, err := c.Overview(context.Background(), doc1, doc2)
if err != nil {
t.Fatal(err)
}
if !r.HasPolicyViolation() {
t.Errorf("c.Overview.HasPolicyViolation = false, want true")
}
want := &PolicyEvaluation{
Violative: true,
PromptResults: []*PolicyResult{
// doc1
{
Results: []*llm.PolicyResult{violationResult},
Violations: []*llm.PolicyResult{violationResult},
},
// doc2
{Results: []*llm.PolicyResult{okResult}},
// instructions
{Results: []*llm.PolicyResult{okResult}},
},
OutputResults: &PolicyResult{
Results: []*llm.PolicyResult{violationResult},
Violations: []*llm.PolicyResult{violationResult},
},
}
if diff := cmp.Diff(want, r.PolicyEvaluation, cmpopts.IgnoreFields(PolicyResult{}, "Text")); diff != "" {
t.Errorf("c.Overview.PolicyEvaluation mismatch (-want,+got):\n%v", diff)
}
// Without violation.
r, err = c.Overview(context.Background(), doc2)
if err != nil {
t.Fatal(err)
}
if r.HasPolicyViolation() {
t.Errorf("c.Overview.HasPolicyViolation = true, want false")
}
want = &PolicyEvaluation{
Violative: false,
PromptResults: []*PolicyResult{
// doc2
{Results: []*llm.PolicyResult{okResult}},
// instructions
{Results: []*llm.PolicyResult{okResult}},
},
OutputResults: &PolicyResult{
Results: []*llm.PolicyResult{okResult},
},
}
if diff := cmp.Diff(want, r.PolicyEvaluation, cmpopts.IgnoreFields(PolicyResult{}, "Text")); diff != "" {
t.Errorf("c.Overview.PolicyEvaluation mismatch (-want,+got):\n%v", diff)
}
}
// badChecker is a test implementation of [llm.PolicyChecker] that
// always returns a policy violation for text containing the string "bad",
// and no violations otherwise.
type badChecker struct{}
// no-op
func (badChecker) SetPolicies(_ []*llm.PolicyConfig) {}
var (
violationResult = &llm.PolicyResult{
PolicyType: llm.PolicyTypeDangerousContent,
ViolationResult: llm.ViolationResultViolative,
Score: 1,
}
okResult = &llm.PolicyResult{
PolicyType: llm.PolicyTypeDangerousContent,
ViolationResult: llm.ViolationResultNonViolative,
Score: 0,
}
)
// return violation for text containing "bad" and no violation for any other text.
func (badChecker) CheckText(_ context.Context, text string, prompts ...llm.Part) ([]*llm.PolicyResult, error) {
if strings.Contains(text, "bad") {
return []*llm.PolicyResult{
violationResult,
}, nil
}
return []*llm.PolicyResult{
okResult,
}, nil
}