blob: af35fb00111cb08529e4c4dcb6247cf106450346 [file] [log] [blame] [edit]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mcp
import (
"bytes"
"context"
"fmt"
"maps"
"slices"
"sort"
"github.com/modelcontextprotocol/go-sdk/mcp"
"golang.org/x/tools/gopls/internal/vulncheck/scan"
)
type vulncheckParams struct {
Dir string `json:"dir,omitempty" jsonschema:"directory to run the vulnerability check within"`
Pattern string `json:"pattern,omitempty" jsonschema:"package pattern to check"`
}
type GroupedVulnFinding struct {
ID string `json:"id"`
Details string `json:"details"`
AffectedPackages []string `json:"affectedPackages"`
}
type VulncheckResultOutput struct {
Findings []GroupedVulnFinding `json:"findings,omitempty"`
Logs string `json:"logs,omitempty"`
}
func (h *handler) vulncheckHandler(ctx context.Context, req *mcp.CallToolRequest, params *vulncheckParams) (*mcp.CallToolResult, *VulncheckResultOutput, error) {
countGoVulncheckMCP.Inc()
snapshot, release, err := h.snapshot()
if err != nil {
return nil, nil, err
}
defer release()
dir := params.Dir
if dir == "" && len(h.session.Views()) > 0 {
dir = h.session.Views()[0].Root().Path()
}
pattern := params.Pattern
if pattern == "" {
pattern = "./..."
}
var logBuf bytes.Buffer
result, err := scan.RunGovulncheck(ctx, pattern, snapshot, dir, &logBuf)
if err != nil {
return nil, nil, fmt.Errorf("running govulncheck failed: %v\nLogs:\n%s", err, logBuf.String())
}
groupedPkgs := make(map[string]map[string]struct{})
for _, finding := range result.Findings {
if osv := result.Entries[finding.OSV]; osv != nil {
if _, ok := groupedPkgs[osv.ID]; !ok {
groupedPkgs[osv.ID] = make(map[string]struct{})
}
pkg := finding.Trace[0].Package
if pkg == "" {
pkg = "Go standard library"
}
groupedPkgs[osv.ID][pkg] = struct{}{}
}
}
var output VulncheckResultOutput
if len(groupedPkgs) > 0 {
output.Findings = make([]GroupedVulnFinding, 0, len(groupedPkgs))
for id, pkgsSet := range groupedPkgs {
pkgs := slices.Sorted(maps.Keys(pkgsSet))
output.Findings = append(output.Findings, GroupedVulnFinding{
ID: id,
Details: result.Entries[id].Details,
AffectedPackages: pkgs,
})
}
sort.Slice(output.Findings, func(i, j int) bool {
return output.Findings[i].ID < output.Findings[j].ID
})
}
if logBuf.Len() > 0 {
output.Logs = logBuf.String()
}
var summary bytes.Buffer
fmt.Fprintf(&summary, "Vulnerability check for pattern %q complete. Found %d vulnerabilities.", pattern, len(output.Findings))
if output.Logs != "" {
fmt.Fprintf(&summary, "\nLogs are available in the structured output.")
}
return nil, &output, nil
}