blob: 5915da4cdae132eb050e0a56bdbb80f6ac32a3d1 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vuln
import (
"context"
vulnc "golang.org/x/vuln/client"
"golang.org/x/vuln/osv"
)
// NewTestClient creates an in-memory client for use in tests.
func NewTestClient(entries []*osv.Entry) *Client {
c := &vulndbTestClient{
entries: entries,
aliasToIDs: map[string][]string{},
modulesToEntries: map[string][]*osv.Entry{},
}
for _, e := range entries {
for _, a := range e.Aliases {
c.aliasToIDs[a] = append(c.aliasToIDs[a], e.ID)
}
for _, affected := range e.Affected {
c.modulesToEntries[affected.Package.Name] = append(c.modulesToEntries[affected.Package.Name], e)
}
}
return &Client{c: c}
}
type vulndbTestClient struct {
vulnc.Client
entries []*osv.Entry
aliasToIDs map[string][]string
modulesToEntries map[string][]*osv.Entry
}
func (c *vulndbTestClient) GetByModule(_ context.Context, module string) ([]*osv.Entry, error) {
return c.modulesToEntries[module], nil
}
func (c *vulndbTestClient) GetByID(_ context.Context, id string) (*osv.Entry, error) {
for _, e := range c.entries {
if e.ID == id {
return e, nil
}
}
return nil, nil
}
func (c *vulndbTestClient) ListIDs(context.Context) ([]string, error) {
var ids []string
for _, e := range c.entries {
ids = append(ids, e.ID)
}
return ids, nil
}
func (c *vulndbTestClient) GetByAlias(ctx context.Context, alias string) ([]*osv.Entry, error) {
ids := c.aliasToIDs[alias]
if len(ids) == 0 {
return nil, nil
}
var es []*osv.Entry
for _, id := range ids {
e, _ := c.GetByID(ctx, id)
es = append(es, e)
}
return es, nil
}