| // Copyright 2023 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Package client provides an interface for accessing vulnerability |
| // databases, via either HTTP or local filesystem access. |
| // |
| // The protocol is described at https://go.dev/security/vuln/database. |
| package client |
| |
| import ( |
| "bytes" |
| "context" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "net/http" |
| "net/url" |
| "os" |
| "path/filepath" |
| "sort" |
| "strings" |
| "time" |
| |
| "golang.org/x/sync/errgroup" |
| "golang.org/x/vuln/internal/derrors" |
| "golang.org/x/vuln/internal/osv" |
| isem "golang.org/x/vuln/internal/semver" |
| "golang.org/x/vuln/internal/web" |
| ) |
| |
| // A Client for reading vulnerability databases. |
| type Client struct { |
| source |
| } |
| |
| type Options struct { |
| HTTPClient *http.Client |
| } |
| |
| // NewClient returns a client that reads the vulnerability database |
| // in source (an "http" or "file" prefixed URL). |
| // |
| // It supports databases following the API described |
| // in https://go.dev/security/vuln/database#api. |
| func NewClient(source string, opts *Options) (_ *Client, err error) { |
| source = strings.TrimRight(source, "/") |
| uri, err := url.Parse(source) |
| if err != nil { |
| return nil, err |
| } |
| switch uri.Scheme { |
| case "http", "https": |
| return newHTTPClient(uri, opts) |
| case "file": |
| return newLocalClient(uri) |
| default: |
| return nil, fmt.Errorf("source %q has unsupported scheme", uri) |
| } |
| } |
| |
| var errUnknownSchema = errors.New("unrecognized vulndb format; see https://go.dev/security/vuln/database#api for accepted schema") |
| |
| func newHTTPClient(uri *url.URL, opts *Options) (*Client, error) { |
| source := uri.String() |
| |
| // v1 returns true if the source likely follows the V1 schema. |
| v1 := func() bool { |
| return source == "https://vuln.go.dev" || |
| endpointExistsHTTP(source, "index/modules.json.gz") |
| } |
| |
| if v1() { |
| return &Client{source: newHTTPSource(uri.String(), opts)}, nil |
| } |
| |
| return nil, errUnknownSchema |
| } |
| |
| func endpointExistsHTTP(source, endpoint string) bool { |
| r, err := http.Head(source + "/" + endpoint) |
| return err == nil && r.StatusCode == http.StatusOK |
| } |
| |
| func newLocalClient(uri *url.URL) (*Client, error) { |
| dir, err := toDir(uri) |
| if err != nil { |
| return nil, err |
| } |
| |
| // Check if the DB likely follows the v1 schema by |
| // looking for the "index/modules.json" endpoint. |
| if endpointExistsDir(dir, modulesEndpoint+".json") { |
| return &Client{source: newLocalSource(dir)}, nil |
| } |
| |
| // If the DB doesn't follow the v1 schema, |
| // attempt to intepret it as a flat list of OSV files. |
| // This is currently a "hidden" feature, so don't output the |
| // specific error if this fails. |
| src, err := newHybridSource(dir) |
| if err != nil { |
| return nil, errUnknownSchema |
| } |
| return &Client{source: src}, nil |
| } |
| |
| func toDir(uri *url.URL) (string, error) { |
| dir, err := web.URLToFilePath(uri) |
| if err != nil { |
| return "", err |
| } |
| fi, err := os.Stat(dir) |
| if err != nil { |
| return "", err |
| } |
| if !fi.IsDir() { |
| return "", fmt.Errorf("%s is not a directory", dir) |
| } |
| return dir, nil |
| } |
| |
| func endpointExistsDir(dir, endpoint string) bool { |
| _, err := os.Stat(filepath.Join(dir, endpoint)) |
| return err == nil |
| } |
| |
| func NewInMemoryClient(entries []*osv.Entry) (*Client, error) { |
| s, err := newInMemorySource(entries) |
| if err != nil { |
| return nil, err |
| } |
| return &Client{source: s}, nil |
| } |
| |
| func (c *Client) LastModifiedTime(ctx context.Context) (_ time.Time, err error) { |
| derrors.Wrap(&err, "LastModifiedTime()") |
| |
| b, err := c.source.get(ctx, dbEndpoint) |
| if err != nil { |
| return time.Time{}, err |
| } |
| |
| var dbMeta dbMeta |
| if err := json.Unmarshal(b, &dbMeta); err != nil { |
| return time.Time{}, err |
| } |
| |
| return dbMeta.Modified, nil |
| } |
| |
| type ModuleRequest struct { |
| // The module path to filter on. |
| // This must be set (if empty, ByModule errors). |
| Path string |
| // (Optional) If set, only return vulnerabilities affected |
| // at this version. |
| Version string |
| } |
| |
| type ModuleResponse struct { |
| Path string |
| Version string |
| Entries []*osv.Entry |
| } |
| |
| // ByModules returns a list of responses |
| // containing the OSV entries corresponding to each request. |
| // |
| // The order of the requests is preserved, and each request has |
| // a response even if there are no entries (in which case the Entries |
| // field is nil). |
| func (c *Client) ByModules(ctx context.Context, reqs []*ModuleRequest) (_ []*ModuleResponse, err error) { |
| derrors.Wrap(&err, "ByModules(%v)", reqs) |
| |
| metas, err := c.moduleMetas(ctx, reqs) |
| if err != nil { |
| return nil, err |
| } |
| |
| resps := make([]*ModuleResponse, len(reqs)) |
| g, gctx := errgroup.WithContext(ctx) |
| g.SetLimit(10) |
| for i, req := range reqs { |
| i, req := i, req |
| g.Go(func() error { |
| entries, err := c.byModule(gctx, req, metas[i]) |
| if err != nil { |
| return err |
| } |
| resps[i] = &ModuleResponse{ |
| Path: req.Path, |
| Version: req.Version, |
| Entries: entries, |
| } |
| return nil |
| }) |
| } |
| if err := g.Wait(); err != nil { |
| return nil, err |
| } |
| |
| return resps, nil |
| } |
| |
| func (c *Client) moduleMetas(ctx context.Context, reqs []*ModuleRequest) (_ []*moduleMeta, err error) { |
| b, err := c.source.get(ctx, modulesEndpoint) |
| if err != nil { |
| return nil, err |
| } |
| |
| dec, err := newStreamDecoder(b) |
| if err != nil { |
| return nil, err |
| } |
| |
| metas := make([]*moduleMeta, len(reqs)) |
| for dec.More() { |
| var m moduleMeta |
| err := dec.Decode(&m) |
| if err != nil { |
| return nil, err |
| } |
| for i, req := range reqs { |
| if m.Path == req.Path { |
| metas[i] = &m |
| } |
| } |
| } |
| |
| return metas, nil |
| } |
| |
| // byModule returns the OSV entries matching the ModuleRequest, |
| // or (nil, nil) if there are none. |
| func (c *Client) byModule(ctx context.Context, req *ModuleRequest, m *moduleMeta) (_ []*osv.Entry, err error) { |
| // This module isn't in the database. |
| if m == nil { |
| return nil, nil |
| } |
| |
| if req.Path == "" { |
| return nil, fmt.Errorf("module path must be set") |
| } |
| |
| if req.Version != "" && !isem.Valid(req.Version) { |
| return nil, fmt.Errorf("version %s is not valid semver", req.Version) |
| } |
| |
| var ids []string |
| for _, v := range m.Vulns { |
| if v.Fixed == "" || isem.Less(req.Version, v.Fixed) { |
| ids = append(ids, v.ID) |
| } |
| } |
| |
| if len(ids) == 0 { |
| return nil, nil |
| } |
| |
| entries, err := c.byIDs(ctx, ids) |
| if err != nil { |
| return nil, err |
| } |
| |
| // Filter by version. |
| if req.Version != "" { |
| affected := func(e *osv.Entry) bool { |
| for _, a := range e.Affected { |
| if a.Module.Path == req.Path && isem.Affects(a.Ranges, req.Version) { |
| return true |
| } |
| } |
| return false |
| } |
| |
| var filtered []*osv.Entry |
| for _, entry := range entries { |
| if affected(entry) { |
| filtered = append(filtered, entry) |
| } |
| } |
| if len(filtered) == 0 { |
| return nil, nil |
| } |
| } |
| |
| sort.SliceStable(entries, func(i, j int) bool { |
| return entries[i].ID < entries[j].ID |
| }) |
| |
| return entries, nil |
| } |
| |
| func (c *Client) byIDs(ctx context.Context, ids []string) (_ []*osv.Entry, err error) { |
| entries := make([]*osv.Entry, len(ids)) |
| g, gctx := errgroup.WithContext(ctx) |
| g.SetLimit(10) |
| for i, id := range ids { |
| i, id := i, id |
| g.Go(func() error { |
| e, err := c.byID(gctx, id) |
| if err != nil { |
| return err |
| } |
| entries[i] = e |
| return nil |
| }) |
| } |
| if err := g.Wait(); err != nil { |
| return nil, err |
| } |
| |
| return entries, nil |
| } |
| |
| // byID returns the OSV entry with the given ID, |
| // or an error if it does not exist / cannot be unmarshaled. |
| func (c *Client) byID(ctx context.Context, id string) (_ *osv.Entry, err error) { |
| derrors.Wrap(&err, "byID(%s)", id) |
| |
| b, err := c.source.get(ctx, entryEndpoint(id)) |
| if err != nil { |
| return nil, err |
| } |
| |
| var entry osv.Entry |
| if err := json.Unmarshal(b, &entry); err != nil { |
| return nil, err |
| } |
| |
| return &entry, nil |
| } |
| |
| // newStreamDecoder returns a decoder that can be used |
| // to read an array of JSON objects. |
| func newStreamDecoder(b []byte) (*json.Decoder, error) { |
| dec := json.NewDecoder(bytes.NewBuffer(b)) |
| |
| // skip open bracket |
| _, err := dec.Token() |
| if err != nil { |
| return nil, err |
| } |
| |
| return dec, nil |
| } |