internal/secret: add JSON flag support

Some of our secrets are in JSON format so that we don't have to deal
with too many. Make a little convenience flag function for those.

For golang/go#51122.

Change-Id: Ie34828443adb9acb16249339a760d28a81ddbd20
Reviewed-on: https://go-review.googlesource.com/c/build/+/386054
Trust: Heschi Kreinick <heschi@google.com>
Run-TryBot: Heschi Kreinick <heschi@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
Auto-Submit: Heschi Kreinick <heschi@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/internal/secret/flag.go b/internal/secret/flag.go
index 6a860ca..e20e5ac 100644
--- a/internal/secret/flag.go
+++ b/internal/secret/flag.go
@@ -2,6 +2,7 @@
 
 import (
 	"context"
+	"encoding/json"
 	"flag"
 	"fmt"
 	"strings"
@@ -18,39 +19,58 @@
 	DefaultProjectID string
 }
 
+const secretSuffix = "[ specify `secret:[project name/]<secret name>` to read from Secret Manager ]"
+
 // Flag declares a string flag on set that will be resolved using r.
 func (r *FlagResolver) Flag(set *flag.FlagSet, name string, usage string) *string {
 	var value string
-	suffixedUsage := usage + " [ specify `secret:[project name/]<secret name>` to read from Secret Manager ]"
+	suffixedUsage := usage + "\n" + secretSuffix
 	set.Func(name, suffixedUsage, func(flagValue string) error {
-		if r.Client == nil || r.Context == nil {
-			return fmt.Errorf("secret resolver was not initialized")
-		}
-		if !strings.HasPrefix(flagValue, "secret:") {
-			value = flagValue
-			return nil
-		}
-
-		secretName := strings.TrimPrefix(flagValue, "secret:")
-		projectID := r.DefaultProjectID
-		if parts := strings.SplitN(secretName, "/", 2); len(parts) == 2 {
-			projectID, secretName = parts[0], parts[1]
-		}
-		if projectID == "" {
-			return fmt.Errorf("missing project ID: none specified in %q, and no default set (not on GCP?)", secretName)
-		}
-		r, err := r.Client.AccessSecretVersion(r.Context, &secretmanagerpb.AccessSecretVersionRequest{
-			Name: buildNamePath(projectID, secretName, "latest"),
-		})
-		if err != nil {
-			return fmt.Errorf("reading secret %q from project %v failed: %v", secretName, projectID, err)
-		}
-		value = string(r.Payload.GetData())
-		return nil
+		var err error
+		value, err = r.resolveSecret(flagValue)
+		return err
 	})
 	return &value
 }
 
+func (r *FlagResolver) resolveSecret(flagValue string) (string, error) {
+	if r.Client == nil || r.Context == nil {
+		return "", fmt.Errorf("secret resolver was not initialized")
+	}
+	if !strings.HasPrefix(flagValue, "secret:") {
+		return flagValue, nil
+	}
+
+	secretName := strings.TrimPrefix(flagValue, "secret:")
+	projectID := r.DefaultProjectID
+	if parts := strings.SplitN(secretName, "/", 2); len(parts) == 2 {
+		projectID, secretName = parts[0], parts[1]
+	}
+	if projectID == "" {
+		return "", fmt.Errorf("missing project ID: none specified in %q, and no default set (not on GCP?)", secretName)
+	}
+	result, err := r.Client.AccessSecretVersion(r.Context, &secretmanagerpb.AccessSecretVersionRequest{
+		Name: buildNamePath(projectID, secretName, "latest"),
+	})
+	if err != nil {
+		return "", fmt.Errorf("reading secret %q from project %v failed: %v", secretName, projectID, err)
+	}
+	return string(result.Payload.GetData()), nil
+}
+
+// JSONVarFlag declares a flag on set that behaves like Flag and then
+// json.Unmarshals the resulting string into value.
+func (r *FlagResolver) JSONVarFlag(set *flag.FlagSet, value interface{}, name string, usage string) {
+	suffixedUsage := usage + "\n" + fmt.Sprintf("A JSON representation of a %T.", value) + "\n" + secretSuffix
+	set.Func(name, suffixedUsage, func(flagValue string) error {
+		stringValue, err := r.resolveSecret(flagValue)
+		if err != nil {
+			return err
+		}
+		return json.Unmarshal([]byte(stringValue), value)
+	})
+}
+
 // DefaultResolver is the FlagResolver used by the convenience functions.
 var DefaultResolver FlagResolver
 
@@ -61,6 +81,12 @@
 	return DefaultResolver.Flag(flag.CommandLine, name, usage)
 }
 
+// JSONVarFlag declares a flag on flag.CommandLine that behaves like Flag
+// and then json.Unmarshals the resulting string into value.
+func JSONVarFlag(value interface{}, name string, usage string) {
+	DefaultResolver.JSONVarFlag(flag.CommandLine, value, name, usage)
+}
+
 // InitFlagSupport initializes the dependencies for flags declared with Flag.
 func InitFlagSupport(ctx context.Context) error {
 	client, err := secretmanager.NewClient(ctx)
diff --git a/internal/secret/gcp_secret_manager_test.go b/internal/secret/gcp_secret_manager_test.go
index b4671c2..5cf66f7 100644
--- a/internal/secret/gcp_secret_manager_test.go
+++ b/internal/secret/gcp_secret_manager_test.go
@@ -8,6 +8,8 @@
 	"context"
 	"flag"
 	"fmt"
+	"io/ioutil"
+	"reflect"
 	"testing"
 
 	gax "github.com/googleapis/gax-go/v2"
@@ -165,7 +167,7 @@
 	for _, tt := range tests {
 		t.Run(tt.flagVal, func(t *testing.T) {
 			fs := flag.NewFlagSet("", flag.ContinueOnError)
-			fs.Usage = func() {} // Minimize console spam; can't prevent it entirely.
+			fs.SetOutput(ioutil.Discard)
 			flagVal := r.Flag(fs, "testflag", "usage")
 			err := fs.Parse([]string{"--testflag", tt.flagVal})
 			if tt.wantErr {
@@ -178,8 +180,57 @@
 				t.Fatalf("flag parsing failed: %v", err)
 			}
 			if *flagVal != tt.wantVal {
-				t.Errorf("flag value = %q, want %q", *flagVal, "hey")
+				t.Errorf("flag value = %q, want %q", *flagVal, tt.wantVal)
 			}
 		})
 	}
 }
+
+type jsonValue struct {
+	Foo, Bar int
+}
+
+func TestJSONFlag(t *testing.T) {
+	r := &FlagResolver{
+		Context: context.Background(),
+		Client: &fakeSecretClient{
+			accessSecretMap: map[string]string{
+				buildNamePath("project1", "secret1", "latest"): `{"Foo": 1, "Bar": 2}`,
+				buildNamePath("project1", "secret2", "latest"): `i am not json`,
+			},
+		},
+		DefaultProjectID: "project1",
+	}
+	tests := []struct {
+		flagVal   string
+		wantValue *jsonValue
+		wantErr   bool
+	}{
+		{"secret:secret1", &jsonValue{Foo: 1, Bar: 2}, false},
+		{"secret:secret2", nil, true},
+		{`{"Foo":0, "Bar":1}`, &jsonValue{Foo: 0, Bar: 1}, false},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.flagVal, func(t *testing.T) {
+			fs := flag.NewFlagSet("", flag.ContinueOnError)
+			fs.SetOutput(ioutil.Discard)
+			value := &jsonValue{}
+			r.JSONVarFlag(fs, value, "testflag", "usage")
+			err := fs.Parse([]string{"--testflag", tt.flagVal})
+			if tt.wantErr {
+				if err == nil {
+					t.Fatalf("flag parsing succeeded, should have failed")
+				}
+				return
+			}
+			if err != nil {
+				t.Fatalf("flag parsing failed: %v", err)
+			}
+			if !reflect.DeepEqual(value, tt.wantValue) {
+				t.Errorf("flag value = %q, want %q", value, tt.wantValue)
+			}
+		})
+	}
+
+}