// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package scan

import (
	"context"
	"fmt"
	"go/ast"
	"go/token"
	"path/filepath"
	"strconv"
	"strings"

	"golang.org/x/tools/go/packages"
	"golang.org/x/vuln/internal"
	"golang.org/x/vuln/internal/client"
	"golang.org/x/vuln/internal/govulncheck"
	"golang.org/x/vuln/internal/osv"
	"golang.org/x/vuln/internal/vulncheck"
)

// runSource reports vulnerabilities that affect the analyzed packages.
//
// Vulnerabilities can be called (affecting the package, because a vulnerable
// symbol is actually exercised) or just imported by the package
// (likely having a non-affecting outcome).
func runSource(ctx context.Context, handler govulncheck.Handler, cfg *config, client *client.Client, dir string) error {
	var pkgs []*packages.Package
	graph := vulncheck.NewPackageGraph(cfg.GoVersion)
	pkgConfig := &packages.Config{Dir: dir, Tests: cfg.test}
	pkgs, err := graph.LoadPackages(pkgConfig, cfg.tags, cfg.patterns)
	if err != nil {
		// Try to provide a meaningful and actionable error message.
		if !fileExists(filepath.Join(dir, "go.mod")) {
			return fmt.Errorf("govulncheck: %v", errNoGoMod)
		}
		if isGoVersionMismatchError(err) {
			return fmt.Errorf("govulncheck: %v\n\n%v", errGoVersionMismatch, err)
		}
		return fmt.Errorf("govulncheck: loading packages: %w", err)
	}
	if err := handler.Progress(sourceProgressMessage(pkgs)); err != nil {
		return err
	}
	vr, err := vulncheck.Source(ctx, pkgs, &cfg.Config, client, graph)
	if err != nil {
		return err
	}
	callStacks := vulncheck.CallStacks(vr)
	filterCallStacks(callStacks)
	return emitResult(handler, vr, callStacks)
}

func filterCallStacks(callstacks map[*vulncheck.Vuln][]vulncheck.CallStack) {
	type key struct {
		id  string
		pkg string
		mod string
	}
	// Collect all called symbols for a package.
	// Needed for creating unique call stacks.
	vulnsPerPkg := make(map[key][]*vulncheck.Vuln)
	for vv := range callstacks {
		if vv.CallSink != nil {
			k := key{id: vv.OSV.ID, pkg: vv.ImportSink.PkgPath, mod: vv.ImportSink.Module.Path}
			vulnsPerPkg[k] = append(vulnsPerPkg[k], vv)
		}
	}
	for vv, stacks := range callstacks {
		var filtered []vulncheck.CallStack
		if vv.CallSink != nil {
			k := key{id: vv.OSV.ID, pkg: vv.ImportSink.PkgPath, mod: vv.ImportSink.Module.Path}
			vcs := uniqueCallStack(vv, stacks, vulnsPerPkg[k])
			if vcs != nil {
				filtered = []vulncheck.CallStack{vcs}
			}
		}
		callstacks[vv] = filtered
	}
}

func emitResult(handler govulncheck.Handler, vr *vulncheck.Result, callstacks map[*vulncheck.Vuln][]vulncheck.CallStack) error {
	osvs := map[string]*osv.Entry{}
	var findings []*govulncheck.Finding
	// first deal with all the affected vulnerabilities
	emitted := map[string]bool{}
	for _, vv := range vr.Vulns {
		osvs[vv.OSV.ID] = vv.OSV
		fixed := fixedVersion(vv.ImportSink.Module.Path, vv.OSV.Affected)
		stacks := callstacks[vv]
		for _, stack := range stacks {
			emitted[vv.OSV.ID] = true
			findings = append(findings, &govulncheck.Finding{
				OSV:          vv.OSV.ID,
				FixedVersion: fixed,
				Trace:        tracefromEntries(stack),
			})
		}
	}
	for _, vv := range vr.Vulns {
		if emitted[vv.OSV.ID] {
			continue
		}
		stacks := callstacks[vv]
		if len(stacks) != 0 {
			continue
		}
		emitted[vv.OSV.ID] = true
		// no callstacks, add an unafected finding
		findings = append(findings, &govulncheck.Finding{
			OSV:          vv.OSV.ID,
			FixedVersion: fixedVersion(vv.ImportSink.Module.Path, vv.OSV.Affected),
			Trace: []*govulncheck.Frame{{
				Module:  vv.ImportSink.Module.Path,
				Version: vv.ImportSink.Module.Version,
				Package: vv.ImportSink.PkgPath,
			}},
		})
	}
	// For each vulnerability, queue it to be written to the output.
	seen := map[string]bool{}
	sortResult(findings)
	for _, f := range findings {
		if !seen[f.OSV] {
			seen[f.OSV] = true
			if err := handler.OSV(osvs[f.OSV]); err != nil {
				return err
			}
		}
		if err := handler.Finding(f); err != nil {
			return err
		}
	}
	return nil
}

// tracefromEntries creates a sequence of
// frames from vcs. Position of a Frame is the
// call position of the corresponding stack entry.
func tracefromEntries(vcs vulncheck.CallStack) []*govulncheck.Frame {
	var frames []*govulncheck.Frame
	for i := len(vcs) - 1; i >= 0; i-- {
		e := vcs[i]
		fr := &govulncheck.Frame{
			Function: e.Function.Name,
			Receiver: e.Function.Receiver(),
		}
		if e.Function.Package != nil {
			fr.Module = e.Function.Package.Module.Path
			fr.Version = e.Function.Package.Module.Version
			fr.Package = e.Function.Package.PkgPath
		}
		if e.Call == nil || e.Call.Pos == nil {
			fr.Position = nil
		} else {
			fr.Position = &govulncheck.Position{
				Filename: e.Call.Pos.Filename,
				Offset:   e.Call.Pos.Offset,
				Line:     e.Call.Pos.Line,
				Column:   e.Call.Pos.Column,
			}
		}
		frames = append(frames, fr)
	}
	return frames
}

// sourceProgressMessage returns a string of the form
//
//	"Scanning your code and P packages across M dependent modules for known vulnerabilities..."
//
// P is the number of strictly dependent packages of
// topPkgs and Y is the number of their modules.
func sourceProgressMessage(topPkgs []*packages.Package) *govulncheck.Progress {
	pkgs, mods := depPkgsAndMods(topPkgs)

	pkgsPhrase := fmt.Sprintf("%d package", pkgs)
	if pkgs != 1 {
		pkgsPhrase += "s"
	}

	modsPhrase := fmt.Sprintf("%d dependent module", mods)
	if mods != 1 {
		modsPhrase += "s"
	}

	msg := fmt.Sprintf("Scanning your code and %s across %s for known vulnerabilities...", pkgsPhrase, modsPhrase)
	return &govulncheck.Progress{Message: msg}
}

// depPkgsAndMods returns the number of packages that
// topPkgs depend on and the number of their modules.
func depPkgsAndMods(topPkgs []*packages.Package) (int, int) {
	tops := make(map[string]bool)
	depPkgs := make(map[string]bool)
	depMods := make(map[string]bool)

	for _, t := range topPkgs {
		tops[t.PkgPath] = true
	}

	var visit func(*packages.Package, bool)
	visit = func(p *packages.Package, top bool) {
		path := p.PkgPath
		if depPkgs[path] {
			return
		}
		if tops[path] && !top {
			// A top package that is a dependency
			// will not be in depPkgs, so we skip
			// reiterating on it here.
			return
		}

		// We don't count a top-level package as
		// a dependency even when they are used
		// as a dependent package.
		if !tops[path] {
			depPkgs[path] = true
			if p.Module != nil &&
				p.Module.Path != internal.GoStdModulePath && // no module for stdlib
				p.Module.Path != internal.UnknownModulePath { // no module for unknown
				depMods[p.Module.Path] = true
			}
		}

		for _, d := range p.Imports {
			visit(d, false)
		}
	}

	for _, t := range topPkgs {
		visit(t, true)
	}

	return len(depPkgs), len(depMods)
}

// summarizeTrace returns a short description of the call stack.
// It prefers to show you the edge from the top module to other code, along with
// the vulnerable symbol.
// Where the vulnerable symbol directly called by the users code, it will only
// show those two points.
// If the vulnerable symbol is in the users code, it will show the entry point
// and the vulnerable symbol.
func summarizeTrace(finding *govulncheck.Finding) string {
	if len(finding.Trace) < 2 {
		return ""
	}
	iTop := len(finding.Trace) - 1
	topModule := finding.Trace[iTop].Module
	// search for the exit point of the top module
	for i, frame := range finding.Trace {
		if frame.Module == topModule {
			iTop = i
			break
		}
	}

	if iTop == 0 {
		// all in one module, reset to the end
		iTop = len(finding.Trace) - 1
	}

	buf := &strings.Builder{}
	topPos := posToString(finding.Trace[iTop].Position)
	if topPos != "" {
		buf.WriteString(topPos)
		buf.WriteString(": ")
	}

	addSymbolName(buf, finding.Trace[iTop])
	buf.WriteString(" calls ")
	addSymbolName(buf, finding.Trace[iTop-1])
	if iTop > 1 {
		buf.WriteString(", which")
		if iTop > 2 {
			buf.WriteString(" eventually")
		}
		buf.WriteString(" calls ")
		addSymbolName(buf, finding.Trace[0])
	}
	return buf.String()
}

func addSymbolName(buf *strings.Builder, frame *govulncheck.Frame) {
	if frame.Package != "" {
		buf.WriteString(frame.Package)
		buf.WriteString(".")
	}
	if frame.Receiver != "" {
		if frame.Receiver[0] == '*' {
			buf.WriteString(frame.Receiver[1:])
		} else {
			buf.WriteString(frame.Receiver)
		}
		buf.WriteString(".")
	}
	funcname := strings.Split(frame.Function, "$")[0]
	buf.WriteString(funcname)
}

// updateInitPositions populates non-existing positions of init functions
// and their respective calls in callStacks (see #51575).
func updateInitPositions(callStacks map[*vulncheck.Vuln][]vulncheck.CallStack) {
	for _, css := range callStacks {
		for _, cs := range css {
			for i := range cs {
				updateInitPosition(&cs[i])
				if i != len(cs)-1 {
					updateInitCallPosition(&cs[i], cs[i+1])
				}
			}
		}
	}
}

// updateInitCallPosition updates the position of a call to init in a stack frame, if
// one already does not exist:
//
//	P1.init -> P2.init: position of call to P2.init is the position of "import P2"
//	statement in P1
//
//	P.init -> P.init#d: P.init is an implicit init. We say it calls the explicit
//	P.init#d at the place of "package P" statement.
func updateInitCallPosition(curr *vulncheck.StackEntry, next vulncheck.StackEntry) {
	call := curr.Call
	if !isInit(next.Function) || (call.Pos != nil && call.Pos.IsValid()) {
		// Skip non-init functions and inits whose call site position is available.
		return
	}

	var pos token.Position
	if curr.Function.Name == "init" && curr.Function.Package == next.Function.Package {
		// We have implicit P.init calling P.init#d. Set the call position to
		// be at "package P" statement position.
		pos = packageStatementPos(curr.Function.Package)
	} else {
		// Choose the beginning of the import statement as the position.
		pos = importStatementPos(curr.Function.Package, next.Function.Package.PkgPath)
	}

	call.Pos = &pos
}

func importStatementPos(pkg *packages.Package, importPath string) token.Position {
	var importSpec *ast.ImportSpec
spec:
	for _, f := range pkg.Syntax {
		for _, impSpec := range f.Imports {
			// Import spec paths have quotation marks.
			impSpecPath, err := strconv.Unquote(impSpec.Path.Value)
			if err != nil {
				panic(fmt.Sprintf("import specification: package path has no quotation marks: %v", err))
			}
			if impSpecPath == importPath {
				importSpec = impSpec
				break spec
			}
		}
	}

	if importSpec == nil {
		// for sanity, in case of a wild call graph imprecision
		return token.Position{}
	}

	// Choose the beginning of the import statement as the position.
	return pkg.Fset.Position(importSpec.Pos())
}

func packageStatementPos(pkg *packages.Package) token.Position {
	if len(pkg.Syntax) == 0 {
		return token.Position{}
	}
	// Choose beginning of the package statement as the position. Pick
	// the first file since it is as good as any.
	return pkg.Fset.Position(pkg.Syntax[0].Package)
}

// updateInitPosition updates the position of P.init function in a stack frame if one
// is not available. The new position is the position of the "package P" statement.
func updateInitPosition(se *vulncheck.StackEntry) {
	fun := se.Function
	if !isInit(fun) || (fun.Pos != nil && fun.Pos.IsValid()) {
		// Skip non-init functions and inits whose position is available.
		return
	}

	pos := packageStatementPos(fun.Package)
	fun.Pos = &pos
}

func isInit(f *vulncheck.FuncNode) bool {
	// A source init function, or anonymous functions used in inits, will
	// be named "init#x" by vulncheck (more precisely, ssa), where x is a
	// positive integer. Implicit inits are named simply "init".
	return f.Name == "init" || strings.HasPrefix(f.Name, "init#")
}

// uniqueCallStack returns the first unique call stack among css, if any.
// Unique means that the call stack does not go through symbols of vg.
func uniqueCallStack(v *vulncheck.Vuln, css []vulncheck.CallStack, vg []*vulncheck.Vuln) vulncheck.CallStack {
	vulnFuncs := make(map[*vulncheck.FuncNode]bool)
	for _, v := range vg {
		vulnFuncs[v.CallSink] = true
	}

callstack:
	for _, cs := range css {
		for _, e := range cs {
			if e.Function != v.CallSink && vulnFuncs[e.Function] {
				continue callstack
			}
		}
		return cs
	}
	return nil
}
