errgroup: add package

Package errgroup provides synchronization, error propagation, and
Context cancellation for groups of goroutines working on subtasks of a
common task.

Change-Id: Ic9e51f6f846124076bbff9d53b0f09dc7fc5f2f0
Reviewed-on: https://go-review.googlesource.com/24894
Reviewed-by: Sameer Ajmani <sameer@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/errgroup/errgroup.go b/errgroup/errgroup.go
new file mode 100644
index 0000000..533438d
--- /dev/null
+++ b/errgroup/errgroup.go
@@ -0,0 +1,67 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package errgroup provides synchronization, error propagation, and Context
+// cancelation for groups of goroutines working on subtasks of a common task.
+package errgroup
+
+import (
+	"sync"
+
+	"golang.org/x/net/context"
+)
+
+// A Group is a collection of goroutines working on subtasks that are part of
+// the same overall task.
+//
+// A zero Group is valid and does not cancel on error.
+type Group struct {
+	cancel func()
+
+	wg sync.WaitGroup
+
+	errOnce sync.Once
+	err     error
+}
+
+// WithContext returns a new Group and an associated Context derived from ctx.
+//
+// The derived Context is canceled the first time a function passed to Go
+// returns a non-nil error or the first time Wait returns, whichever occurs
+// first.
+func WithContext(ctx context.Context) (*Group, context.Context) {
+	ctx, cancel := context.WithCancel(ctx)
+	return &Group{cancel: cancel}, ctx
+}
+
+// Wait blocks until all function calls from the Go method have returned, then
+// returns the first non-nil error (if any) from them.
+func (g *Group) Wait() error {
+	g.wg.Wait()
+	if g.cancel != nil {
+		g.cancel()
+	}
+	return g.err
+}
+
+// Go calls the given function in a new goroutine.
+//
+// The first call to return a non-nil error cancels the group; its error will be
+// returned by Wait.
+func (g *Group) Go(f func() error) {
+	g.wg.Add(1)
+
+	go func() {
+		defer g.wg.Done()
+
+		if err := f(); err != nil {
+			g.errOnce.Do(func() {
+				g.err = err
+				if g.cancel != nil {
+					g.cancel()
+				}
+			})
+		}
+	}()
+}
diff --git a/errgroup/errgroup_example_md5all_test.go b/errgroup/errgroup_example_md5all_test.go
new file mode 100644
index 0000000..a6cfc8e
--- /dev/null
+++ b/errgroup/errgroup_example_md5all_test.go
@@ -0,0 +1,101 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package errgroup_test
+
+import (
+	"crypto/md5"
+	"fmt"
+	"io/ioutil"
+	"log"
+	"os"
+	"path/filepath"
+
+	"golang.org/x/net/context"
+	"golang.org/x/sync/errgroup"
+)
+
+// Pipeline demonstrates the use of a Group to implement a multi-stage
+// pipeline: a version of the MD5All function with bounded parallelism from
+// https://blog.golang.org/pipelines.
+func ExampleGroup_pipeline() {
+	m, err := MD5All(context.Background(), ".")
+	if err != nil {
+		log.Fatal(err)
+	}
+
+	for k, sum := range m {
+		fmt.Printf("%s:\t%x\n", k, sum)
+	}
+}
+
+type result struct {
+	path string
+	sum  [md5.Size]byte
+}
+
+// MD5All reads all the files in the file tree rooted at root and returns a map
+// from file path to the MD5 sum of the file's contents. If the directory walk
+// fails or any read operation fails, MD5All returns an error.
+func MD5All(ctx context.Context, root string) (map[string][md5.Size]byte, error) {
+	// ctx is canceled when MD5All calls g.Wait(). When this version of MD5All
+	// returns - even in case of error! - we know that all of the goroutines have
+	// finished and the memory they were using can be garbage-collected.
+	g, ctx := errgroup.WithContext(ctx)
+	paths := make(chan string)
+
+	g.Go(func() error {
+		defer close(paths)
+		return filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
+			if err != nil {
+				return err
+			}
+			if !info.Mode().IsRegular() {
+				return nil
+			}
+			select {
+			case paths <- path:
+			case <-ctx.Done():
+				return ctx.Err()
+			}
+			return nil
+		})
+	})
+
+	// Start a fixed number of goroutines to read and digest files.
+	c := make(chan result)
+	const numDigesters = 20
+	for i := 0; i < numDigesters; i++ {
+		g.Go(func() error {
+			for path := range paths {
+				data, err := ioutil.ReadFile(path)
+				if err != nil {
+					return err
+				}
+				select {
+				case c <- result{path, md5.Sum(data)}:
+				case <-ctx.Done():
+					return ctx.Err()
+				}
+			}
+			return nil
+		})
+	}
+	go func() {
+		g.Wait()
+		close(c)
+	}()
+
+	m := make(map[string][md5.Size]byte)
+	for r := range c {
+		m[r.path] = r.sum
+	}
+	// Check whether any of the goroutines failed. Since g is accumulating the
+	// errors, we don't need to send them (or check for them) in the individual
+	// results sent on the channel.
+	if err := g.Wait(); err != nil {
+		return nil, err
+	}
+	return m, nil
+}
diff --git a/errgroup/errgroup_test.go b/errgroup/errgroup_test.go
new file mode 100644
index 0000000..661a070
--- /dev/null
+++ b/errgroup/errgroup_test.go
@@ -0,0 +1,176 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package errgroup_test
+
+import (
+	"errors"
+	"fmt"
+	"net/http"
+	"os"
+	"testing"
+
+	"golang.org/x/net/context"
+	"golang.org/x/sync/errgroup"
+)
+
+var (
+	Web   = fakeSearch("web")
+	Image = fakeSearch("image")
+	Video = fakeSearch("video")
+)
+
+type Result string
+type Search func(ctx context.Context, query string) (Result, error)
+
+func fakeSearch(kind string) Search {
+	return func(_ context.Context, query string) (Result, error) {
+		return Result(fmt.Sprintf("%s result for %q", kind, query)), nil
+	}
+}
+
+// JustErrors illustrates the use of a Group in place of a sync.WaitGroup to
+// simplify goroutine counting and error handling. This example is derived from
+// the sync.WaitGroup example at https://golang.org/pkg/sync/#example_WaitGroup.
+func ExampleGroup_justErrors() {
+	var g errgroup.Group
+	var urls = []string{
+		"http://www.golang.org/",
+		"http://www.google.com/",
+		"http://www.somestupidname.com/",
+	}
+	for _, url := range urls {
+		// Launch a goroutine to fetch the URL.
+		url := url // https://golang.org/doc/faq#closures_and_goroutines
+		g.Go(func(url string) error {
+			// Fetch the URL.
+			resp, err := http.Get(url)
+			if err == nil {
+				resp.Body.Close()
+			}
+			return err
+		})
+	}
+	// Wait for all HTTP fetches to complete.
+	if err := wg.Wait(); err == nil {
+		fmt.Println("Successfully fetched all URLs.")
+	}
+}
+
+// Parallel illustrates the use of a Group for synchronizing a simple parallel
+// task: the "Google Search 2.0" function from
+// https://talks.golang.org/2012/concurrency.slide#46, augmented with a Context
+// and error-handling.
+func ExampleGroup_parallel() {
+	Google := func(ctx context.Context, query string) ([]Result, error) {
+		g, ctx := errgroup.WithContext(ctx)
+
+		searches := []Search{Web, Image, Video}
+		results := make([]Result, len(searches))
+		for i, search := range searches {
+			i, search := i, search // https://golang.org/doc/faq#closures_and_goroutines
+			g.Go(func() error {
+				result, err := search(ctx, query)
+				if err == nil {
+					results[i] = result
+				}
+				return err
+			})
+		}
+		if err := g.Wait(); err != nil {
+			return nil, err
+		}
+		return results, nil
+	}
+
+	results, err := Google(context.Background(), "golang")
+	if err != nil {
+		fmt.Fprintln(os.Stderr, err)
+		return
+	}
+	for _, result := range results {
+		fmt.Println(result)
+	}
+
+	// Output:
+	// web result for "golang"
+	// image result for "golang"
+	// video result for "golang"
+}
+
+func TestZeroGroup(t *testing.T) {
+	err1 := errors.New("errgroup_test: 1")
+	err2 := errors.New("errgroup_test: 2")
+
+	cases := []struct {
+		errs []error
+	}{
+		{errs: []error{}},
+		{errs: []error{nil}},
+		{errs: []error{err1}},
+		{errs: []error{err1, nil}},
+		{errs: []error{err1, nil, err2}},
+	}
+
+	for _, tc := range cases {
+		var g errgroup.Group
+
+		var firstErr error
+		for i, err := range tc.errs {
+			err := err
+			g.Go(func() error { return err })
+
+			if firstErr == nil && err != nil {
+				firstErr = err
+			}
+
+			if gErr := g.Wait(); gErr != firstErr {
+				t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+
+					"g.Wait() = %v; want %v",
+					g, tc.errs[:i+1], err, firstErr)
+			}
+		}
+	}
+}
+
+func TestWithContext(t *testing.T) {
+	errDoom := errors.New("group_test: doomed")
+
+	cases := []struct {
+		errs []error
+		want error
+	}{
+		{want: nil},
+		{errs: []error{nil}, want: nil},
+		{errs: []error{errDoom}, want: errDoom},
+		{errs: []error{errDoom, nil}, want: errDoom},
+	}
+
+	for _, tc := range cases {
+		g, ctx := errgroup.WithContext(context.Background())
+
+		for _, err := range tc.errs {
+			err := err
+			g.Go(func() error { return err })
+		}
+
+		if err := g.Wait(); err != tc.want {
+			t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+
+				"g.Wait() = %v; want %v",
+				g, tc.errs, err, tc.want)
+		}
+
+		canceled := false
+		select {
+		case <-ctx.Done():
+			canceled = true
+		default:
+		}
+		if !canceled {
+			t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+
+				"ctx.Done() was not closed",
+				g, tc.errs)
+		}
+	}
+}