flag: allow a FlagSet to not write to os.Stderr
Fixes #2747
R=golang-dev, gri, r, rogpeppe, r
CC=golang-dev
https://golang.org/cl/5564065
diff --git a/src/pkg/flag/flag.go b/src/pkg/flag/flag.go
index 964f554..1719af8 100644
--- a/src/pkg/flag/flag.go
+++ b/src/pkg/flag/flag.go
@@ -62,6 +62,7 @@
import (
"errors"
"fmt"
+ "io"
"os"
"sort"
"strconv"
@@ -228,6 +229,7 @@
args []string // arguments after flags
exitOnError bool // does the program exit if there's an error?
errorHandling ErrorHandling
+ output io.Writer // nil means stderr; use out() accessor
}
// A Flag represents the state of a flag.
@@ -254,6 +256,19 @@
return result
}
+func (f *FlagSet) out() io.Writer {
+ if f.output == nil {
+ return os.Stderr
+ }
+ return f.output
+}
+
+// SetOutput sets the destination for usage and error messages.
+// If output is nil, os.Stderr is used.
+func (f *FlagSet) SetOutput(output io.Writer) {
+ f.output = output
+}
+
// VisitAll visits the flags in lexicographical order, calling fn for each.
// It visits all flags, even those not set.
func (f *FlagSet) VisitAll(fn func(*Flag)) {
@@ -315,15 +330,16 @@
return commandLine.Set(name, value)
}
-// PrintDefaults prints to standard error the default values of all defined flags in the set.
+// PrintDefaults prints, to standard error unless configured
+// otherwise, the default values of all defined flags in the set.
func (f *FlagSet) PrintDefaults() {
- f.VisitAll(func(f *Flag) {
+ f.VisitAll(func(flag *Flag) {
format := " -%s=%s: %s\n"
- if _, ok := f.Value.(*stringValue); ok {
+ if _, ok := flag.Value.(*stringValue); ok {
// put quotes on the value
format = " -%s=%q: %s\n"
}
- fmt.Fprintf(os.Stderr, format, f.Name, f.DefValue, f.Usage)
+ fmt.Fprintf(f.out(), format, flag.Name, flag.DefValue, flag.Usage)
})
}
@@ -334,7 +350,7 @@
// defaultUsage is the default function to print a usage message.
func defaultUsage(f *FlagSet) {
- fmt.Fprintf(os.Stderr, "Usage of %s:\n", f.name)
+ fmt.Fprintf(f.out(), "Usage of %s:\n", f.name)
f.PrintDefaults()
}
@@ -601,7 +617,7 @@
flag := &Flag{name, usage, value, value.String()}
_, alreadythere := f.formal[name]
if alreadythere {
- fmt.Fprintf(os.Stderr, "%s flag redefined: %s\n", f.name, name)
+ fmt.Fprintf(f.out(), "%s flag redefined: %s\n", f.name, name)
panic("flag redefinition") // Happens only if flags are declared with identical names
}
if f.formal == nil {
@@ -624,7 +640,7 @@
// returns the error.
func (f *FlagSet) failf(format string, a ...interface{}) error {
err := fmt.Errorf(format, a...)
- fmt.Fprintln(os.Stderr, err)
+ fmt.Fprintln(f.out(), err)
f.usage()
return err
}
diff --git a/src/pkg/flag/flag_test.go b/src/pkg/flag/flag_test.go
index 698c15f..a9561f2 100644
--- a/src/pkg/flag/flag_test.go
+++ b/src/pkg/flag/flag_test.go
@@ -5,10 +5,12 @@
package flag_test
import (
+ "bytes"
. "flag"
"fmt"
"os"
"sort"
+ "strings"
"testing"
"time"
)
@@ -206,6 +208,17 @@
}
}
+func TestSetOutput(t *testing.T) {
+ var flags FlagSet
+ var buf bytes.Buffer
+ flags.SetOutput(&buf)
+ flags.Init("test", ContinueOnError)
+ flags.Parse([]string{"-unknown"})
+ if out := buf.String(); !strings.Contains(out, "-unknown") {
+ t.Logf("expected output mentioning unknown; got %q", out)
+ }
+}
+
// This tests that one can reset the flags. This still works but not well, and is
// superseded by FlagSet.
func TestChangingArgs(t *testing.T) {