{cmd,internal}/screentest: add support for request headers

Change-Id: I79ecc613ac12cc0f42edcaee258f3addb5857d69
Reviewed-on: https://go-review.googlesource.com/c/website/+/373654
Trust: Jamal Carvalho <jamalcarvalho@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/cmd/screentest/main.go b/cmd/screentest/main.go
index cd55ca0..0174aa4 100644
--- a/cmd/screentest/main.go
+++ b/cmd/screentest/main.go
@@ -8,19 +8,41 @@
 
 import (
 	"flag"
+	"fmt"
 	"log"
+	"os"
+	"strings"
 
 	"golang.org/x/website/internal/screentest"
 )
 
 var (
-	testdata = flag.String("testdata", "cmd/screentest/testdata/*.txt", "directory to look for testdata")
-	update   = flag.Bool("update", false, "use this flag to update cached screenshots")
+	update  = flag.Bool("update", false, "update cached screenshots")
+	headers = flag.String("H", "", "set request headers")
 )
 
 func main() {
+	flag.Usage = func() {
+		fmt.Printf("Usage: screentest [OPTIONS] glob\n")
+		flag.PrintDefaults()
+	}
 	flag.Parse()
-	if err := screentest.CheckHandler(*testdata, *update); err != nil {
+	args := flag.Args()
+	if len(args) != 1 {
+		flag.Usage()
+		os.Exit(1)
+	}
+	hdr := make(map[string]interface{})
+	if *headers != "" {
+		for _, h := range strings.Split(*headers, ",") {
+			parts := strings.Split(h, ":")
+			if len(parts) != 2 {
+				log.Fatalf("invalid header %s", h)
+			}
+			hdr[parts[0]] = parts[1]
+		}
+	}
+	if err := screentest.CheckHandler(args[0], *update, hdr); err != nil {
 		log.Fatal(err)
 	}
 }
diff --git a/internal/screentest/screentest.go b/internal/screentest/screentest.go
index 56484db..5b7c387 100644
--- a/internal/screentest/screentest.go
+++ b/internal/screentest/screentest.go
@@ -107,6 +107,7 @@
 	"testing"
 	"time"
 
+	"github.com/chromedp/cdproto/network"
 	"github.com/chromedp/cdproto/page"
 	"github.com/chromedp/chromedp"
 	"github.com/n7olkachev/imgdiff/pkg/imgdiff"
@@ -115,7 +116,7 @@
 
 // CheckHandler runs the test scripts matched by glob. If any errors are
 // encountered, CheckHandler returns an error listing the problems.
-func CheckHandler(glob string, update bool) error {
+func CheckHandler(glob string, update bool, headers map[string]interface{}) error {
 	ctx := context.Background()
 	files, err := filepath.Glob(glob)
 	if err != nil {
@@ -142,7 +143,7 @@
 		defer cancel()
 		var hdr bool
 		for _, test := range tests {
-			if err := runDiff(ctx, test, update); err != nil {
+			if err := runDiff(ctx, test, update, headers); err != nil {
 				if !hdr {
 					fmt.Fprintf(&buf, "%s\n\n", file)
 					hdr = true
@@ -159,7 +160,7 @@
 }
 
 // TestHandler runs the test script files matched by glob.
-func TestHandler(t *testing.T, glob string, update bool) error {
+func TestHandler(t *testing.T, glob string, update bool, headers map[string]interface{}) error {
 	ctx := context.Background()
 	files, err := filepath.Glob(glob)
 	if err != nil {
@@ -182,7 +183,7 @@
 		defer cancel()
 		for _, test := range tests {
 			t.Run(test.name, func(t *testing.T) {
-				if err := runDiff(ctx, test, update); err != nil {
+				if err := runDiff(ctx, test, update, headers); err != nil {
 					t.Fatal(err)
 				}
 			})
@@ -426,13 +427,13 @@
 
 // runDiff generates screenshots for a given test case and
 // a diff if the screenshots do not match.
-func runDiff(ctx context.Context, test *testcase, update bool) error {
+func runDiff(ctx context.Context, test *testcase, update bool, headers map[string]interface{}) error {
 	fmt.Printf("test %s\n", test.name)
-	screenA, err := screenshot(ctx, test, test.urlA, test.outImgA, test.cacheA, update)
+	screenA, err := screenshot(ctx, test, test.urlA, test.outImgA, test.cacheA, update, headers)
 	if err != nil {
 		return fmt.Errorf("screenshot(ctx, %q, %q, %q, %v): %w", test, test.urlA, test.outImgA, test.cacheA, err)
 	}
-	screenB, err := screenshot(ctx, test, test.urlB, test.outImgB, test.cacheB, update)
+	screenB, err := screenshot(ctx, test, test.urlB, test.outImgB, test.cacheB, update, headers)
 	if err != nil {
 		return fmt.Errorf("screenshot(ctx, %q, %q, %q, %v): %w", test, test.urlB, test.outImgB, test.cacheB, err)
 	}
@@ -470,7 +471,9 @@
 // screenshot gets a screenshot for a testcase url. When cache is true it will
 // attempt to read the screenshot from a cache or capture a new screenshot
 // and write it to the cache if it does not exist.
-func screenshot(ctx context.Context, test *testcase, url, file string, cache, update bool) (_ *image.Image, err error) {
+func screenshot(ctx context.Context, test *testcase,
+	url, file string, cache, update bool, headers map[string]interface{},
+) (_ *image.Image, err error) {
 	var data []byte
 	// If cache is enabled, try to read the file from the cache.
 	if cache {
@@ -486,7 +489,7 @@
 	// If cache is false, this is the first test run, or an update is requested
 	// we capture a new screenshot from a live URL.
 	if !cache || update {
-		data, err = captureScreenshot(ctx, test, url)
+		data, err = captureScreenshot(ctx, test, url, headers)
 		if err != nil {
 			return nil, fmt.Errorf("captureScreenshot(ctx, %q, %q): %w", url, test, err)
 		}
@@ -507,18 +510,24 @@
 
 // captureScreenshot runs a series of browser actions and takes a screenshot
 // of the resulting webpage in an instance of headless chrome.
-func captureScreenshot(ctx context.Context, test *testcase, url string) ([]byte, error) {
+func captureScreenshot(ctx context.Context, test *testcase,
+	url string, headers map[string]interface{},
+) ([]byte, error) {
 	var buf []byte
 	ctx, cancel := chromedp.NewContext(ctx)
 	defer cancel()
 	ctx, cancel = context.WithTimeout(ctx, time.Minute)
 	defer cancel()
-	tasks := chromedp.Tasks{
+	var tasks chromedp.Tasks
+	if headers != nil {
+		tasks = append(tasks, network.SetExtraHTTPHeaders(headers))
+	}
+	tasks = append(tasks,
 		chromedp.EmulateViewport(int64(test.viewportWidth), int64(test.viewportHeight)),
 		chromedp.Navigate(url),
 		waitForEvent("networkIdle"),
 		test.tasks,
-	}
+	)
 	switch test.screenshotType {
 	case fullScreenshot:
 		tasks = append(tasks, chromedp.FullScreenshot(&buf, 100))
diff --git a/internal/screentest/screentest_test.go b/internal/screentest/screentest_test.go
index 84bbbfa..6791564 100644
--- a/internal/screentest/screentest_test.go
+++ b/internal/screentest/screentest_test.go
@@ -5,7 +5,9 @@
 package screentest
 
 import (
+	"context"
 	"fmt"
+	"net/http"
 	"os"
 	"os/exec"
 	"path/filepath"
@@ -194,17 +196,17 @@
 		{
 			name: "cached",
 			args: args{
-				output: "testdata/screenshots",
+				output: "testdata/screenshots/cached",
 				glob:   "testdata/cached.txt",
 			},
 			wantFiles: []string{
-				filepath.Join("testdata", "screenshots", "homepage.go-dev.png"),
+				filepath.Join("testdata", "screenshots", "cached", "homepage.go-dev.png"),
 			},
 		},
 	}
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			if err := CheckHandler(tt.args.glob, false); (err != nil) != tt.wantErr {
+			if err := CheckHandler(tt.args.glob, false, nil); (err != nil) != tt.wantErr {
 				t.Fatalf("CheckHandler() error = %v, wantErr %v", err, tt.wantErr)
 			}
 			if len(tt.wantFiles) != 0 {
@@ -227,5 +229,42 @@
 	if err != nil {
 		t.Skip()
 	}
-	TestHandler(t, "testdata/pass.txt", false)
+	TestHandler(t, "testdata/pass.txt", false, nil)
+}
+
+func TestHeaders(t *testing.T) {
+	// Skip this test if Google Chrome is not installed.
+	_, err := exec.LookPath("google-chrome")
+	if err != nil {
+		t.Skip()
+	}
+	go headerServer()
+	if err := runDiff(context.Background(), &testcase{
+		name:              "go.dev homepage",
+		urlA:              "http://localhost:6061",
+		cacheA:            true,
+		urlB:              "http://localhost:6061",
+		outImgA:           filepath.Join("testdata", "screenshots", "headers", "headers-test.localhost-6061.png"),
+		outImgB:           filepath.Join("testdata", "screenshots", "headers", "headers-test.localhost-6061.png"),
+		outDiff:           filepath.Join("testdata", "screenshots", "headers", "headers-test.diff.png"),
+		viewportWidth:     1536,
+		viewportHeight:    960,
+		screenshotType:    elementScreenshot,
+		screenshotElement: "#result",
+	}, false, map[string]interface{}{"Authorization": "Bearer token"}); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func headerServer() error {
+	mux := http.NewServeMux()
+	mux.HandleFunc("/", func(res http.ResponseWriter, req *http.Request) {
+		fmt.Fprintf(res, `<!doctype html>
+		<html>
+		<body>
+		  <span id="result">%s</span>
+		</body>
+		</html>`, req.Header.Get("Authorization"))
+	})
+	return http.ListenAndServe(fmt.Sprintf(":%d", 6061), mux)
 }
diff --git a/internal/screentest/testdata/cached.txt b/internal/screentest/testdata/cached.txt
index 7357dfc..39e3158 100644
--- a/internal/screentest/testdata/cached.txt
+++ b/internal/screentest/testdata/cached.txt
@@ -1,6 +1,6 @@
 windowsize 1536x960
 compare https://go.dev::cache https://go.dev::cache
-output testdata/screenshots
+output testdata/screenshots/cached
 
 test homepage
 pathname /
diff --git a/internal/screentest/testdata/screenshots/homepage.go-dev.png b/internal/screentest/testdata/screenshots/cached/homepage.go-dev.png
similarity index 100%
rename from internal/screentest/testdata/screenshots/homepage.go-dev.png
rename to internal/screentest/testdata/screenshots/cached/homepage.go-dev.png
Binary files differ
diff --git a/internal/screentest/testdata/screenshots/headers/headers-test.localhost-6061.png b/internal/screentest/testdata/screenshots/headers/headers-test.localhost-6061.png
new file mode 100755
index 0000000..04f9f49
--- /dev/null
+++ b/internal/screentest/testdata/screenshots/headers/headers-test.localhost-6061.png
Binary files differ