internal/gocommand: use semaphores to ensure correct serialization

Piped go commands cannot run concurrently, as a user may be looking at
the stdout/stderr, and we don't want to show any load concurrency
errors. We use a semaphore with a maximum number of in-flight go
commands to make sure that all running go commands have completed before
starting a piped command.

Change-Id: Ie027d9ff704fb7dd9640da06569345e89ed7a012
Reviewed-on: https://go-review.googlesource.com/c/tools/+/238059
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
diff --git a/internal/gocommand/invoke.go b/internal/gocommand/invoke.go
index ff4f9d5..83b654b 100644
--- a/internal/gocommand/invoke.go
+++ b/internal/gocommand/invoke.go
@@ -23,9 +23,24 @@
 // An Runner will run go command invocations and serialize
 // them if it sees a concurrency error.
 type Runner struct {
-	// LoadMu guards packages.Load calls and associated state.
-	loadMu         sync.Mutex
-	serializeLoads int
+	// once guards the runner initialization.
+	once sync.Once
+
+	// inFlight tracks available workers.
+	inFlight chan struct{}
+
+	// serialized guards the ability to run a go command serially,
+	// to avoid deadlocks when claiming workers.
+	serialized chan struct{}
+}
+
+const maxInFlight = 10
+
+func (runner *Runner) initialize() {
+	runner.once.Do(func() {
+		runner.inFlight = make(chan struct{}, maxInFlight)
+		runner.serialized = make(chan struct{}, 1)
+	})
 }
 
 // 1.13: go: updates to go.mod needed, but contents have changed
@@ -35,7 +50,7 @@
 // Run is a convenience wrapper around RunRaw.
 // It returns only stdout and a "friendly" error.
 func (runner *Runner) Run(ctx context.Context, inv Invocation) (*bytes.Buffer, error) {
-	stdout, _, friendly, _ := runner.runRaw(ctx, inv)
+	stdout, _, friendly, _ := runner.RunRaw(ctx, inv)
 	return stdout, friendly
 }
 
@@ -49,57 +64,67 @@
 // RunRaw runs the invocation, serializing requests only if they fight over
 // go.mod changes.
 func (runner *Runner) RunRaw(ctx context.Context, inv Invocation) (*bytes.Buffer, *bytes.Buffer, error, error) {
-	return runner.runRaw(ctx, inv)
+	// Make sure the runner is always initialized.
+	runner.initialize()
+
+	// First, try to run the go command concurrently.
+	stdout, stderr, friendlyErr, err := runner.runConcurrent(ctx, inv)
+
+	// If we encounter a load concurrency error, we need to retry serially.
+	if friendlyErr == nil || !modConcurrencyError.MatchString(friendlyErr.Error()) {
+		return stdout, stderr, friendlyErr, err
+	}
+	event.Error(ctx, "Load concurrency error, will retry serially", err)
+
+	// Run serially by calling runPiped.
+	stdout.Reset()
+	stderr.Reset()
+	friendlyErr, err = runner.runPiped(ctx, inv, stdout, stderr)
+	return stdout, stderr, friendlyErr, err
+}
+
+func (runner *Runner) runConcurrent(ctx context.Context, inv Invocation) (*bytes.Buffer, *bytes.Buffer, error, error) {
+	// Wait for 1 worker to become available.
+	select {
+	case <-ctx.Done():
+		return nil, nil, nil, ctx.Err()
+	case runner.inFlight <- struct{}{}:
+		defer func() { <-runner.inFlight }()
+	}
+
+	stdout, stderr := &bytes.Buffer{}, &bytes.Buffer{}
+	friendlyErr, err := inv.runWithFriendlyError(ctx, stdout, stderr)
+	return stdout, stderr, friendlyErr, err
 }
 
 func (runner *Runner) runPiped(ctx context.Context, inv Invocation, stdout, stderr io.Writer) (error, error) {
-	runner.loadMu.Lock()
-	runner.serializeLoads++
+	// Make sure the runner is always initialized.
+	runner.initialize()
 
-	defer func() {
-		runner.serializeLoads--
-		runner.loadMu.Unlock()
-	}()
+	// Acquire the serialization lock. This avoids deadlocks between two
+	// runPiped commands.
+	select {
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	case runner.serialized <- struct{}{}:
+		defer func() { <-runner.serialized }()
+	}
+
+	// Wait for all in-progress go commands to return before proceeding,
+	// to avoid load concurrency errors.
+	for i := 0; i < maxInFlight; i++ {
+		select {
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		case runner.inFlight <- struct{}{}:
+			// Make sure we always "return" any workers we took.
+			defer func() { <-runner.inFlight }()
+		}
+	}
 
 	return inv.runWithFriendlyError(ctx, stdout, stderr)
 }
 
-func (runner *Runner) runRaw(ctx context.Context, inv Invocation) (*bytes.Buffer, *bytes.Buffer, error, error) {
-	// We want to run invocations concurrently as much as possible. However,
-	// if go.mod updates are needed, only one can make them and the others will
-	// fail. We need to retry in those cases, but we don't want to thrash so
-	// badly we never recover. To avoid that, once we've seen one concurrency
-	// error, start serializing everything until the backlog has cleared out.
-	runner.loadMu.Lock()
-	var locked bool // If true, we hold the mutex and have incremented.
-	if runner.serializeLoads == 0 {
-		runner.loadMu.Unlock()
-	} else {
-		locked = true
-		runner.serializeLoads++
-	}
-	defer func() {
-		if locked {
-			locked = false
-			runner.serializeLoads--
-			runner.loadMu.Unlock()
-		}
-	}()
-
-	for {
-		stdout, stderr := &bytes.Buffer{}, &bytes.Buffer{}
-		friendlyErr, err := inv.runWithFriendlyError(ctx, stdout, stderr)
-		if friendlyErr == nil || !modConcurrencyError.MatchString(friendlyErr.Error()) {
-			return stdout, stderr, friendlyErr, err
-		}
-		event.Error(ctx, "Load concurrency error, will retry serially", err)
-		if !locked {
-			runner.loadMu.Lock()
-			runner.serializeLoads++
-		}
-	}
-}
-
 // An Invocation represents a call to the go command.
 type Invocation struct {
 	Verb       string