blob: c577e2af868762741c89a0f786454aff7568f38b [file] [log] [blame]
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package test
import (
"sort"
"golang.org/x/vuln/internal/govulncheck"
"golang.org/x/vuln/internal/osv"
)
// MockHandler implements govulncheck.Handler but (currently)
// does nothing.
//
// For use in tests.
type MockHandler struct {
ConfigMessages []*govulncheck.Config
ProgressMessages []*govulncheck.Progress
OSVMessages []*osv.Entry
FindingMessages []*govulncheck.Finding
}
func NewMockHandler() *MockHandler {
return &MockHandler{}
}
func (h *MockHandler) Config(config *govulncheck.Config) error {
h.ConfigMessages = append(h.ConfigMessages, config)
return nil
}
func (h *MockHandler) Progress(progress *govulncheck.Progress) error {
h.ProgressMessages = append(h.ProgressMessages, progress)
return nil
}
func (h *MockHandler) OSV(entry *osv.Entry) error {
h.OSVMessages = append(h.OSVMessages, entry)
return nil
}
func (h *MockHandler) Finding(finding *govulncheck.Finding) error {
h.FindingMessages = append(h.FindingMessages, finding)
return nil
}
func (h *MockHandler) Sort() {
sort.Slice(h.FindingMessages, func(i, j int) bool {
if h.FindingMessages[i].OSV > h.FindingMessages[j].OSV {
return true
}
if h.FindingMessages[i].OSV < h.FindingMessages[j].OSV {
return false
}
iframe := h.FindingMessages[i].Trace[0]
jframe := h.FindingMessages[j].Trace[0]
if iframe.Module < jframe.Module {
return true
}
if iframe.Module > jframe.Module {
return false
}
if iframe.Package < jframe.Package {
return true
}
if iframe.Package > jframe.Package {
return false
}
if iframe.Receiver < jframe.Receiver {
return true
}
if iframe.Receiver > jframe.Receiver {
return false
}
return iframe.Function < jframe.Function
})
}
func (h *MockHandler) Write(to govulncheck.Handler) error {
h.Sort()
for _, config := range h.ConfigMessages {
if err := to.Config(config); err != nil {
return err
}
}
for _, progress := range h.ProgressMessages {
if err := to.Progress(progress); err != nil {
return err
}
}
seen := map[string]bool{}
for _, finding := range h.FindingMessages {
if !seen[finding.OSV] {
seen[finding.OSV] = true
// first time seeing this osv, so find and write the osv message
for _, osv := range h.OSVMessages {
if osv.ID == finding.OSV {
if err := to.OSV(osv); err != nil {
return err
}
}
}
}
if err := to.Finding(finding); err != nil {
return err
}
}
for _, osv := range h.OSVMessages {
if !seen[osv.ID] {
if err := to.OSV(osv); err != nil {
return err
}
}
}
return nil
}