blob: 356743f0479a3c4d490f91d004cd1c4f26347a2d [file] [log] [blame]
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package maintner
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/golang/protobuf/proto"
"golang.org/x/build/maintner/maintpb"
"golang.org/x/build/maintner/reclog"
)
// NewNetworkMutationSource returns a mutation source from a master server.
// The server argument should be a URL to the JSON logs index.
func NewNetworkMutationSource(server, cacheDir string) MutationSource {
base, err := url.Parse(server)
if err != nil {
panic(fmt.Sprintf("invalid URL: %q", server))
}
return &netMutSource{
server: server,
base: base,
cacheDir: cacheDir,
}
}
type netMutSource struct {
server string
base *url.URL
cacheDir string
last []fileSeg
// Hooks for testing. If nil, unused:
testHookGetServerSegments func(context.Context, int64) ([]LogSegmentJSON, error)
testHookWaitAfterServerDupData func(context.Context) error
testHookSyncSeg func(context.Context, LogSegmentJSON) (fileSeg, error)
testHookFilePrefixSum224 func(file string, n int64) string
}
func (ns *netMutSource) GetMutations(ctx context.Context) <-chan MutationStreamEvent {
ch := make(chan MutationStreamEvent, 50)
go func() {
err := ns.sendMutations(ctx, ch)
final := MutationStreamEvent{Err: err}
if err == nil {
final.End = true
}
select {
case ch <- final:
case <-ctx.Done():
}
}()
return ch
}
// isNoInternetError reports whether the provided error is because there's no
// network connectivity.
func isNoInternetError(err error) bool {
if err == nil {
return false
}
switch err := err.(type) {
case *url.Error:
return isNoInternetError(err.Err)
case *net.OpError:
return isNoInternetError(err.Err)
case *net.DNSError:
// Trashy:
return err.Err == "no such host"
default:
log.Printf("Unknown error type %T: %#v", err, err)
return false
}
}
func (ns *netMutSource) locallyCachedSegments() (segs []fileSeg, err error) {
defer func() {
if err != nil {
log.Printf("No network connection and failed to use local cache: %v", err)
} else {
log.Printf("No network connection; using %d locally cached segments.", len(segs))
}
}()
fis, err := ioutil.ReadDir(ns.cacheDir)
if err != nil {
return nil, err
}
fiMap := map[string]os.FileInfo{}
segHex := map[int]string{}
segGrowing := map[int]bool{}
for _, fi := range fis {
name := fi.Name()
if !strings.HasSuffix(name, ".mutlog") {
continue
}
fiMap[name] = fi
if len(name) == len("0000.6897fab4d3afcda332424b2a2a1a4469021074282bc7be5606aaa221.mutlog") {
num, err := strconv.Atoi(name[:4])
if err != nil {
continue
}
segHex[num] = strings.TrimSuffix(name[5:], ".mutlog")
} else if strings.HasSuffix(name, ".growing.mutlog") {
num, err := strconv.Atoi(name[:4])
if err != nil {
continue
}
segGrowing[num] = true
}
}
for num := 0; ; num++ {
if hex, ok := segHex[num]; ok {
name := fmt.Sprintf("%04d.%s.mutlog", num, hex)
segs = append(segs, fileSeg{
seg: num,
file: filepath.Join(ns.cacheDir, name),
size: fiMap[name].Size(),
sha224: hex,
})
continue
}
if segGrowing[num] {
name := fmt.Sprintf("%04d.growing.mutlog", num)
slurp, err := ioutil.ReadFile(filepath.Join(ns.cacheDir, name))
if err != nil {
return nil, err
}
segs = append(segs, fileSeg{
seg: num,
file: filepath.Join(ns.cacheDir, name),
size: int64(len(slurp)),
sha224: fmt.Sprintf("%x", sha256.Sum224(slurp)),
})
}
return segs, nil
}
}
// waitSizeNot optionally specifies that the request should long-poll waiting for the server
// to have a sum of log segment sizes different than the value specified.
func (ns *netMutSource) getServerSegments(ctx context.Context, waitSizeNot int64) ([]LogSegmentJSON, error) {
if fn := ns.testHookGetServerSegments; fn != nil {
return fn(ctx, waitSizeNot)
}
logsURL := ns.server
if waitSizeNot > 0 {
logsURL += fmt.Sprintf("?waitsizenot=%d", waitSizeNot)
}
for {
req, err := http.NewRequest("GET", logsURL, nil)
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
// If we're doing a long poll and the server replies
// with a 304 response, that means the server is just
// heart-beating us and trying to get a response back
// within its various deadlines. But we should just
// try again.
if waitSizeNot > 0 && res.StatusCode == http.StatusNotModified {
res.Body.Close()
continue
}
defer res.Body.Close()
if res.StatusCode != 200 {
return nil, fmt.Errorf("%s: %v", ns.server, res.Status)
}
var segs []LogSegmentJSON
err = json.NewDecoder(res.Body).Decode(&segs)
if err != nil {
return nil, fmt.Errorf("decoding %s JSON: %v", ns.server, err)
}
return segs, nil
}
}
func (ns *netMutSource) getNewSegments(ctx context.Context) ([]fileSeg, error) {
for {
sumLast := sumSegSize(ns.last)
segs, err := ns.getServerSegments(ctx, sumLast)
if isNoInternetError(err) {
if sumLast == 0 {
return ns.locallyCachedSegments()
}
log.Printf("No internet; blocking.")
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(15 * time.Second):
continue
}
}
if err != nil {
return nil, err
}
// TODO: optimization: if already on GCE, skip sync to disk part and just
// read from network. fast & free network inside.
var fileSegs []fileSeg
for _, seg := range segs {
fileSeg, err := ns.syncSeg(ctx, seg)
if err != nil {
return nil, fmt.Errorf("syncing segment %d: %v", seg.Number, err)
}
fileSegs = append(fileSegs, fileSeg)
}
sumCommon := ns.sumCommonPrefixSize(fileSegs, ns.last)
if sumLast != sumCommon {
return nil, ErrSplit
}
sumCur := sumSegSize(fileSegs)
if sumCommon == sumCur {
// Nothing new. This shouldn't happen once the
// server is updated to respect the
// "?waitsizenot=NNN" long polling parameter.
// But keep this brief pause as a backup to
// prevent spinning and because clients &
// servers won't be updated simultaneously.
if ns.testHookGetServerSegments == nil {
log.Printf("maintner.netsource: server returned unchanged log segments; old server?")
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(1 * time.Second):
}
continue
}
ns.last = fileSegs
newSegs := trimLeadingSegBytes(fileSegs, sumCommon)
return newSegs, nil
}
}
func trimLeadingSegBytes(in []fileSeg, trim int64) []fileSeg {
// First trim off whole segments, sharing the same underlying memory.
for len(in) > 0 && trim >= in[0].size {
trim -= in[0].size
in = in[1:]
}
if len(in) == 0 {
return nil
}
// Now copy, since we'll be modifying the first element.
out := append([]fileSeg(nil), in...)
out[0].skip = trim
return out
}
// filePrefixSum224 returns the lowercase hex SHA-224 of the first n bytes of file.
func (ns *netMutSource) filePrefixSum224(file string, n int64) string {
if fn := ns.testHookFilePrefixSum224; fn != nil {
return fn(file, n)
}
f, err := os.Open(file)
if err != nil {
if !os.IsNotExist(err) {
log.Print(err)
}
return ""
}
defer f.Close()
h := sha256.New224()
_, err = io.CopyN(h, f, n)
if err != nil {
log.Print(err)
return ""
}
return fmt.Sprintf("%x", h.Sum(nil))
}
func sumSegSize(segs []fileSeg) (sum int64) {
for _, seg := range segs {
sum += seg.size
}
return
}
func (ns *netMutSource) sumCommonPrefixSize(a, b []fileSeg) (sum int64) {
for len(a) > 0 && len(b) > 0 {
sa, sb := a[0], b[0]
if sa.sha224 == sb.sha224 {
// Whole chunk in common.
sum += sa.size
a, b = a[1:], b[1:]
continue
}
if sa.size == sb.size {
// If they're the same size but different
// sums, it must've forked.
return
}
// See if one chunk is a prefix of the other.
// Make sa be the smaller one.
if sb.size < sa.size {
sa, sb = sb, sa
}
// Hash the beginning of the bigger size.
bPrefixSum := ns.filePrefixSum224(sb.file, sa.size)
if bPrefixSum == sa.sha224 {
sum += sa.size
}
break
}
return
}
func (ns *netMutSource) sendMutations(ctx context.Context, ch chan<- MutationStreamEvent) error {
newSegs, err := ns.getNewSegments(ctx)
if err != nil {
return err
}
return foreachFileSeg(newSegs, func(seg fileSeg) error {
f, err := os.Open(seg.file)
if err != nil {
return err
}
defer f.Close()
if seg.skip > 0 {
if _, err := f.Seek(seg.skip, io.SeekStart); err != nil {
return err
}
}
return reclog.ForeachRecord(io.LimitReader(f, seg.size-seg.skip), seg.skip, func(off int64, hdr, rec []byte) error {
m := new(maintpb.Mutation)
if err := proto.Unmarshal(rec, m); err != nil {
return err
}
select {
case ch <- MutationStreamEvent{Mutation: m}:
return nil
case <-ctx.Done():
return ctx.Err()
}
})
})
}
func foreachFileSeg(segs []fileSeg, fn func(seg fileSeg) error) error {
for _, seg := range segs {
if err := fn(seg); err != nil {
return err
}
}
return nil
}
// TODO: add a constructor for this? or simplify it. make it Size +
// File + embedded LogSegmentJSON?
type fileSeg struct {
seg int
file string // full path
sha224 string
skip int64
size int64
}
func (ns *netMutSource) syncSeg(ctx context.Context, seg LogSegmentJSON) (fileSeg, error) {
if fn := ns.testHookSyncSeg; fn != nil {
return fn(ctx, seg)
}
isFinalSeg := !strings.HasPrefix(seg.URL, "https://storage.googleapis.com/")
relURL, err := url.Parse(seg.URL)
if err != nil {
return fileSeg{}, err
}
segURL := ns.base.ResolveReference(relURL)
frozen := filepath.Join(ns.cacheDir, fmt.Sprintf("%04d.%s.mutlog", seg.Number, seg.SHA224))
// Do we already have it? Files named in their final form with the sha224 are considered
// complete and immutable.
if fi, err := os.Stat(frozen); err == nil && fi.Size() == seg.Size {
return fileSeg{seg: seg.Number, file: frozen, size: fi.Size(), sha224: seg.SHA224}, nil
}
// See how much data we already have in the partial growing file.
partial := filepath.Join(ns.cacheDir, fmt.Sprintf("%04d.growing.mutlog", seg.Number))
have, _ := ioutil.ReadFile(partial)
if int64(len(have)) == seg.Size {
got224 := fmt.Sprintf("%x", sha256.Sum224(have))
if got224 == seg.SHA224 {
if !isFinalSeg {
// This was growing for us, but the server started a new growing segment.
if err := os.Rename(partial, frozen); err != nil {
return fileSeg{}, err
}
return fileSeg{seg: seg.Number, file: frozen, sha224: seg.SHA224, size: seg.Size}, nil
}
return fileSeg{seg: seg.Number, file: partial, sha224: seg.SHA224, size: seg.Size}, nil
}
}
// Otherwise, download.
req, err := http.NewRequest("GET", segURL.String(), nil)
if err != nil {
return fileSeg{}, err
}
req = req.WithContext(ctx)
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", len(have), seg.Size-1))
log.Printf("Downloading %d bytes of %s ...", seg.Size-int64(len(have)), segURL)
res, err := http.DefaultClient.Do(req)
if err != nil {
return fileSeg{}, err
}
if res.StatusCode != 200 && res.StatusCode != 206 {
return fileSeg{}, fmt.Errorf("%s: %s", segURL.String(), res.Status)
}
slurp, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
return fileSeg{}, err
}
var newContents []byte
if int64(len(slurp)) == seg.Size {
newContents = slurp
} else if int64(len(have)+len(slurp)) == seg.Size {
newContents = append(have, slurp...)
}
got224 := fmt.Sprintf("%x", sha256.Sum224(newContents))
if got224 != seg.SHA224 {
if len(have) == 0 {
return fileSeg{}, errors.New("corrupt download")
}
// Try again
os.Remove(partial)
return ns.syncSeg(ctx, seg)
}
tf, err := ioutil.TempFile(ns.cacheDir, "tempseg")
if err != nil {
return fileSeg{}, err
}
if _, err := tf.Write(newContents); err != nil {
return fileSeg{}, err
}
if err := tf.Close(); err != nil {
return fileSeg{}, err
}
finalName := partial
if !isFinalSeg {
finalName = frozen
}
if err := os.Rename(tf.Name(), finalName); err != nil {
return fileSeg{}, err
}
log.Printf("wrote %v", finalName)
return fileSeg{seg: seg.Number, file: finalName, size: seg.Size, sha224: seg.SHA224}, nil
}
type LogSegmentJSON struct {
Number int `json:"number"`
Size int64 `json:"size"`
SHA224 string `json:"sha224"`
URL string `json:"url"`
}