internal/{database,osv}: add more robust validation for osv entries

Adds functions to validate OSV entries, and calls these functions in
both unit tests and pre-deploy checks.

Change-Id: Id5ddbb6c1a5c81b9176491d5cf1a88fbae928606
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/495498
Run-TryBot: Tatiana Bradley <tatianabradley@google.com>
Reviewed-by: Tatiana Bradley <tatianabradley@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/all_test.go b/all_test.go
index d632c29..2e16df4 100644
--- a/all_test.go
+++ b/all_test.go
@@ -21,6 +21,7 @@
 	"github.com/google/go-cmp/cmp"
 	"github.com/google/go-cmp/cmp/cmpopts"
 	"golang.org/x/vulndb/internal/cveschema5"
+	"golang.org/x/vulndb/internal/osvutils"
 	"golang.org/x/vulndb/internal/report"
 )
 
@@ -99,6 +100,9 @@
 				if diff := cmp.Diff(generated, current, cmpopts.EquateEmpty()); diff != "" {
 					t.Errorf("%s does not match report:\n%v", osvFilename, diff)
 				}
+				if err := osvutils.ValidateExceptTimestamps(&current); err != nil {
+					t.Error(err)
+				}
 			}
 			if r.CVEMetadata != nil {
 				generated, err := r.ToCVE5(goID)
diff --git a/internal/database/load.go b/internal/database/load.go
index b37fd7b..094b65c 100644
--- a/internal/database/load.go
+++ b/internal/database/load.go
@@ -14,6 +14,7 @@
 	"golang.org/x/exp/slices"
 	"golang.org/x/vulndb/internal/derrors"
 	"golang.org/x/vulndb/internal/osv"
+	"golang.org/x/vulndb/internal/osvutils"
 	"golang.org/x/vulndb/internal/report"
 )
 
@@ -117,7 +118,7 @@
 		"index.json", // index.json is OK to accommodate legacy spec
 	}
 	for _, entry := range db.Entries {
-		if err = validateEntry(entry); err != nil {
+		if err = osvutils.Validate(&entry); err != nil {
 			return err
 		}
 		path := filepath.Join(idPath, entry.ID+".json")
@@ -145,16 +146,6 @@
 	return nil
 }
 
-func validateEntry(entry osv.Entry) error {
-	if entry.Modified.IsZero() {
-		return fmt.Errorf("%s: modified time must be non-zero (found %s)", entry.ID, entry.Modified)
-	}
-	if entry.Published.After(entry.Modified.Time) {
-		return fmt.Errorf("%s: published time (%s) cannot be after modified time (%s)", entry.ID, entry.Published, entry.Modified)
-	}
-	return nil
-}
-
 // checkFiles ensures that filepath and filepath+".gz" exist and
 // have contents consistent with v.
 // Returns an error if:
diff --git a/internal/osvutils/validate.go b/internal/osvutils/validate.go
new file mode 100644
index 0000000..e3ebba5
--- /dev/null
+++ b/internal/osvutils/validate.go
@@ -0,0 +1,251 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package osvutils provides utilities for working with Go OSV entries.
+// It is separated from package osv because that package
+// promises to only import from the standard library.
+package osvutils
+
+import (
+	"errors"
+	"fmt"
+	"regexp"
+	"strings"
+
+	"golang.org/x/mod/semver"
+	"golang.org/x/vulndb/internal/derrors"
+	"golang.org/x/vulndb/internal/osv"
+)
+
+// Validate errors if there are any problems with the OSV Entry.
+// It is used to validate OSV entries before publishing them to the
+// Go vulnerability database, and has stricter requirements than
+// the general OSV format.
+func Validate(e *osv.Entry) (err error) {
+	derrors.Wrap(&err, "Validate(%s)", e.ID)
+	return validate(e, true)
+}
+
+// ValidateExceptTimestamps errors if there are any problems with the
+// OSV Entry, with the exception of the timestamps (published, modified and
+// withdrawn) which are not checked.
+// This is used to validate entries at CL submit time, before their timestamps
+// are corrected.
+func ValidateExceptTimestamps(e *osv.Entry) (err error) {
+	derrors.Wrap(&err, "ValidateExceptTimestamps(%s)", e.ID)
+	return validate(e, false)
+}
+
+var (
+	// Errors for incorrect timestamps.
+	errNoModified             = errors.New("modified time must be non-zero")
+	errNoPublished            = errors.New("published time must be non-zero")
+	errPublishedAfterModified = errors.New("published time cannot be after modified time")
+
+	// Errors for missing fields.
+	errNoID                = errors.New("id field is empty")
+	errNoSchemaVersion     = errors.New("schema_version field is empty")
+	errNoDetails           = errors.New("details field is empty")
+	errNoAffected          = errors.New("affected field is empty")
+	errNoReferences        = errors.New("references field is empty")
+	errNoDatabaseSpecific  = errors.New("database_specific field is empty")
+	errNoModule            = errors.New("affected field missing module path")
+	errNotGoEcosystem      = errors.New("affected ecosystem is not Go")
+	errNoRanges            = errors.New("affected field contains no ranges")
+	errNoEcosystemSpecific = errors.New("affected field contains no ecosytem_specific field")
+	errNoPackages          = errors.New("affected.ecosystem_specific field has no packages")
+	errNoPackagePath       = errors.New("affected.ecosystem_specific.imports field has no package path")
+
+	// Errors for invalid fields.
+	errInvalidAlias           = errors.New("alias must be CVE or GHSA ID")
+	errInvalidPkgsiteURL      = errors.New("database_specific.URL must be a link to https://pkg.go.dev/vuln/<Go id>")
+	errInvalidPackagePath     = errors.New("package path must be prefixed by module path")
+	errTooManyRanges          = errors.New("each module should have exactly one version range")
+	errRangeTypeNotSemver     = errors.New("range type must be SEMVER")
+	errNoRangeEvents          = errors.New("range must contain one or more events")
+	errOutOfOrderRange        = errors.New("introduced and fixed versions must alternate")
+	errUnsortedRange          = errors.New("range events must be in strictly ascending order")
+	errNoIntroducedOrFixed    = errors.New("introduced or fixed must be set")
+	errBothIntroducedAndFixed = errors.New("introduced and fixed cannot both be set in same event")
+	errInvalidSemver          = errors.New("invalid or non-canonical semver version")
+
+	// Regular expressions.
+	ghsaRegex        = regexp.MustCompile(`^GHSA-[^-]{4}-[^-]{4}-[^-]{4}$`)
+	cveRegex         = regexp.MustCompile(`^CVE-\d{4}-\d{4,}$`)
+	pkgsiteLinkRegex = regexp.MustCompile(`^https://pkg.go.dev/vuln/GO-\d{4}-\d{4,}$`)
+)
+
+func validate(e *osv.Entry, checkTimestamps bool) (err error) {
+	if checkTimestamps {
+		switch {
+		case e.Modified.IsZero():
+			return errNoModified
+		case e.Published.IsZero():
+			return errNoPublished
+		case e.Published.After(e.Modified.Time):
+			return fmt.Errorf("%w (published=%s, modified=%s)", errPublishedAfterModified, e.Published, e.Modified)
+		}
+	}
+
+	// Check for missing required fields.
+	switch {
+	case e.ID == "":
+		return errNoID
+	case e.SchemaVersion == "":
+		return errNoSchemaVersion
+	case e.Details == "":
+		return errNoDetails
+	case len(e.Affected) == 0:
+		return errNoAffected
+	case len(e.References) == 0:
+		return errNoReferences
+	case e.DatabaseSpecific == nil:
+		return errNoDatabaseSpecific
+	}
+
+	for _, a := range e.Affected {
+		if err := validateAffected(&a); err != nil {
+			return err
+		}
+	}
+	for _, alias := range e.Aliases {
+		if !ghsaRegex.MatchString(alias) && !cveRegex.MatchString(alias) {
+			return fmt.Errorf("%w (found alias %s)", errInvalidAlias, alias)
+		}
+	}
+
+	return validateDatabaseSpecific(e.DatabaseSpecific)
+}
+
+func validateAffected(a *osv.Affected) error {
+	switch {
+	case a.Module.Path == "":
+		return errNoModule
+	case a.Module.Ecosystem != osv.GoEcosystem:
+		return errNotGoEcosystem
+	}
+
+	if err := validateRanges(a.Ranges); err != nil {
+		return err
+	}
+
+	return validateEcosystemSpecific(a.EcosystemSpecific, a.Module.Path)
+}
+
+func validateRanges(ranges []osv.Range) error {
+	switch {
+	case len(ranges) == 0:
+		return errNoRanges
+	case len(ranges) > 1:
+		return fmt.Errorf("%w (found %d ranges)", errTooManyRanges, len(ranges))
+	}
+
+	return validateRange(&ranges[0])
+}
+
+func validateRange(r *osv.Range) error {
+	switch {
+	case r.Type != osv.RangeTypeSemver:
+		return fmt.Errorf("%w (found range type %q)",
+			errRangeTypeNotSemver, r.Type)
+	case len(r.Events) == 0:
+		return errNoRangeEvents
+	}
+
+	// Check that all the events are valid and sorted in ascending order.
+	prev, err := parseRangeEvent(&r.Events[0])
+	if err != nil {
+		return err
+	}
+	for _, event := range r.Events[1:] {
+		current, err := parseRangeEvent(&event)
+		if err != nil {
+			return fmt.Errorf("invalid range event: %w", err)
+		}
+		// Introduced and fixed versions must alternate.
+		if current.introduced == prev.introduced {
+			return errOutOfOrderRange
+		}
+		if !less(prev.v, current.v) {
+			return fmt.Errorf("%w (found %s>=%s)", errUnsortedRange, prev.v, current.v)
+		}
+		prev = current
+	}
+
+	return nil
+}
+
+func less(v, w string) bool {
+	// Ensure that version 0 is always lowest.
+	if v == "0" {
+		return true
+	}
+	if w == "0" {
+		return false
+	}
+	return semver.Compare("v"+v, "v"+w) < 0
+}
+
+type version struct {
+	v          string
+	introduced bool
+}
+
+func parseRangeEvent(e *osv.RangeEvent) (*version, error) {
+	introduced, fixed := e.Introduced, e.Fixed
+
+	var v string
+	var isIntroduced bool
+	switch {
+	case introduced == "" && fixed == "":
+		return nil, errNoIntroducedOrFixed
+	case introduced != "" && fixed != "":
+		return nil, errBothIntroducedAndFixed
+	case introduced == "0":
+		return &version{v: "0", introduced: true}, nil
+	case introduced != "":
+		v = introduced
+		isIntroduced = true
+	case fixed != "":
+		v = fixed
+		isIntroduced = false
+	}
+
+	if sv := "v" + v; !semver.IsValid(sv) || semver.Canonical(sv) != sv {
+		return nil, fmt.Errorf("%w (found %s)", errInvalidSemver, v)
+	}
+
+	return &version{v: v, introduced: isIntroduced}, nil
+}
+
+func validateEcosystemSpecific(es *osv.EcosystemSpecific, module string) error {
+	if es == nil {
+		return errNoEcosystemSpecific
+	}
+
+	if len(es.Packages) == 0 {
+		return errNoPackages
+	}
+
+	for _, pkg := range es.Packages {
+		if pkg.Path == "" {
+			return errNoPackagePath
+		}
+		// Package path must be prefixed by module path unless it is
+		// in the Go standard library or toolchain.
+		if (module != osv.GoStdModulePath && module != osv.GoCmdModulePath) &&
+			!strings.HasPrefix(pkg.Path, module) {
+			return fmt.Errorf("%w (found module=%q, package=%q)", errInvalidPackagePath, module, pkg.Path)
+		}
+	}
+
+	return nil
+}
+
+func validateDatabaseSpecific(d *osv.DatabaseSpecific) error {
+	if !pkgsiteLinkRegex.MatchString(d.URL) {
+		return fmt.Errorf("%w (found URL %q)", errInvalidPkgsiteURL, d.URL)
+	}
+	return nil
+}
diff --git a/internal/osvutils/validate_test.go b/internal/osvutils/validate_test.go
new file mode 100644
index 0000000..5a371ad
--- /dev/null
+++ b/internal/osvutils/validate_test.go
@@ -0,0 +1,328 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package osvutils
+
+import (
+	"errors"
+	"testing"
+	"time"
+
+	"golang.org/x/vulndb/internal/osv"
+)
+
+var (
+	jan1999 = osv.Time{Time: time.Date(1999, 1, 1, 0, 0, 0, 0, time.UTC)}
+	jan2000 = osv.Time{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}
+)
+
+// testEntry creates a valid Entry and modifies it by running transform.
+// If transform is nil, it returns the base test Entry.
+func testEntry(transform func(e *osv.Entry)) *osv.Entry {
+	e := &osv.Entry{
+		SchemaVersion: "1.3.1",
+		ID:            "GO-1999-0001",
+		Published:     jan1999,
+		Modified:      jan2000,
+		Aliases:       []string{"CVE-1999-1111"},
+		Details:       "Some details",
+		Affected: []osv.Affected{
+			{
+				Module: osv.Module{
+					Path:      "example.com/module",
+					Ecosystem: "Go",
+				},
+				Ranges: []osv.Range{
+					{
+						Type: "SEMVER",
+						Events: []osv.RangeEvent{
+							{Introduced: "0"}, {Fixed: "1.1.0"},
+							{Introduced: "1.2.0"},
+							{Fixed: "1.2.2"},
+						}}},
+				EcosystemSpecific: &osv.EcosystemSpecific{
+					Packages: []osv.Package{{Path: "example.com/module/package", Symbols: []string{"Symbol"}}}}},
+			{
+				Module: osv.Module{
+					Path:      "stdlib",
+					Ecosystem: "Go",
+				},
+				Ranges: []osv.Range{
+					{
+						Type: "SEMVER",
+						Events: []osv.RangeEvent{
+							{Introduced: "0"}, {Fixed: "1.1.0"},
+							{Introduced: "1.2.0"},
+							{Fixed: "1.2.2"},
+						}}},
+				EcosystemSpecific: &osv.EcosystemSpecific{
+					Packages: []osv.Package{{Path: "package", Symbols: []string{"Symbol"}}}}},
+		},
+		References: []osv.Reference{
+			{Type: "FIX", URL: "https://example.com/cl/123"},
+		},
+		DatabaseSpecific: &osv.DatabaseSpecific{
+			URL: "https://pkg.go.dev/vuln/GO-1999-0001"}}
+
+	if transform == nil {
+		return e
+	}
+
+	transform(e)
+	return e
+}
+
+func TestValidate(t *testing.T) {
+	t.Run("ok", func(t *testing.T) {
+		if err := Validate(testEntry(nil)); err != nil {
+			t.Error(err)
+		}
+		if err := ValidateExceptTimestamps(testEntry(nil)); err != nil {
+			t.Error("ValidateExceptTimestamps():", err)
+		}
+	})
+
+	t.Run("timestamps", func(t *testing.T) {
+		for _, tc := range []struct {
+			name    string
+			entry   *osv.Entry
+			wantErr error
+		}{
+			{
+				name: "no modified",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Modified = osv.Time{}
+				}),
+				wantErr: errNoModified,
+			},
+			{
+				name: "no published",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Published = osv.Time{}
+				}),
+				wantErr: errNoPublished,
+			},
+			{
+				name: "published after modified",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Modified = jan1999
+					e.Published = jan2000
+				}),
+				wantErr: errPublishedAfterModified,
+			},
+		} {
+			t.Run(tc.name, func(t *testing.T) {
+				want := tc.wantErr
+				if got := Validate(tc.entry); !errors.Is(got, want) {
+					t.Errorf("Validate() error = %v, want error %v", got, want)
+				}
+
+				// There should be no error when we don't check timestamps.
+				if err := ValidateExceptTimestamps(tc.entry); err != nil {
+					t.Errorf("ValidateExceptTimestamps() error = %v", err)
+				}
+			})
+		}
+
+	})
+
+	t.Run("fail", func(t *testing.T) {
+		for _, tc := range []struct {
+			name    string
+			entry   *osv.Entry
+			wantErr error
+		}{
+			{
+				name: "no ID",
+				entry: testEntry(func(e *osv.Entry) {
+					e.ID = ""
+				}),
+				wantErr: errNoID,
+			},
+			{
+				name: "no schema version",
+				entry: testEntry(func(e *osv.Entry) {
+					e.SchemaVersion = ""
+				}),
+				wantErr: errNoSchemaVersion,
+			},
+			{
+				name: "no details",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Details = ""
+				}),
+				wantErr: errNoDetails,
+			},
+			{
+				name: "no affected",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected = nil
+				}),
+				wantErr: errNoAffected,
+			},
+			{
+				name: "no references",
+				entry: testEntry(func(e *osv.Entry) {
+					e.References = nil
+				}),
+				wantErr: errNoReferences,
+			},
+			{
+				name: "no database specific",
+				entry: testEntry(func(e *osv.Entry) {
+					e.DatabaseSpecific = nil
+				}),
+				wantErr: errNoDatabaseSpecific,
+			},
+			{
+				name: "missing module path",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Module.Path = ""
+				}),
+				wantErr: errNoModule,
+			},
+			{
+				name: "non-Go ecosystem",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Module.Ecosystem = "Goo"
+				}),
+				wantErr: errNotGoEcosystem,
+			},
+			{
+				name: "no version ranges",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Ranges = nil
+				}),
+				wantErr: errNoRanges,
+			},
+			{
+				name: "no ecosystem specific",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].EcosystemSpecific = nil
+				}),
+				wantErr: errNoEcosystemSpecific,
+			},
+			{
+				name: "no packages",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].EcosystemSpecific.Packages = nil
+				}),
+				wantErr: errNoPackages,
+			},
+			{
+				name: "no package path",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].EcosystemSpecific.Packages[0].Path = ""
+				}),
+				wantErr: errNoPackagePath,
+			},
+			{
+				name: "invalid alias",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Aliases = append(e.Aliases, "CVE-GHSA-123")
+				}),
+				wantErr: errInvalidAlias,
+			},
+			{
+				name: "invalid pkgsite URL",
+				entry: testEntry(func(e *osv.Entry) {
+					// missing "/vuln/"
+					e.DatabaseSpecific.URL = "https://pkg.go.dev/GO-1234-5667"
+				}),
+				wantErr: errInvalidPkgsiteURL,
+			},
+			{
+				name: "package path not prefixed by module path",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Module.Path = "example.com/module"
+					e.Affected[0].EcosystemSpecific.Packages[0].Path = "example.com/package"
+				}),
+				wantErr: errInvalidPackagePath,
+			},
+			{
+				name: "more than one version range",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Ranges = append(e.Affected[0].Ranges, osv.Range{})
+				}),
+				wantErr: errTooManyRanges,
+			},
+			{
+				name: "non-SEMVER range",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Ranges[0].Type = "unknown"
+				}),
+				wantErr: errRangeTypeNotSemver,
+			},
+			{
+				name: "no range events",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Ranges[0].Events = nil
+				}),
+				wantErr: errNoRangeEvents,
+			},
+			{
+				name: "out of order range",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Ranges[0].Events = []osv.RangeEvent{
+						{Fixed: "1.1.1"}, {Fixed: "1.1.2"},
+					}
+				}),
+				wantErr: errOutOfOrderRange,
+			},
+			{
+				name: "unsorted range",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Ranges[0].Events = []osv.RangeEvent{
+						{Fixed: "1.1.1"}, {Introduced: "1.1.0"},
+					}
+				}),
+				wantErr: errUnsortedRange,
+			},
+			{
+				name: "no introduced or fixed in range event",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Ranges[0].Events[0] = osv.RangeEvent{}
+				}),
+				wantErr: errNoIntroducedOrFixed,
+			},
+			{
+				name: "both introduced and fixed in range event",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Ranges[0].Events = []osv.RangeEvent{
+						{Introduced: "1.1.0", Fixed: "1.1.1"},
+					}
+				}),
+				wantErr: errBothIntroducedAndFixed,
+			},
+			{
+				name: "non-canonical semver",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Ranges[0].Events = []osv.RangeEvent{
+						{Introduced: "1.1"},
+					}
+				}),
+				wantErr: errInvalidSemver,
+			},
+			{
+				name: "invalid semver",
+				entry: testEntry(func(e *osv.Entry) {
+					e.Affected[0].Ranges[0].Events = []osv.RangeEvent{
+						{Introduced: "1x2x3"},
+					}
+				}),
+				wantErr: errInvalidSemver,
+			},
+		} {
+			t.Run(tc.name, func(t *testing.T) {
+				want := tc.wantErr
+				if got := Validate(tc.entry); !errors.Is(got, want) {
+					t.Errorf("Validate() error = %v, want error %v", got, want)
+				}
+				if got := ValidateExceptTimestamps(tc.entry); !errors.Is(got, want) {
+					t.Errorf("ValidateExceptTimestamps() error = %v, want error %v", got, want)
+				}
+			})
+		}
+	})
+}