blob: f8bc3eea9a6a157a02d65c7183c053c8d09aa075 [file] [log] [blame]
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cveclient
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strconv"
"strings"
"testing"
"time"
"golang.org/x/vulndb/internal/cveschema"
"golang.org/x/vulndb/internal/cveschema5"
)
const (
testApiKey = "test_api_key"
testApiOrg = "test_api_org"
testApiUser = "test_api_user"
defaultTestCVEID = "CVE-2022-0000"
)
var (
defaultTestCVE = newTestCVE(defaultTestCVEID, cveschema.StateReserved, "2022")
defaultTestCVEs = AssignedCVEList{
defaultTestCVE,
newTestCVE("CVE-2022-0001", cveschema.StateReserved, "2022")}
defaultTestQuota = &Quota{
Quota: 10,
Reserved: 3,
Available: 7,
}
defaultTestOrg = &Org{
Name: "An Org",
ShortName: testApiOrg,
UUID: "000-000-000",
}
)
func readTestData(t *testing.T, filename string) *cveschema5.CVERecord {
record, err := cveschema5.Read(fmt.Sprintf("../cveschema5/testdata/%s", filename))
if err != nil {
t.Fatalf("could not read test data from file %s: %v", filename, err)
}
return record
}
var getDefaultTestCVERecord = func(t *testing.T) *cveschema5.CVERecord {
return readTestData(t, "basic-example.json")
}
var (
testTime2022 = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
testTime2000 = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
testTime1999 = time.Date(1999, 1, 1, 0, 0, 0, 0, time.UTC)
testTime1992 = time.Date(1992, 1, 1, 0, 0, 0, 0, time.UTC)
)
func newTestCVE(id string, state cveschema5.State, year string) AssignedCVE {
return AssignedCVE{
ID: id,
Year: year,
State: state,
CNA: testApiOrg,
Reserved: testTime2022,
RequestedBy: RequestedBy{
CNA: testApiOrg,
User: testApiUser,
},
}
}
func newTestClientAndServer(handler http.HandlerFunc) (*Client, *httptest.Server) {
s := httptest.NewServer(handler)
c := New(Config{
Endpoint: s.URL,
Key: testApiKey,
Org: testApiOrg,
User: testApiUser})
c.c = s.Client()
return c, s
}
func checkHeaders(t *testing.T, r *http.Request) {
if got, want := r.Header.Get(headerApiUser), testApiUser; got != want {
t.Errorf("HTTP Header %q = %s, want %s", headerApiUser, got, want)
}
if got, want := r.Header.Get(headerApiOrg), testApiOrg; got != want {
t.Errorf("HTTP Header %q = %s, want %s", headerApiOrg, got, want)
}
if got, want := r.Header.Get(headerApiKey), testApiKey; got != want {
t.Errorf("HTTP Header %q = %s, want %s", headerApiKey, got, want)
}
}
func newTestHandler(t *testing.T, mockStatus int, mockResponse any, validateRequest func(t *testing.T, r *http.Request)) http.HandlerFunc {
mr, err := json.Marshal(mockResponse)
if err != nil {
t.Fatalf("could not marshal mock response: %v", err)
}
return func(w http.ResponseWriter, r *http.Request) {
if validateRequest != nil {
validateRequest(t, r)
}
checkHeaders(t, r)
w.WriteHeader(mockStatus)
_, err := w.Write(mr)
if err != nil {
t.Errorf("could not write mock response body: %v", err)
}
}
}
func newTestHandlerMultiPage(t *testing.T, mockResponses []any, validateRequest func(t *testing.T, r *http.Request)) http.HandlerFunc {
var mrs [][]byte
for _, r := range mockResponses {
mr, err := json.Marshal(r)
if err != nil {
t.Fatalf("could not marshal mock response: %v", err)
}
mrs = append(mrs, mr)
}
return func(w http.ResponseWriter, r *http.Request) {
if validateRequest != nil {
validateRequest(t, r)
}
parsed, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
t.Errorf("could not parse URL query: %v", err)
}
var page int
if pages := parsed["page"]; len(pages) >= 1 {
page, err = strconv.Atoi(parsed["page"][0])
if err != nil {
t.Errorf("could not parse page as int: %v", err)
}
}
checkHeaders(t, r)
w.WriteHeader(http.StatusOK)
_, err = w.Write(mrs[page])
if err != nil {
t.Errorf("could not write mock response body: %v", err)
}
}
}
func TestCreateReserveIDsRequest(t *testing.T) {
tests := []struct {
opts ReserveOptions
wantParams string
}{
{
opts: ReserveOptions{
NumIDs: 1,
Year: 2000,
Mode: SequentialRequest,
},
wantParams: "amount=1&cve_year=2000&short_name=test_api_org",
},
{
opts: ReserveOptions{
NumIDs: 2,
Year: 2022,
Mode: SequentialRequest,
},
wantParams: "amount=2&batch_type=sequential&cve_year=2022&short_name=test_api_org",
},
{
opts: ReserveOptions{
NumIDs: 3,
Year: 2010,
Mode: NonsequentialRequest,
},
wantParams: "amount=3&batch_type=nonsequential&cve_year=2010&short_name=test_api_org",
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("NumIDs=%d/Year=%d/Mode=%s", test.opts.NumIDs, test.opts.Year, test.opts.Mode), func(t *testing.T) {
c, s := newTestClientAndServer(nil)
defer s.Close()
req, err := c.createReserveIDsRequest(test.opts)
if err != nil {
t.Fatalf("unexpected error getting reserve ID request: %v", err)
}
if got, want := req.URL.RawQuery, test.wantParams; got != want {
t.Errorf("incorrect request params: got %v, want %v", got, want)
}
})
}
}
type queryFunc func(t *testing.T, c *Client) (any, error)
var (
reserveIDsQuery = func(t *testing.T, c *Client) (any, error) {
return c.ReserveIDs(ReserveOptions{
NumIDs: 2,
Year: 2002,
Mode: SequentialRequest,
})
}
retrieveQuotaQuery = func(t *testing.T, c *Client) (any, error) {
return c.RetrieveQuota()
}
retrieveIDQuery = func(t *testing.T, c *Client) (any, error) {
return c.RetrieveID(getDefaultTestCVERecord(t).Metadata.ID)
}
retrieveRecordQuery = func(t *testing.T, c *Client) (any, error) {
return c.RetrieveRecord(getDefaultTestCVERecord(t).Metadata.ID)
}
createRecordQuery = func(t *testing.T, c *Client) (any, error) {
return c.CreateRecord(defaultTestCVE.ID, &getDefaultTestCVERecord(t).Containers)
}
updateRecordQuery = func(t *testing.T, c *Client) (any, error) {
return c.UpdateRecord(defaultTestCVE.ID, &getDefaultTestCVERecord(t).Containers)
}
retrieveOrgQuery = func(t *testing.T, c *Client) (any, error) {
return c.RetrieveOrg()
}
listOrgCVEsQuery = func(t *testing.T, c *Client) (any, error) {
return c.ListOrgCVEs(&ListOptions{})
}
)
func TestAllSuccess(t *testing.T) {
defaultTestCVERecord := getDefaultTestCVERecord(t)
tests := []struct {
name string
mockStatus int
mockResponse any
query queryFunc
wantHTTPMethod string
wantPath string
want any
}{
{
name: "ReserveIDs",
query: reserveIDsQuery,
mockStatus: http.StatusOK,
mockResponse: reserveIDsResponse{
CVEs: defaultTestCVEs},
wantHTTPMethod: http.MethodPost,
wantPath: "/api/cve-id",
want: defaultTestCVEs,
},
{
name: "ReserveIDs/partial content ok",
query: reserveIDsQuery,
mockStatus: http.StatusPartialContent,
mockResponse: reserveIDsResponse{
CVEs: AssignedCVEList{defaultTestCVE}},
wantHTTPMethod: http.MethodPost,
wantPath: "/api/cve-id",
want: AssignedCVEList{defaultTestCVE},
},
{
name: "RetrieveQuota",
query: retrieveQuotaQuery,
mockStatus: http.StatusOK,
mockResponse: defaultTestQuota,
wantHTTPMethod: http.MethodGet,
wantPath: "/api/org/test_api_org/id_quota",
want: defaultTestQuota,
},
{
name: "RetrieveID",
query: retrieveIDQuery,
mockStatus: http.StatusOK,
mockResponse: defaultTestCVE,
wantHTTPMethod: http.MethodGet,
wantPath: "/api/cve-id/CVE-2022-0000",
want: &defaultTestCVE,
},
{
name: "RetrieveRecord",
query: retrieveRecordQuery,
mockStatus: http.StatusOK,
mockResponse: defaultTestCVERecord,
wantHTTPMethod: http.MethodGet,
wantPath: "/api/cve/CVE-2022-0000",
want: defaultTestCVERecord,
},
{
name: "CreateRecord",
query: createRecordQuery,
mockStatus: http.StatusOK,
mockResponse: createResponse{*defaultTestCVERecord},
wantHTTPMethod: http.MethodPost,
wantPath: "/api/cve/CVE-2022-0000/cna",
want: defaultTestCVERecord,
},
{
name: "UpdateRecord",
query: updateRecordQuery,
mockStatus: http.StatusOK,
mockResponse: updateResponse{*defaultTestCVERecord},
wantHTTPMethod: http.MethodPut,
wantPath: "/api/cve/CVE-2022-0000/cna",
want: defaultTestCVERecord,
},
{
name: "RetrieveOrg",
query: retrieveOrgQuery,
mockStatus: http.StatusOK,
mockResponse: defaultTestOrg,
wantHTTPMethod: http.MethodGet,
wantPath: "/api/org/test_api_org",
want: defaultTestOrg,
},
{
name: "ListOrgCVEs/single page",
query: listOrgCVEsQuery,
mockStatus: http.StatusOK,
mockResponse: listOrgCVEsResponse{
CurrentPage: 0,
NextPage: -1,
CVEs: defaultTestCVEs,
},
wantHTTPMethod: http.MethodGet,
wantPath: "/api/cve-id",
want: defaultTestCVEs,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
validateRequest := func(t *testing.T, r *http.Request) {
if got, want := r.Method, test.wantHTTPMethod; got != want {
t.Errorf("incorrect HTTP method: got %v, want %v", got, want)
}
if got, want := r.URL.Path, test.wantPath; got != want {
t.Errorf("incorrect request URL path: got %v, want %v", got, want)
}
}
c, s := newTestClientAndServer(
newTestHandler(t, test.mockStatus, test.mockResponse, validateRequest))
defer s.Close()
got, err := test.query(t, c)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if want := test.want; !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
})
}
}
func TestAllFail(t *testing.T) {
tests := []struct {
name string
query queryFunc
}{
{
name: "ReserveIDs",
query: reserveIDsQuery,
},
{
name: "RetrieveQuota",
query: retrieveQuotaQuery,
},
{
name: "RetrieveID",
query: retrieveIDQuery,
},
{
name: "RetrieveRecord",
query: retrieveRecordQuery,
},
{
name: "CreateRecord",
query: createRecordQuery,
},
{
name: "UpdateRecord",
query: updateRecordQuery,
},
{
name: "RetrieveOrg",
query: retrieveOrgQuery,
},
{
name: "ListOrgCVEs",
query: listOrgCVEsQuery,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
mockStatus := http.StatusUnauthorized
mockResponse := apiError{
Error: "more info",
Message: "even more info",
}
c, s := newTestClientAndServer(newTestHandler(t, mockStatus, mockResponse, nil))
defer s.Close()
want := "401 Unauthorized: more info: even more info"
_, err := test.query(t, c)
if err == nil {
t.Fatalf("unexpected success: want err %v", want)
}
if got := err.Error(); !strings.Contains(got, want) {
t.Errorf("unexpected error string: got %v, want %v", got, want)
}
})
}
}
func TestCreateListOrgCVEsRequest(t *testing.T) {
tests := []struct {
opts ListOptions
page int
wantParams string
}{
{
opts: ListOptions{
State: cveschema.StateReserved,
Year: 2000,
ReservedBefore: &testTime2022,
ReservedAfter: &testTime1999,
ModifiedBefore: &testTime2000,
ModifiedAfter: &testTime1992,
},
page: 0,
wantParams: "cve_id_year=2000&state=RESERVED&time_modified.gt=1992-01-01T00%3A00%3A00Z&time_modified.lt=2000-01-01T00%3A00%3A00Z&time_reserved.gt=1999-01-01T00%3A00%3A00Z&time_reserved.lt=2022-01-01T00%3A00%3A00Z",
},
{
opts: ListOptions{
State: cveschema.StateRejected,
Year: 1999,
ReservedBefore: &testTime1999,
ReservedAfter: &testTime2000,
ModifiedBefore: &testTime1992,
ModifiedAfter: &testTime2022,
},
page: 1,
wantParams: "cve_id_year=1999&page=1&state=REJECT&time_modified.gt=2022-01-01T00%3A00%3A00Z&time_modified.lt=1992-01-01T00%3A00%3A00Z&time_reserved.gt=2000-01-01T00%3A00%3A00Z&time_reserved.lt=1999-01-01T00%3A00%3A00Z",
},
{
opts: ListOptions{
State: cveschema.StatePublic,
Year: 2000,
},
page: 2,
wantParams: "cve_id_year=2000&page=2&state=PUBLIC",
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("State=%s/Year=%d/ReservedBefore=%s/ReservedAfter=%s/ModifiedBefore=%s/ModifiedAfter=%s", test.opts.State, test.opts.Year, test.opts.ReservedBefore, test.opts.ReservedAfter, test.opts.ModifiedBefore, test.opts.ModifiedAfter), func(t *testing.T) {
c, s := newTestClientAndServer(nil)
defer s.Close()
req, err := c.createListOrgCVEsRequest(&test.opts, test.page)
if err != nil {
t.Fatalf("unexpected error creating ListOrgCVEs request: %v", err)
}
if got, want := req.URL.RawQuery, test.wantParams; got != want {
t.Errorf("incorrect request params: got %v, want %v", got, want)
}
})
}
}
func TestListOrgCVEsMultiPage(t *testing.T) {
extraCVE := newTestCVE("CVE-2000-1234", cveschema.StateReserved, "2000")
mockResponses := []any{
listOrgCVEsResponse{
CurrentPage: 0,
NextPage: 1,
CVEs: defaultTestCVEs,
},
listOrgCVEsResponse{
CurrentPage: 1,
NextPage: -1,
CVEs: AssignedCVEList{extraCVE},
},
}
c, s := newTestClientAndServer(
newTestHandlerMultiPage(t, mockResponses, nil))
defer s.Close()
got, err := c.ListOrgCVEs(nil)
if err != nil {
t.Fatalf("unexpected error listing org cves: %v", err)
}
want := append(defaultTestCVEs, extraCVE)
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}