blob: c71b20c7fad2866e5387b3c164ea24e761b1bcd3 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package server
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
"golang.org/x/benchmarks/sweet/benchmarks/internal/driver"
"golang.org/x/benchmarks/sweet/common/diagnostics"
)
// FetchDiagnostic reads a profile or trace from the pprof endpoint at host. The
// returned stop function finalizes the diagnostic file on disk and returns the
// total size in bytes. Because of limitations of net/http/pprof, this cannot
// actually stop collection on the server side, so stop should only be called
// when the server is about to be shut down.
func FetchDiagnostic(host string, diag *driver.Diagnostics, typ diagnostics.Type, name string) (stop func()) {
if typ.HTTPEndpoint() == "" {
panic("diagnostic " + string(typ) + " has no endpoint")
}
if !driver.DiagnosticEnabled(typ) {
return func() {}
}
// If this is a snapshot-type diagnostic, wait until the end to collect it.
if typ.IsSnapshot() {
return func() {
err := collectTo(context.Background(), host, diag, typ, name)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to read diagnostic %s: %v", typ, err)
}
}
}
// Otherwise, start collecting it now. If it can be truncated, then we try
// to collect it in one long run and cut if off when stop is called.
// If it can be merged, we can collect several of them.
var wg sync.WaitGroup
wg.Add(1)
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer wg.Done()
// If we can't truncate this diagnostic, make sure we collect it at
// least once. This is important for PGO, which first does a profiling run.
ctx1 := ctx
if typ.CanMerge() && !typ.CanTruncate() {
var cancel1 func()
ctx1, cancel1 = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel1()
}
for {
err := collectTo(ctx1, host, diag, typ, name)
ctx1 = ctx
if err != nil {
if !errors.Is(err, context.Canceled) {
fmt.Fprintf(os.Stderr, "failed to read diagnostic %s: %v", typ, err)
}
break
}
if !typ.CanMerge() {
break
}
}
}()
return func() {
// Stop the loop.
cancel()
wg.Wait()
}
}
func collectTo(ctx context.Context, host string, diag *driver.Diagnostics, typ diagnostics.Type, name string) error {
// Construct the endpoint URL
var endpoint string
endpoint = fmt.Sprintf("http://%s/%s", host, typ.HTTPEndpoint())
if typ.CanMerge() && !typ.CanTruncate() {
// Collect in lots of small increments because we won't be able to just
// stop it.
endpoint += "?seconds=1"
} else if typ.CanTruncate() {
// Collect a long run that we can cut off.
endpoint += "?seconds=999999"
}
// Start profile collection.
req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
// Read into a diagnostic file
f, err := diag.CreateNamed(typ, name)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
if err == nil || typ.CanTruncate() {
// If we got a complete file, or it's fine to truncate it anyway, commit
// the diagnostic file.
f.Close()
f.Commit()
}
return err
}
// TODO: Delete below here
func CollectDiagnostic(host, tmpDir, benchName string, typ diagnostics.Type) (int64, error) {
// We attempt to use the benchmark name to create a temp file so replace all
// path separators with "_".
benchName = strings.Replace(benchName, "/", "_", -1)
benchName = strings.Replace(benchName, string(os.PathSeparator), "_", -1)
f, err := os.CreateTemp(tmpDir, benchName+"."+string(typ))
if err != nil {
return 0, err
}
defer f.Close()
resp, err := http.Get(fmt.Sprintf("http://%s/debug/pprof/%s", host, endpoint(typ)))
if err != nil {
return 0, err
}
defer resp.Body.Close()
n, err := io.Copy(f, resp.Body)
if err != nil {
return 0, err
}
return n, driver.CopyDiagnosticData(f.Name(), typ, benchName)
}
func endpoint(typ diagnostics.Type) string {
switch typ {
case diagnostics.CPUProfile:
return "profile?seconds=1"
case diagnostics.MemProfile:
return "heap"
case diagnostics.Trace:
return "trace?seconds=1"
}
panic("diagnostic " + string(typ) + " has no endpoint")
}
func PollDiagnostic(host, tmpDir, benchName string, typ diagnostics.Type) (stop func() uint64) {
// TODO(mknyszek): This is kind of a hack. We really should find a way to just
// enable diagnostic collection at a lower level for the entire server run.
var stopc chan struct{}
var wg sync.WaitGroup
var size uint64
wg.Add(1)
stopc = make(chan struct{})
go func() {
defer wg.Done()
for {
select {
case <-stopc:
return
default:
}
n, err := CollectDiagnostic(host, tmpDir, benchName, typ)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to read diagnostic %s: %v", typ, err)
return
}
size += uint64(n)
}
}()
return func() uint64 {
// Stop the loop.
close(stopc)
wg.Wait()
return size
}
}