flag: add Func

Fixes #39557

Change-Id: Ida578f7484335e8c6bf927255f75377eda63b563
GitHub-Last-Rev: b97294f7669c24011e5b093179d65636512a84cd
GitHub-Pull-Request: golang/go#39880
Reviewed-on: https://go-review.googlesource.com/c/go/+/240014
Reviewed-by: Russ Cox <rsc@golang.org>
Trust: Ian Lance Taylor <iant@golang.org>
diff --git a/src/flag/example_func_test.go b/src/flag/example_func_test.go
new file mode 100644
index 0000000..7c30c5e
--- /dev/null
+++ b/src/flag/example_func_test.go
@@ -0,0 +1,41 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package flag_test
+
+import (
+	"errors"
+	"flag"
+	"fmt"
+	"net"
+	"os"
+)
+
+func ExampleFunc() {
+	fs := flag.NewFlagSet("ExampleFunc", flag.ContinueOnError)
+	fs.SetOutput(os.Stdout)
+	var ip net.IP
+	fs.Func("ip", "`IP address` to parse", func(s string) error {
+		ip = net.ParseIP(s)
+		if ip == nil {
+			return errors.New("could not parse IP")
+		}
+		return nil
+	})
+	fs.Parse([]string{"-ip", "127.0.0.1"})
+	fmt.Printf("{ip: %v, loopback: %t}\n\n", ip, ip.IsLoopback())
+
+	// 256 is not a valid IPv4 component
+	fs.Parse([]string{"-ip", "256.0.0.1"})
+	fmt.Printf("{ip: %v, loopback: %t}\n\n", ip, ip.IsLoopback())
+
+	// Output:
+	// {ip: 127.0.0.1, loopback: true}
+	//
+	// invalid value "256.0.0.1" for flag -ip: could not parse IP
+	// Usage of ExampleFunc:
+	//   -ip IP address
+	//     	IP address to parse
+	// {ip: <nil>, loopback: false}
+}
diff --git a/src/flag/flag.go b/src/flag/flag.go
index 286bba68..a8485f0 100644
--- a/src/flag/flag.go
+++ b/src/flag/flag.go
@@ -278,6 +278,12 @@
 
 func (d *durationValue) String() string { return (*time.Duration)(d).String() }
 
+type funcValue func(string) error
+
+func (f funcValue) Set(s string) error { return f(s) }
+
+func (f funcValue) String() string { return "" }
+
 // Value is the interface to the dynamic value stored in a flag.
 // (The default value is represented as a string.)
 //
@@ -296,7 +302,7 @@
 // Getter is an interface that allows the contents of a Value to be retrieved.
 // It wraps the Value interface, rather than being part of it, because it
 // appeared after Go 1 and its compatibility rules. All Value types provided
-// by this package satisfy the Getter interface.
+// by this package satisfy the Getter interface, except the type used by Func.
 type Getter interface {
 	Value
 	Get() interface{}
@@ -830,6 +836,20 @@
 	return CommandLine.Duration(name, value, usage)
 }
 
+// Func defines a flag with the specified name and usage string.
+// Each time the flag is seen, fn is called with the value of the flag.
+// If fn returns a non-nil error, it will be treated as a flag value parsing error.
+func (f *FlagSet) Func(name, usage string, fn func(string) error) {
+	f.Var(funcValue(fn), name, usage)
+}
+
+// Func defines a flag with the specified name and usage string.
+// Each time the flag is seen, fn is called with the value of the flag.
+// If fn returns a non-nil error, it will be treated as a flag value parsing error.
+func Func(name, usage string, fn func(string) error) {
+	CommandLine.Func(name, usage, fn)
+}
+
 // Var defines a flag with the specified name and usage string. The type and
 // value of the flag are represented by the first argument, of type Value, which
 // typically holds a user-defined implementation of Value. For instance, the
diff --git a/src/flag/flag_test.go b/src/flag/flag_test.go
index a01a5e4..2793064 100644
--- a/src/flag/flag_test.go
+++ b/src/flag/flag_test.go
@@ -38,6 +38,7 @@
 	String("test_string", "0", "string value")
 	Float64("test_float64", 0, "float64 value")
 	Duration("test_duration", 0, "time.Duration value")
+	Func("test_func", "func value", func(string) error { return nil })
 
 	m := make(map[string]*Flag)
 	desired := "0"
@@ -52,6 +53,8 @@
 				ok = true
 			case f.Name == "test_duration" && f.Value.String() == desired+"s":
 				ok = true
+			case f.Name == "test_func" && f.Value.String() == "":
+				ok = true
 			}
 			if !ok {
 				t.Error("Visit: bad value", f.Value.String(), "for", f.Name)
@@ -59,7 +62,7 @@
 		}
 	}
 	VisitAll(visitor)
-	if len(m) != 8 {
+	if len(m) != 9 {
 		t.Error("VisitAll misses some flags")
 		for k, v := range m {
 			t.Log(k, *v)
@@ -82,9 +85,10 @@
 	Set("test_string", "1")
 	Set("test_float64", "1")
 	Set("test_duration", "1s")
+	Set("test_func", "1")
 	desired = "1"
 	Visit(visitor)
-	if len(m) != 8 {
+	if len(m) != 9 {
 		t.Error("Visit fails after set")
 		for k, v := range m {
 			t.Log(k, *v)
@@ -257,6 +261,48 @@
 	}
 }
 
+func TestUserDefinedFunc(t *testing.T) {
+	var flags FlagSet
+	flags.Init("test", ContinueOnError)
+	var ss []string
+	flags.Func("v", "usage", func(s string) error {
+		ss = append(ss, s)
+		return nil
+	})
+	if err := flags.Parse([]string{"-v", "1", "-v", "2", "-v=3"}); err != nil {
+		t.Error(err)
+	}
+	if len(ss) != 3 {
+		t.Fatal("expected 3 args; got ", len(ss))
+	}
+	expect := "[1 2 3]"
+	if got := fmt.Sprint(ss); got != expect {
+		t.Errorf("expected value %q got %q", expect, got)
+	}
+	// test usage
+	var buf strings.Builder
+	flags.SetOutput(&buf)
+	flags.Parse([]string{"-h"})
+	if usage := buf.String(); !strings.Contains(usage, "usage") {
+		t.Errorf("usage string not included: %q", usage)
+	}
+	// test Func error
+	flags = *NewFlagSet("test", ContinueOnError)
+	flags.Func("v", "usage", func(s string) error {
+		return fmt.Errorf("test error")
+	})
+	// flag not set, so no error
+	if err := flags.Parse(nil); err != nil {
+		t.Error(err)
+	}
+	// flag set, expect error
+	if err := flags.Parse([]string{"-v", "1"}); err == nil {
+		t.Error("expected error; got none")
+	} else if errMsg := err.Error(); !strings.Contains(errMsg, "test error") {
+		t.Errorf(`error should contain "test error"; got %q`, errMsg)
+	}
+}
+
 func TestUserDefinedForCommandLine(t *testing.T) {
 	const help = "HELP"
 	var result string