internal/rules: use labeling classification logic
Change-Id: If0febf465bf96399079c4aebf23ff71122fcc902
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/635456
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/rules/rules.go b/internal/rules/rules.go
index 5c50e0f..ec210a9 100644
--- a/internal/rules/rules.go
+++ b/internal/rules/rules.go
@@ -15,6 +15,7 @@
"text/template"
"golang.org/x/oscar/internal/github"
+ "golang.org/x/oscar/internal/labels"
"golang.org/x/oscar/internal/llm"
)
@@ -38,9 +39,17 @@
return &result, nil
}
+ kind, reasoning, err := Classify(ctx, cgen, i)
+ if err != nil {
+ return nil, err
+ }
+
+ // For now, report the classification. We won't do this in the final version.
+ result.Response += fmt.Sprintf("## Classification\n**%s**\n\n> %s\n\n", kind.Name, reasoning)
+
// Extract issue text into a string.
var issueText bytes.Buffer
- err := template.Must(template.New("prompt").Parse(body)).Execute(&issueText, bodyArgs{
+ err = template.Must(template.New("prompt").Parse(body)).Execute(&issueText, bodyArgs{
Title: i.Title,
Body: i.Body,
})
@@ -48,42 +57,8 @@
return nil, err
}
- // Build system prompt to ask about the issue kind.
- var systemPrompt bytes.Buffer
- systemPrompt.WriteString(kindPrompt)
- for _, kind := range rulesConfig.IssueKinds {
- fmt.Fprintf(&systemPrompt, "%s: %s", kind.Name, kind.Text)
- if kind.Details != "" {
- fmt.Fprintf(&systemPrompt, " (%s)", kind.Details)
- }
- systemPrompt.WriteString("\n")
- }
-
- // Ask about the kind of issue.
- res, err := cgen.GenerateContent(ctx, nil, []llm.Part{llm.Text(systemPrompt.String()), llm.Text(issueText.String())})
- if err != nil {
- return nil, fmt.Errorf("llm request failed: %w\n", err)
- }
- firstLine, remainingLines, _ := strings.Cut(res, "\n")
-
- // Parse the result.
- var kind IssueKind
- for _, k := range rulesConfig.IssueKinds {
- if firstLine == k.Name {
- kind = k
- break
- }
- }
- if kind.Name == "" {
- log.Printf("kind %q response not valid", firstLine)
- return nil, fmt.Errorf("llm returned invalid kind: %s", firstLine)
- // TODO: just return Response=="" if LLM isn't obeying the prompt?
- }
-
- // For now, report the classification. We won't do this in the final version.
- result.Response += fmt.Sprintf("## Classification\n**%s**\n\n> %s\n\n", kind.Name, remainingLines)
-
// Now that we know the kind, ask about each of the rules for the kind.
+ var systemPrompt bytes.Buffer
var failed []Rule
var failedReason []string
for _, rule := range kind.Rules {
@@ -127,19 +102,36 @@
return &result, nil
}
+// Classify returns the kind of issue we're dealing with.
+// Returns a description of the classification and a string describing
+// the llm's reasoning.
+func Classify(ctx context.Context, cgen llm.ContentGenerator, i *github.Issue) (IssueKind, string, error) {
+ // TODO: use the default github label categories, and adjust
+ // the rule file to match.
+ var cats []labels.Category
+ for _, kind := range rulesConfig.IssueKinds {
+ cats = append(cats, labels.Category{
+ Name: kind.Name,
+ Description: kind.Text,
+ Extra: kind.Details,
+ })
+ }
+ cat, explanation, err := labels.IssueCategoryFromList(ctx, cgen, i, cats)
+ if err != nil {
+ return IssueKind{}, "", err
+ }
+ for _, kind := range rulesConfig.IssueKinds {
+ if kind.Name == cat.Name {
+ return kind, explanation, nil
+ }
+ }
+ return IssueKind{}, "", fmt.Errorf("unexpected category %s", cat.Name)
+}
+
//go:embed static/*
var staticFS embed.FS
// TODO: put some of these in the staticFS
-const kindPrompt = `
-Your job is to categorize Go issues.
-The issue is described by a title and a body.
-The issue body is encoded in markdown.
-Report the category of the issue on a line by itself, followed by an explanation of your decision.
-Each category and its description are listed below.
-
-`
-
const rulePrompt = `
Your job is to decide whether a Go issue follows this rule: %s (%s)
The issue is described by a title and a body.
diff --git a/internal/rules/rules_test.go b/internal/rules/rules_test.go
index e37adb4..1f2d7b6 100644
--- a/internal/rules/rules_test.go
+++ b/internal/rules/rules_test.go
@@ -6,7 +6,6 @@
import (
"context"
- "fmt"
"strings"
"testing"
@@ -40,13 +39,34 @@
}
}
+func TestClassify(t *testing.T) {
+ ctx := context.Background()
+ llm := ruleTestGenerator()
+
+ // Construct a test issue.
+ i := new(github.Issue)
+ i.Number = 999
+ i.User = github.User{Login: "user"}
+ i.Title = "title"
+ i.Body = "body"
+
+ // Run classifier.
+ r, _, err := Classify(ctx, llm, i)
+ if err != nil {
+ t.Fatalf("Classify failed with %v", err)
+ }
+
+ // Check result.
+ want := "bug"
+ if r.Name != want {
+ t.Errorf("Classify got %q, want %q", r.Name, want)
+ }
+}
+
func ruleTestGenerator() llm.ContentGenerator {
return llm.TestContentGenerator(
"ruleTestGenerator",
func(_ context.Context, schema *llm.Schema, promptParts []llm.Part) (string, error) {
- if schema != nil {
- return "", fmt.Errorf("not implemented")
- }
var strs []string
for _, p := range promptParts {
strs = append(strs, string(p.(llm.Text)))
@@ -54,7 +74,7 @@
req := strings.Join(strs, " ")
if strings.Contains(req, "Your job is to categorize") {
// categorize request. Always report it as a "bug".
- return "bug\nI think this is a bug.", nil
+ return `{"CategoryName":"bug","Explanation":"I think this is a bug."}`, nil
}
if strings.Contains(req, "Your job is to decide") {
// rule request. Report that the title rule failed and the others succeeded.