blob: 7599b556297e735cb082601abccd0e812b6cdacd [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Labeleval is a program for evaluating issue categorization.
It applies the internal/labels package to a selected set of issues
and compares the results with expected values.
Usage:
labeleval categoryconfig issueconfig
Categoryconfig defines the list of categories to use.
It is a YAML file that matches the type
struct {
Categories []labels.Category
}
Issueconfig is a list of issue numbers to evaluate, along with their expected category.
The issues must all be in the production DB under the golang/go project.
By default, all the issues in the issueconfig file are evaluated.
The -run flag provides a way to run a subset of them.
There are four results of evaluating an issue:
PASS: the LLM chose a category that matches the desired one
FAIL: the LLM chose a different category
NEW: the expected category is not in the category config
ERROR: there was a failure trying to run the LLM
A typical run will use the category config from the internal/labels package and the list
of issues in this package. From repo root:
go run ./internal/devtools/cmd/labeleval internal/labels/static/categories.yaml internal/devtools/cmd/labeleval/issues.txt
*/
package main
import (
"context"
"errors"
"flag"
"fmt"
"log"
"log/slog"
"maps"
"net/http"
"os"
"slices"
"strconv"
"strings"
"time"
"golang.org/x/oscar/internal/gcp/firestore"
"golang.org/x/oscar/internal/gcp/gemini"
"golang.org/x/oscar/internal/github"
"golang.org/x/oscar/internal/labels"
"golang.org/x/oscar/internal/secret"
"golang.org/x/oscar/internal/storage"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v3"
)
var (
runIssueNumbers = flag.String("run", "", "comma-separated issue numbers to run")
count = flag.Int("count", 1, "repeats of the evaluation")
)
func usage() {
fmt.Fprintf(os.Stderr, "usage: labeleval categoryconfig issueconfig\n")
flag.PrintDefaults()
os.Exit(2)
}
func main() {
log.SetFlags(0)
log.SetPrefix("labeleval: ")
flag.Usage = usage
flag.Parse()
if flag.NArg() != 2 {
usage()
}
if err := run(context.Background(), flag.Arg(0), flag.Arg(1)); err != nil {
log.Fatal(err)
}
}
// An issueConfig refers to an issue in the golang/go project, with its expected category.
type issueConfig struct {
Number int
WantCategory string
issue *github.Issue
}
// A reponse holds the return values from [labels.IssueCategoryFromList].
type response struct {
cat labels.Category
explanation string
err error
}
func run(ctx context.Context, categoryconfigFile, issueConfigFile string) error {
if *count <= 0 {
return fmt.Errorf("bad count: %d", *count)
}
var categoryConfig struct {
Categories []labels.Category
}
if err := readYAMLFile(categoryconfigFile, &categoryConfig); err != nil {
return err
}
knownCategories := map[string]bool{}
for _, c := range categoryConfig.Categories {
knownCategories[c.Name] = true
}
allIssueConfigs, err := readIssueFile(issueConfigFile)
if err != nil {
return err
}
var issueConfigs []*issueConfig
if len(*runIssueNumbers) == 0 {
issueConfigs = allIssueConfigs
} else {
// Filter by the provided issue numbers.
rns := strings.Split(*runIssueNumbers, ",")
runNums := map[int]bool{}
for _, rn := range rns {
n, err := strconv.Atoi(strings.TrimSpace(rn))
if err != nil {
return err
}
runNums[n] = true
}
for _, ic := range allIssueConfigs {
if runNums[ic.Number] {
issueConfigs = append(issueConfigs, ic)
}
}
}
if len(issueConfigs) == 0 {
return errors.New("no issues to evaluate")
}
lg := slog.New(slog.NewTextHandler(os.Stderr, nil))
db, err := firestore.NewDB(ctx, lg, "oscar-go-1", "prod")
if err != nil {
return err
}
cgen, err := newGeminiClient(ctx, lg)
if err != nil {
return err
}
start := time.Now()
if err := lookupIssues(db, issueConfigs); err != nil {
return err
}
log.Printf("looked up %d issues in %.1fs", len(issueConfigs), time.Since(start).Seconds())
responseLists := make([][]response, len(issueConfigs))
for i := range responseLists {
responseLists[i] = make([]response, *count)
}
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(10)
start = time.Now()
for c := range *count {
for i, ic := range issueConfigs {
g.Go(func() error {
got, exp, err := labels.IssueCategoryFromList(gctx, cgen, ic.issue, categoryConfig.Categories)
responseLists[i][c] = response{got, exp, err}
return nil
})
}
}
if err := g.Wait(); err != nil {
return err
}
log.Printf("evaluated %d times with %s in %.1f seconds",
*count, cgen.Model(), time.Since(start).Seconds())
printResults(issueConfigs, responseLists, knownCategories)
return nil
}
func lookupIssues(db storage.DB, ics []*issueConfig) error {
for _, ic := range ics {
issue, err := github.LookupIssue(db, "golang/go", int64(ic.Number))
if err != nil {
return err
}
ic.issue = issue
}
return nil
}
func printResults(issueConfigs []*issueConfig, responseLists [][]response, knownCategories map[string]bool) {
Issues:
for i, ic := range issueConfigs {
responseList := responseLists[i]
fmt.Printf("%-6d ", ic.Number)
if len(responseList) == 1 {
res := responseList[0]
if res.err != nil {
fmt.Printf("ERROR %v\n", res.err)
} else if !knownCategories[ic.WantCategory] {
fmt.Printf("NEW got %s; want %s is not in list of known categories\n",
res.cat.Name, ic.WantCategory)
} else if res.cat.Name != ic.WantCategory {
fmt.Printf("FAIL got %s want %s\n", res.cat.Name, ic.WantCategory)
fmt.Printf("%11s\n", res.explanation)
} else {
fmt.Printf("PASS\n")
}
} else {
// If any response is an error or NEW, stop there.
for _, res := range responseList {
if res.err != nil {
fmt.Printf("ERROR %v\n", res.err)
continue Issues
}
if !knownCategories[ic.WantCategory] {
fmt.Printf("NEW got %s; want %s is not in list of known categories\n",
res.cat.Name, ic.WantCategory)
continue Issues
}
}
// Group by category.
counts := map[string]int{}
for _, res := range responseList {
counts[res.cat.Name]++
}
p := counts[ic.WantCategory]
fmt.Printf("PASS:%d FAIL:%d ", p, len(responseList)-p)
// Sort passes first, then alphabetically.
cats := slices.Collect(maps.Keys(counts))
slices.SortFunc(cats, func(c1, c2 string) int {
if c1 == c2 {
return 0
}
if c1 == ic.WantCategory {
return -1
}
if c2 == ic.WantCategory {
return 1
}
return strings.Compare(c1, c2)
})
for _, c := range cats {
fmt.Printf(" %s:%d", c, counts[c])
}
fmt.Println()
}
}
}
func newGeminiClient(ctx context.Context, lg *slog.Logger) (*gemini.Client, error) {
sdb := secret.Netrc()
c, err := gemini.NewClient(ctx, lg, sdb, http.DefaultClient,
gemini.DefaultEmbeddingModel, gemini.DefaultGenerativeModel)
if err != nil {
return nil, err
}
c.SetTemperature(0)
return c, nil
}
func readYAMLFile(filename string, p any) error {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
dec := yaml.NewDecoder(f)
dec.KnownFields(true)
return dec.Decode(p)
}
func readIssueFile(filename string) ([]*issueConfig, error) {
data, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
var issues []*issueConfig
for i, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if len(line) == 0 || line[0] == '#' {
continue
}
// Trim comment from end of line.
if i := strings.LastIndexByte(line, '#'); i > 0 {
line = line[:i]
}
num, rest, found := strings.Cut(line, " ")
if !found {
return nil, fmt.Errorf("%s:%d: missing want", filename, i+1)
}
n, err := strconv.Atoi(num)
if err != nil {
return nil, fmt.Errorf("%s:%d: %v", filename, i+1, err)
}
issues = append(issues, &issueConfig{Number: n, WantCategory: strings.TrimSpace(rest)})
}
return issues, nil
}