blob: 7364f8a4acf8a04cb9abee8f20c6add627e1138e [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pkgsite
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"sync"
"testing"
"time"
"golang.org/x/time/rate"
"golang.org/x/vulndb/internal/stdlib"
"golang.org/x/vulndb/internal/worker/log"
)
type Client struct {
url string
cache *cache
}
func Default() *Client {
return New(URL)
}
func New(url string) *Client {
return &Client{
url: url,
cache: newCache(),
}
}
func (pc *Client) SetKnownModules(known []string) {
pc.cache.setKnownModules(known)
}
// Limit pkgsite requests to this many per second.
const pkgsiteQPS = 20
var (
// The limiter used to throttle pkgsite requests.
// The second argument to rate.NewLimiter is the burst, which
// basically lets you exceed the rate briefly.
pkgsiteRateLimiter = rate.NewLimiter(rate.Every(1*time.Second/pkgsiteQPS), 3)
)
var URL = "https://pkg.go.dev"
// KnownModule reports whether pkgsite knows that path actually refers
// to a module or package path.
func (pc *Client) KnownModule(ctx context.Context, path string) (bool, error) {
return pc.lookupEndpoint(ctx, moduleEndpoint(path))
}
// KnownAtVersion reports whether pkgsite knows that the path exists at the given
// bare version.
func (pc *Client) KnownAtVersion(ctx context.Context, path, version string) (bool, error) {
prefix := "v"
if stdlib.Contains(path) {
prefix = "go"
}
return pc.lookupEndpoint(ctx, "/"+path+"@"+prefix+version)
}
func (pc *Client) lookupEndpoint(ctx context.Context, endpoint string) (bool, error) {
found, ok := pc.cache.lookup(endpoint)
if ok {
return found, nil
}
// Pause to maintain a max QPS.
if err := pkgsiteRateLimiter.Wait(ctx); err != nil {
return false, err
}
start := time.Now()
res, err := http.Head(pc.url + endpoint)
var status string
if err == nil {
status = strconv.Quote(res.Status)
}
log.With(
"latency", time.Since(start),
"status", status,
"error", err,
).Debugf(ctx, "checked if %s is known to pkgsite", endpoint)
if err != nil {
return false, err
}
known := res.StatusCode == http.StatusOK
pc.cache.add(endpoint, known)
return known, nil
}
func (pc *Client) URL() string {
return pc.url
}
func readKnown(r io.Reader) (map[string]bool, error) {
b, err := io.ReadAll(r)
if err != nil {
return nil, err
}
if len(b) == 0 {
return nil, fmt.Errorf("no data")
}
seen := make(map[string]bool)
if err := json.Unmarshal(b, &seen); err != nil {
return nil, err
}
return seen, nil
}
func (c *cache) writeKnown(w io.Writer) error {
c.mu.Lock()
defer c.mu.Unlock()
b, err := json.MarshalIndent(c.seen, "", " ")
if err != nil {
return err
}
_, err = w.Write(b)
return err
}
// cacheFile returns a default cache file that can be used as an input
// to testClient.
//
// For testing.
func cacheFile(t *testing.T) (*os.File, error) {
filename := filepath.Join("testdata", "pkgsite", t.Name()+".json")
if err := os.MkdirAll(filepath.Dir(filename), 0755); err != nil {
return nil, err
}
// If the file doesn't exist, or is empty, add an empty map.
fi, err := os.Stat(filename)
if err != nil || fi.Size() == 0 {
if err := os.WriteFile(filename, []byte("{}\n"), 0644); err != nil {
return nil, err
}
}
f, err := os.OpenFile(filename, os.O_RDWR, 0644)
if err != nil {
return nil, err
}
t.Cleanup(func() {
if err := f.Close(); err != nil {
t.Error(err)
}
})
return f, nil
}
// TestClient returns a pkgsite client that talks to either
// a fake server or the real pkg.go.dev, depending on the useRealPkgsite value.
//
// For testing.
func TestClient(t *testing.T, useRealPkgsite bool) (*Client, error) {
cf, err := cacheFile(t)
if err != nil {
return nil, err
}
return testClient(t, useRealPkgsite, cf)
}
func testClient(t *testing.T, useRealPkgsite bool, rw io.ReadWriter) (*Client, error) {
if useRealPkgsite {
c := Default()
t.Cleanup(func() {
err := c.cache.writeKnown(rw)
if err != nil {
t.Error(err)
}
})
return c, nil
}
known, err := readKnown(rw)
if err != nil {
return nil, fmt.Errorf("could not read known modules: %w", err)
}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !known[r.URL.Path] {
http.Error(w, "unknown", http.StatusNotFound)
}
}))
t.Cleanup(s.Close)
return New(s.URL), nil
}
type cache struct {
mu sync.Mutex
// Endpoints already seen.
seen map[string]bool
// Does the cache contain all known endpoints
complete bool
}
func newCache() *cache {
return &cache{
seen: make(map[string]bool),
complete: false,
}
}
func (c *cache) setKnownModules(known []string) {
c.mu.Lock()
defer c.mu.Unlock()
for _, km := range known {
c.seen[moduleEndpoint(km)] = true
}
c.complete = true
}
func moduleEndpoint(path string) string {
return "/mod/" + path
}
func (c *cache) lookup(endpoint string) (known bool, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
// In the cache.
if known, ok := c.seen[endpoint]; ok {
return known, true
}
// Not in the cache, but the cache is complete, so this
// endpoint is not known.
if c.complete {
return false, true
}
// We can't make a statement about this endpoint.
return false, false
}
func (c *cache) add(endpoint string, known bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.seen[endpoint] = known
}