internal/{frontend,vuln}: update ByAlias
The function vuln.ByAlias now returns at most one result, a Go ID string
that corresponds to the alias. This avoids an extra HTTP request whose
result is discarded anyway.
Change-Id: I93f02a28350d38daace1f8c74b3b7884a227331b
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/484455
Run-TryBot: Tatiana Bradley <tatianabradley@google.com>
Reviewed-by: Julie Qiu <julieqiu@google.com>
Reviewed-by: Tatiana Bradley <tatianabradley@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
diff --git a/internal/frontend/search.go b/internal/frontend/search.go
index eeb75ef..7cf96c1 100644
--- a/internal/frontend/search.go
+++ b/internal/frontend/search.go
@@ -404,26 +404,11 @@
if mode != searchModeVuln || !vuln.IsAlias(cq) || vc == nil {
return nil, nil
}
- aliasEntries, err := vc.ByAlias(ctx, cq)
+ id, err := vc.ByAlias(ctx, cq)
if err != nil {
- return nil, err
+ return nil, &serverError{status: derrors.ToStatus(err)}
}
- switch len(aliasEntries) {
- case 0:
- return nil, &serverError{status: http.StatusNotFound}
- case 1:
- return &searchAction{redirectURL: "/vuln/" + aliasEntries[0].ID}, nil
- default:
- var entries []OSVEntry
- for _, e := range aliasEntries {
- entries = append(entries, OSVEntry{e})
- }
- return &searchAction{
- title: fmt.Sprintf("%s - Vulnerability Reports", cq),
- template: "vuln/list",
- page: &VulnListPage{Entries: entries},
- }, nil
- }
+ return &searchAction{redirectURL: "/vuln/" + id}, nil
}
// searchMode reports whether the search performed should be in package or
diff --git a/internal/frontend/search_test.go b/internal/frontend/search_test.go
index 8ade9d4..ba7b2b8 100644
--- a/internal/frontend/search_test.go
+++ b/internal/frontend/search_test.go
@@ -6,6 +6,7 @@
import (
"context"
+ "errors"
"fmt"
"net/http"
"net/http/httptest"
@@ -122,6 +123,11 @@
wantRedirect: "/vuln/GO-1990-01",
},
{
+ name: "vuln alias with no match",
+ query: "q=GHSA-aaaa-bbbb-dddd",
+ wantStatus: http.StatusNotFound,
+ },
+ {
// An explicit mode overrides that.
name: "vuln alias symbol mode",
query: "q=GHSA-aaaa-bbbb-cccc?m=symbol",
@@ -141,9 +147,9 @@
}
gotAction, err := determineSearchAction(req, ds, vc)
if err != nil {
- serr, ok := err.(*serverError)
- if !ok {
- t.Fatal(err)
+ var serr *serverError
+ if !errors.As(err, &serr) {
+ t.Fatalf("got err %#v, want type *serverError", err)
}
if g, w := serr.status, test.wantStatus; g != w {
t.Errorf("got status %d, want %d", g, w)
diff --git a/internal/vuln/client.go b/internal/vuln/client.go
index 3e7c677..da52f2c 100644
--- a/internal/vuln/client.go
+++ b/internal/vuln/client.go
@@ -183,55 +183,35 @@
return &entry, nil
}
-// ByAlias returns the OSV entries that have the given alias, or (nil, nil)
-// if there are none.
-// It returns a list for compatibility with the legacy implementation,
-// but the list always contains at most one element.
-func (c *Client) ByAlias(ctx context.Context, alias string) (_ []*osv.Entry, err error) {
+// ByAlias returns the Go ID of the OSV entry that has the given alias,
+// or a NotFound error if there isn't one.
+func (c *Client) ByAlias(ctx context.Context, alias string) (_ string, err error) {
derrors.Wrap(&err, "ByAlias(%s)", alias)
b, err := c.vulns(ctx)
if err != nil {
- return nil, err
+ return "", err
}
dec, err := newStreamDecoder(b)
if err != nil {
- return nil, err
+ return "", err
}
- var id string
for dec.More() {
var v VulnMeta
err := dec.Decode(&v)
if err != nil {
- return nil, err
+ return "", err
}
for _, vAlias := range v.Aliases {
if alias == vAlias {
- id = v.ID
- break
+ return v.ID, nil
}
}
- if id != "" {
- break
- }
}
- if id == "" {
- return nil, nil
- }
-
- entry, err := c.ByID(ctx, id)
- if err != nil {
- return nil, err
- }
-
- if entry == nil {
- return nil, fmt.Errorf("vulnerability %s was found in %s but could not be retrieved", id, vulnsEndpoint)
- }
-
- return []*osv.Entry{entry}, nil
+ return "", derrors.NotFound
}
// IDs returns a list of the IDs of all the entries in the database.
diff --git a/internal/vuln/client_test.go b/internal/vuln/client_test.go
index 8f3d46c..66437b1 100644
--- a/internal/vuln/client_test.go
+++ b/internal/vuln/client_test.go
@@ -262,24 +262,25 @@
}
tests := []struct {
- name string
- alias string
- want []*osv.Entry
+ name string
+ alias string
+ want string
+ wantErr bool
}{
{
name: "CVE",
alias: "CVE-1999-1111",
- want: []*osv.Entry{&testOSV1},
+ want: testOSV1.ID,
},
{
name: "GHSA",
alias: "GHSA-xxxx-yyyy-zzzz",
- want: []*osv.Entry{&testOSV3},
+ want: testOSV3.ID,
},
{
- name: "Not found",
- alias: "CVE-0000-0000",
- want: nil,
+ name: "Not found",
+ alias: "CVE-0000-0000",
+ wantErr: true,
},
}
@@ -287,11 +288,15 @@
t.Run(test.name, func(t *testing.T) {
ctx := context.Background()
got, err := c.ByAlias(ctx, test.alias)
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(got, test.want) {
- t.Errorf("ByAlias(%s) = %v, want %v", test.alias, got, test.want)
+ if !test.wantErr {
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, test.want) {
+ t.Errorf("ByAlias(%s) = %v, want %v", test.alias, got, test.want)
+ }
+ } else if err == nil {
+ t.Errorf("ByAlias(%s) = %v, want error", test.alias, got)
}
})
}