blob: 0c22fb9573d72048be960a0f531dbc817b6179c5 [file] [log] [blame]
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cveclient implements a client for interacting with MITRE CVE
// Services API as described at https://cveawg.mitre.org/api-docs/openapi.json.
package cveclient
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
const (
// ProdEndpoint is the production endpoint
ProdEndpoint = "https://cveawg.mitre.org"
// TestEndpoint is the test endpoint
TestEndpoint = "https://cveawg-test.mitre.org"
// DevEndpoint is the dev endpoint
DevEndpoint = "https://cveawg-dev.mitre.org"
)
// Client is a MITRE CVE Services API client.
type Client struct {
Config
c *http.Client
}
// Config contains client configuration data.
type Config struct {
// Endpoint is the endpoint to access when making API calls. Required.
Endpoint string
// Org is the shortname for the organization that is authenticated when
// making API calls. Required.
Org string
// Key is the user's API key. Required.
Key string
// User is the username for the account that is making API calls. Required.
User string
}
// New returns an initialized client configured via cfg.
func New(cfg Config) *Client {
return &Client{cfg, http.DefaultClient}
}
// AssignedCVE contains information about an assigned CVE.
type AssignedCVE struct {
ID string `json:"cve_id"`
Year string `json:"cve_year"`
State string `json:"state"`
CNA string `json:"owning_cna"`
Reserved time.Time `json:"reserved"`
RequestedBy RequestedBy `json:"requested_by"`
}
// RequestedBy indicates the requesting user and organization for a CVE.
type RequestedBy struct {
CNA string `json:"cna"`
User string `json:"user"`
}
func (c AssignedCVE) String() string {
return fmt.Sprintf("%s: state=%s, cna=%s, requester=%s", c.ID, c.State, c.CNA, c.RequestedBy.User)
}
// AssignedCVEList is a list of AssignedCVEs.
type AssignedCVEList []AssignedCVE
// ShortString outputs a formatted string of comma-separated CVE IDs.
func (cs AssignedCVEList) ShortString() string {
strs := []string{}
for _, c := range cs {
strs = append(strs, c.ID)
}
return strings.Join(strs, ", ")
}
// String outputs a formatted string of newline-separated CVE data.
func (cs AssignedCVEList) String() string {
strs := []string{}
for _, c := range cs {
strs = append(strs, c.String())
}
return strings.Join(strs, "\n")
}
// ReserveOptions contains the configuration options for reserving new
// CVE IDs.
type ReserveOptions struct {
// NumIDs is the the number of CVE IDs to reserve. Required.
NumIDs int
// Year is the CVE ID year for new IDs, indicating the year the
// vulnerability was discovered. Required.
Year int
// Mode indicates whether the block of CVEs should be in sequence.
// Relevant only if NumIDs > 1.
Mode RequestType
}
// RequestType is the type of CVE ID reserve request.
type RequestType string
const (
// SequentialRequest requests CVE IDs be reserved in a sequential fashion.
SequentialRequest RequestType = "sequential"
// NonsequentialRequest requests CVE IDs be reserved in a nonsequential fashion.
NonsequentialRequest RequestType = "nonsequential"
)
func (o *ReserveOptions) getURLParams(org string) url.Values {
params := url.Values{}
params.Set("amount", fmt.Sprint(o.NumIDs))
if o.Year != 0 {
params.Set("cve_year", strconv.Itoa(o.Year))
}
params.Set("short_name", org)
if o.NumIDs > 1 {
params.Set("batch_type", string(o.Mode))
}
return params
}
func (c *Client) createReserveIDsRequest(opts ReserveOptions) (*http.Request, error) {
req, err := c.createRequest(http.MethodPost,
fmt.Sprintf("%s/api/cve-id", c.Endpoint))
if err != nil {
return nil, err
}
req.URL.RawQuery = opts.getURLParams(c.Org).Encode()
return req, nil
}
type reserveIDsResponse struct {
CVEs AssignedCVEList `json:"cve_ids"`
}
// ReserveIDs sends a request to the CVE API to reserve a block of CVE IDs.
// Returns a list of the reserved CVE IDs and their associated data.
// There may be fewer IDs than requested if, for example, the organization's
// quota is reached.
func (c *Client) ReserveIDs(opts ReserveOptions) (AssignedCVEList, error) {
req, err := c.createReserveIDsRequest(opts)
if err != nil {
return nil, err
}
var assigned reserveIDsResponse
checkStatus := func(s int) bool {
return s == http.StatusOK || s == http.StatusPartialContent
}
err = c.sendRequest(req, checkStatus, &assigned)
if err != nil {
return nil, err
}
return assigned.CVEs, nil
}
// Quota contains information about an organizations reservation quota.
type Quota struct {
Quota int `json:"id_quota"`
Reserved int `json:"total_reserved"`
Available int `json:"available"`
}
// RetrieveQuota queries the API for the organizations reservation quota.
func (c *Client) RetrieveQuota() (Quota, error) {
req, err := c.createRequest(http.MethodGet, fmt.Sprintf("%s/api/org/%s/id_quota", c.Endpoint, c.Org))
if err != nil {
return Quota{}, err
}
var q Quota
err = c.sendRequest(req, nil, &q)
if err != nil {
return Quota{}, err
}
return q, nil
}
// RetrieveCVE requests information about an assigned CVE ID.
func (c *Client) RetrieveCVE(id string) (AssignedCVE, error) {
req, err := c.createRequest(http.MethodGet, fmt.Sprintf("%s/api/cve-id/%s", c.Endpoint, id))
if err != nil {
return AssignedCVE{}, err
}
var cve AssignedCVE
err = c.sendRequest(req, nil, &cve)
if err != nil {
return AssignedCVE{}, err
}
return cve, nil
}
// ListOptions contains filters to be used when requesting a list of
// assigned CVEs.
type ListOptions struct {
State string
Year int
ReservedBefore *time.Time
ReservedAfter *time.Time
ModifiedBefore *time.Time
ModifiedAfter *time.Time
}
func (o ListOptions) String() string {
var s []string
if o.State != "" {
s = append(s, fmt.Sprintf("state=%s", o.State))
}
if o.Year != 0 {
s = append(s, fmt.Sprintf("year=%d", o.Year))
}
if o.ReservedBefore != nil {
s = append(s, fmt.Sprintf("reserved_before=%s", o.ReservedBefore.Format(time.RFC3339)))
}
if o.ReservedAfter != nil {
s = append(s, fmt.Sprintf("reserved_after=%s", o.ReservedAfter.Format(time.RFC3339)))
}
if o.ModifiedBefore != nil {
s = append(s, fmt.Sprintf("modified_before=%s", o.ModifiedBefore.Format(time.RFC3339)))
}
if o.ModifiedAfter != nil {
s = append(s, fmt.Sprintf("modified_after=%s", o.ModifiedAfter.Format(time.RFC3339)))
}
return strings.Join(s, ", ")
}
func (o *ListOptions) getURLParams() url.Values {
params := url.Values{}
if o.State != "" {
params.Set("state", o.State)
}
if o.Year != 0 {
params.Set("cve_id_year", strconv.Itoa(o.Year))
}
if o.ReservedBefore != nil {
params.Set("time_reserved.lt", o.ReservedBefore.Format(time.RFC3339))
}
if o.ReservedAfter != nil {
params.Set("time_reserved.gt", o.ReservedAfter.Format(time.RFC3339))
}
if o.ModifiedBefore != nil {
params.Set("time_modified.lt", o.ModifiedBefore.Format(time.RFC3339))
}
if o.ModifiedAfter != nil {
params.Set("time_modified.gt", o.ModifiedAfter.Format(time.RFC3339))
}
return params
}
type listOrgCVEsResponse struct {
CurrentPage int `json:"currentPage"`
NextPage int `json:"nextPage"`
CVEs AssignedCVEList `json:"cve_ids"`
}
func (c Client) createListOrgCVEsRequest(opts *ListOptions, page int) (*http.Request, error) {
req, err := c.createRequest(http.MethodGet, fmt.Sprintf("%s/api/cve-id", c.Endpoint))
if err != nil {
return nil, err
}
params := url.Values{}
if opts != nil {
params = opts.getURLParams()
}
if page > 0 {
params.Set("page", fmt.Sprint(page))
}
req.URL.RawQuery = params.Encode()
return req, nil
}
// ListOrgCVEs requests information about the CVEs the organization has been
// assigned. This list can be filtered by setting the fields in opts.
func (c *Client) ListOrgCVEs(opts *ListOptions) (AssignedCVEList, error) {
var cves []AssignedCVE
page := 0
for {
req, err := c.createListOrgCVEsRequest(opts, page)
if err != nil {
return nil, err
}
var result listOrgCVEsResponse
err = c.sendRequest(req, nil, &result)
if err != nil {
return nil, err
}
cves = append(cves, result.CVEs...)
if result.NextPage <= result.CurrentPage {
break
}
page = result.NextPage
}
return cves, nil
}
var (
headerApiUser = "CVE-API-USER"
headerApiOrg = "CVE-API-ORG"
headerApiKey = "CVE-API-KEY"
)
// createRequest creates a new HTTP request and sets the header fields.
func (c *Client) createRequest(method, url string) (*http.Request, error) {
req, err := http.NewRequest(method, url, nil)
if err != nil {
return nil, err
}
req.Header.Set(headerApiUser, c.User)
req.Header.Set(headerApiOrg, c.Org)
req.Header.Set(headerApiKey, c.Key)
return req, nil
}
// sendRequest sends an HTTP request, checks the returned status via
// checkStatus, and attempts to unmarshal the response into result.
// if checkStatus is nil, checks for http.StatusOK.
func (c *Client) sendRequest(req *http.Request, checkStatus func(int) bool, result any) (err error) {
resp, err := c.c.Do(req)
if err != nil {
return fmt.Errorf("could not send HTTP request: %v", err)
}
defer resp.Body.Close()
if checkStatus == nil {
checkStatus = func(s int) bool {
return s == http.StatusOK
}
}
if !checkStatus(resp.StatusCode) {
return fmt.Errorf("HTTP request %s %q returned error: %w", req.Method, req.URL, extractError(resp))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if err := json.Unmarshal(body, result); err != nil {
return err
}
return nil
}
type apiError struct {
Error string `json:"error"`
Message string `json:"message"`
}
// extractError extracts additional error messages from the HTTP response
// if available, and wraps them into a single error.
func extractError(resp *http.Response) error {
errMsg := resp.Status
body, err := io.ReadAll(resp.Body)
if err != nil {
// Discard the read error and return the HTTP status.
return fmt.Errorf(errMsg)
}
var apiErr apiError
if err := json.Unmarshal(body, &apiErr); err != nil {
// Discard the unmarshal error and return the HTTP status.
return fmt.Errorf(errMsg)
}
// Append the error and message text if they add extra information
// beyond the HTTP status text.
statusText := strings.ToLower(http.StatusText(resp.StatusCode))
for _, errText := range []string{apiErr.Error, apiErr.Message} {
if errText != "" && strings.ToLower(errText) != statusText {
errMsg = fmt.Sprintf("%s: %s", errMsg, errText)
}
}
return fmt.Errorf(errMsg)
}