internal/secret: add JSON flag support
Some of our secrets are in JSON format so that we don't have to deal
with too many. Make a little convenience flag function for those.
For golang/go#51122.
Change-Id: Ie34828443adb9acb16249339a760d28a81ddbd20
Reviewed-on: https://go-review.googlesource.com/c/build/+/386054
Trust: Heschi Kreinick <heschi@google.com>
Run-TryBot: Heschi Kreinick <heschi@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
Auto-Submit: Heschi Kreinick <heschi@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/internal/secret/flag.go b/internal/secret/flag.go
index 6a860ca..e20e5ac 100644
--- a/internal/secret/flag.go
+++ b/internal/secret/flag.go
@@ -2,6 +2,7 @@
import (
"context"
+ "encoding/json"
"flag"
"fmt"
"strings"
@@ -18,39 +19,58 @@
DefaultProjectID string
}
+const secretSuffix = "[ specify `secret:[project name/]<secret name>` to read from Secret Manager ]"
+
// Flag declares a string flag on set that will be resolved using r.
func (r *FlagResolver) Flag(set *flag.FlagSet, name string, usage string) *string {
var value string
- suffixedUsage := usage + " [ specify `secret:[project name/]<secret name>` to read from Secret Manager ]"
+ suffixedUsage := usage + "\n" + secretSuffix
set.Func(name, suffixedUsage, func(flagValue string) error {
- if r.Client == nil || r.Context == nil {
- return fmt.Errorf("secret resolver was not initialized")
- }
- if !strings.HasPrefix(flagValue, "secret:") {
- value = flagValue
- return nil
- }
-
- secretName := strings.TrimPrefix(flagValue, "secret:")
- projectID := r.DefaultProjectID
- if parts := strings.SplitN(secretName, "/", 2); len(parts) == 2 {
- projectID, secretName = parts[0], parts[1]
- }
- if projectID == "" {
- return fmt.Errorf("missing project ID: none specified in %q, and no default set (not on GCP?)", secretName)
- }
- r, err := r.Client.AccessSecretVersion(r.Context, &secretmanagerpb.AccessSecretVersionRequest{
- Name: buildNamePath(projectID, secretName, "latest"),
- })
- if err != nil {
- return fmt.Errorf("reading secret %q from project %v failed: %v", secretName, projectID, err)
- }
- value = string(r.Payload.GetData())
- return nil
+ var err error
+ value, err = r.resolveSecret(flagValue)
+ return err
})
return &value
}
+func (r *FlagResolver) resolveSecret(flagValue string) (string, error) {
+ if r.Client == nil || r.Context == nil {
+ return "", fmt.Errorf("secret resolver was not initialized")
+ }
+ if !strings.HasPrefix(flagValue, "secret:") {
+ return flagValue, nil
+ }
+
+ secretName := strings.TrimPrefix(flagValue, "secret:")
+ projectID := r.DefaultProjectID
+ if parts := strings.SplitN(secretName, "/", 2); len(parts) == 2 {
+ projectID, secretName = parts[0], parts[1]
+ }
+ if projectID == "" {
+ return "", fmt.Errorf("missing project ID: none specified in %q, and no default set (not on GCP?)", secretName)
+ }
+ result, err := r.Client.AccessSecretVersion(r.Context, &secretmanagerpb.AccessSecretVersionRequest{
+ Name: buildNamePath(projectID, secretName, "latest"),
+ })
+ if err != nil {
+ return "", fmt.Errorf("reading secret %q from project %v failed: %v", secretName, projectID, err)
+ }
+ return string(result.Payload.GetData()), nil
+}
+
+// JSONVarFlag declares a flag on set that behaves like Flag and then
+// json.Unmarshals the resulting string into value.
+func (r *FlagResolver) JSONVarFlag(set *flag.FlagSet, value interface{}, name string, usage string) {
+ suffixedUsage := usage + "\n" + fmt.Sprintf("A JSON representation of a %T.", value) + "\n" + secretSuffix
+ set.Func(name, suffixedUsage, func(flagValue string) error {
+ stringValue, err := r.resolveSecret(flagValue)
+ if err != nil {
+ return err
+ }
+ return json.Unmarshal([]byte(stringValue), value)
+ })
+}
+
// DefaultResolver is the FlagResolver used by the convenience functions.
var DefaultResolver FlagResolver
@@ -61,6 +81,12 @@
return DefaultResolver.Flag(flag.CommandLine, name, usage)
}
+// JSONVarFlag declares a flag on flag.CommandLine that behaves like Flag
+// and then json.Unmarshals the resulting string into value.
+func JSONVarFlag(value interface{}, name string, usage string) {
+ DefaultResolver.JSONVarFlag(flag.CommandLine, value, name, usage)
+}
+
// InitFlagSupport initializes the dependencies for flags declared with Flag.
func InitFlagSupport(ctx context.Context) error {
client, err := secretmanager.NewClient(ctx)
diff --git a/internal/secret/gcp_secret_manager_test.go b/internal/secret/gcp_secret_manager_test.go
index b4671c2..5cf66f7 100644
--- a/internal/secret/gcp_secret_manager_test.go
+++ b/internal/secret/gcp_secret_manager_test.go
@@ -8,6 +8,8 @@
"context"
"flag"
"fmt"
+ "io/ioutil"
+ "reflect"
"testing"
gax "github.com/googleapis/gax-go/v2"
@@ -165,7 +167,7 @@
for _, tt := range tests {
t.Run(tt.flagVal, func(t *testing.T) {
fs := flag.NewFlagSet("", flag.ContinueOnError)
- fs.Usage = func() {} // Minimize console spam; can't prevent it entirely.
+ fs.SetOutput(ioutil.Discard)
flagVal := r.Flag(fs, "testflag", "usage")
err := fs.Parse([]string{"--testflag", tt.flagVal})
if tt.wantErr {
@@ -178,8 +180,57 @@
t.Fatalf("flag parsing failed: %v", err)
}
if *flagVal != tt.wantVal {
- t.Errorf("flag value = %q, want %q", *flagVal, "hey")
+ t.Errorf("flag value = %q, want %q", *flagVal, tt.wantVal)
}
})
}
}
+
+type jsonValue struct {
+ Foo, Bar int
+}
+
+func TestJSONFlag(t *testing.T) {
+ r := &FlagResolver{
+ Context: context.Background(),
+ Client: &fakeSecretClient{
+ accessSecretMap: map[string]string{
+ buildNamePath("project1", "secret1", "latest"): `{"Foo": 1, "Bar": 2}`,
+ buildNamePath("project1", "secret2", "latest"): `i am not json`,
+ },
+ },
+ DefaultProjectID: "project1",
+ }
+ tests := []struct {
+ flagVal string
+ wantValue *jsonValue
+ wantErr bool
+ }{
+ {"secret:secret1", &jsonValue{Foo: 1, Bar: 2}, false},
+ {"secret:secret2", nil, true},
+ {`{"Foo":0, "Bar":1}`, &jsonValue{Foo: 0, Bar: 1}, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.flagVal, func(t *testing.T) {
+ fs := flag.NewFlagSet("", flag.ContinueOnError)
+ fs.SetOutput(ioutil.Discard)
+ value := &jsonValue{}
+ r.JSONVarFlag(fs, value, "testflag", "usage")
+ err := fs.Parse([]string{"--testflag", tt.flagVal})
+ if tt.wantErr {
+ if err == nil {
+ t.Fatalf("flag parsing succeeded, should have failed")
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("flag parsing failed: %v", err)
+ }
+ if !reflect.DeepEqual(value, tt.wantValue) {
+ t.Errorf("flag value = %q, want %q", value, tt.wantValue)
+ }
+ })
+ }
+
+}