cmd/ejobs: better support for subcommand flags

Intead of global flag sets, each subcommand gets its own flag set
created and parsed for it.

The flag variables themselves are still global, though.

Change-Id: Id587b62d124651a08d4c7d953445de0726bf9ca0
Reviewed-on: https://go-review.googlesource.com/c/pkgsite-metrics/+/511815
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Maceo Thompson <maceothompson@google.com>
diff --git a/cmd/ejobs/main.go b/cmd/ejobs/main.go
index 7277cfa..519b7bf 100644
--- a/cmd/ejobs/main.go
+++ b/cmd/ejobs/main.go
@@ -45,15 +45,10 @@
 )
 
 var (
-	startFlagSet = flag.NewFlagSet("start", flag.ContinueOnError)
-	minImporters = startFlagSet.Int("min", -1, "run on modules with at least this many importers (<0: use server default of 10)")
-
-	waitFlagSet  = flag.NewFlagSet("wait", flag.ContinueOnError)
-	waitInterval = waitFlagSet.Duration("i", 0, "display updates at this interval")
-
-	resultsFlagSet = flag.NewFlagSet("results", flag.ContinueOnError)
-	force          = resultsFlagSet.Bool("f", false, "download even if unfinished")
-	outfile        = resultsFlagSet.String("o", "", "output filename")
+	minImporters int           // for start
+	waitInterval time.Duration // for wait
+	force        bool          // for results
+	outfile      string        // for results
 )
 
 var commands = []command{
@@ -68,30 +63,50 @@
 		doCancel, nil},
 	{"start", "[-min MIN_IMPORTERS] BINARY ARGS...",
 		"start a job",
-		doStart, startFlagSet},
+		doStart,
+		func(fs *flag.FlagSet) {
+			fs.IntVar(&minImporters, "min", -1,
+				"run on modules with at least this many importers (<0: use server default of 10)")
+		},
+	},
 	{"wait", "JOBID",
 		"do not exit until JOBID is done",
-		doWait, nil},
+		doWait,
+		func(fs *flag.FlagSet) {
+			fs.DurationVar(&waitInterval, "i", 0, "display updates at this interval")
+		},
+	},
 	{"results", "[-f] [-o FILE.json] JOBID",
 		"download results as JSON",
-		doResults, nil},
+		doResults,
+		func(fs *flag.FlagSet) {
+			fs.BoolVar(&force, "f", false, "download even if unfinished")
+			fs.StringVar(&outfile, "o", "", "output filename")
+		},
+	},
 }
 
 type command struct {
-	name   string
-	argdoc string
-	desc   string
-	run    func(context.Context, []string) error
-	flags  *flag.FlagSet
+	name     string
+	argdoc   string
+	desc     string
+	run      func(context.Context, []string) error
+	flagdefs func(*flag.FlagSet)
 }
 
 func main() {
 	flag.Usage = func() {
 		out := flag.CommandLine.Output()
-		fmt.Fprintln(out, "usage:")
+		fmt.Fprintln(out, "Usage:")
 		for _, cmd := range commands {
-			fmt.Fprintf(out, "  ejobs %s %s\n", cmd.name, cmd.argdoc)
+			fmt.Println()
+			fmt.Fprintf(out, "ejobs %s %s\n", cmd.name, cmd.argdoc)
 			fmt.Fprintf(out, "\t%s\n", cmd.desc)
+			if cmd.flagdefs != nil {
+				fs := flag.NewFlagSet(cmd.name, flag.ContinueOnError)
+				cmd.flagdefs(fs)
+				fs.Usage()
+			}
 		}
 		fmt.Fprintln(out, "\ncommon flags:")
 		flag.PrintDefaults()
@@ -116,7 +131,16 @@
 	name := flag.Arg(0)
 	for _, cmd := range commands {
 		if cmd.name == name {
-			return cmd.run(ctx, flag.Args()[1:])
+			args := flag.Args()[1:]
+			if cmd.flagdefs != nil {
+				fs := flag.NewFlagSet(cmd.name, flag.ContinueOnError)
+				cmd.flagdefs(fs)
+				if err := fs.Parse(args); err != nil {
+					return err
+				}
+				args = fs.Args()
+			}
+			return cmd.run(ctx, args)
 		}
 	}
 	return fmt.Errorf("unknown command %q", name)
@@ -204,14 +228,11 @@
 }
 
 func doWait(ctx context.Context, args []string) error {
-	if err := waitFlagSet.Parse(args); err != nil {
-		return err
-	}
-	if waitFlagSet.NArg() != 1 {
+	if len(args) != 1 {
 		return errors.New("wrong number of args: want [-i DURATION] JOB_ID")
 	}
-	jobID := waitFlagSet.Arg(0)
-	sleepInterval := *waitInterval
+	jobID := args[0]
+	sleepInterval := waitInterval
 	displayUpdates := sleepInterval != 0
 	if sleepInterval < time.Second {
 		sleepInterval = time.Second
@@ -242,13 +263,10 @@
 
 func doStart(ctx context.Context, args []string) error {
 	// Validate arguments.
-	if err := startFlagSet.Parse(args); err != nil {
-		return err
-	}
-	if startFlagSet.NArg() == 0 {
+	if len(args) == 0 {
 		return errors.New("wrong number of args: want [-min N] BINARY [ARG1 ARG2 ...]")
 	}
-	binaryFile := startFlagSet.Arg(0)
+	binaryFile := args[0]
 	if fi, err := os.Stat(binaryFile); err != nil {
 		if errors.Is(err, os.ErrNotExist) {
 			return fmt.Errorf("%s does not exist", binaryFile)
@@ -260,8 +278,7 @@
 		return err
 	}
 	// Check args to binary for whitespace, which we don't support.
-
-	binaryArgs := startFlagSet.Args()[1:]
+	binaryArgs := args[1:]
 	for _, arg := range binaryArgs {
 		if strings.IndexFunc(arg, unicode.IsSpace) >= 0 {
 			return fmt.Errorf("arg %q contains whitespace: not supported", arg)
@@ -282,8 +299,8 @@
 	if len(binaryArgs) > 0 {
 		u += fmt.Sprintf("&args=%s", url.QueryEscape(strings.Join(binaryArgs, " ")))
 	}
-	if *minImporters >= 0 {
-		u += fmt.Sprintf("&min=%d", *minImporters)
+	if minImporters >= 0 {
+		u += fmt.Sprintf("&min=%d", minImporters)
 	}
 	if *dryRun {
 		fmt.Printf("dryrun: GET %s\n", u)
@@ -434,14 +451,10 @@
 }
 
 func doResults(ctx context.Context, args []string) (err error) {
-	fs := resultsFlagSet
-	if err := fs.Parse(args); err != nil {
-		return err
-	}
-	if fs.NArg() == 0 {
+	if len(args) == 0 {
 		return errors.New("wrong number of args: want [-f] [-o FILE.json] JOB_ID")
 	}
-	jobID := fs.Arg(0)
+	jobID := args[0]
 	ts, err := identityTokenSource(ctx)
 	if err != nil {
 		return err
@@ -451,7 +464,7 @@
 		return err
 	}
 	done := job.NumFinished()
-	if !*force && done < job.NumEnqueued {
+	if !force && done < job.NumEnqueued {
 		return fmt.Errorf("job not finished (%d/%d completed); use -f for partial results", done, job.NumEnqueued)
 	}
 	results, err := requestJSON[jobs.Results](ctx, "jobs/results?jobid="+jobID, ts)
@@ -459,8 +472,8 @@
 		return err
 	}
 	out := os.Stdout
-	if *outfile != "" {
-		out, err = os.Create(*outfile)
+	if outfile != "" {
+		out, err = os.Create(outfile)
 		if err != nil {
 			return err
 		}