blob: 34726f5a91e7ba8ad4ac64247d34945e91796659 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vuln
import (
"bytes"
"context"
"encoding/json"
"golang.org/x/tools/txtar"
vulnc "golang.org/x/vuln/client"
"golang.org/x/vuln/osv"
)
// NewInMemoryClient creates an in-memory client for use in tests.
func NewInMemoryClient(entries []*osv.Entry) (*Client, error) {
inMemory, err := newInMemorySource(entries)
if err != nil {
return nil, err
}
return &Client{legacy: newTestLegacyClient(entries), v1: &client{inMemory}}, nil
}
// newTestClientFromTxtar creates an in-memory client for use in tests.
// It reads test data from a txtar file which must follow the
// v1 database schema.
func newTestClientFromTxtar(txtarFile string) (*client, error) {
data := make(map[string][]byte)
ar, err := txtar.ParseFile(txtarFile)
if err != nil {
return nil, err
}
for _, f := range ar.Files {
fdata, err := removeWhitespace(f.Data)
if err != nil {
return nil, err
}
data[f.Name] = fdata
}
return &client{&inMemorySource{data: data}}, nil
}
func removeWhitespace(data []byte) ([]byte, error) {
var b bytes.Buffer
if err := json.Compact(&b, data); err != nil {
return nil, err
}
return b.Bytes(), nil
}
func newTestLegacyClient(entries []*osv.Entry) *legacyClient {
c := &testVulnClient{
entries: entries,
aliasToIDs: map[string][]string{},
modulesToEntries: map[string][]*osv.Entry{},
}
for _, e := range entries {
for _, a := range e.Aliases {
c.aliasToIDs[a] = append(c.aliasToIDs[a], e.ID)
}
for _, affected := range e.Affected {
c.modulesToEntries[affected.Package.Name] = append(c.modulesToEntries[affected.Package.Name], e)
}
}
return &legacyClient{c}
}
// Implements x/vuln.Client.
type testVulnClient struct {
vulnc.Client
entries []*osv.Entry
aliasToIDs map[string][]string
modulesToEntries map[string][]*osv.Entry
}
func (c *testVulnClient) GetByModule(_ context.Context, module string) ([]*osv.Entry, error) {
return c.modulesToEntries[module], nil
}
func (c *testVulnClient) GetByID(_ context.Context, id string) (*osv.Entry, error) {
for _, e := range c.entries {
if e.ID == id {
return e, nil
}
}
return nil, nil
}
func (c *testVulnClient) ListIDs(context.Context) ([]string, error) {
var ids []string
for _, e := range c.entries {
ids = append(ids, e.ID)
}
return ids, nil
}
func (c *testVulnClient) GetByAlias(ctx context.Context, alias string) ([]*osv.Entry, error) {
ids := c.aliasToIDs[alias]
if len(ids) == 0 {
return nil, nil
}
var es []*osv.Entry
for _, id := range ids {
e, _ := c.GetByID(ctx, id)
es = append(es, e)
}
return es, nil
}