os/signal: add NotifyContext to cancel context using system signals

Fixes #37255

Change-Id: Ic0fde3498afefed6e4447f8476e4da7c1faa7145
Reviewed-on: https://go-review.googlesource.com/c/go/+/219640
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Giovanni Bajo <rasky@develer.com>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
diff --git a/src/os/signal/example_unix_test.go b/src/os/signal/example_unix_test.go
new file mode 100644
index 0000000..a0af37a
--- /dev/null
+++ b/src/os/signal/example_unix_test.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
+
+package signal_test
+
+import (
+	"context"
+	"fmt"
+	"log"
+	"os"
+	"os/signal"
+	"time"
+)
+
+// This example passes a context with a signal to tell a blocking function that
+// it should abandon its work after a signal is received.
+func ExampleNotifyContext() {
+	ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
+	defer stop()
+
+	p, err := os.FindProcess(os.Getpid())
+	if err != nil {
+		log.Fatal(err)
+	}
+
+	// On a Unix-like system, pressing Ctrl+C on a keyboard sends a
+	// SIGINT signal to the process of the program in execution.
+	//
+	// This example simulates that by sending a SIGINT signal to itself.
+	if err := p.Signal(os.Interrupt); err != nil {
+		log.Fatal(err)
+	}
+
+	select {
+	case <-time.After(time.Second):
+		fmt.Println("missed signal")
+	case <-ctx.Done():
+		fmt.Println(ctx.Err()) // prints "context canceled"
+		stop()                 // stop receiving signal notifications as soon as possible.
+	}
+
+	// Output:
+	// context canceled
+}
diff --git a/src/os/signal/signal.go b/src/os/signal/signal.go
index 8e31aa2..4250a7e 100644
--- a/src/os/signal/signal.go
+++ b/src/os/signal/signal.go
@@ -5,6 +5,7 @@
 package signal
 
 import (
+	"context"
 	"os"
 	"sync"
 )
@@ -257,3 +258,77 @@
 		}
 	}
 }
+
+// NotifyContext returns a copy of the parent context that is marked done
+// (its Done channel is closed) when one of the listed signals arrives,
+// when the returned stop function is called, or when the parent context's
+// Done channel is closed, whichever happens first.
+//
+// The stop function unregisters the signal behavior, which, like signal.Reset,
+// may restore the default behavior for a given signal. For example, the default
+// behavior of a Go program receiving os.Interrupt is to exit. Calling
+// NotifyContext(parent, os.Interrupt) will change the behavior to cancel
+// the returned context. Future interrupts received will not trigger the default
+// (exit) behavior until the returned stop function is called.
+//
+// The stop function releases resources associated with it, so code should
+// call stop as soon as the operations running in this Context complete and
+// signals no longer need to be diverted to the context.
+func NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
+	ctx, cancel := context.WithCancel(parent)
+	c := &signalCtx{
+		Context: ctx,
+		cancel:  cancel,
+		signals: signals,
+	}
+	c.ch = make(chan os.Signal, 1)
+	Notify(c.ch, c.signals...)
+	if ctx.Err() == nil {
+		go func() {
+			select {
+			case <-c.ch:
+				c.cancel()
+			case <-c.Done():
+			}
+		}()
+	}
+	return c, c.stop
+}
+
+type signalCtx struct {
+	context.Context
+
+	cancel  context.CancelFunc
+	signals []os.Signal
+	ch      chan os.Signal
+}
+
+func (c *signalCtx) stop() {
+	c.cancel()
+	Stop(c.ch)
+}
+
+type stringer interface {
+	String() string
+}
+
+func (c *signalCtx) String() string {
+	var buf []byte
+	// We know that the type of c.Context is context.cancelCtx, and we know that the
+	// String method of cancelCtx returns a string that ends with ".WithCancel".
+	name := c.Context.(stringer).String()
+	name = name[:len(name)-len(".WithCancel")]
+	buf = append(buf, "signal.NotifyContext("+name...)
+	if len(c.signals) != 0 {
+		buf = append(buf, ", ["...)
+		for i, s := range c.signals {
+			buf = append(buf, s.String()...)
+			if i != len(c.signals)-1 {
+				buf = append(buf, ' ')
+			}
+		}
+		buf = append(buf, ']')
+	}
+	buf = append(buf, ')')
+	return string(buf)
+}
diff --git a/src/os/signal/signal_test.go b/src/os/signal/signal_test.go
index f0e06b8..23e33fe 100644
--- a/src/os/signal/signal_test.go
+++ b/src/os/signal/signal_test.go
@@ -8,6 +8,7 @@
 
 import (
 	"bytes"
+	"context"
 	"flag"
 	"fmt"
 	"internal/testenv"
@@ -674,3 +675,164 @@
 	close(stop)
 	<-done
 }
+
+func TestNotifyContext(t *testing.T) {
+	c, stop := NotifyContext(context.Background(), syscall.SIGINT)
+	defer stop()
+
+	if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
+		t.Errorf("c.String() = %q, want %q", got, want)
+	}
+
+	syscall.Kill(syscall.Getpid(), syscall.SIGINT)
+	select {
+	case <-c.Done():
+		if got := c.Err(); got != context.Canceled {
+			t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+		}
+	case <-time.After(time.Second):
+		t.Errorf("timed out waiting for context to be done after SIGINT")
+	}
+}
+
+func TestNotifyContextStop(t *testing.T) {
+	Ignore(syscall.SIGHUP)
+	if !Ignored(syscall.SIGHUP) {
+		t.Errorf("expected SIGHUP to be ignored when explicitly ignoring it.")
+	}
+
+	parent, cancelParent := context.WithCancel(context.Background())
+	defer cancelParent()
+	c, stop := NotifyContext(parent, syscall.SIGHUP)
+	defer stop()
+
+	// If we're being notified, then the signal should not be ignored.
+	if Ignored(syscall.SIGHUP) {
+		t.Errorf("expected SIGHUP to not be ignored.")
+	}
+
+	if want, got := "signal.NotifyContext(context.Background.WithCancel, [hangup])", fmt.Sprint(c); want != got {
+		t.Errorf("c.String() = %q, wanted %q", got, want)
+	}
+
+	stop()
+	select {
+	case <-c.Done():
+		if got := c.Err(); got != context.Canceled {
+			t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+		}
+	case <-time.After(time.Second):
+		t.Errorf("timed out waiting for context to be done after calling stop")
+	}
+}
+
+func TestNotifyContextCancelParent(t *testing.T) {
+	parent, cancelParent := context.WithCancel(context.Background())
+	defer cancelParent()
+	c, stop := NotifyContext(parent, syscall.SIGINT)
+	defer stop()
+
+	if want, got := "signal.NotifyContext(context.Background.WithCancel, [interrupt])", fmt.Sprint(c); want != got {
+		t.Errorf("c.String() = %q, want %q", got, want)
+	}
+
+	cancelParent()
+	select {
+	case <-c.Done():
+		if got := c.Err(); got != context.Canceled {
+			t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+		}
+	case <-time.After(time.Second):
+		t.Errorf("timed out waiting for parent context to be canceled")
+	}
+}
+
+func TestNotifyContextPrematureCancelParent(t *testing.T) {
+	parent, cancelParent := context.WithCancel(context.Background())
+	defer cancelParent()
+
+	cancelParent() // Prematurely cancel context before calling NotifyContext.
+	c, stop := NotifyContext(parent, syscall.SIGINT)
+	defer stop()
+
+	if want, got := "signal.NotifyContext(context.Background.WithCancel, [interrupt])", fmt.Sprint(c); want != got {
+		t.Errorf("c.String() = %q, want %q", got, want)
+	}
+
+	select {
+	case <-c.Done():
+		if got := c.Err(); got != context.Canceled {
+			t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+		}
+	case <-time.After(time.Second):
+		t.Errorf("timed out waiting for parent context to be canceled")
+	}
+}
+
+func TestNotifyContextSimultaneousNotifications(t *testing.T) {
+	c, stop := NotifyContext(context.Background(), syscall.SIGINT)
+	defer stop()
+
+	if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
+		t.Errorf("c.String() = %q, want %q", got, want)
+	}
+
+	var wg sync.WaitGroup
+	n := 10
+	wg.Add(n)
+	for i := 0; i < n; i++ {
+		go func() {
+			syscall.Kill(syscall.Getpid(), syscall.SIGINT)
+			wg.Done()
+		}()
+	}
+	wg.Wait()
+	select {
+	case <-c.Done():
+		if got := c.Err(); got != context.Canceled {
+			t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+		}
+	case <-time.After(time.Second):
+		t.Errorf("expected context to be canceled")
+	}
+}
+
+func TestNotifyContextSimultaneousStop(t *testing.T) {
+	c, stop := NotifyContext(context.Background(), syscall.SIGINT)
+	defer stop()
+
+	if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
+		t.Errorf("c.String() = %q, want %q", got, want)
+	}
+
+	var wg sync.WaitGroup
+	n := 10
+	wg.Add(n)
+	for i := 0; i < n; i++ {
+		go func() {
+			stop()
+			wg.Done()
+		}()
+	}
+	wg.Wait()
+	select {
+	case <-c.Done():
+		if got := c.Err(); got != context.Canceled {
+			t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+		}
+	case <-time.After(time.Second):
+		t.Errorf("expected context to be canceled")
+	}
+}
+
+func TestNotifyContextStringer(t *testing.T) {
+	parent, cancelParent := context.WithCancel(context.Background())
+	defer cancelParent()
+	c, stop := NotifyContext(parent, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
+	defer stop()
+
+	want := `signal.NotifyContext(context.Background.WithCancel, [hangup interrupt terminated])`
+	if got := fmt.Sprint(c); got != want {
+		t.Errorf("c.String() = %q, want %q", got, want)
+	}
+}