blob: 236c9d1b072dc5b0b4caf2c14fae3cb0b985afce [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package client provides an interface for accessing vulnerability
// databases, via either HTTP or local filesystem access.
//
// The protocol is described at https://go.dev/security/vuln/database.
package client
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"sort"
"strings"
"time"
"golang.org/x/sync/errgroup"
"golang.org/x/vuln/internal/derrors"
"golang.org/x/vuln/internal/osv"
isem "golang.org/x/vuln/internal/semver"
"golang.org/x/vuln/internal/web"
)
// A Client for reading vulnerability databases.
type Client struct {
source
}
type Options struct {
HTTPClient *http.Client
}
// NewClient returns a client that reads the vulnerability database
// in source (an "http" or "file" prefixed URL).
//
// It supports databases following the API described
// in https://go.dev/security/vuln/database#api.
func NewClient(source string, opts *Options) (_ *Client, err error) {
source = strings.TrimRight(source, "/")
uri, err := url.Parse(source)
if err != nil {
return nil, err
}
switch uri.Scheme {
case "http", "https":
return newHTTPClient(uri, opts)
case "file":
return newLocalClient(uri)
default:
return nil, fmt.Errorf("source %q has unsupported scheme", uri)
}
}
var errUnknownSchema = errors.New("unrecognized vulndb format; see https://go.dev/security/vuln/database#api for accepted schema")
func newHTTPClient(uri *url.URL, opts *Options) (*Client, error) {
source := uri.String()
// v1 returns true if the source likely follows the V1 schema.
v1 := func() bool {
return source == "https://vuln.go.dev" ||
endpointExistsHTTP(source, "index/modules.json.gz")
}
if v1() {
return &Client{source: newHTTPSource(uri.String(), opts)}, nil
}
return nil, errUnknownSchema
}
func endpointExistsHTTP(source, endpoint string) bool {
r, err := http.Head(source + "/" + endpoint)
return err == nil && r.StatusCode == http.StatusOK
}
func newLocalClient(uri *url.URL) (*Client, error) {
dir, err := toDir(uri)
if err != nil {
return nil, err
}
// Check if the DB likely follows the v1 schema by
// looking for the "index/modules.json" endpoint.
if endpointExistsDir(dir, modulesEndpoint+".json") {
return &Client{source: newLocalSource(dir)}, nil
}
// If the DB doesn't follow the v1 schema,
// attempt to intepret it as a flat list of OSV files.
// This is currently a "hidden" feature, so don't output the
// specific error if this fails.
src, err := newHybridSource(dir)
if err != nil {
return nil, errUnknownSchema
}
return &Client{source: src}, nil
}
func toDir(uri *url.URL) (string, error) {
dir, err := web.URLToFilePath(uri)
if err != nil {
return "", err
}
fi, err := os.Stat(dir)
if err != nil {
return "", err
}
if !fi.IsDir() {
return "", fmt.Errorf("%s is not a directory", dir)
}
return dir, nil
}
func endpointExistsDir(dir, endpoint string) bool {
_, err := os.Stat(filepath.Join(dir, endpoint))
return err == nil
}
func NewInMemoryClient(entries []*osv.Entry) (*Client, error) {
s, err := newInMemorySource(entries)
if err != nil {
return nil, err
}
return &Client{source: s}, nil
}
func (c *Client) LastModifiedTime(ctx context.Context) (_ time.Time, err error) {
derrors.Wrap(&err, "LastModifiedTime()")
b, err := c.source.get(ctx, dbEndpoint)
if err != nil {
return time.Time{}, err
}
var dbMeta dbMeta
if err := json.Unmarshal(b, &dbMeta); err != nil {
return time.Time{}, err
}
return dbMeta.Modified, nil
}
type ModuleRequest struct {
// The module path to filter on.
// This must be set (if empty, ByModule errors).
Path string
// (Optional) If set, only return vulnerabilities affected
// at this version.
Version string
}
type ModuleResponse struct {
Path string
Version string
Entries []*osv.Entry
}
// ByModules returns a list of responses
// containing the OSV entries corresponding to each request.
//
// The order of the requests is preserved, and each request has
// a response even if there are no entries (in which case the Entries
// field is nil).
func (c *Client) ByModules(ctx context.Context, reqs []*ModuleRequest) (_ []*ModuleResponse, err error) {
derrors.Wrap(&err, "ByModules(%v)", reqs)
metas, err := c.moduleMetas(ctx, reqs)
if err != nil {
return nil, err
}
resps := make([]*ModuleResponse, len(reqs))
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(10)
for i, req := range reqs {
i, req := i, req
g.Go(func() error {
entries, err := c.byModule(gctx, req, metas[i])
if err != nil {
return err
}
resps[i] = &ModuleResponse{
Path: req.Path,
Version: req.Version,
Entries: entries,
}
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return resps, nil
}
func (c *Client) moduleMetas(ctx context.Context, reqs []*ModuleRequest) (_ []*moduleMeta, err error) {
b, err := c.source.get(ctx, modulesEndpoint)
if err != nil {
return nil, err
}
dec, err := newStreamDecoder(b)
if err != nil {
return nil, err
}
metas := make([]*moduleMeta, len(reqs))
for dec.More() {
var m moduleMeta
err := dec.Decode(&m)
if err != nil {
return nil, err
}
for i, req := range reqs {
if m.Path == req.Path {
metas[i] = &m
}
}
}
return metas, nil
}
// byModule returns the OSV entries matching the ModuleRequest,
// or (nil, nil) if there are none.
func (c *Client) byModule(ctx context.Context, req *ModuleRequest, m *moduleMeta) (_ []*osv.Entry, err error) {
// This module isn't in the database.
if m == nil {
return nil, nil
}
if req.Path == "" {
return nil, fmt.Errorf("module path must be set")
}
if req.Version != "" && !isem.Valid(req.Version) {
return nil, fmt.Errorf("version %s is not valid semver", req.Version)
}
var ids []string
for _, v := range m.Vulns {
if v.Fixed == "" || isem.Less(req.Version, v.Fixed) {
ids = append(ids, v.ID)
}
}
if len(ids) == 0 {
return nil, nil
}
entries, err := c.byIDs(ctx, ids)
if err != nil {
return nil, err
}
// Filter by version.
if req.Version != "" {
affected := func(e *osv.Entry) bool {
for _, a := range e.Affected {
if a.Module.Path == req.Path && isem.Affects(a.Ranges, req.Version) {
return true
}
}
return false
}
var filtered []*osv.Entry
for _, entry := range entries {
if affected(entry) {
filtered = append(filtered, entry)
}
}
if len(filtered) == 0 {
return nil, nil
}
}
sort.SliceStable(entries, func(i, j int) bool {
return entries[i].ID < entries[j].ID
})
return entries, nil
}
func (c *Client) byIDs(ctx context.Context, ids []string) (_ []*osv.Entry, err error) {
entries := make([]*osv.Entry, len(ids))
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(10)
for i, id := range ids {
i, id := i, id
g.Go(func() error {
e, err := c.byID(gctx, id)
if err != nil {
return err
}
entries[i] = e
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return entries, nil
}
// byID returns the OSV entry with the given ID,
// or an error if it does not exist / cannot be unmarshaled.
func (c *Client) byID(ctx context.Context, id string) (_ *osv.Entry, err error) {
derrors.Wrap(&err, "byID(%s)", id)
b, err := c.source.get(ctx, entryEndpoint(id))
if err != nil {
return nil, err
}
var entry osv.Entry
if err := json.Unmarshal(b, &entry); err != nil {
return nil, err
}
return &entry, nil
}
// newStreamDecoder returns a decoder that can be used
// to read an array of JSON objects.
func newStreamDecoder(b []byte) (*json.Decoder, error) {
dec := json.NewDecoder(bytes.NewBuffer(b))
// skip open bracket
_, err := dec.Token()
if err != nil {
return nil, err
}
return dec, nil
}