internal/rules: use labeling classification logic

Change-Id: If0febf465bf96399079c4aebf23ff71122fcc902
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/635456
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/rules/rules.go b/internal/rules/rules.go
index 5c50e0f..ec210a9 100644
--- a/internal/rules/rules.go
+++ b/internal/rules/rules.go
@@ -15,6 +15,7 @@
 	"text/template"
 
 	"golang.org/x/oscar/internal/github"
+	"golang.org/x/oscar/internal/labels"
 	"golang.org/x/oscar/internal/llm"
 )
 
@@ -38,9 +39,17 @@
 		return &result, nil
 	}
 
+	kind, reasoning, err := Classify(ctx, cgen, i)
+	if err != nil {
+		return nil, err
+	}
+
+	// For now, report the classification. We won't do this in the final version.
+	result.Response += fmt.Sprintf("## Classification\n**%s**\n\n> %s\n\n", kind.Name, reasoning)
+
 	// Extract issue text into a string.
 	var issueText bytes.Buffer
-	err := template.Must(template.New("prompt").Parse(body)).Execute(&issueText, bodyArgs{
+	err = template.Must(template.New("prompt").Parse(body)).Execute(&issueText, bodyArgs{
 		Title: i.Title,
 		Body:  i.Body,
 	})
@@ -48,42 +57,8 @@
 		return nil, err
 	}
 
-	// Build system prompt to ask about the issue kind.
-	var systemPrompt bytes.Buffer
-	systemPrompt.WriteString(kindPrompt)
-	for _, kind := range rulesConfig.IssueKinds {
-		fmt.Fprintf(&systemPrompt, "%s: %s", kind.Name, kind.Text)
-		if kind.Details != "" {
-			fmt.Fprintf(&systemPrompt, " (%s)", kind.Details)
-		}
-		systemPrompt.WriteString("\n")
-	}
-
-	// Ask about the kind of issue.
-	res, err := cgen.GenerateContent(ctx, nil, []llm.Part{llm.Text(systemPrompt.String()), llm.Text(issueText.String())})
-	if err != nil {
-		return nil, fmt.Errorf("llm request failed: %w\n", err)
-	}
-	firstLine, remainingLines, _ := strings.Cut(res, "\n")
-
-	// Parse the result.
-	var kind IssueKind
-	for _, k := range rulesConfig.IssueKinds {
-		if firstLine == k.Name {
-			kind = k
-			break
-		}
-	}
-	if kind.Name == "" {
-		log.Printf("kind %q response not valid", firstLine)
-		return nil, fmt.Errorf("llm returned invalid kind: %s", firstLine)
-		// TODO: just return Response=="" if LLM isn't obeying the prompt?
-	}
-
-	// For now, report the classification. We won't do this in the final version.
-	result.Response += fmt.Sprintf("## Classification\n**%s**\n\n> %s\n\n", kind.Name, remainingLines)
-
 	// Now that we know the kind, ask about each of the rules for the kind.
+	var systemPrompt bytes.Buffer
 	var failed []Rule
 	var failedReason []string
 	for _, rule := range kind.Rules {
@@ -127,19 +102,36 @@
 	return &result, nil
 }
 
+// Classify returns the kind of issue we're dealing with.
+// Returns a description of the classification and a string describing
+// the llm's reasoning.
+func Classify(ctx context.Context, cgen llm.ContentGenerator, i *github.Issue) (IssueKind, string, error) {
+	// TODO: use the default github label categories, and adjust
+	// the rule file to match.
+	var cats []labels.Category
+	for _, kind := range rulesConfig.IssueKinds {
+		cats = append(cats, labels.Category{
+			Name:        kind.Name,
+			Description: kind.Text,
+			Extra:       kind.Details,
+		})
+	}
+	cat, explanation, err := labels.IssueCategoryFromList(ctx, cgen, i, cats)
+	if err != nil {
+		return IssueKind{}, "", err
+	}
+	for _, kind := range rulesConfig.IssueKinds {
+		if kind.Name == cat.Name {
+			return kind, explanation, nil
+		}
+	}
+	return IssueKind{}, "", fmt.Errorf("unexpected category %s", cat.Name)
+}
+
 //go:embed static/*
 var staticFS embed.FS
 
 // TODO: put some of these in the staticFS
-const kindPrompt = `
-Your job is to categorize Go issues.
-The issue is described by a title and a body.
-The issue body is encoded in markdown.
-Report the category of the issue on a line by itself, followed by an explanation of your decision.
-Each category and its description are listed below.
-
-`
-
 const rulePrompt = `
 Your job is to decide whether a Go issue follows this rule: %s (%s)
 The issue is described by a title and a body.
diff --git a/internal/rules/rules_test.go b/internal/rules/rules_test.go
index e37adb4..1f2d7b6 100644
--- a/internal/rules/rules_test.go
+++ b/internal/rules/rules_test.go
@@ -6,7 +6,6 @@
 
 import (
 	"context"
-	"fmt"
 	"strings"
 	"testing"
 
@@ -40,13 +39,34 @@
 	}
 }
 
+func TestClassify(t *testing.T) {
+	ctx := context.Background()
+	llm := ruleTestGenerator()
+
+	// Construct a test issue.
+	i := new(github.Issue)
+	i.Number = 999
+	i.User = github.User{Login: "user"}
+	i.Title = "title"
+	i.Body = "body"
+
+	// Run classifier.
+	r, _, err := Classify(ctx, llm, i)
+	if err != nil {
+		t.Fatalf("Classify failed with %v", err)
+	}
+
+	// Check result.
+	want := "bug"
+	if r.Name != want {
+		t.Errorf("Classify got %q, want %q", r.Name, want)
+	}
+}
+
 func ruleTestGenerator() llm.ContentGenerator {
 	return llm.TestContentGenerator(
 		"ruleTestGenerator",
 		func(_ context.Context, schema *llm.Schema, promptParts []llm.Part) (string, error) {
-			if schema != nil {
-				return "", fmt.Errorf("not implemented")
-			}
 			var strs []string
 			for _, p := range promptParts {
 				strs = append(strs, string(p.(llm.Text)))
@@ -54,7 +74,7 @@
 			req := strings.Join(strs, " ")
 			if strings.Contains(req, "Your job is to categorize") {
 				// categorize request. Always report it as a "bug".
-				return "bug\nI think this is a bug.", nil
+				return `{"CategoryName":"bug","Explanation":"I think this is a bug."}`, nil
 			}
 			if strings.Contains(req, "Your job is to decide") {
 				// rule request. Report that the title rule failed and the others succeeded.