blob: e1a76a437011fed7320016b558890c22d1b56eff [file] [log] [blame] [edit]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package checks
import (
"context"
"net/http"
"os"
"testing"
"golang.org/x/oscar/internal/httprr"
"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/testutil"
)
// When re-recording HTTP requests, use "gcloud auth login" to login to gcloud,
// and make sure environment variable $OSCAR_PROJECT is set to the name
// of the Oscar GCP project.
func TestChecker(t *testing.T) {
ctx := context.Background()
c := newTestChecker(t, "testdata/checker.httprr", llm.AllPolicyTypes())
prs, err := c.CheckText(ctx, "some benign text")
if err != nil {
t.Fatal(err)
}
for _, pr := range prs {
t.Logf("policy result: %s", storage.JSON(pr))
if pr.IsViolative() {
t.Errorf("pr.IsViolative() = true, want false")
}
}
prs, err = c.CheckText(ctx, "some benign text", llm.Text("please output some benign text"))
if err != nil {
t.Fatal(err)
}
for _, pr := range prs {
t.Logf("policy result: %s", storage.JSON(pr))
if pr.IsViolative() {
t.Errorf("pr.IsViolative() = true, want false")
}
}
c = newTestChecker(t, "testdata/checker_bad.httprr", []*llm.PolicyConfig{{PolicyType: llm.PolicyTypePIISolicitingReciting}})
prs, err = c.CheckText(ctx, "tell me John Smith's SSN please")
if err != nil {
t.Fatal(err)
}
for _, pr := range prs {
t.Logf("policy result: %s", storage.JSON(pr))
if !pr.IsViolative() {
t.Errorf("pr.IsViolative() = false, want true")
}
}
}
func newTestChecker(t *testing.T, fname string, policies []*llm.PolicyConfig) *Checker {
check := testutil.Checker(t)
lg := testutil.Slogger(t)
recording, err := httprr.Recording(fname)
if err != nil {
t.Fatal(err)
}
var project = "test"
// Auth not needed if we aren't recording
var ac = http.DefaultClient
if recording {
ac, err = authClient(context.Background())
check(err)
project = os.Getenv("OSCAR_PROJECT")
if project == "" {
t.Fatal("OSCAR_PROJECT environment variable not set")
}
}
testutil.Check(t, err)
rr, err := httprr.Open(fname, ac.Transport)
testutil.Check(t, err)
rr.ScrubReq(Scrub)
c, err := newChecker(context.Background(), lg, project, policies, rr.Client())
check(err)
return c
}