internal/workflow: add persistence and logging support

Add workflow.Resume, which takes saved state and resumes a workflow.
The host can pass a Listener to capture task states to pass to it. There
is no offical API for modifying the state of the workflow before
resuming, e.g. to implement retries.

Add per-task logging support. Task functions can accept a *TaskContext
instead of a context.Context, which adds a logging interface that's
implemented by the host to do logging however it wants.

Change-Id: Iae9411ca55ac55718025e1e62f3010cba9b8363a
Reviewed-on: https://go-review.googlesource.com/c/build/+/348131
Trust: Heschi Kreinick <heschi@google.com>
Trust: Alexander Rakoczy <alex@golang.org>
Run-TryBot: Heschi Kreinick <heschi@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Alexander Rakoczy <alex@golang.org>
diff --git a/internal/workflow/workflow.go b/internal/workflow/workflow.go
index 781437f..36dbf6a 100644
--- a/internal/workflow/workflow.go
+++ b/internal/workflow/workflow.go
@@ -18,8 +18,9 @@
 //
 // Each task has a set of input Values, and returns a single output Value.
 // Calling Task defines a task that will run a Go function when it runs. That
-// function must take a context.Context, followed by arguments corresponding to
-// the dynamic type of the Values passed to it.
+// function must take a *TaskContext or context.Context, followed by arguments
+// corresponding to the dynamic type of the Values passed to it. The TaskContext
+// can be used as a normal Context, and also supports unstructured logging.
 //
 // Once a Definition is complete, call Start to set its parameters and
 // instantiate it into a Workflow. Call Run to execute the workflow until
@@ -163,8 +164,8 @@
 	if ftyp.NumIn()-1 != len(args) {
 		panic(fmt.Errorf("%v takes %v non-Context arguments, but was passed %v", f, ftyp.NumIn()-1, len(args)))
 	}
-	if ftyp.In(0) != reflect.TypeOf((*context.Context)(nil)).Elem() {
-		panic(fmt.Errorf("the first argument of %v must be a context.Context, is %v", f, ftyp.In(0)))
+	if !reflect.TypeOf((*TaskContext)(nil)).AssignableTo(ftyp.In(0)) {
+		panic(fmt.Errorf("the first argument of %v must be a context.Context or *TaskContext, is %v", f, ftyp.In(0)))
 	}
 	for i, arg := range args {
 		if !arg.typ().AssignableTo(ftyp.In(i + 1)) {
@@ -182,6 +183,42 @@
 	return &taskResult{task: td}
 }
 
+// A TaskContext is a context.Context, plus workflow-related features.
+type TaskContext struct {
+	context.Context
+	Logger
+}
+
+// A Listener is used to notify the workflow host of state changes, for display
+// and persistence.
+type Listener interface {
+	// TaskStateChanged is called when the state of a task changes.
+	// state is safe to store or modify.
+	TaskStateChanged(workflowID, taskID string, state *TaskState) error
+	// Logger is called to obtain a Logger for a particular task.
+	Logger(workflowID, taskID string) Logger
+}
+
+// TaskState contains the state of a task in a running workflow. Once Finished
+// is true, either Result or Error will be populated.
+type TaskState struct {
+	Name     string
+	Finished bool
+	Result   interface{}
+	Error    error
+}
+
+// WorkflowState contains the shallow state of a running workflow.
+type WorkflowState struct {
+	ID     string
+	Params map[string]string
+}
+
+// A Logger is a debug logger passed to a task implementation.
+type Logger interface {
+	Printf(format string, v ...interface{})
+}
+
 type taskDefinition struct {
 	name string
 	args []Value
@@ -206,7 +243,7 @@
 
 // A Workflow is an instantiated workflow instance, ready to run.
 type Workflow struct {
-	id     string
+	ID     string
 	def    *Definition
 	params map[string]string
 
@@ -244,47 +281,83 @@
 	}
 }
 
-// TaskState contains the state of a task in a running workflow. Once finished
-// is true, either Result or Error will be populated.
-type TaskState struct {
-	Name     string
-	Finished bool
-	Result   interface{}
-	Error    error
-}
-
 // Start instantiates a workflow with the given parameters.
 func Start(def *Definition, params map[string]string) (*Workflow, error) {
 	w := &Workflow{
-		id:     uuid.New().String(),
+		ID:     uuid.New().String(),
 		def:    def,
 		params: params,
 		tasks:  map[*taskDefinition]*taskState{},
 	}
-	used := map[*taskDefinition]bool{}
+	if err := w.validate(); err != nil {
+		return nil, err
+	}
 	for _, taskDef := range def.tasks {
 		w.tasks[taskDef] = &taskState{def: taskDef, w: w}
+	}
+	return w, nil
+}
+
+func (w *Workflow) validate() error {
+	used := map[*taskDefinition]bool{}
+	for _, taskDef := range w.def.tasks {
 		for _, arg := range taskDef.args {
 			for _, argDep := range arg.deps() {
 				used[argDep] = true
 			}
 		}
 	}
-	for _, output := range def.outputs {
+	for _, output := range w.def.outputs {
 		used[output.task] = true
 	}
-	for _, task := range def.tasks {
+	for _, task := range w.def.tasks {
 		if !used[task] {
-			return nil, fmt.Errorf("task %v is not referenced and should be deleted", task.name)
+			return fmt.Errorf("task %v is not referenced and should be deleted", task.name)
 		}
 	}
+	return nil
+}
+
+// Resume restores a workflow from stored state. The WorkflowState can be
+// constructed by the host. TaskStates should be saved from Listener calls.
+// Tasks that had not finished will be restarted, but tasks that finished in
+// errors will not be retried.
+func Resume(def *Definition, state *WorkflowState, taskStates map[string]*TaskState) (*Workflow, error) {
+	w := &Workflow{
+		ID:     state.ID,
+		def:    def,
+		params: state.Params,
+		tasks:  map[*taskDefinition]*taskState{},
+	}
+	if err := w.validate(); err != nil {
+		return nil, err
+	}
+	for _, taskDef := range def.tasks {
+		tState, ok := taskStates[taskDef.name]
+		if !ok {
+			return nil, fmt.Errorf("task state for %q not found", taskDef.name)
+		}
+		w.tasks[taskDef] = &taskState{
+			def:      taskDef,
+			w:        w,
+			started:  tState.Finished, // Can't resume tasks, so either it's new or done.
+			finished: tState.Finished,
+			result:   tState.Result,
+			err:      tState.Error,
+		}
+	}
+
 	return w, nil
 }
 
 // Run runs a workflow to successful completion and returns its outputs.
-// statusFunc will be called when each task starts and finishes. It should be
-// used only for monitoring purposes - to read task results, register Outputs.
-func (w *Workflow) Run(ctx context.Context, stateFunc func(*TaskState)) (map[string]interface{}, error) {
+// listener.TaskStateChanged will be called when each task starts and
+// finishes. It should be used only for monitoring and persistence purposes -
+// to read task results, register Outputs.
+func (w *Workflow) Run(ctx context.Context, listener Listener) (map[string]interface{}, error) {
+	if listener == nil {
+		listener = &defaultListener{}
+	}
 	var running sync.WaitGroup
 	defer running.Wait()
 
@@ -311,10 +384,10 @@
 				continue
 			}
 			task.started = true
-			stateFunc(task.toExported())
+			listener.TaskStateChanged(w.ID, task.def.name, task.toExported())
 			running.Add(1)
 			go func(task taskState) {
-				stateChan <- w.runTask(ctx, task, in)
+				stateChan <- w.runTask(ctx, listener, task, in)
 				running.Done()
 			}(*task)
 		}
@@ -324,13 +397,14 @@
 			return nil, ctx.Err()
 		case state := <-stateChan:
 			w.tasks[state.def] = &state
-			stateFunc(state.toExported())
+			listener.TaskStateChanged(w.ID, state.def.name, state.toExported())
 		}
 	}
 }
 
-func (w *Workflow) runTask(ctx context.Context, state taskState, args []reflect.Value) taskState {
-	in := append([]reflect.Value{reflect.ValueOf(ctx)}, args...)
+func (w *Workflow) runTask(ctx context.Context, listener Listener, state taskState, args []reflect.Value) taskState {
+	tctx := &TaskContext{Context: ctx, Logger: listener.Logger(w.ID, state.def.name)}
+	in := append([]reflect.Value{reflect.ValueOf(tctx)}, args...)
 	out := reflect.ValueOf(state.def.f).Call(in)
 	var err error
 	if !out[1].IsNil() {
@@ -340,3 +414,17 @@
 	state.result, state.err = out[0].Interface(), err
 	return state
 }
+
+type defaultListener struct{}
+
+func (s *defaultListener) TaskStateChanged(_, _ string, _ *TaskState) error {
+	return nil
+}
+
+func (s *defaultListener) Logger(_, task string) Logger {
+	return &defaultLogger{}
+}
+
+type defaultLogger struct{}
+
+func (l *defaultLogger) Printf(format string, v ...interface{}) {}
diff --git a/internal/workflow/workflow_test.go b/internal/workflow/workflow_test.go
index 4cff893..a4379bc 100644
--- a/internal/workflow/workflow_test.go
+++ b/internal/workflow/workflow_test.go
@@ -6,34 +6,45 @@
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"reflect"
 	"strings"
+	"sync/atomic"
 	"testing"
 	"time"
 
+	"github.com/google/go-cmp/cmp"
 	"golang.org/x/build/internal/workflow"
 )
 
 func TestTrivial(t *testing.T) {
+	echo := func(ctx context.Context, arg string) (string, error) {
+		return arg, nil
+	}
+
 	wd := workflow.New()
 	greeting := wd.Task("echo", echo, wd.Constant("hello world"))
 	wd.Output("greeting", greeting)
 
-	w, err := workflow.Start(wd, map[string]string{})
-	if err != nil {
-		t.Fatal(err)
-	}
-	outputs, err := w.Run(context.Background(), loggingListener(t))
-	if err != nil {
-		t.Fatal(err)
-	}
+	w := startWorkflow(t, wd, nil)
+	outputs := runWorkflow(t, w, nil)
 	if got, want := outputs["greeting"], "hello world"; got != want {
 		t.Errorf("greeting = %q, want %q", got, want)
 	}
 }
 
 func TestSplitJoin(t *testing.T) {
+	echo := func(ctx context.Context, arg string) (string, error) {
+		return arg, nil
+	}
+	appendInt := func(ctx context.Context, s string, i int) (string, error) {
+		return fmt.Sprintf("%v%v", s, i), nil
+	}
+	join := func(ctx context.Context, s []string) (string, error) {
+		return strings.Join(s, ","), nil
+	}
+
 	wd := workflow.New()
 	in := wd.Task("echo", echo, wd.Constant("string #"))
 	add1 := wd.Task("add 1", appendInt, in, wd.Constant(1))
@@ -42,14 +53,8 @@
 	out := wd.Task("join", join, both)
 	wd.Output("strings", out)
 
-	w, err := workflow.Start(wd, map[string]string{})
-	if err != nil {
-		t.Fatal(err)
-	}
-	outputs, err := w.Run(context.Background(), loggingListener(t))
-	if err != nil {
-		t.Fatal(err)
-	}
+	w := startWorkflow(t, wd, nil)
+	outputs := runWorkflow(t, w, nil)
 	if got, want := outputs["strings"], "string #1,string #2"; got != want {
 		t.Errorf("joined output = %q, want %q", got, want)
 	}
@@ -80,19 +85,19 @@
 	wd.Output("out1", out1)
 	wd.Output("out2", out2)
 
-	w, err := workflow.Start(wd, map[string]string{})
-	if err != nil {
-		t.Fatal(err)
-	}
+	w := startWorkflow(t, wd, nil)
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 	defer cancel()
-	_, err = w.Run(ctx, loggingListener(t))
-	if err != nil {
+	if _, err := w.Run(ctx, &verboseListener{t}); err != nil {
 		t.Fatal(err)
 	}
 }
 
 func TestParameters(t *testing.T) {
+	echo := func(ctx context.Context, arg string) (string, error) {
+		return arg, nil
+	}
+
 	wd := workflow.New()
 	param1 := wd.Parameter("param1")
 	param2 := wd.Parameter("param2")
@@ -101,40 +106,177 @@
 	wd.Output("out1", out1)
 	wd.Output("out2", out2)
 
-	w, err := workflow.Start(wd, map[string]string{"param1": "#1", "param2": "#2"})
-	if err != nil {
-		t.Fatal(err)
-	}
-	outputs, err := w.Run(context.Background(), loggingListener(t))
-	if err != nil {
-		t.Fatal(err)
-	}
+	w := startWorkflow(t, wd, map[string]string{"param1": "#1", "param2": "#2"})
+	outputs := runWorkflow(t, w, nil)
 	if want := map[string]interface{}{"out1": "#1", "out2": "#2"}; !reflect.DeepEqual(outputs, want) {
 		t.Errorf("outputs = %#v, want %#v", outputs, want)
 	}
 }
 
-func appendInt(ctx context.Context, s string, i int) (string, error) {
-	return fmt.Sprintf("%v%v", s, i), nil
-}
-
-func join(ctx context.Context, s []string) (string, error) {
-	return strings.Join(s, ","), nil
-}
-
-func echo(ctx context.Context, arg string) (string, error) {
-	return arg, nil
-}
-
-func loggingListener(t *testing.T) func(*workflow.TaskState) {
-	return func(st *workflow.TaskState) {
-		switch {
-		case !st.Finished:
-			t.Logf("task %-10v: started", st.Name)
-		case st.Error != nil:
-			t.Logf("task %-10v: error: %v", st.Name, st.Error)
-		default:
-			t.Logf("task %-10v: done: %v", st.Name, st.Result)
-		}
+func TestLogging(t *testing.T) {
+	log := func(ctx *workflow.TaskContext, arg string) (string, error) {
+		ctx.Printf("logging argument: %v", arg)
+		return arg, nil
 	}
+
+	wd := workflow.New()
+	out := wd.Task("log", log, wd.Constant("hey there"))
+	wd.Output("out", out)
+
+	logger := &capturingLogger{}
+	listener := &logTestListener{
+		Listener: &verboseListener{t},
+		logger:   logger,
+	}
+	w := startWorkflow(t, wd, nil)
+	runWorkflow(t, w, listener)
+	if want := []string{"logging argument: hey there"}; !reflect.DeepEqual(logger.lines, want) {
+		t.Errorf("unexpected logging result: got %v, want %v", logger.lines, want)
+	}
+}
+
+type logTestListener struct {
+	workflow.Listener
+	logger workflow.Logger
+}
+
+func (l *logTestListener) Logger(_, _ string) workflow.Logger {
+	return l.logger
+}
+
+type capturingLogger struct {
+	lines []string
+}
+
+func (l *capturingLogger) Printf(format string, v ...interface{}) {
+	l.lines = append(l.lines, fmt.Sprintf(format, v...))
+}
+
+func TestResume(t *testing.T) {
+	// We expect runOnlyOnce to only run once.
+	var runs int64
+	runOnlyOnce := func(ctx context.Context) (string, error) {
+		atomic.AddInt64(&runs, 1)
+		return "ran", nil
+	}
+	// blockOnce blocks the first time it's called, so that the workflow can be
+	// canceled at its step.
+	block := true
+	blocked := make(chan bool, 1)
+	maybeBlock := func(ctx context.Context, _ string) (string, error) {
+		if block {
+			blocked <- true
+			<-ctx.Done()
+			return "blocked", ctx.Err()
+		}
+		return "not blocked", nil
+	}
+	wd := workflow.New()
+	v1 := wd.Task("run once", runOnlyOnce)
+	v2 := wd.Task("block", maybeBlock, v1)
+	wd.Output("output", v2)
+
+	// Cancel the workflow once we've entered maybeBlock.
+	ctx, cancel := context.WithCancel(context.Background())
+	go func() {
+		<-blocked
+		cancel()
+	}()
+	w, err := workflow.Start(wd, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	storage := &mapListener{Listener: &verboseListener{t}}
+	_, err = w.Run(ctx, storage)
+	if !errors.Is(err, context.Canceled) {
+		t.Fatalf("canceled workflow returned error %v, wanted Canceled", err)
+	}
+	storage.assertState(t, w, map[string]*workflow.TaskState{
+		"run once": {Name: "run once", Finished: true, Result: "ran"},
+		"block":    {Name: "block"}, // We cancelled the workflow before it could save its state.
+	})
+
+	block = false
+	wfState := &workflow.WorkflowState{ID: w.ID, Params: nil}
+	taskStates := storage.states[w.ID]
+	w2, err := workflow.Resume(wd, wfState, taskStates)
+	out := runWorkflow(t, w2, storage)
+	if got, want := out["output"], "not blocked"; got != want {
+		t.Errorf("output from maybeBlock was %q, wanted %q", got, want)
+	}
+	if runs != 1 {
+		t.Errorf("runOnlyOnce ran %v times, wanted 1", runs)
+	}
+	storage.assertState(t, w, map[string]*workflow.TaskState{
+		"run once": {Name: "run once", Finished: true, Result: "ran"},
+		"block":    {Name: "block", Finished: true, Result: "not blocked"},
+	})
+}
+
+type mapListener struct {
+	workflow.Listener
+	states map[string]map[string]*workflow.TaskState
+}
+
+func (l *mapListener) TaskStateChanged(workflowID, taskID string, state *workflow.TaskState) error {
+	if l.states == nil {
+		l.states = map[string]map[string]*workflow.TaskState{}
+	}
+	if l.states[workflowID] == nil {
+		l.states[workflowID] = map[string]*workflow.TaskState{}
+	}
+	l.states[workflowID][taskID] = state
+	return l.Listener.TaskStateChanged(workflowID, taskID, state)
+}
+
+func (l *mapListener) assertState(t *testing.T, w *workflow.Workflow, want map[string]*workflow.TaskState) {
+	if diff := cmp.Diff(l.states[w.ID], want); diff != "" {
+		t.Errorf("task state didn't match expections: %v", diff)
+	}
+}
+
+func startWorkflow(t *testing.T, wd *workflow.Definition, params map[string]string) *workflow.Workflow {
+	w, err := workflow.Start(wd, params)
+	if err != nil {
+		t.Fatal(err)
+	}
+	return w
+}
+
+func runWorkflow(t *testing.T, w *workflow.Workflow, listener workflow.Listener) map[string]interface{} {
+	if listener == nil {
+		listener = &verboseListener{t}
+	}
+	outputs, err := w.Run(context.Background(), listener)
+	if err != nil {
+		t.Fatal(err)
+	}
+	return outputs
+}
+
+type verboseListener struct{ t *testing.T }
+
+func (l *verboseListener) TaskStateChanged(_, _ string, st *workflow.TaskState) error {
+	switch {
+	case !st.Finished:
+		l.t.Logf("task %-10v: started", st.Name)
+	case st.Error != nil:
+		l.t.Logf("task %-10v: error: %v", st.Name, st.Error)
+	default:
+		l.t.Logf("task %-10v: done: %v", st.Name, st.Result)
+	}
+	return nil
+}
+
+func (l *verboseListener) Logger(_, task string) workflow.Logger {
+	return &testLogger{t: l.t, task: task}
+}
+
+type testLogger struct {
+	t    *testing.T
+	task string
+}
+
+func (l *testLogger) Printf(format string, v ...interface{}) {
+	l.t.Logf("task %-10v: LOG: %s", l.task, fmt.Sprintf(format, v...))
 }