google: add Credentials.UniverseDomainProvider

* move MDS universe retrieval within Compute credentials

Change-Id: I847d2075ca11bde998a06220307626e902230c23
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/575936
Reviewed-by: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/google/default.go b/google/default.go
index 4b55b3f..df95835 100644
--- a/google/default.go
+++ b/google/default.go
@@ -42,6 +42,17 @@
 	// running on Google Cloud Platform.
 	JSON []byte
 
+	// UniverseDomainProvider returns the default service domain for a given
+	// Cloud universe. Optional.
+	//
+	// On GCE, UniverseDomainProvider should return the universe domain value
+	// from Google Compute Engine (GCE)'s metadata server. See also [The attached service
+	// account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
+	// If the GCE metadata server returns a 404 error, the default universe
+	// domain value should be returned. If the GCE metadata server returns an
+	// error other than 404, the error should be returned.
+	UniverseDomainProvider func() (string, error)
+
 	udMu sync.Mutex // guards universeDomain
 	// universeDomain is the default service domain for a given Cloud universe.
 	universeDomain string
@@ -64,54 +75,32 @@
 }
 
 // GetUniverseDomain returns the default service domain for a given Cloud
-// universe.
+// universe. If present, UniverseDomainProvider will be invoked and its return
+// value will be cached.
 //
 // The default value is "googleapis.com".
-//
-// It obtains the universe domain from the attached service account on GCE when
-// authenticating via the GCE metadata server. See also [The attached service
-// account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
-// If the GCE metadata server returns a 404 error, the default value is
-// returned. If the GCE metadata server returns an error other than 404, the
-// error is returned.
 func (c *Credentials) GetUniverseDomain() (string, error) {
 	c.udMu.Lock()
 	defer c.udMu.Unlock()
-	if c.universeDomain == "" && metadata.OnGCE() {
-		// If we're on Google Compute Engine, an App Engine standard second
-		// generation runtime, or App Engine flexible, use the metadata server.
-		err := c.computeUniverseDomain()
+	if c.universeDomain == "" && c.UniverseDomainProvider != nil {
+		// On Google Compute Engine, an App Engine standard second generation
+		// runtime, or App Engine flexible, use an externally provided function
+		// to request the universe domain from the metadata server.
+		ud, err := c.UniverseDomainProvider()
 		if err != nil {
 			return "", err
 		}
+		c.universeDomain = ud
 	}
-	// If not on Google Compute Engine, or in case of any non-error path in
-	// computeUniverseDomain that did not set universeDomain, set the default
-	// universe domain.
+	// If no UniverseDomainProvider (meaning not on Google Compute Engine), or
+	// in case of any (non-error) empty return value from
+	// UniverseDomainProvider, set the default universe domain.
 	if c.universeDomain == "" {
 		c.universeDomain = defaultUniverseDomain
 	}
 	return c.universeDomain, nil
 }
 
-// computeUniverseDomain fetches the default service domain for a given Cloud
-// universe from Google Compute Engine (GCE)'s metadata server. It's only valid
-// to use this method if your program is running on a GCE instance.
-func (c *Credentials) computeUniverseDomain() error {
-	var err error
-	c.universeDomain, err = metadata.Get("universe/universe_domain")
-	if err != nil {
-		if _, ok := err.(metadata.NotDefinedError); ok {
-			// http.StatusNotFound (404)
-			c.universeDomain = defaultUniverseDomain
-			return nil
-		} else {
-			return err
-		}
-	}
-	return nil
-}
-
 // DefaultCredentials is the old name of Credentials.
 //
 // Deprecated: use Credentials instead.
@@ -226,10 +215,23 @@
 	// or App Engine flexible, use the metadata server.
 	if metadata.OnGCE() {
 		id, _ := metadata.ProjectID()
+		universeDomainProvider := func() (string, error) {
+			universeDomain, err := metadata.Get("universe/universe_domain")
+			if err != nil {
+				if _, ok := err.(metadata.NotDefinedError); ok {
+					// http.StatusNotFound (404)
+					return defaultUniverseDomain, nil
+				} else {
+					return "", err
+				}
+			}
+			return universeDomain, nil
+		}
 		return &Credentials{
-			ProjectID:      id,
-			TokenSource:    computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
-			universeDomain: params.UniverseDomain,
+			ProjectID:              id,
+			TokenSource:            computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
+			UniverseDomainProvider: universeDomainProvider,
+			universeDomain:         params.UniverseDomain,
 		}, nil
 	}
 
diff --git a/google/default_test.go b/google/default_test.go
index 7352ffc..c8465e9 100644
--- a/google/default_test.go
+++ b/google/default_test.go
@@ -10,6 +10,8 @@
 	"net/http/httptest"
 	"strings"
 	"testing"
+
+	"cloud.google.com/go/compute/metadata"
 )
 
 var saJSONJWT = []byte(`{
@@ -255,9 +257,14 @@
 func TestComputeUniverseDomain(t *testing.T) {
 	universeDomainPath := "/computeMetadata/v1/universe/universe_domain"
 	universeDomainResponseBody := "example.com"
+	var requests int
 	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		requests++
 		if r.URL.Path != universeDomainPath {
-			t.Errorf("got %s, want %s", r.URL.Path, universeDomainPath)
+			t.Errorf("bad path, got %s, want %s", r.URL.Path, universeDomainPath)
+		}
+		if requests > 1 {
+			t.Errorf("too many requests, got %d, want 1", requests)
 		}
 		w.Write([]byte(universeDomainResponseBody))
 	}))
@@ -268,11 +275,19 @@
 	params := CredentialsParams{
 		Scopes: []string{scope},
 	}
+	universeDomainProvider := func() (string, error) {
+		universeDomain, err := metadata.Get("universe/universe_domain")
+		if err != nil {
+			return "", err
+		}
+		return universeDomain, nil
+	}
 	// Copied from FindDefaultCredentialsWithParams, metadata.OnGCE() = true block
 	creds := &Credentials{
-		ProjectID:      "fake_project",
-		TokenSource:    computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
-		universeDomain: params.UniverseDomain, // empty
+		ProjectID:              "fake_project",
+		TokenSource:            computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
+		UniverseDomainProvider: universeDomainProvider,
+		universeDomain:         params.UniverseDomain, // empty
 	}
 	c := make(chan bool)
 	go func() {
@@ -285,7 +300,7 @@
 		}
 		c <- true
 	}()
-	got, err := creds.GetUniverseDomain() // Second conflicting access.
+	got, err := creds.GetUniverseDomain() // Second conflicting (and potentially uncached) access.
 	<-c
 	if err != nil {
 		t.Error(err)