blob: 15e519dcbc7283bb0e5b68b74f65d6f73eedde4d [file] [log] [blame]
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ghsa
import (
"context"
"flag"
"os"
"reflect"
"strings"
"testing"
"time"
)
var githubTokenFile = flag.String("ghtokenfile", "",
"path to file containing GitHub access token")
var githubToken = flag.String("ghtoken", os.Getenv("VULN_GITHUB_ACCESS_TOKEN"), "GitHub access token")
func mustGetAccessToken(t *testing.T) string {
var token string
switch {
case *githubToken != "":
token = *githubToken
case *githubTokenFile != "":
bytes, err := os.ReadFile(*githubTokenFile)
if err != nil {
t.Fatal(err)
}
token = string(bytes)
default:
t.Skip("neither -ghtokenfile nor -ghtoken provided")
}
return strings.TrimSpace(string(token))
}
func setupClient(ctx context.Context, t *testing.T) *Client {
t.Helper()
accessToken := mustGetAccessToken(t)
return NewClient(ctx, accessToken)
}
func TestList(t *testing.T) {
ctx := context.Background()
c := setupClient(ctx, t)
// There were at least three relevant SAs since this date.
since := time.Date(2022, 9, 1, 0, 0, 0, 0, time.UTC)
got, err := c.List(context.Background(), since)
if err != nil {
t.Fatal(err)
}
want := 3
if len(got) < want {
t.Errorf("got %d, want at least %d", len(got), want)
}
}
func TestFetchGHSA(t *testing.T) {
ctx := context.Background()
c := setupClient(ctx, t)
// Real GHSA that should be found.
const ghsaID string = "GHSA-g9mp-8g3h-3c5c"
got, err := c.FetchGHSA(context.Background(), ghsaID)
if err != nil {
t.Fatal(err)
}
if gotID, want := got.ID, ghsaID; gotID != want {
t.Errorf("got GHSA with id %q, want %q", got.ID, want)
}
}
func TestListForCVE(t *testing.T) {
ctx := context.Background()
c := setupClient(ctx, t)
tests := []struct {
name string
cve string
want []string
}{
{
name: "Real CVE/GHSA",
cve: "CVE-2022-27191",
want: []string{"GHSA-8c26-wmh5-6g9v"},
},
{
name: "Check exact matching",
cve: "CVE-2022-2529",
want: []string{"GHSA-9rpw-2h95-666c"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := c.ListForCVE(ctx, tt.cve)
if err != nil {
t.Errorf("ListForCVE() error = %v", err)
return
}
gotIDs := []string{}
for _, sa := range got {
gotIDs = append(gotIDs, sa.ID)
}
if !reflect.DeepEqual(gotIDs, tt.want) {
t.Errorf("ListForCVE() = %v, want %v", gotIDs, tt.want)
}
})
}
}