cmd/vulnreport: no-op refactor to move more code out of main
The main function was getting pretty large, so this change
breaks some of it up into functions.
Change-Id: I08d3796904886109ad821a58f8c7e15881485d4d
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/432595
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/cmd/vulnreport/main.go b/cmd/vulnreport/main.go
index 3031475..fa42fa1 100644
--- a/cmd/vulnreport/main.go
+++ b/cmd/vulnreport/main.go
@@ -62,89 +62,39 @@
}
flag.Parse()
- if flag.NArg() < 2 {
+ if flag.NArg() < 3 {
flag.Usage()
os.Exit(1)
}
cmd := flag.Arg(0)
- names := flag.Args()[1:]
- switch cmd {
- case "create":
- if *githubToken == "" {
- flag.Usage()
- log.Fatalf("githubToken must be provided")
- }
- if len(names) < 1 {
- log.Fatal("need at least one ID")
- }
- var githubIDs []int
- parseGithubID := func(s string) int {
- id, err := strconv.Atoi(s)
- if err != nil {
- log.Fatalf("invalid GitHub issue ID: %q", s)
- }
- return id
- }
- existingByIssue, existingByFile, err := existingReports()
+ args := flag.Args()[1:]
+
+ // Create operates on github issue IDs instead of filenames, so it is
+ // separated from the other commands.
+ if cmd == "create" {
+ githubIDs, cfg, err := setupCreate(args)
if err != nil {
log.Fatal(err)
}
- for _, name := range names {
- if !strings.Contains(name, "-") {
- githubIDs = append(githubIDs, parseGithubID(name))
- continue
- }
- from, to, _ := strings.Cut(name, "-")
- fromID := parseGithubID(from)
- toID := parseGithubID(to)
- if fromID > toID {
- log.Fatalf("%v > %v", fromID, toID)
- }
- for id := fromID; id <= toID; id++ {
- if existingByIssue[id] != nil {
- continue
- }
- githubIDs = append(githubIDs, id)
- }
- }
- repoPath := cvelistrepo.URL
- if *localRepoPath != "" {
- repoPath = *localRepoPath
- } else if len(githubIDs) > 1 {
- // Maybe we should automatically maintain a local clone of the
- // cvelist repo, but for now we can avoid repeatedly fetching it
- // when iterating over a list of reports.
- log.Fatalf("git clone %v to a local directory, and set -local-cve-repo to that path", cvelistrepo.URL)
- }
- owner, repoName, err := gitrepo.ParseGitHubRepo(*issueRepo)
- if err != nil {
- log.Fatal(err)
- }
- c := issues.NewGitHubClient(owner, repoName, *githubToken)
for _, githubID := range githubIDs {
- if err := create(ctx, githubID, *githubToken, repoPath, c, existingByFile); err != nil {
- log.Print(err)
+ if err := create(ctx, githubID, cfg); err != nil {
+ log.Fatal(err)
}
}
+ return
+ }
+
+ var cmdFunc func(string) error
+ switch cmd {
case "lint":
- if err := multi(lint, names); err != nil {
- log.Fatal(err)
- }
+ cmdFunc = lint
case "commit":
- f := func(name string) error { return commit(ctx, name, *githubToken) }
- if err := multi(f, names); err != nil {
- log.Fatal(err)
- }
+ cmdFunc = func(name string) error { return commit(ctx, name, *githubToken) }
case "newcve":
- if err := multi(newCVE, names); err != nil {
- log.Fatal(err)
- }
+ cmdFunc = newCVE
case "fix":
- f := func(name string) error { return fix(ctx, name, *githubToken) }
- if err := multi(f, names); err != nil {
- log.Fatal(err)
- }
+ cmdFunc = func(name string) error { return fix(ctx, name, *githubToken) }
case "set-dates":
repo, err := gitrepo.Open(ctx, ".")
if err != nil {
@@ -154,16 +104,13 @@
if err != nil {
log.Fatal(err)
}
- f := func(name string) error { return setDates(name, commitDates) }
- if err := multi(f, names); err != nil {
- log.Fatal(err)
- }
+ cmdFunc = func(name string) error { return setDates(name, commitDates) }
case "xref":
_, existingByFile, err := existingReports()
if err != nil {
log.Fatal(err)
}
- f := func(name string) error {
+ cmdFunc = func(name string) error {
r, err := report.Read(name)
if err != nil {
return err
@@ -172,33 +119,36 @@
fmt.Print(xref(name, r, existingByFile))
return nil
}
- if err := multi(f, names); err != nil {
- log.Fatal(err)
- }
default:
flag.Usage()
log.Fatalf("unsupported command: %q", cmd)
}
-}
-func multi(f func(string) error, args []string) error {
+ // Run the command on each argument.
for _, arg := range args {
- if _, err := os.Stat(arg); err != nil {
- // If arg isn't a file, see if it might be an issue ID
- // with an existing report.
- for _, padding := range []string{"", "0", "00", "000"} {
- m, _ := filepath.Glob("data/*/GO-*-" + padding + arg + ".yaml")
- if len(m) == 1 {
- arg = m[0]
- break
- }
- }
+ arg, err := argToFilename(arg)
+ if err != nil {
+ log.Fatal(err)
}
- if err := f(arg); err != nil {
- return err
+ if err := cmdFunc(arg); err != nil {
+ log.Fatal(err)
}
}
- return nil
+}
+
+func argToFilename(arg string) (string, error) {
+ if _, err := os.Stat(arg); err != nil {
+ // If arg isn't a file, see if it might be an issue ID
+ // with an existing report.
+ for _, padding := range []string{"", "0", "00", "000"} {
+ m, _ := filepath.Glob("data/*/GO-*-" + padding + arg + ".yaml")
+ if len(m) == 1 {
+ return m[0], nil
+ }
+ }
+ return "", fmt.Errorf("%s is not a valid filename or issue ID with existing report", arg)
+ }
+ return arg, nil
}
func existingReports() (byIssue map[int]*report.Report, byFile map[string]*report.Report, err error) {
@@ -237,10 +187,92 @@
return byIssue, byFile, nil
}
-func create(ctx context.Context, issueNumber int, ghToken, repoPath string, c issues.Client, existingByFile map[string]*report.Report) (err error) {
+func parseArgsToGithubIDs(args []string, existingByIssue map[int]*report.Report) ([]int, error) {
+ var githubIDs []int
+ parseGithubID := func(s string) (int, error) {
+ id, err := strconv.Atoi(s)
+ if err != nil {
+ return 0, fmt.Errorf("invalid GitHub issue ID: %q", s)
+ }
+ return id, nil
+ }
+ for _, arg := range args {
+ if !strings.Contains(arg, "-") {
+ id, err := parseGithubID(arg)
+ if err != nil {
+ return nil, err
+ }
+ githubIDs = append(githubIDs, id)
+ continue
+ }
+ from, to, _ := strings.Cut(arg, "-")
+ fromID, err := parseGithubID(from)
+ if err != nil {
+ return nil, err
+ }
+ toID, err := parseGithubID(to)
+ if err != nil {
+ return nil, err
+ }
+ if fromID > toID {
+ return nil, fmt.Errorf("%v > %v", fromID, toID)
+ }
+ for id := fromID; id <= toID; id++ {
+ if existingByIssue[id] != nil {
+ continue
+ }
+ githubIDs = append(githubIDs, id)
+ }
+ }
+ return githubIDs, nil
+}
+
+type createCfg struct {
+ ghToken string
+ repoPath string
+ issuesClient issues.Client
+ existingByFile map[string]*report.Report
+}
+
+func setupCreate(args []string) ([]int, *createCfg, error) {
+ if *githubToken == "" {
+ flag.Usage()
+ log.Fatalf("githubToken must be provided")
+ }
+ existingByIssue, existingByFile, err := existingReports()
+ if err != nil {
+ log.Fatal(err)
+ }
+ githubIDs, err := parseArgsToGithubIDs(args, existingByIssue)
+ if err != nil {
+ log.Fatal(err)
+ }
+ if len(githubIDs) > 1 {
+ // Maybe we should automatically maintain a local clone of the
+ // cvelist repo, but for now we can avoid repeatedly fetching it
+ // when iterating over a list of reports.
+ return nil, nil, fmt.Errorf("git clone %v to a local directory, and set -local-cve-repo to that path", cvelistrepo.URL)
+ }
+ repoPath := cvelistrepo.URL
+ if *localRepoPath != "" {
+ repoPath = *localRepoPath
+ }
+ owner, repoName, err := gitrepo.ParseGitHubRepo(*issueRepo)
+ if err != nil {
+ return nil, nil, err
+ }
+ return githubIDs, &createCfg{
+ ghToken: *githubToken,
+ repoPath: repoPath,
+ issuesClient: issues.NewGitHubClient(owner, repoName, *githubToken),
+ existingByFile: existingByFile,
+ }, nil
+}
+
+func create(ctx context.Context, issueNumber int, cfg *createCfg) (err error) {
defer derrors.Wrap(&err, "create(%d)", issueNumber)
// Get GitHub issue.
- iss, err := c.GetIssue(ctx, issueNumber, issues.GetIssueOptions{GetLabels: true})
+ iss, err := cfg.issuesClient.GetIssue(ctx, issueNumber, issues.GetIssueOptions{GetLabels: true})
if err != nil {
return err
}
@@ -275,7 +307,7 @@
}
if len(ghsas) == 0 && len(cves) > 0 {
for _, cve := range cves {
- sas, err := ghsa.ListForCVE(ctx, ghToken, cve)
+ sas, err := ghsa.ListForCVE(ctx, cfg.ghToken, cve)
if err != nil {
return err
}
@@ -290,13 +322,13 @@
var r *report.Report
switch {
case len(ghsas) > 0:
- ghsa, err := ghsa.FetchGHSA(ctx, ghToken, ghsas[0])
+ ghsa, err := ghsa.FetchGHSA(ctx, cfg.ghToken, ghsas[0])
if err != nil {
return err
}
r = report.GHSAToReport(ghsa, modulePath)
case len(cves) > 0:
- cve, err := cvelistrepo.FetchCVE(ctx, repoPath, cves[0])
+ cve, err := cvelistrepo.FetchCVE(ctx, cfg.repoPath, cves[0])
if err != nil {
return err
}
@@ -335,7 +367,7 @@
return err
}
fmt.Println(filename)
- fmt.Print(xref(filename, r, existingByFile))
+ fmt.Print(xref(filename, r, cfg.existingByFile))
return nil
}