blob: 09c8f11025a2d9c271ddc92e0d6a913d762ba8ed [file] [log] [blame]
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.18
// +build go1.18
package main
import (
"fmt"
"os"
"path/filepath"
"reflect"
"testing"
"time"
"golang.org/x/sync/errgroup"
"golang.org/x/vuln/client"
"golang.org/x/vuln/osv"
)
func TestCache(t *testing.T) {
tmpDir := t.TempDir()
cache := &fsCache{rootDir: tmpDir}
dbName := "vulndb.golang.org"
_, _, err := cache.ReadIndex(dbName)
if err != nil {
t.Fatalf("ReadIndex failed for non-existent database: %v", err)
}
if err = os.Mkdir(filepath.Join(tmpDir, dbName), 0777); err != nil {
t.Fatalf("os.Mkdir failed: %v", err)
}
_, _, err = cache.ReadIndex(dbName)
if err != nil {
t.Fatalf("ReadIndex failed for database without cached index: %v", err)
}
now := time.Now()
expectedIdx := client.DBIndex{
"a.vuln.example.com": time.Time{}.Add(time.Hour),
"b.vuln.example.com": time.Time{}.Add(time.Hour * 2),
"c.vuln.example.com": time.Time{}.Add(time.Hour * 3),
}
if err = cache.WriteIndex(dbName, expectedIdx, now); err != nil {
t.Fatalf("WriteIndex failed to write index: %v", err)
}
idx, retrieved, err := cache.ReadIndex(dbName)
if err != nil {
t.Fatalf("ReadIndex failed for database with cached index: %v", err)
}
if !reflect.DeepEqual(idx, expectedIdx) {
t.Errorf("ReadIndex returned unexpected index, got:\n%s\nwant:\n%s", idx, expectedIdx)
}
if !retrieved.Equal(now) {
t.Errorf("ReadIndex returned unexpected retrieved: got %s, want %s", retrieved, now)
}
if _, err = cache.ReadEntries(dbName, "vuln.example.com"); err != nil {
t.Fatalf("ReadEntires failed for non-existent package: %v", err)
}
expectedEntries := []*osv.Entry{
{ID: "001"},
{ID: "002"},
{ID: "003"},
}
if err := cache.WriteEntries(dbName, "vuln.example.com", expectedEntries); err != nil {
t.Fatalf("WriteEntries failed: %v", err)
}
entries, err := cache.ReadEntries(dbName, "vuln.example.com")
if err != nil {
t.Fatalf("ReadEntries failed for cached package: %v", err)
}
if !reflect.DeepEqual(entries, expectedEntries) {
t.Errorf("ReadEntries returned unexpected entries, got:\n%v\nwant:\n%v", entries, expectedEntries)
}
}
func TestConcurrency(t *testing.T) {
tmpDir := t.TempDir()
cache := &fsCache{rootDir: tmpDir}
dbName := "vulndb.golang.org"
g := new(errgroup.Group)
for i := 0; i < 1000; i++ {
i := i
g.Go(func() error {
id := i % 5
p := fmt.Sprintf("package%d", id)
entries, err := cache.ReadEntries(dbName, p)
if err != nil {
return err
}
err = cache.WriteEntries(dbName, p, append(entries, &osv.Entry{ID: fmt.Sprint(id)}))
if err != nil {
return err
}
return nil
})
}
if err := g.Wait(); err != nil {
t.Errorf("error in parallel cache entries read/write: %v", err)
}
// sanity checking
for i := 0; i < 5; i++ {
id := fmt.Sprint(i)
p := fmt.Sprintf("package%s", id)
es, err := cache.ReadEntries(dbName, p)
if err != nil {
t.Fatalf("failed to read entries: %v", err)
}
for _, e := range es {
if e.ID != id {
t.Errorf("want %s ID for vuln entry; got %s", id, e.ID)
}
}
}
// do similar for cache index
start := time.Now()
for i := 0; i < 1000; i++ {
i := i
g.Go(func() error {
id := i % 5
p := fmt.Sprintf("package%v", id)
idx, _, err := cache.ReadIndex(dbName)
if err != nil {
return err
}
if idx == nil {
idx = client.DBIndex{}
}
// sanity checking
if rt, ok := idx[p]; ok && !start.Before(rt) {
return fmt.Errorf("unexpected past time in index: start %v not before %v", start, rt)
}
now := time.Now()
idx[p] = now
if err := cache.WriteIndex(dbName, idx, now); err != nil {
return err
}
return nil
})
}
if err := g.Wait(); err != nil {
t.Errorf("error in parallel cache index read/write: %v", err)
}
}