cmd/vulnreport: no-op refactor to move more code out of main

The main function was getting pretty large, so this change
breaks some of it up into functions.

Change-Id: I08d3796904886109ad821a58f8c7e15881485d4d
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/432595
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/cmd/vulnreport/main.go b/cmd/vulnreport/main.go
index 3031475..fa42fa1 100644
--- a/cmd/vulnreport/main.go
+++ b/cmd/vulnreport/main.go
@@ -62,89 +62,39 @@
 	}
 
 	flag.Parse()
-	if flag.NArg() < 2 {
+	if flag.NArg() < 3 {
 		flag.Usage()
 		os.Exit(1)
 	}
 
 	cmd := flag.Arg(0)
-	names := flag.Args()[1:]
-	switch cmd {
-	case "create":
-		if *githubToken == "" {
-			flag.Usage()
-			log.Fatalf("githubToken must be provided")
-		}
-		if len(names) < 1 {
-			log.Fatal("need at least one ID")
-		}
-		var githubIDs []int
-		parseGithubID := func(s string) int {
-			id, err := strconv.Atoi(s)
-			if err != nil {
-				log.Fatalf("invalid GitHub issue ID: %q", s)
-			}
-			return id
-		}
-		existingByIssue, existingByFile, err := existingReports()
+	args := flag.Args()[1:]
+
+	// Create operates on github issue IDs instead of filenames, so it is
+	// separated from the other commands.
+	if cmd == "create" {
+		githubIDs, cfg, err := setupCreate(args)
 		if err != nil {
 			log.Fatal(err)
 		}
-		for _, name := range names {
-			if !strings.Contains(name, "-") {
-				githubIDs = append(githubIDs, parseGithubID(name))
-				continue
-			}
-			from, to, _ := strings.Cut(name, "-")
-			fromID := parseGithubID(from)
-			toID := parseGithubID(to)
-			if fromID > toID {
-				log.Fatalf("%v > %v", fromID, toID)
-			}
-			for id := fromID; id <= toID; id++ {
-				if existingByIssue[id] != nil {
-					continue
-				}
-				githubIDs = append(githubIDs, id)
-			}
-		}
-		repoPath := cvelistrepo.URL
-		if *localRepoPath != "" {
-			repoPath = *localRepoPath
-		} else if len(githubIDs) > 1 {
-			// Maybe we should automatically maintain a local clone of the
-			// cvelist repo, but for now we can avoid repeatedly fetching it
-			// when iterating over a list of reports.
-			log.Fatalf("git clone %v to a local directory, and set -local-cve-repo to that path", cvelistrepo.URL)
-		}
-		owner, repoName, err := gitrepo.ParseGitHubRepo(*issueRepo)
-		if err != nil {
-			log.Fatal(err)
-		}
-		c := issues.NewGitHubClient(owner, repoName, *githubToken)
 		for _, githubID := range githubIDs {
-			if err := create(ctx, githubID, *githubToken, repoPath, c, existingByFile); err != nil {
-				log.Print(err)
+			if err := create(ctx, githubID, cfg); err != nil {
+				log.Fatal(err)
 			}
 		}
+		return
+	}
+
+	var cmdFunc func(string) error
+	switch cmd {
 	case "lint":
-		if err := multi(lint, names); err != nil {
-			log.Fatal(err)
-		}
+		cmdFunc = lint
 	case "commit":
-		f := func(name string) error { return commit(ctx, name, *githubToken) }
-		if err := multi(f, names); err != nil {
-			log.Fatal(err)
-		}
+		cmdFunc = func(name string) error { return commit(ctx, name, *githubToken) }
 	case "newcve":
-		if err := multi(newCVE, names); err != nil {
-			log.Fatal(err)
-		}
+		cmdFunc = newCVE
 	case "fix":
-		f := func(name string) error { return fix(ctx, name, *githubToken) }
-		if err := multi(f, names); err != nil {
-			log.Fatal(err)
-		}
+		cmdFunc = func(name string) error { return fix(ctx, name, *githubToken) }
 	case "set-dates":
 		repo, err := gitrepo.Open(ctx, ".")
 		if err != nil {
@@ -154,16 +104,13 @@
 		if err != nil {
 			log.Fatal(err)
 		}
-		f := func(name string) error { return setDates(name, commitDates) }
-		if err := multi(f, names); err != nil {
-			log.Fatal(err)
-		}
+		cmdFunc = func(name string) error { return setDates(name, commitDates) }
 	case "xref":
 		_, existingByFile, err := existingReports()
 		if err != nil {
 			log.Fatal(err)
 		}
-		f := func(name string) error {
+		cmdFunc = func(name string) error {
 			r, err := report.Read(name)
 			if err != nil {
 				return err
@@ -172,33 +119,36 @@
 			fmt.Print(xref(name, r, existingByFile))
 			return nil
 		}
-		if err := multi(f, names); err != nil {
-			log.Fatal(err)
-		}
 	default:
 		flag.Usage()
 		log.Fatalf("unsupported command: %q", cmd)
 	}
-}
 
-func multi(f func(string) error, args []string) error {
+	// Run the command on each argument.
 	for _, arg := range args {
-		if _, err := os.Stat(arg); err != nil {
-			// If arg isn't a file, see if it might be an issue ID
-			// with an existing report.
-			for _, padding := range []string{"", "0", "00", "000"} {
-				m, _ := filepath.Glob("data/*/GO-*-" + padding + arg + ".yaml")
-				if len(m) == 1 {
-					arg = m[0]
-					break
-				}
-			}
+		arg, err := argToFilename(arg)
+		if err != nil {
+			log.Fatal(err)
 		}
-		if err := f(arg); err != nil {
-			return err
+		if err := cmdFunc(arg); err != nil {
+			log.Fatal(err)
 		}
 	}
-	return nil
+}
+
+func argToFilename(arg string) (string, error) {
+	if _, err := os.Stat(arg); err != nil {
+		// If arg isn't a file, see if it might be an issue ID
+		// with an existing report.
+		for _, padding := range []string{"", "0", "00", "000"} {
+			m, _ := filepath.Glob("data/*/GO-*-" + padding + arg + ".yaml")
+			if len(m) == 1 {
+				return m[0], nil
+			}
+		}
+		return "", fmt.Errorf("%s is not a valid filename or issue ID with existing report", arg)
+	}
+	return arg, nil
 }
 
 func existingReports() (byIssue map[int]*report.Report, byFile map[string]*report.Report, err error) {
@@ -237,10 +187,92 @@
 	return byIssue, byFile, nil
 }
 
-func create(ctx context.Context, issueNumber int, ghToken, repoPath string, c issues.Client, existingByFile map[string]*report.Report) (err error) {
+func parseArgsToGithubIDs(args []string, existingByIssue map[int]*report.Report) ([]int, error) {
+	var githubIDs []int
+	parseGithubID := func(s string) (int, error) {
+		id, err := strconv.Atoi(s)
+		if err != nil {
+			return 0, fmt.Errorf("invalid GitHub issue ID: %q", s)
+		}
+		return id, nil
+	}
+	for _, arg := range args {
+		if !strings.Contains(arg, "-") {
+			id, err := parseGithubID(arg)
+			if err != nil {
+				return nil, err
+			}
+			githubIDs = append(githubIDs, id)
+			continue
+		}
+		from, to, _ := strings.Cut(arg, "-")
+		fromID, err := parseGithubID(from)
+		if err != nil {
+			return nil, err
+		}
+		toID, err := parseGithubID(to)
+		if err != nil {
+			return nil, err
+		}
+		if fromID > toID {
+			return nil, fmt.Errorf("%v > %v", fromID, toID)
+		}
+		for id := fromID; id <= toID; id++ {
+			if existingByIssue[id] != nil {
+				continue
+			}
+			githubIDs = append(githubIDs, id)
+		}
+	}
+	return githubIDs, nil
+}
+
+type createCfg struct {
+	ghToken        string
+	repoPath       string
+	issuesClient   issues.Client
+	existingByFile map[string]*report.Report
+}
+
+func setupCreate(args []string) ([]int, *createCfg, error) {
+	if *githubToken == "" {
+		flag.Usage()
+		log.Fatalf("githubToken must be provided")
+	}
+	existingByIssue, existingByFile, err := existingReports()
+	if err != nil {
+		log.Fatal(err)
+	}
+	githubIDs, err := parseArgsToGithubIDs(args, existingByIssue)
+	if err != nil {
+		log.Fatal(err)
+	}
+	if len(githubIDs) > 1 {
+		// Maybe we should automatically maintain a local clone of the
+		// cvelist repo, but for now we can avoid repeatedly fetching it
+		// when iterating over a list of reports.
+		return nil, nil, fmt.Errorf("git clone %v to a local directory, and set -local-cve-repo to that path", cvelistrepo.URL)
+	}
+	repoPath := cvelistrepo.URL
+	if *localRepoPath != "" {
+		repoPath = *localRepoPath
+	}
+	owner, repoName, err := gitrepo.ParseGitHubRepo(*issueRepo)
+	if err != nil {
+		return nil, nil, err
+	}
+	return githubIDs, &createCfg{
+		ghToken:        *githubToken,
+		repoPath:       repoPath,
+		issuesClient:   issues.NewGitHubClient(owner, repoName, *githubToken),
+		existingByFile: existingByFile,
+	}, nil
+}
+
+func create(ctx context.Context, issueNumber int, cfg *createCfg) (err error) {
 	defer derrors.Wrap(&err, "create(%d)", issueNumber)
 	// Get GitHub issue.
-	iss, err := c.GetIssue(ctx, issueNumber, issues.GetIssueOptions{GetLabels: true})
+	iss, err := cfg.issuesClient.GetIssue(ctx, issueNumber, issues.GetIssueOptions{GetLabels: true})
 	if err != nil {
 		return err
 	}
@@ -275,7 +307,7 @@
 	}
 	if len(ghsas) == 0 && len(cves) > 0 {
 		for _, cve := range cves {
-			sas, err := ghsa.ListForCVE(ctx, ghToken, cve)
+			sas, err := ghsa.ListForCVE(ctx, cfg.ghToken, cve)
 			if err != nil {
 				return err
 			}
@@ -290,13 +322,13 @@
 	var r *report.Report
 	switch {
 	case len(ghsas) > 0:
-		ghsa, err := ghsa.FetchGHSA(ctx, ghToken, ghsas[0])
+		ghsa, err := ghsa.FetchGHSA(ctx, cfg.ghToken, ghsas[0])
 		if err != nil {
 			return err
 		}
 		r = report.GHSAToReport(ghsa, modulePath)
 	case len(cves) > 0:
-		cve, err := cvelistrepo.FetchCVE(ctx, repoPath, cves[0])
+		cve, err := cvelistrepo.FetchCVE(ctx, cfg.repoPath, cves[0])
 		if err != nil {
 			return err
 		}
@@ -335,7 +367,7 @@
 		return err
 	}
 	fmt.Println(filename)
-	fmt.Print(xref(filename, r, existingByFile))
+	fmt.Print(xref(filename, r, cfg.existingByFile))
 	return nil
 }