blob: e03e7145a38c266f037776892f7bf1721ae53653 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vuln
import (
"bytes"
"context"
"encoding/json"
"path/filepath"
"sort"
"strings"
"golang.org/x/exp/slices"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/osv"
"golang.org/x/pkgsite/internal/stdlib"
"golang.org/x/sync/errgroup"
)
// Client reads Go vulnerability databases.
type Client struct {
src source
}
// NewClient returns a client that can read from the vulnerability
// database in src (a URL representing either a http or file source).
func NewClient(src string) (*Client, error) {
s, err := NewSource(src)
if err != nil {
return nil, err
}
return &Client{src: s}, nil
}
// NewInMemoryClient creates an in-memory vulnerability client for use
// in tests.
func NewInMemoryClient(entries []*osv.Entry) (*Client, error) {
inMemory, err := newInMemorySource(entries)
if err != nil {
return nil, err
}
return &Client{src: inMemory}, nil
}
type PackageRequest struct {
// Module is the module path to filter on.
// ByPackage will only return entries that affect this module.
// This must be set (if empty, ByPackage will always return nil).
Module string
// The package path to filter on.
// ByPackage will only return entries that affect this package.
// If empty, ByPackage will not filter based on the package.
Package string
// The version to filter on.
// ByPackage will only return entries affected at this module
// version.
// If empty, ByPackage will not filter based on version.
Version string
}
// ByPackage returns the OSV entries matching the package request.
func (c *Client) ByPackage(ctx context.Context, req *PackageRequest) (_ []*osv.Entry, err error) {
derrors.Wrap(&err, "ByPackage(%v)", req)
// Find the metadata for the module with the given module path.
ms, err := c.modulesFilter(ctx, func(m *ModuleMeta) bool {
return m.Path == req.Module
}, 1)
if err != nil {
return nil, err
}
if len(ms) == 0 {
return nil, nil
}
// Figure out which vulns we actually need to download.
var ids []string
for _, v := range ms[0].Vulns {
// We need to download the full entry if there is no fix,
// or the requested version is less than the vuln's
// highest fixed version.
if v.Fixed == "" || osv.LessSemver(req.Version, v.Fixed) {
ids = append(ids, v.ID)
}
}
if len(ids) == 0 {
return nil, nil
}
return c.byIDsFilter(ctx, ids, func(e *osv.Entry) bool {
return isAffected(e, req)
})
}
func (c *Client) modulesFilter(ctx context.Context, filter func(*ModuleMeta) bool, n int) ([]*ModuleMeta, error) {
if n == 0 {
return nil, nil
}
b, err := c.modules(ctx)
if err != nil {
return nil, err
}
dec, err := newStreamDecoder(b)
if err != nil {
return nil, err
}
ms := make([]*ModuleMeta, 0)
for dec.More() {
var m ModuleMeta
err := dec.Decode(&m)
if err != nil {
return nil, err
}
if filter(&m) {
ms = append(ms, &m)
if len(ms) == n {
return ms, nil
}
}
}
if len(ms) == 0 {
return nil, nil
}
return ms, nil
}
func isAffected(e *osv.Entry, req *PackageRequest) bool {
for _, a := range e.Affected {
if a.Module.Path != req.Module || !osv.AffectsSemver(a.Ranges, req.Version) {
continue
}
if packageMatches := func() bool {
if req.Package == "" {
return true // match module only
}
if len(a.EcosystemSpecific.Packages) == 0 {
return true // no package info available, so match on module
}
for _, p := range a.EcosystemSpecific.Packages {
if req.Package == p.Path {
return true // package matches
}
}
return false
}(); !packageMatches {
continue
}
return true
}
return false
}
// ByID returns the OSV entry with the given ID or (nil, nil)
// if there isn't one.
func (c *Client) ByID(ctx context.Context, id string) (_ *osv.Entry, err error) {
derrors.Wrap(&err, "ByID(%s)", id)
b, err := c.entry(ctx, id)
if err != nil {
// entry only fails if the entry is not found, so do not return
// the error.
return nil, nil
}
var entry osv.Entry
if err := json.Unmarshal(b, &entry); err != nil {
return nil, err
}
return &entry, nil
}
// ByAlias returns the Go ID of the OSV entry that has the given alias,
// or a NotFound error if there isn't one.
func (c *Client) ByAlias(ctx context.Context, alias string) (_ string, err error) {
derrors.Wrap(&err, "ByAlias(%s)", alias)
b, err := c.vulns(ctx)
if err != nil {
return "", err
}
dec, err := newStreamDecoder(b)
if err != nil {
return "", err
}
for dec.More() {
var v VulnMeta
err := dec.Decode(&v)
if err != nil {
return "", err
}
for _, vAlias := range v.Aliases {
if alias == vAlias {
return v.ID, nil
}
}
}
return "", derrors.NotFound
}
// Entries returns all entries in the database, sorted in descending
// order by Go ID (most recent to least recent).
// If n >= 0, only the n most recent entries are returned.
func (c *Client) Entries(ctx context.Context, n int) (_ []*osv.Entry, err error) {
derrors.Wrap(&err, "Entries(n=%d)", n)
if n == 0 {
return nil, nil
}
ids, err := c.IDs(ctx)
if err != nil {
return nil, err
}
sortIDs(ids)
if n >= 0 && len(ids) > n {
ids = ids[:n]
}
return c.byIDs(ctx, ids)
}
func sortIDs(ids []string) {
sort.Slice(ids, func(i, j int) bool { return ids[i] > ids[j] })
}
// ByPackagePrefix returns all the OSV entries that match the given
// package prefix, in descending order by ID, or (nil, nil) if there
// are none.
//
// An entry matches a prefix if:
// - Any affected module or package equals the given prefix, OR
// - Any affected module or package's path begins with the given prefix
// interpreted as a full path. (E.g. "example.com/module/package" matches
// the prefix "example.com/module" but not "example.com/mod")
func (c *Client) ByPackagePrefix(ctx context.Context, prefix string) (_ []*osv.Entry, err error) {
derrors.Wrap(&err, "ByPackagePrefix(%s)", prefix)
prefix = strings.TrimSuffix(prefix, "/")
prefixPath := prefix + "/"
prefixMatch := func(s string) bool {
return s == prefix || strings.HasPrefix(s, prefixPath)
}
moduleMatch := func(m *ModuleMeta) bool {
// If the prefix possibly refers to a standard library package,
// always look at the stdlib and toolchain modules.
if stdlib.Contains(prefix) &&
(m.Path == osv.GoStdModulePath || m.Path == osv.GoCmdModulePath) {
return true
}
// Look at the module if it is either prefixed by the prefix,
// or it is itself a prefix of the prefix.
// (The latter case catches queries that are prefixes of the package
// path but longer than the module path).
return prefixMatch(m.Path) || strings.HasPrefix(prefix, m.Path)
}
entryMatch := func(e *osv.Entry) bool {
for _, aff := range e.Affected {
if prefixMatch(aff.Module.Path) {
return true
}
for _, pkg := range aff.EcosystemSpecific.Packages {
if prefixMatch(pkg.Path) {
return true
}
}
}
return false
}
ms, err := c.modulesFilter(ctx, moduleMatch, -1)
if err != nil {
return nil, err
}
if len(ms) == 0 {
return nil, nil
}
var ids []string
for _, m := range ms {
for _, vs := range m.Vulns {
ids = append(ids, vs.ID)
}
}
sortIDs(ids)
// Remove any duplicates.
ids = slices.Compact(ids)
return c.byIDsFilter(ctx, ids, entryMatch)
}
func (c *Client) byIDsFilter(ctx context.Context, ids []string, filter func(*osv.Entry) bool) (_ []*osv.Entry, err error) {
entries, err := c.byIDs(ctx, ids)
if err != nil {
return nil, err
}
var filtered []*osv.Entry
for _, entry := range entries {
if filter(entry) {
filtered = append(filtered, entry)
}
}
if len(filtered) == 0 {
return nil, nil
}
return filtered, nil
}
func (c *Client) byIDs(ctx context.Context, ids []string) (_ []*osv.Entry, err error) {
entries := make([]*osv.Entry, len(ids))
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(10)
for i, id := range ids {
i, id := i, id
g.Go(func() error {
e, err := c.ByID(gctx, id)
if err != nil {
return err
}
entries[i] = e
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return entries, nil
}
// IDs returns a list of the IDs of all the entries in the database.
func (c *Client) IDs(ctx context.Context) (_ []string, err error) {
b, err := c.vulns(ctx)
if err != nil {
return nil, err
}
dec, err := newStreamDecoder(b)
if err != nil {
return nil, err
}
var ids []string
for dec.More() {
var v VulnMeta
err := dec.Decode(&v)
if err != nil {
return nil, err
}
ids = append(ids, v.ID)
}
return ids, nil
}
// newStreamDecoder returns a decoder that can be used
// to read an array of JSON objects.
func newStreamDecoder(b []byte) (*json.Decoder, error) {
dec := json.NewDecoder(bytes.NewBuffer(b))
// skip open bracket
_, err := dec.Token()
if err != nil {
return nil, err
}
return dec, nil
}
func (c *Client) modules(ctx context.Context) ([]byte, error) {
return c.src.get(ctx, modulesEndpoint)
}
func (c *Client) vulns(ctx context.Context) ([]byte, error) {
return c.src.get(ctx, vulnsEndpoint)
}
func (c *Client) entry(ctx context.Context, id string) ([]byte, error) {
return c.src.get(ctx, filepath.Join(idDir, id))
}