| // Copyright 2022 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Package cveclient implements a client for interacting with MITRE CVE |
| // Services API as described at https://cveawg.mitre.org/api-docs/openapi.json. |
| package cveclient |
| |
| import ( |
| "encoding/json" |
| "fmt" |
| "io" |
| "net/http" |
| "net/url" |
| "strconv" |
| "strings" |
| "time" |
| ) |
| |
| const ( |
| // ProdEndpoint is the production endpoint |
| ProdEndpoint = "https://cveawg.mitre.org" |
| // TestEndpoint is the test endpoint |
| TestEndpoint = "https://cveawg-test.mitre.org" |
| // DevEndpoint is the dev endpoint |
| DevEndpoint = "https://cveawg-dev.mitre.org" |
| ) |
| |
| // Client is a MITRE CVE Services API client. |
| type Client struct { |
| Config |
| c *http.Client |
| } |
| |
| // Config contains client configuration data. |
| type Config struct { |
| // Endpoint is the endpoint to access when making API calls. Required. |
| Endpoint string |
| // Org is the shortname for the organization that is authenticated when |
| // making API calls. Required. |
| Org string |
| // Key is the user's API key. Required. |
| Key string |
| // User is the username for the account that is making API calls. Required. |
| User string |
| } |
| |
| // New returns an initialized client configured via cfg. |
| func New(cfg Config) *Client { |
| return &Client{cfg, http.DefaultClient} |
| } |
| |
| // AssignedCVE contains information about an assigned CVE. |
| type AssignedCVE struct { |
| ID string `json:"cve_id"` |
| Year string `json:"cve_year"` |
| State string `json:"state"` |
| CNA string `json:"owning_cna"` |
| Reserved time.Time `json:"reserved"` |
| RequestedBy RequestedBy `json:"requested_by"` |
| } |
| |
| // RequestedBy indicates the requesting user and organization for a CVE. |
| type RequestedBy struct { |
| CNA string `json:"cna"` |
| User string `json:"user"` |
| } |
| |
| func (c AssignedCVE) String() string { |
| return fmt.Sprintf("%s: state=%s, cna=%s, requester=%s", c.ID, c.State, c.CNA, c.RequestedBy.User) |
| } |
| |
| // AssignedCVEList is a list of AssignedCVEs. |
| type AssignedCVEList []AssignedCVE |
| |
| // ShortString outputs a formatted string of comma-separated CVE IDs. |
| func (cs AssignedCVEList) ShortString() string { |
| strs := []string{} |
| for _, c := range cs { |
| strs = append(strs, c.ID) |
| } |
| return strings.Join(strs, ", ") |
| } |
| |
| // String outputs a formatted string of newline-separated CVE data. |
| func (cs AssignedCVEList) String() string { |
| strs := []string{} |
| for _, c := range cs { |
| strs = append(strs, c.String()) |
| } |
| return strings.Join(strs, "\n") |
| } |
| |
| // ReserveOptions contains the configuration options for reserving new |
| // CVE IDs. |
| type ReserveOptions struct { |
| // NumIDs is the the number of CVE IDs to reserve. Required. |
| NumIDs int |
| // Year is the CVE ID year for new IDs, indicating the year the |
| // vulnerability was discovered. Required. |
| Year int |
| // Mode indicates whether the block of CVEs should be in sequence. |
| // Relevant only if NumIDs > 1. |
| Mode RequestType |
| } |
| |
| // RequestType is the type of CVE ID reserve request. |
| type RequestType string |
| |
| const ( |
| // SequentialRequest requests CVE IDs be reserved in a sequential fashion. |
| SequentialRequest RequestType = "sequential" |
| // NonsequentialRequest requests CVE IDs be reserved in a nonsequential fashion. |
| NonsequentialRequest RequestType = "nonsequential" |
| ) |
| |
| func (o *ReserveOptions) getURLParams(org string) url.Values { |
| params := url.Values{} |
| params.Set("amount", fmt.Sprint(o.NumIDs)) |
| if o.Year != 0 { |
| params.Set("cve_year", strconv.Itoa(o.Year)) |
| } |
| params.Set("short_name", org) |
| if o.NumIDs > 1 { |
| params.Set("batch_type", string(o.Mode)) |
| } |
| return params |
| } |
| |
| func (c *Client) createReserveIDsRequest(opts ReserveOptions) (*http.Request, error) { |
| req, err := c.createRequest(http.MethodPost, |
| fmt.Sprintf("%s/api/cve-id", c.Endpoint)) |
| if err != nil { |
| return nil, err |
| } |
| req.URL.RawQuery = opts.getURLParams(c.Org).Encode() |
| return req, nil |
| } |
| |
| type reserveIDsResponse struct { |
| CVEs AssignedCVEList `json:"cve_ids"` |
| } |
| |
| // ReserveIDs sends a request to the CVE API to reserve a block of CVE IDs. |
| // Returns a list of the reserved CVE IDs and their associated data. |
| // There may be fewer IDs than requested if, for example, the organization's |
| // quota is reached. |
| func (c *Client) ReserveIDs(opts ReserveOptions) (AssignedCVEList, error) { |
| req, err := c.createReserveIDsRequest(opts) |
| if err != nil { |
| return nil, err |
| } |
| var assigned reserveIDsResponse |
| checkStatus := func(s int) bool { |
| return s == http.StatusOK || s == http.StatusPartialContent |
| } |
| err = c.sendRequest(req, checkStatus, &assigned) |
| if err != nil { |
| return nil, err |
| } |
| return assigned.CVEs, nil |
| } |
| |
| // Quota contains information about an organizations reservation quota. |
| type Quota struct { |
| Quota int `json:"id_quota"` |
| Reserved int `json:"total_reserved"` |
| Available int `json:"available"` |
| } |
| |
| // RetrieveQuota queries the API for the organizations reservation quota. |
| func (c *Client) RetrieveQuota() (Quota, error) { |
| req, err := c.createRequest(http.MethodGet, fmt.Sprintf("%s/api/org/%s/id_quota", c.Endpoint, c.Org)) |
| if err != nil { |
| return Quota{}, err |
| } |
| |
| var q Quota |
| err = c.sendRequest(req, nil, &q) |
| if err != nil { |
| return Quota{}, err |
| } |
| return q, nil |
| } |
| |
| // RetrieveCVE requests information about an assigned CVE ID. |
| func (c *Client) RetrieveCVE(id string) (AssignedCVE, error) { |
| req, err := c.createRequest(http.MethodGet, fmt.Sprintf("%s/api/cve-id/%s", c.Endpoint, id)) |
| if err != nil { |
| return AssignedCVE{}, err |
| } |
| |
| var cve AssignedCVE |
| err = c.sendRequest(req, nil, &cve) |
| if err != nil { |
| return AssignedCVE{}, err |
| } |
| return cve, nil |
| } |
| |
| // ListOptions contains filters to be used when requesting a list of |
| // assigned CVEs. |
| type ListOptions struct { |
| State string |
| Year int |
| ReservedBefore *time.Time |
| ReservedAfter *time.Time |
| ModifiedBefore *time.Time |
| ModifiedAfter *time.Time |
| } |
| |
| func (o ListOptions) String() string { |
| var s []string |
| if o.State != "" { |
| s = append(s, fmt.Sprintf("state=%s", o.State)) |
| } |
| if o.Year != 0 { |
| s = append(s, fmt.Sprintf("year=%d", o.Year)) |
| } |
| if o.ReservedBefore != nil { |
| s = append(s, fmt.Sprintf("reserved_before=%s", o.ReservedBefore.Format(time.RFC3339))) |
| } |
| if o.ReservedAfter != nil { |
| s = append(s, fmt.Sprintf("reserved_after=%s", o.ReservedAfter.Format(time.RFC3339))) |
| } |
| if o.ModifiedBefore != nil { |
| s = append(s, fmt.Sprintf("modified_before=%s", o.ModifiedBefore.Format(time.RFC3339))) |
| } |
| if o.ModifiedAfter != nil { |
| s = append(s, fmt.Sprintf("modified_after=%s", o.ModifiedAfter.Format(time.RFC3339))) |
| } |
| return strings.Join(s, ", ") |
| } |
| |
| func (o *ListOptions) getURLParams() url.Values { |
| params := url.Values{} |
| if o.State != "" { |
| params.Set("state", o.State) |
| } |
| if o.Year != 0 { |
| params.Set("cve_id_year", strconv.Itoa(o.Year)) |
| } |
| if o.ReservedBefore != nil { |
| params.Set("time_reserved.lt", o.ReservedBefore.Format(time.RFC3339)) |
| } |
| if o.ReservedAfter != nil { |
| params.Set("time_reserved.gt", o.ReservedAfter.Format(time.RFC3339)) |
| } |
| if o.ModifiedBefore != nil { |
| params.Set("time_modified.lt", o.ModifiedBefore.Format(time.RFC3339)) |
| } |
| if o.ModifiedAfter != nil { |
| params.Set("time_modified.gt", o.ModifiedAfter.Format(time.RFC3339)) |
| } |
| return params |
| } |
| |
| type listOrgCVEsResponse struct { |
| CurrentPage int `json:"currentPage"` |
| NextPage int `json:"nextPage"` |
| CVEs AssignedCVEList `json:"cve_ids"` |
| } |
| |
| func (c Client) createListOrgCVEsRequest(opts *ListOptions, page int) (*http.Request, error) { |
| req, err := c.createRequest(http.MethodGet, fmt.Sprintf("%s/api/cve-id", c.Endpoint)) |
| if err != nil { |
| return nil, err |
| } |
| params := url.Values{} |
| if opts != nil { |
| params = opts.getURLParams() |
| } |
| if page > 0 { |
| params.Set("page", fmt.Sprint(page)) |
| } |
| req.URL.RawQuery = params.Encode() |
| return req, nil |
| } |
| |
| // ListOrgCVEs requests information about the CVEs the organization has been |
| // assigned. This list can be filtered by setting the fields in opts. |
| func (c *Client) ListOrgCVEs(opts *ListOptions) (AssignedCVEList, error) { |
| var cves []AssignedCVE |
| page := 0 |
| for { |
| req, err := c.createListOrgCVEsRequest(opts, page) |
| if err != nil { |
| return nil, err |
| } |
| var result listOrgCVEsResponse |
| err = c.sendRequest(req, nil, &result) |
| if err != nil { |
| return nil, err |
| } |
| cves = append(cves, result.CVEs...) |
| if result.NextPage <= result.CurrentPage { |
| break |
| } |
| page = result.NextPage |
| } |
| return cves, nil |
| } |
| |
| var ( |
| headerApiUser = "CVE-API-USER" |
| headerApiOrg = "CVE-API-ORG" |
| headerApiKey = "CVE-API-KEY" |
| ) |
| |
| // createRequest creates a new HTTP request and sets the header fields. |
| func (c *Client) createRequest(method, url string) (*http.Request, error) { |
| req, err := http.NewRequest(method, url, nil) |
| if err != nil { |
| return nil, err |
| } |
| req.Header.Set(headerApiUser, c.User) |
| req.Header.Set(headerApiOrg, c.Org) |
| req.Header.Set(headerApiKey, c.Key) |
| return req, nil |
| } |
| |
| // sendRequest sends an HTTP request, checks the returned status via |
| // checkStatus, and attempts to unmarshal the response into result. |
| // if checkStatus is nil, checks for http.StatusOK. |
| func (c *Client) sendRequest(req *http.Request, checkStatus func(int) bool, result any) (err error) { |
| resp, err := c.c.Do(req) |
| if err != nil { |
| return fmt.Errorf("could not send HTTP request: %v", err) |
| } |
| defer resp.Body.Close() |
| if checkStatus == nil { |
| checkStatus = func(s int) bool { |
| return s == http.StatusOK |
| } |
| } |
| if !checkStatus(resp.StatusCode) { |
| return fmt.Errorf("HTTP request %s %q returned error: %w", req.Method, req.URL, extractError(resp)) |
| } |
| body, err := io.ReadAll(resp.Body) |
| if err != nil { |
| return err |
| } |
| if err := json.Unmarshal(body, result); err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| type apiError struct { |
| Error string `json:"error"` |
| Message string `json:"message"` |
| } |
| |
| // extractError extracts additional error messages from the HTTP response |
| // if available, and wraps them into a single error. |
| func extractError(resp *http.Response) error { |
| errMsg := resp.Status |
| body, err := io.ReadAll(resp.Body) |
| if err != nil { |
| // Discard the read error and return the HTTP status. |
| return fmt.Errorf(errMsg) |
| } |
| var apiErr apiError |
| if err := json.Unmarshal(body, &apiErr); err != nil { |
| // Discard the unmarshal error and return the HTTP status. |
| return fmt.Errorf(errMsg) |
| } |
| |
| // Append the error and message text if they add extra information |
| // beyond the HTTP status text. |
| statusText := strings.ToLower(http.StatusText(resp.StatusCode)) |
| for _, errText := range []string{apiErr.Error, apiErr.Message} { |
| if errText != "" && strings.ToLower(errText) != statusText { |
| errMsg = fmt.Sprintf("%s: %s", errMsg, errText) |
| } |
| } |
| return fmt.Errorf(errMsg) |
| } |