cmd/log: move stackdriver logging to its own package

This change breaks out the stackdriver logging to its own package. It
also only does logging of trace ids and labels to the stackdriver
logger, to keep the logging interface simple.

For golang/go#61399

Change-Id: I9a25d0c6391d1667fe476e5fdc30fc057f07c40f
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/515375
TryBot-Result: Gopher Robot <gobot@golang.org>
kokoro-CI: kokoro <noreply+kokoro@google.com>
Run-TryBot: Michael Matloob <matloob@golang.org>
Reviewed-by: Jamal Carvalho <jamal@golang.org>
diff --git a/cmd/internal/cmdconfig/cmdconfig.go b/cmd/internal/cmdconfig/cmdconfig.go
index f26ceb7..e5a7ee2 100644
--- a/cmd/internal/cmdconfig/cmdconfig.go
+++ b/cmd/internal/cmdconfig/cmdconfig.go
@@ -22,6 +22,7 @@
 	"golang.org/x/pkgsite/internal/database"
 	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/log"
+	"golang.org/x/pkgsite/internal/log/stackdriverlogger"
 	"golang.org/x/pkgsite/internal/middleware"
 	"golang.org/x/pkgsite/internal/postgres"
 )
@@ -36,11 +37,12 @@
 				"k8s-pod/app": cfg.Application(),
 			}))
 		}
-		logger, err := log.UseStackdriver(ctx, logName, cfg.ProjectID, opts)
+		logger, parent, err := stackdriverlogger.New(ctx, logName, cfg.ProjectID, opts)
+		log.Use(logger)
 		if err != nil {
 			log.Fatal(ctx, err)
 		}
-		return logger
+		return parent
 	}
 	return middleware.LocalLogger{}
 }
diff --git a/internal/log/log.go b/internal/log/log.go
index ea38b9b..60c9371 100644
--- a/internal/log/log.go
+++ b/internal/log/log.go
@@ -7,35 +7,57 @@
 
 import (
 	"context"
-	"errors"
 	"fmt"
 	"log"
 	"os"
 	"strings"
 	"sync"
 
-	"cloud.google.com/go/logging"
-	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/experiment"
 )
 
+type Severity int
+
+const (
+	SeverityDefault = Severity(iota)
+	SeverityDebug
+	SeverityInfo
+	SeverityWarning
+	SeverityError
+	SeverityCritical
+)
+
+func (s Severity) String() string {
+	switch s {
+	case SeverityDefault:
+		return "Default"
+	case SeverityDebug:
+		return "Debug"
+	case SeverityInfo:
+		return "Info"
+	case SeverityWarning:
+		return "Warning"
+	case SeverityError:
+		return "Error"
+	case SeverityCritical:
+		return "Critical"
+	default:
+		return fmt.Sprint(int(s))
+	}
+}
+
+type Logger interface {
+	Log(ctx context.Context, s Severity, payload any)
+	Flush()
+}
+
 var (
 	mu     sync.Mutex
-	logger interface {
-		log(context.Context, logging.Severity, any)
-	} = stdlibLogger{}
+	logger Logger = stdlibLogger{}
 
 	// currentLevel holds current log level.
 	// No logs will be printed below currentLevel.
-	currentLevel = logging.Default
-)
-
-type (
-	// traceIDKey is the type of the context key for trace IDs.
-	traceIDKey struct{}
-
-	// labelsKey is the type of the context key for labels.
-	labelsKey struct{}
+	currentLevel = SeverityDefault
 )
 
 // Set the log level
@@ -45,79 +67,17 @@
 	currentLevel = toLevel(v)
 }
 
-func getLevel() logging.Severity {
+func getLevel() Severity {
 	mu.Lock()
 	defer mu.Unlock()
 	return currentLevel
 }
 
-// NewContextWithTraceID creates a new context from ctx that adds the trace ID.
-func NewContextWithTraceID(ctx context.Context, traceID string) context.Context {
-	return context.WithValue(ctx, traceIDKey{}, traceID)
-}
-
-// NewContextWithLabel creates anew context from ctx that adds a label that will
-// appear in the log entry.
-func NewContextWithLabel(ctx context.Context, key, value string) context.Context {
-	oldLabels, _ := ctx.Value(labelsKey{}).(map[string]string)
-	// Copy the labels, to preserve immutability of contexts.
-	newLabels := map[string]string{}
-	for k, v := range oldLabels {
-		newLabels[k] = v
-	}
-	newLabels[key] = value
-	return context.WithValue(ctx, labelsKey{}, newLabels)
-}
-
-// stackdriverLogger logs to GCP Stackdriver.
-type stackdriverLogger struct {
-	sdlogger *logging.Logger
-}
-
-func (l *stackdriverLogger) log(ctx context.Context, s logging.Severity, payload any) {
-	// Convert errors to strings, or they may serialize as the empty JSON object.
-	if err, ok := payload.(error); ok {
-		payload = err.Error()
-	}
-	traceID, _ := ctx.Value(traceIDKey{}).(string) // if not present, traceID is "", which is fine
-	labels, _ := ctx.Value(labelsKey{}).(map[string]string)
-	es := experimentString(ctx)
-	if len(es) > 0 {
-		nl := map[string]string{}
-		for k, v := range labels {
-			nl[k] = v
-		}
-		nl["experiments"] = es
-		labels = nl
-	}
-	l.sdlogger.Log(logging.Entry{
-		Severity: s,
-		Labels:   labels,
-		Payload:  payload,
-		Trace:    traceID,
-	})
-}
-
 // stdlibLogger uses the Go standard library logger.
 type stdlibLogger struct{}
 
-func init() {
-	// Log to stdout on GKE so the log messages are severity Info, rather than Error.
-	if os.Getenv("GO_DISCOVERY_ON_GKE") != "" {
-		log.SetOutput(os.Stdout)
-
-	}
-}
-
-func (stdlibLogger) log(ctx context.Context, s logging.Severity, payload any) {
+func (stdlibLogger) Log(ctx context.Context, s Severity, payload any) {
 	var extras []string
-	traceID, _ := ctx.Value(traceIDKey{}).(string) // if not present, traceID is ""
-	if traceID != "" {
-		extras = append(extras, fmt.Sprintf("traceID %s", traceID))
-	}
-	if labels, ok := ctx.Value(labelsKey{}).(map[string]string); ok {
-		extras = append(extras, fmt.Sprint(labels))
-	}
 	es := experimentString(ctx)
 	if len(es) > 0 {
 		extras = append(extras, fmt.Sprintf("experiments %s", es))
@@ -130,99 +90,79 @@
 
 }
 
+func (stdlibLogger) Flush() {}
+
 func experimentString(ctx context.Context) string {
 	return strings.Join(experiment.FromContext(ctx).Active(), ", ")
 }
 
-// UseStackdriver switches from the default stdlib logger to a Stackdriver
-// logger. It assumes config.Init has been called. UseStackdriver returns a
-// "parent" *logging.Logger that should be used to log the start and end of a
-// request. It also creates and remembers internally a "child" logger that will
-// be used to log within a request. The two loggers are necessary to get request-scoped
-// logs in Stackdriver.
-// See https://cloud.google.com/appengine/docs/standard/go/writing-application-logs.
-//
-// UseStackdriver can only be called once. If it is called a second time, it returns an error.
-func UseStackdriver(ctx context.Context, logName, projectID string, opts []logging.LoggerOption) (_ *logging.Logger, err error) {
-	defer derrors.Wrap(&err, "UseStackdriver(ctx, %q)", logName)
-	client, err := logging.NewClient(ctx, projectID)
-	if err != nil {
-		return nil, err
-	}
-	parent := client.Logger(logName, opts...)
-	child := client.Logger(logName+"-child", opts...)
+func Use(l Logger) {
 	mu.Lock()
 	defer mu.Unlock()
-	if _, ok := logger.(*stackdriverLogger); ok {
-		return nil, errors.New("already called once")
-	}
-	logger = &stackdriverLogger{child}
-	return parent, nil
+	logger = l
 }
 
 // Infof logs a formatted string at the Info level.
 func Infof(ctx context.Context, format string, args ...any) {
-	logf(ctx, logging.Info, format, args)
+	logf(ctx, SeverityInfo, format, args)
 }
 
 // Warningf logs a formatted string at the Warning level.
 func Warningf(ctx context.Context, format string, args ...any) {
-	logf(ctx, logging.Warning, format, args)
+	logf(ctx, SeverityWarning, format, args)
 }
 
 // Errorf logs a formatted string at the Error level.
 func Errorf(ctx context.Context, format string, args ...any) {
-	logf(ctx, logging.Error, format, args)
+	logf(ctx, SeverityError, format, args)
 }
 
 // Debugf logs a formatted string at the Debug level.
 func Debugf(ctx context.Context, format string, args ...any) {
-	logf(ctx, logging.Debug, format, args)
+	logf(ctx, SeverityDebug, format, args)
 }
 
 // Fatalf logs formatted string at the Critical level followed by exiting the program.
 func Fatalf(ctx context.Context, format string, args ...any) {
-	logf(ctx, logging.Critical, format, args)
+	logf(ctx, SeverityCritical, format, args)
 	die()
 }
 
-func logf(ctx context.Context, s logging.Severity, format string, args []any) {
+func logf(ctx context.Context, s Severity, format string, args []any) {
 	doLog(ctx, s, fmt.Sprintf(format, args...))
 }
 
 // Info logs arg, which can be a string or a struct, at the Info level.
-func Info(ctx context.Context, arg any) { doLog(ctx, logging.Info, arg) }
+func Info(ctx context.Context, arg any) { doLog(ctx, SeverityInfo, arg) }
 
 // Warning logs arg, which can be a string or a struct, at the Warning level.
-func Warning(ctx context.Context, arg any) { doLog(ctx, logging.Warning, arg) }
+func Warning(ctx context.Context, arg any) { doLog(ctx, SeverityWarning, arg) }
 
 // Error logs arg, which can be a string or a struct, at the Error level.
-func Error(ctx context.Context, arg any) { doLog(ctx, logging.Error, arg) }
+func Error(ctx context.Context, arg any) { doLog(ctx, SeverityError, arg) }
 
 // Debug logs arg, which can be a string or a struct, at the Debug level.
-func Debug(ctx context.Context, arg any) { doLog(ctx, logging.Debug, arg) }
+func Debug(ctx context.Context, arg any) { doLog(ctx, SeverityDebug, arg) }
 
 // Fatal logs arg, which can be a string or a struct, at the Critical level followed by exiting the program.
 func Fatal(ctx context.Context, arg any) {
-	doLog(ctx, logging.Critical, arg)
+	doLog(ctx, SeverityCritical, arg)
 	die()
 }
 
-func doLog(ctx context.Context, s logging.Severity, payload any) {
+func doLog(ctx context.Context, s Severity, payload any) {
 	if getLevel() > s {
 		return
 	}
 	mu.Lock()
 	l := logger
 	mu.Unlock()
-	l.log(ctx, s, payload)
+	l.Log(ctx, s, payload)
 }
 
 func die() {
 	mu.Lock()
-	if sl, ok := logger.(*stackdriverLogger); ok {
-		sl.sdlogger.Flush()
-	}
+	logger.Flush()
 	mu.Unlock()
 	os.Exit(1)
 }
@@ -230,26 +170,26 @@
 // toLevel returns the logging.Severity for a given string.
 // Possible input values are "", "debug", "info", "warning", "error", "fatal".
 // In case of invalid string input, it maps to DefaultLevel.
-func toLevel(v string) logging.Severity {
+func toLevel(v string) Severity {
 	v = strings.ToLower(v)
 
 	switch v {
 	case "":
 		// default log level will print everything.
-		return logging.Default
+		return SeverityDefault
 	case "debug":
-		return logging.Debug
+		return SeverityDebug
 	case "info":
-		return logging.Info
+		return SeverityInfo
 	case "warning":
-		return logging.Warning
+		return SeverityWarning
 	case "error":
-		return logging.Error
+		return SeverityError
 	case "fatal":
-		return logging.Critical
+		return SeverityCritical
 	}
 
 	// Default log level in case of invalid input.
 	log.Printf("Error: %s is invalid LogLevel. Possible values are [debug, info, warning, error, fatal]", v)
-	return logging.Default
+	return SeverityDefault
 }
diff --git a/internal/log/log_test.go b/internal/log/log_test.go
index d655875..791bd77 100644
--- a/internal/log/log_test.go
+++ b/internal/log/log_test.go
@@ -9,8 +9,6 @@
 	"fmt"
 	"strings"
 	"testing"
-
-	"cloud.google.com/go/logging"
 )
 
 const (
@@ -27,15 +25,15 @@
 	tests := []struct {
 		name      string
 		newLevel  string
-		wantLevel logging.Severity
+		wantLevel Severity
 	}{
-		{name: "default level", newLevel: "", wantLevel: logging.Default},
-		{name: "invalid level", newLevel: "xyz", wantLevel: logging.Default},
-		{name: "debug level", newLevel: "debug", wantLevel: logging.Debug},
-		{name: "info level", newLevel: "info", wantLevel: logging.Info},
-		{name: "warning level", newLevel: "warning", wantLevel: logging.Warning},
-		{name: "error level", newLevel: "error", wantLevel: logging.Error},
-		{name: "fatal level", newLevel: "fatal", wantLevel: logging.Critical},
+		{name: "default level", newLevel: "", wantLevel: SeverityDefault},
+		{name: "invalid level", newLevel: "xyz", wantLevel: SeverityDefault},
+		{name: "debug level", newLevel: "debug", wantLevel: SeverityDebug},
+		{name: "info level", newLevel: "info", wantLevel: SeverityInfo},
+		{name: "warning level", newLevel: "warning", wantLevel: SeverityWarning},
+		{name: "error level", newLevel: "error", wantLevel: SeverityError},
+		{name: "fatal level", newLevel: "fatal", wantLevel: SeverityCritical},
 	}
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
@@ -118,6 +116,8 @@
 	logs string
 }
 
-func (l *mockLogger) log(ctx context.Context, s logging.Severity, payload any) {
+func (l *mockLogger) Log(ctx context.Context, s Severity, payload any) {
 	l.logs += fmt.Sprintf("%s: %+v", s, payload)
 }
+
+func (l *mockLogger) Flush() {}
diff --git a/internal/log/stackdriverlogger/log.go b/internal/log/stackdriverlogger/log.go
new file mode 100644
index 0000000..184e858
--- /dev/null
+++ b/internal/log/stackdriverlogger/log.go
@@ -0,0 +1,146 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package log supports structured and unstructured logging with levels
+// to GCP stackdriver.
+package stackdriverlogger
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	stdlog "log"
+	"os"
+	"strings"
+	"sync"
+
+	"cloud.google.com/go/logging"
+	"golang.org/x/pkgsite/internal/derrors"
+	"golang.org/x/pkgsite/internal/experiment"
+	"golang.org/x/pkgsite/internal/log"
+)
+
+func init() {
+	// Log to stdout on GKE so the log messages are severity Info, rather than Error.
+	if os.Getenv("GO_DISCOVERY_ON_GKE") != "" {
+		// Question for the reviewer. Was this meant to be done for cmd/pkgsite? This
+		// package won't be depended on by cmd/pkgsite, so that behavior will change, but
+		// we don't want the core to have knowledge of the GO_DISCOVERY_ON_GKE variable.
+		stdlog.SetOutput(os.Stdout)
+	}
+}
+
+type (
+	// traceIDKey is the type of the context key for trace IDs.
+	traceIDKey struct{}
+
+	// labelsKey is the type of the context key for labels.
+	labelsKey struct{}
+)
+
+// NewContextWithTraceID creates a new context from ctx that adds the trace ID.
+func NewContextWithTraceID(ctx context.Context, traceID string) context.Context {
+	return context.WithValue(ctx, traceIDKey{}, traceID)
+}
+
+// NewContextWithLabel creates a new context from ctx that adds a label that will
+// appear in the log entry.
+func NewContextWithLabel(ctx context.Context, key, value string) context.Context {
+	oldLabels, _ := ctx.Value(labelsKey{}).(map[string]string)
+	// Copy the labels, to preserve immutability of contexts.
+	newLabels := map[string]string{}
+	for k, v := range oldLabels {
+		newLabels[k] = v
+	}
+	newLabels[key] = value
+	return context.WithValue(ctx, labelsKey{}, newLabels)
+}
+
+// logger logs to GCP Stackdriver.
+type logger struct {
+	sdlogger *logging.Logger
+}
+
+func experimentString(ctx context.Context) string {
+	return strings.Join(experiment.FromContext(ctx).Active(), ", ")
+}
+
+func stackdriverSeverity(s log.Severity) logging.Severity {
+	switch s {
+	case log.SeverityDefault:
+		return logging.Default
+	case log.SeverityDebug:
+		return logging.Debug
+	case log.SeverityInfo:
+		return logging.Info
+	case log.SeverityWarning:
+		return logging.Warning
+	case log.SeverityError:
+		return logging.Error
+	case log.SeverityCritical:
+		return logging.Critical
+	default:
+		panic(fmt.Errorf("unknown severity: %v", s))
+	}
+}
+
+func (l *logger) Log(ctx context.Context, s log.Severity, payload any) {
+	// Convert errors to strings, or they may serialize as the empty JSON object.
+	if err, ok := payload.(error); ok {
+		payload = err.Error()
+	}
+	traceID, _ := ctx.Value(traceIDKey{}).(string) // if not present, traceID is "", which is fine
+	labels, _ := ctx.Value(labelsKey{}).(map[string]string)
+	es := experimentString(ctx)
+	if len(es) > 0 {
+		nl := map[string]string{}
+		for k, v := range labels {
+			nl[k] = v
+		}
+		nl["experiments"] = es
+		labels = nl
+	}
+	l.sdlogger.Log(logging.Entry{
+		Severity: stackdriverSeverity(s),
+		Labels:   labels,
+		Payload:  payload,
+		Trace:    traceID,
+	})
+}
+
+func (l *logger) Flush() {
+	l.sdlogger.Flush()
+}
+
+var (
+	mu            sync.Mutex
+	alreadyCalled bool
+)
+
+// New creates a new Logger that logs to Stackdriver.
+// It assumes config.Init has been called. New returns a
+// "parent" *logging.Logger that should be used to log the start and end of a
+// request. It also creates and remembers internally a "child" log.Logger that will
+// be used to log within a request. The child logger should be passed to log.Use to
+// forward the log package's logging calls to the child logger.
+// The two loggers are necessary to get request-scoped  logs in Stackdriver.
+// See https://cloud.google.com/appengine/docs/standard/go/writing-application-logs.
+//
+// New can only be called once. If it is called a second time, it returns an error.
+func New(ctx context.Context, logName, projectID string, opts []logging.LoggerOption) (_ log.Logger, _ *logging.Logger, err error) {
+	defer derrors.Wrap(&err, "New(ctx, %q)", logName)
+	client, err := logging.NewClient(ctx, projectID)
+	if err != nil {
+		return nil, nil, err
+	}
+	parent := client.Logger(logName, opts...)
+	child := client.Logger(logName+"-child", opts...)
+	mu.Lock()
+	defer mu.Unlock()
+	if alreadyCalled {
+		return nil, nil, errors.New("already called once")
+	}
+	alreadyCalled = true
+	return &logger{child}, parent, nil
+}
diff --git a/internal/middleware/requestlog.go b/internal/middleware/requestlog.go
index 2aa948b..93b4267 100644
--- a/internal/middleware/requestlog.go
+++ b/internal/middleware/requestlog.go
@@ -14,6 +14,7 @@
 
 	"cloud.google.com/go/logging"
 	"golang.org/x/pkgsite/internal/log"
+	"golang.org/x/pkgsite/internal/log/stackdriverlogger"
 )
 
 // Logger is the interface used to write request logs to GCP.
@@ -72,7 +73,7 @@
 		Trace:    traceID,
 	})
 	w2 := &responseWriter{ResponseWriter: w}
-	h.delegate.ServeHTTP(w2, r.WithContext(log.NewContextWithTraceID(r.Context(), traceID)))
+	h.delegate.ServeHTTP(w2, r.WithContext(stackdriverlogger.NewContextWithTraceID(r.Context(), traceID)))
 	s := severity
 	if w2.status == http.StatusServiceUnavailable {
 		// load shedding is a warning, not an error
diff --git a/internal/worker/fetch.go b/internal/worker/fetch.go
index 14d99ca..28bfcc8 100644
--- a/internal/worker/fetch.go
+++ b/internal/worker/fetch.go
@@ -29,6 +29,7 @@
 	"golang.org/x/pkgsite/internal/experiment"
 	"golang.org/x/pkgsite/internal/fetch"
 	"golang.org/x/pkgsite/internal/log"
+	"golang.org/x/pkgsite/internal/log/stackdriverlogger"
 	"golang.org/x/pkgsite/internal/postgres"
 	"golang.org/x/pkgsite/internal/proxy"
 	"golang.org/x/pkgsite/internal/source"
@@ -110,7 +111,7 @@
 	defer derrors.Wrap(&err, "FetchAndUpdateState(%q, %q, %q)", modulePath, requestedVersion, appVersionLabel)
 	tctx, span := trace.StartSpan(ctx, "FetchAndUpdateState")
 	ctx = experiment.NewContext(tctx, experiment.FromContext(ctx).Active()...)
-	ctx = log.NewContextWithLabel(ctx, "fetch", modulePath+"@"+requestedVersion)
+	ctx = stackdriverlogger.NewContextWithLabel(ctx, "fetch", modulePath+"@"+requestedVersion)
 
 	start := time.Now()
 	var nPackages int64