blob: 6eff01c8bdcfcd3ef41df082e94556f2d5b8cd85 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vuln
import (
"bytes"
"compress/gzip"
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"testing"
"golang.org/x/pkgsite/internal/osv"
)
func TestNewSource(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("windows is not supported (see convertFileURLPath")
}
t.Run("https", func(t *testing.T) {
url := "https://vuln.go.dev"
s, err := NewSource(url)
if err != nil {
t.Fatal(err)
}
if _, ok := s.(*httpSource); !ok {
t.Errorf("NewSource(%s) = %#v, want type *httpSource ", url, s)
}
})
t.Run("file", func(t *testing.T) {
fileURL := "file:///" + t.TempDir()
s, err := NewSource(fileURL)
if err != nil {
t.Fatal(err)
}
if _, ok := s.(*localSource); !ok {
t.Errorf("NewSource(%s) = %#v, want type *localSource", fileURL, s)
}
})
}
func TestHTTPSource(t *testing.T) {
want := []byte("some data")
gzipped, err := gzipped(want)
if err != nil {
t.Fatal(err)
}
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/test/endpoint.json.gz" {
if _, err := rw.Write(gzipped); err != nil {
rw.WriteHeader(http.StatusInternalServerError)
}
return
}
rw.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
src := httpSource{
url: server.URL,
c: server.Client(),
}
got, err := src.get(context.Background(), "test/endpoint")
if err != nil {
t.Fatal(err)
}
if string(got) != string(want) {
t.Errorf("httpSource.get = %s, want %s", got, want)
}
}
func TestLocalSource(t *testing.T) {
temp := t.TempDir()
if err := os.Mkdir(filepath.Join(temp, "test"), 0755); err != nil {
t.Fatal(err)
}
want := []byte("some data")
if err := os.WriteFile(filepath.Join(temp, "test/endpoint.json"), want, 0644); err != nil {
t.Fatal(err)
}
src := localSource{
dir: temp,
}
got, err := src.get(context.Background(), "test/endpoint")
if err != nil {
t.Fatal(err)
}
if string(got) != string(want) {
t.Errorf("localSource.get = %s, want %s", got, want)
}
}
func TestInMemorySource(t *testing.T) {
want := []byte("some data")
src := inMemorySource{
data: map[string][]byte{
"test/endpoint": want,
},
}
got, err := src.get(context.Background(), "test/endpoint")
if err != nil {
t.Fatal(err)
}
if string(got) != string(want) {
t.Errorf("inMemorySource.get = %s, want %s", got, want)
}
}
func TestNewInMemorySource(t *testing.T) {
fromTxtar, err := newTestClientFromTxtar(dbTxtar)
if err != nil {
t.Fatal(err)
}
fromEntries, err := newInMemorySource([]*osv.Entry{&testOSV1, &testOSV2, &testOSV3})
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
endpoints := []string{dbEndpoint, modulesEndpoint, vulnsEndpoint, idDir + "/" + testOSV1.ID, idDir + "/" + testOSV2.ID, idDir + "/" + testOSV3.ID}
for _, endpoint := range endpoints {
got, err := fromEntries.get(ctx, endpoint)
if err != nil {
t.Fatal(err)
}
want, err := fromTxtar.src.get(ctx, endpoint)
if err != nil {
t.Fatal(err)
}
if string(got) != string(want) {
t.Errorf("newInMemorySource().get(%q) = %s, want %s", endpoint, got, want)
}
}
}
func gzipped(data []byte) ([]byte, error) {
var b bytes.Buffer
w := gzip.NewWriter(&b)
defer w.Close()
if _, err := w.Write(data); err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return b.Bytes(), nil
}