notary/internal/notecheck: refactor for integration into go command

Also fix up a few of the recent URL format changes.

Change-Id: I9c62f1ff3782dc5f6a93fb97470d103cc3c1b264
Reviewed-on: https://go-review.googlesource.com/c/exp/+/172413
Run-TryBot: Russ Cox <rsc@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
diff --git a/notary/internal/notecheck/main.go b/notary/internal/notecheck/main.go
index deb57cf..e0923f7 100644
--- a/notary/internal/notecheck/main.go
+++ b/notary/internal/notecheck/main.go
@@ -20,14 +20,13 @@
 
 import (
 	"bytes"
-	"encoding/hex"
 	"flag"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"log"
 	"net/http"
 	"os"
-	"strconv"
 	"strings"
 	"sync"
 	"time"
@@ -63,29 +62,15 @@
 	if *url == "" {
 		*url = "https://" + verifier.Name()
 	}
-	msg, err := httpGet(*url + "/latest")
-	if err != nil {
-		log.Fatal(err)
-	}
-	treeNote, err := note.Open(msg, note.VerifierList(verifier))
-	if err != nil {
-		log.Fatalf("reading note: %v\nnote:\n%s", err, msg)
-	}
-	tree, err := tlog.ParseTree([]byte(treeNote.Text))
-	if err != nil {
-		log.Fatal(err)
-	}
 
-	if *vflag {
-		log.Printf("validating against %s @%d", verifier.Name(), tree.N)
+	// TODO(rsc): Load initial db.latest, db.latestNote from on-disk cache.
+	db := &GoSumDB{
+		url:       *url,
+		verifiers: note.VerifierList(verifier),
 	}
-
-	verifierURL := *url
-	tr := &tileReader{url: verifierURL + "/"}
-	thr := tlog.TileHashReader(tree, tr)
-	if _, err := tlog.TreeHash(tree.N, thr); err != nil {
-		log.Fatal(err)
-	}
+	db.httpClient.Timeout = 1 * time.Minute
+	db.tileReader.db = db
+	db.tileReader.url = db.url + "/"
 
 	for _, arg := range flag.Args()[1:] {
 		data, err := ioutil.ReadFile(arg)
@@ -93,17 +78,21 @@
 			log.Fatal(err)
 		}
 		log.SetPrefix("notecheck: " + arg + ": ")
-		checkGoSum(data, verifierURL, thr)
+		checkGoSum(db, data)
 		log.SetPrefix("notecheck: ")
 	}
 }
 
-func checkGoSum(data []byte, verifierURL string, thr tlog.HashReader) {
-	lines := strings.SplitAfter(string(data), "\n")
+func checkGoSum(db *GoSumDB, data []byte) {
+	lines := strings.Split(string(data), "\n")
 	if lines[len(lines)-1] != "" {
 		log.Printf("error: final line missing newline")
 		return
 	}
+	// TODO(rsc): This assumes that the /go.mod and the whole-tree hashes
+	// always appear together in a go.sum.
+	// Sometimes the /go.mod can appear alone.
+	// The code needs to be updated to handle that case.
 	lines = lines[:len(lines)-1]
 	if len(lines)%2 != 0 {
 		log.Printf("error: odd number of lines")
@@ -116,89 +105,195 @@
 			continue
 		}
 
-		data, err := httpGet(verifierURL + "/lookup/" + f1[0] + "@" + f1[1])
+		dbLines, err := db.Lookup(f1[0], f1[1])
 		if err != nil {
 			log.Printf("%s@%s: %v", f1[0], f1[1], err)
 			continue
 		}
-		j := bytes.IndexByte(data, '\n')
-		if j < 0 {
-			log.Printf("%s@%s: short response from lookup", f1[0], f1[1])
-			continue
-		}
-		id, err := strconv.ParseInt(strings.TrimSpace(string(data[:j])), 10, 64)
-		if err != nil {
-			log.Printf("%s@%s: unexpected response:\n%s", f1[0], f1[1], data)
-			continue
-		}
-		ldata := data[j+1:]
 
-		c := make(chan *tlog.Hash, 1)
-		go func() {
-			hashes, err := thr.ReadHashes([]int64{tlog.StoredHashIndex(0, id)})
-			if err != nil {
-				log.Printf("%s@%s: %v", f1[0], f1[1], err)
-				c <- nil
-				return
+		if strings.Join(lines[i:i+2], "\n") != strings.Join(dbLines, "\n") {
+			log.Printf("%s@%s: invalid go.sum entries:\ngo.sum:\n\t%s\nsum.golang.org:\n\t%s", f1[0], f1[1], strings.Join(lines[i:i+2], "\n\t"), strings.Join(dbLines, "\n\t"))
+		}
+	}
+}
+
+// A GoSumDB is a client for a go.sum database.
+type GoSumDB struct {
+	url        string         // root url of database, without trailing slash
+	verifiers  note.Verifiers // accepted verifiers for signed trees
+	tileReader tileReader     // tlog.TileReader implementation
+	httpCache  parCache
+	httpClient http.Client
+
+	// latest accepted tree head
+	mu         sync.Mutex
+	latest     tlog.Tree
+	latestNote []byte // signed note
+}
+
+// parCache is a minimal simulation of cmd/go's par.Cache.
+// When this code moves into cmd/go, it should use the real par.Cache
+type parCache struct {
+}
+
+func (c *parCache) Do(key interface{}, f func() interface{}) interface{} {
+	return f()
+}
+
+// Lookup returns the go.sum lines for the given module path and version.
+func (db *GoSumDB) Lookup(path, vers string) ([]string, error) {
+	// TODO(rsc): !-encode the path.
+	data, err := db.httpGet(db.url + "/lookup/" + path + "@" + vers)
+	if err != nil {
+		return nil, err
+	}
+
+	id, text, treeMsg, err := tlog.ParseRecord(data)
+	if err != nil {
+		return nil, fmt.Errorf("%s@%s: %v", path, vers, err)
+	}
+	if err := db.updateLatest(treeMsg); err != nil {
+		return nil, fmt.Errorf("%s@%s: %v", path, vers, err)
+	}
+	if err := db.checkRecord(id, text); err != nil {
+		return nil, fmt.Errorf("%s@%s: %v", path, vers, err)
+	}
+
+	prefix := path + " " + vers + " "
+	prefixGoMod := path + " " + vers + "/go.mod "
+	var hashes []string
+	for _, line := range strings.Split(string(text), "\n") {
+		if strings.HasPrefix(line, prefix) || strings.HasPrefix(line, prefixGoMod) {
+			hashes = append(hashes, line)
+		}
+	}
+	return hashes, nil
+}
+
+// updateLatest updates db's idea of the latest tree head
+// to incorporate the signed tree head in msg.
+// If msg is before the current latest tree head,
+// updateLatest still checks that it fits into the known timeline.
+// updateLatest returns an error for non-malicious problems.
+// If it detects a fork in the tree history, it prints a detailed
+// message and calls log.Fatal.
+func (db *GoSumDB) updateLatest(msg []byte) error {
+	if len(msg) == 0 {
+		return nil
+	}
+	note, err := note.Open(msg, db.verifiers)
+	if err != nil {
+		return fmt.Errorf("reading tree note: %v\nnote:\n%s", err, msg)
+	}
+	tree, err := tlog.ParseTree([]byte(note.Text))
+	if err != nil {
+		return fmt.Errorf("reading tree: %v\ntree:\n%s", err, note.Text)
+	}
+
+Update:
+	for {
+		db.mu.Lock()
+		latest := db.latest
+		latestNote := db.latestNote
+		db.mu.Unlock()
+
+		switch {
+		case tree.N <= latest.N:
+			return db.checkTrees(tree, msg, latest, latestNote)
+
+		case tree.N > latest.N:
+			if err := db.checkTrees(latest, latestNote, tree, msg); err != nil {
+				return err
 			}
-			c <- &hashes[0]
-		}()
-
-		// The record lookup can be skipped in favor of using the /lookup response
-		// but we fetch record and test that they match, to check the server.
-		data, err = httpGet(verifierURL + "/record/" + fmt.Sprint(id))
-		if err != nil {
-			log.Printf("%s@%s: %v", f1[0], f1[1], err)
-			continue
-		}
-		if !bytes.Equal(data, ldata) {
-			log.Printf("%s@%s: different data from lookup and record:\n%s\n%s", f1[0], f1[1], hex.Dump(ldata), hex.Dump(data))
-			continue
-		}
-
-		hash := tlog.RecordHash(data)
-		hash1 := <-c
-		if hash1 == nil {
-			continue
-		}
-		if *hash1 != hash {
-			log.Printf("%s@%s: inconsistent records on notary!", f1[0], f1[1])
-			continue
-		}
-		if string(data) != lines[i]+lines[i+1] {
-			log.Printf("%s@%s: invalid go.sum entries:\nhave:\n\t%s\t%swant:\n\t%s", f1[0], f1[1], lines[i], lines[i+1], strings.Replace(string(data), "\n", "\n\t", -1))
+			db.mu.Lock()
+			if db.latest != latest {
+				if db.latest.N > latest.N {
+					db.mu.Unlock()
+					continue Update
+				}
+				log.Fatalf("go.sum database changed underfoot:\n\t%v ->\n\t%v", latest, db.latest)
+			}
+			db.latest = tree
+			db.latestNote = msg
+			db.mu.Unlock()
+			return nil
 		}
 	}
 }
 
-func init() {
-	http.DefaultClient.Timeout = 10 * time.Second
+// checkTrees checks that older (from olderNote) is contained in newer (from newerNote).
+// If an error occurs, such as malformed data or a network problem, checkTrees returns that error.
+// If on the other hand checkTrees finds evidence of misbehavior, it prepares a detailed
+// message and calls log.Fatal.
+func (db *GoSumDB) checkTrees(older tlog.Tree, olderNote []byte, newer tlog.Tree, newerNote []byte) error {
+	thr := tlog.TileHashReader(newer, &db.tileReader)
+	h, err := tlog.TreeHash(older.N, thr)
+	if err != nil {
+		return fmt.Errorf("checking tree#%d against tree#%d: %v", older.N, newer.N, err)
+	}
+	if h == older.Hash {
+		return nil
+	}
+
+	// Detected a fork in the tree timeline.
+	// Start by reporting the inconsistent signed tree notes.
+	var buf bytes.Buffer
+	fmt.Fprintf(&buf, "SECURITY ERROR\n")
+	fmt.Fprintf(&buf, "go.sum database server misbehavior detected!\n\n")
+	indent := func(b []byte) []byte {
+		return bytes.Replace(b, []byte("\n"), []byte("\n\t"), -1)
+	}
+	fmt.Fprintf(&buf, "old database:\n\t%v\n", indent(olderNote))
+	fmt.Fprintf(&buf, "new database:\n\t%v\n", indent(newerNote))
+
+	// The notes alone are not enough to prove the inconsistency.
+	// We also need to show that the newer note's tree hash for older.N
+	// does not match older.Hash. The consumer of this report could
+	// of course consult the server to try to verify the inconsistency,
+	// but we are holding all the bits we need to prove it right now,
+	// so we might as well print them and make the report not depend
+	// on the continued availability of the misbehaving server.
+	// Preparing this data only reuses the tiled hashes needed for
+	// tlog.TreeHash(older.N, thr) above, so assuming thr is caching tiles,
+	// there are no new access to the server here, and these operations cannot fail.
+	fmt.Fprintf(&buf, "proof of misbehavior:\n\t%v", h)
+	if p, err := tlog.ProveTree(newer.N, older.N, thr); err != nil {
+		fmt.Fprintf(&buf, "\tinternal error: %v\n", err)
+	} else if err := tlog.CheckTree(p, newer.N, newer.Hash, older.N, h); err != nil {
+		fmt.Fprintf(&buf, "\tinternal error: generated inconsistent proof\n")
+	} else {
+		for _, h := range p {
+			fmt.Fprintf(&buf, "\n\t%v", h)
+		}
+	}
+	log.Fatalf("%v", buf.String())
+	panic("not reached")
 }
 
-func httpGet(url string) ([]byte, error) {
-	start := time.Now()
-	resp, err := http.Get(url)
+// checkRecord checks that record #id's hash matches data.
+func (db *GoSumDB) checkRecord(id int64, data []byte) error {
+	db.mu.Lock()
+	tree := db.latest
+	db.mu.Unlock()
+
+	if id >= tree.N {
+		return fmt.Errorf("cannot validate record %d in tree of size %d", id, tree.N)
+	}
+	hashes, err := tlog.TileHashReader(tree, &db.tileReader).ReadHashes([]int64{tlog.StoredHashIndex(0, id)})
 	if err != nil {
-		return nil, err
+		return err
 	}
-	defer resp.Body.Close()
-	if resp.StatusCode != 200 {
-		return nil, fmt.Errorf("GET %v: %v", url, resp.Status)
+	if hashes[0] == tlog.RecordHash(data) {
+		return nil
 	}
-	data, err := ioutil.ReadAll(resp.Body)
-	if err != nil {
-		return nil, err
-	}
-	if *vflag {
-		fmt.Fprintf(os.Stderr, "%.3fs %s\n", time.Since(start).Seconds(), url)
-	}
-	return data, nil
+	return fmt.Errorf("cannot authenticate record data in server response")
 }
 
 type tileReader struct {
 	url     string
 	cache   map[tlog.Tile][]byte
 	cacheMu sync.Mutex
+	db      *GoSumDB
 }
 
 func (r *tileReader) Height() int {
@@ -206,13 +301,12 @@
 }
 
 func (r *tileReader) SaveTiles(tiles []tlog.Tile, data [][]byte) {
-	// no on-disk cache here
+	// TODO(rsc): On-disk cache in GOPATH.
 }
 
-// TODO(rsc): Move some variant of this to package tlog
-// once we are more sure of the API.
-
 func (r *tileReader) ReadTiles(tiles []tlog.Tile) ([][]byte, error) {
+	// TODO(rsc): Look in on-disk cache in GOPATH.
+
 	var wg sync.WaitGroup
 	out := make([][]byte, len(tiles))
 	errs := make([]error, len(tiles))
@@ -228,11 +322,11 @@
 		wg.Add(1)
 		go func(i int, tile tlog.Tile) {
 			defer wg.Done()
-			data, err := httpGet(r.url + tile.Path())
+			data, err := r.db.httpGet(r.url + tile.Path())
 			if err != nil && tile.W != 1<<uint(tile.H) {
 				fullTile := tile
 				fullTile.W = 1 << uint(tile.H)
-				if fullData, err1 := httpGet(r.url + fullTile.Path()); err1 == nil {
+				if fullData, err1 := r.db.httpGet(r.url + fullTile.Path()); err1 == nil {
 					data = fullData[:tile.W*tlog.HashSize]
 					err = nil
 				}
@@ -258,3 +352,32 @@
 
 	return out, nil
 }
+
+func (db *GoSumDB) httpGet(url string) ([]byte, error) {
+	type cached struct {
+		data []byte
+		err  error
+	}
+
+	c := db.httpCache.Do(url, func() interface{} {
+		start := time.Now()
+		resp, err := db.httpClient.Get(url)
+		if err != nil {
+			return cached{nil, err}
+		}
+		defer resp.Body.Close()
+		if resp.StatusCode != 200 {
+			return cached{nil, fmt.Errorf("GET %v: %v", url, resp.Status)}
+		}
+		data, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
+		if err != nil {
+			return cached{nil, err}
+		}
+		if *vflag {
+			fmt.Fprintf(os.Stderr, "%.3fs %s\n", time.Since(start).Seconds(), url)
+		}
+		return cached{data, nil}
+	}).(cached)
+
+	return c.data, c.err
+}