blob: b11f8e2033117d390fd028a16f6d65b0ea65b973 [file] [log] [blame]
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package client
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"golang.org/x/vuln/internal"
"golang.org/x/vuln/internal/web"
"golang.org/x/vuln/osv"
)
// testCache for testing purposes
type testCache struct {
indexMap map[string]DBIndex
indexStamp map[string]time.Time
vulnMap map[string]map[string][]*osv.Entry
}
func newTestCache() *testCache {
return &testCache{
indexMap: make(map[string]DBIndex),
indexStamp: make(map[string]time.Time),
vulnMap: make(map[string]map[string][]*osv.Entry),
}
}
func (tc *testCache) ReadIndex(db string) (DBIndex, time.Time, error) {
index, ok := tc.indexMap[db]
if !ok {
return nil, time.Time{}, nil
}
stamp, ok := tc.indexStamp[db]
if !ok {
return nil, time.Time{}, nil
}
return index, stamp, nil
}
func (tc *testCache) WriteIndex(db string, index DBIndex, stamp time.Time) error {
tc.indexMap[db] = index
tc.indexStamp[db] = stamp
return nil
}
func (tc *testCache) ReadEntries(db, module string) ([]*osv.Entry, error) {
mMap, ok := tc.vulnMap[db]
if !ok {
return nil, nil
}
return mMap[module], nil
}
func (tc *testCache) WriteEntries(db, module string, entries []*osv.Entry) error {
mMap, ok := tc.vulnMap[db]
if !ok {
mMap = make(map[string][]*osv.Entry)
tc.vulnMap[db] = mMap
}
mMap[module] = nil
for _, e := range entries {
e2 := *e
e2.Details = "cached: " + e2.Details
mMap[module] = append(mMap[module], &e2)
}
return nil
}
func newTestServer() *httptest.Server {
mux := http.NewServeMux()
mux.Handle("/", http.FileServer(http.Dir("testdata/vulndb")))
return httptest.NewServer(mux)
}
var localURL = func() string {
absDir, err := filepath.Abs("testdata/vulndb")
if err != nil {
panic(fmt.Sprintf("failed to read testdata/vulndb: %v", err))
}
u, err := web.URLFromFilePath(absDir)
if err != nil {
panic(fmt.Sprintf("failed to read testdata/vulndb: %v", err))
}
return u.String()
}()
func TestByModule(t *testing.T) {
if runtime.GOOS == "js" {
t.Skip("skipping test: no network on js")
}
ctx := context.Background()
// Create a local http database.
srv := newTestServer()
defer srv.Close()
const (
modulePath = "github.com/BeeGo/beego"
detailStart = "Routes in the beego HTTP router"
modulePathLowercase = "github.com/tidwall/gjson"
detailStartLowercase = "Due to improper bounds checking"
)
for _, test := range []struct {
name string
source string
module string
cache Cache
detailPrefix string
wantVulns int
}{
// Test the http client without any cache.
{name: "http-no-cache", source: srv.URL, module: modulePath,
cache: nil, detailPrefix: detailStart, wantVulns: 3},
{name: "http-cache", source: srv.URL, module: modulePath,
cache: newTestCache(), detailPrefix: "cached: Route", wantVulns: 3},
// Repeat the same for local file client.
{name: "file-no-cache", source: localURL, module: modulePath,
cache: nil, detailPrefix: detailStart, wantVulns: 3},
// Cache does not play a role in local file databases.
{name: "file-cache", source: localURL, module: modulePath,
cache: newTestCache(), detailPrefix: detailStart, wantVulns: 3},
// Test all-lowercase module path.
{name: "lower-http", source: srv.URL, module: modulePathLowercase,
cache: nil, detailPrefix: detailStartLowercase, wantVulns: 4},
{name: "lower-file", source: localURL, module: modulePathLowercase,
cache: nil, detailPrefix: detailStartLowercase, wantVulns: 4},
} {
t.Run(test.name, func(t *testing.T) {
client, err := NewClient([]string{test.source}, Options{HTTPCache: test.cache})
if err != nil {
t.Fatal(err)
}
// First call fills cache, if present.
if _, err := client.GetByModule(ctx, test.module); err != nil {
t.Fatal(err)
}
vulns, err := client.GetByModule(ctx, test.module)
if err != nil {
t.Fatal(err)
}
if got, want := len(vulns), test.wantVulns; got != want {
t.Fatalf("got %d vulns for %s, want %d", got, test.module, want)
}
if v := vulns[0]; !strings.HasPrefix(v.Details, test.detailPrefix) {
got := v.Details
if len(got) > len(test.detailPrefix) {
got = got[:len(test.detailPrefix)] + "..."
}
t.Errorf("got\n\t%s\nbut should start with\n\t%s", got, test.detailPrefix)
}
})
}
}
// TestMustUseIndex checks that httpSource in NewClient(...)
// - always calls Index function before making an http
// request in GetByModule.
// - if an http request was made, then the module path
// must be in the index.
//
// This test serves as an approximate mechanism to make sure
// that we only send module info to the vuln db server if the
// module has known vulnerabilities. Otherwise, we might send
// unknown private module information to the db.
func TestMustUseIndex(t *testing.T) {
if runtime.GOOS == "js" {
t.Skip("skipping test: no network on js")
}
ctx := context.Background()
// Create a local http database.
srv := newTestServer()
defer srv.Close()
// List of modules to query, some are repeated to exercise cache hits.
modulePaths := []string{"github.com/BeeGo/beego", "github.com/tidwall/gjson", "net/http", "abc.xyz", "github.com/BeeGo/beego"}
for _, cache := range []Cache{newTestCache(), nil} {
clt, err := NewClient([]string{srv.URL}, Options{HTTPCache: cache})
if err != nil {
t.Fatal(err)
}
hs := clt.(*client).sources[0].(*httpSource)
for _, modulePath := range modulePaths {
indexCalls := hs.indexCalls
httpCalls := hs.httpCalls
if _, err := clt.GetByModule(ctx, modulePath); err != nil {
t.Fatal(err)
}
// Number of index Calls should be increased.
if hs.indexCalls == indexCalls {
t.Errorf("GetByModule(ctx, %s) [cache:%t] did not call Index(...)", modulePath, cache != nil)
}
// If http request was made, then the modulePath must be in the index.
if hs.httpCalls > httpCalls {
index, err := hs.Index(ctx)
if err != nil {
t.Fatal(err)
}
_, present := index[modulePath]
if !present {
t.Errorf("GetByModule(ctx, %s) [cache:%t] issued http request for module not in Index(...)", modulePath, cache != nil)
}
}
}
}
}
func TestSpecialPaths(t *testing.T) {
if runtime.GOOS == "js" {
t.Skip("skipping test: no network on js")
}
ctx := context.Background()
srv := newTestServer()
defer srv.Close()
for _, test := range []struct {
name string
source string
}{
{"local", localURL},
{"http", srv.URL},
} {
t.Run(test.name, func(t *testing.T) {
client, err := NewClient([]string{test.source}, Options{})
if err != nil {
t.Fatal(err)
}
for specialPath := range specialCaseModulePaths {
t.Run(test.name+"-"+specialPath, func(t *testing.T) {
if _, err := client.GetByModule(ctx, specialPath); err != nil {
t.Fatal(err)
}
})
}
})
}
}
func TestCorrectFetchesNoCache(t *testing.T) {
fetches := map[string]int{}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fetches[r.URL.Path]++
if r.URL.Path == "/index.json" {
j, _ := json.Marshal(DBIndex{
"m.com/a": time.Now(),
"m.com/b": time.Now(),
})
w.Write(j)
} else {
w.Write([]byte("[]"))
}
}))
defer ts.Close()
hs := &httpSource{url: ts.URL, c: new(http.Client)}
for _, module := range []string{"m.com/a", "m.com/b", "m.com/c"} {
if _, err := hs.GetByModule(context.Background(), module); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
expectedFetches := map[string]int{"/index.json": 3,
"/m.com/a.json": 1, "/m.com/b.json": 1}
if !reflect.DeepEqual(fetches, expectedFetches) {
t.Errorf("unexpected fetches, got %v, want %v", fetches, expectedFetches)
}
}
// Make sure that a cached index is used in the case it is stale
// but there were no changes to it at the server side.
func TestCorrectFetchesNoChangeIndex(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/index.json" {
w.WriteHeader(http.StatusNotModified)
}
}))
defer ts.Close()
url, _ := url.Parse(ts.URL)
// set timestamp so that cached index is stale,
// i.e., more than two hours old.
timeStamp := time.Now().Add(time.Hour * (-3))
index := DBIndex{"a": timeStamp}
cache := newTestCache()
cache.WriteIndex(url.Hostname(), index, timeStamp)
e := &osv.Entry{
ID: "ID1",
Details: "details",
Modified: timeStamp,
}
cache.WriteEntries(url.Hostname(), "a", []*osv.Entry{e})
client, err := NewClient([]string{ts.URL}, Options{HTTPCache: cache})
if err != nil {
t.Fatal(err)
}
gots, err := client.GetByModule(context.Background(), "a")
if err != nil {
t.Fatal(err)
}
if len(gots) != 1 {
t.Errorf("got %d vulns, want 1", len(gots))
} else {
got := gots[0]
want := *e
want.Details = "cached: " + want.Details
if !cmp.Equal(got, &want) {
t.Errorf("\ngot %+v\nwant %+v", got, &want)
}
}
}
func TestClientByID(t *testing.T) {
if runtime.GOOS == "js" {
t.Skip("skipping test: no network on js")
}
const vulnID = "GO-2022-0463"
want := mustReadEntry(t, vulnID)
srv := newTestServer()
defer srv.Close()
for _, test := range []struct {
name string
source string
in string
want *osv.Entry
}{
{name: "http", in: vulnID, source: srv.URL, want: want},
{name: "file", in: vulnID, source: localURL, want: want},
{name: "http", in: "NO-SUCH-VULN", source: srv.URL, want: nil},
{name: "http", in: "NO-SUCH-VULN", source: localURL, want: nil},
} {
t.Run(test.name, func(t *testing.T) {
client, err := NewClient([]string{test.source}, Options{})
if err != nil {
t.Fatal(err)
}
got, err := client.GetByID(context.Background(), vulnID)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(got, want) {
t.Errorf("got\n%+v\nwant\n%+v", got, want)
}
})
}
}
func mustReadEntry(t *testing.T, vulnID string) *osv.Entry {
t.Helper()
data, err := os.ReadFile(filepath.Join("testdata", "vulndb", internal.IDDirectory, vulnID+".json"))
if err != nil {
t.Fatal(err)
}
var e *osv.Entry
if err := json.Unmarshal(data, &e); err != nil {
t.Fatal(err)
}
return e
}
func TestClientByAlias(t *testing.T) {
if runtime.GOOS == "js" {
t.Skip("skipping test: no network on js")
}
const alias = "CVE-2015-5739"
want := []*osv.Entry{mustReadEntry(t, "GO-2021-0157"), mustReadEntry(t, "GO-2021-0159")}
srv := newTestServer()
defer srv.Close()
for _, test := range []struct {
name string
source string
}{
{name: "http", source: srv.URL},
{name: "file", source: localURL},
} {
t.Run(test.name, func(t *testing.T) {
client, err := NewClient([]string{test.source}, Options{})
if err != nil {
t.Fatal(err)
}
got, err := client.GetByAlias(context.Background(), alias)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(got, want) {
t.Errorf("got\n%+v\nwant\n%+v", got, want)
}
})
}
}
func TestListIDs(t *testing.T) {
if runtime.GOOS == "js" {
t.Skip("skipping test: no network on js")
}
srv := newTestServer()
defer srv.Close()
want := []string{"GO-2022-0463", "GO-2022-0569", "GO-2022-0572"}
for _, test := range []struct {
name string
source string
}{
{name: "http", source: srv.URL},
{name: "file", source: localURL},
} {
t.Run(test.name, func(t *testing.T) {
client, err := NewClient([]string{test.source}, Options{})
if err != nil {
t.Fatal(err)
}
got, err := client.ListIDs(context.Background())
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(got, want) {
t.Errorf("got\n%+v\nwant\n%+v", got, want)
}
})
}
}
func TestLastModifiedTime(t *testing.T) {
if runtime.GOOS == "js" {
t.Skip("skipping test: no network on js")
}
srv := newTestServer()
defer srv.Close()
info, err := os.Stat(filepath.Join("testdata", "vulndb", "index.json"))
if err != nil {
t.Fatal(err)
}
want := info.ModTime().Truncate(time.Second).In(time.UTC)
for _, test := range []struct {
name string
source string
}{
{name: "http", source: srv.URL},
{name: "file", source: localURL},
} {
t.Run(test.name, func(t *testing.T) {
client, err := NewClient([]string{test.source}, Options{})
if err != nil {
t.Fatal(err)
}
got, err := client.LastModifiedTime(context.Background())
if err != nil {
t.Fatal(err)
}
got = got.Truncate(time.Second)
if !got.Equal(want) {
t.Errorf("got %s, want %s", got, want)
}
})
}
}