// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package client provides an interface for accessing vulnerability
// databases, via either HTTP or local filesystem access.
//
// The protocol is described at https://go.dev/security/vulndb/#protocol.
//
// The expected database layout is the same for both HTTP and local
// databases. The database index is located at the root of the
// database, and contains a list of all of the vulnerable modules
// documented in the database and the time the most recent vulnerability
// was added. The index file is called index.json, and has the
// following format:
//
//	map[string]time.Time (DBIndex)
//
// Each vulnerable module is represented by an individual JSON file
// which contains all of the vulnerabilities in that module. The path
// for each module file is simply the import path of the module.
// For example, vulnerabilities in golang.org/x/crypto are contained in the
// golang.org/x/crypto.json file. The per-module JSON files contain a slice of
// https://pkg.go.dev/golang.org/x/vuln/osv#Entry.
//
// A single client.Client can be used to access multiple vulnerability
// databases. When looking up vulnerable modules, each database is
// consulted, and results are merged together.
package client

import (
	"context"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"os"
	"path"
	"path/filepath"
	"sort"
	"strings"
	"time"

	"golang.org/x/mod/module"
	"golang.org/x/vuln/internal"
	"golang.org/x/vuln/internal/derrors"
	"golang.org/x/vuln/internal/web"
	"golang.org/x/vuln/osv"
)

// DBIndex contains a mapping of vulnerable packages to the last time a new
// vulnerability was added to the database.
type DBIndex map[string]time.Time

// Client interface for fetching vulnerabilities based on module path or ID.
type Client interface {
	// GetByModule returns the entries that affect the given module path.
	// It returns (nil, nil) if there are none.
	GetByModule(context.Context, string) ([]*osv.Entry, error)

	// GetByID returns the entry with the given ID, or (nil, nil) if there isn't
	// one.
	GetByID(context.Context, string) (*osv.Entry, error)

	// GetByAlias returns the entries that have the given aliases, or (nil, nil)
	// if there are none.
	GetByAlias(context.Context, string) ([]*osv.Entry, error)

	// ListIDs returns the IDs of all entries in the database.
	ListIDs(context.Context) ([]string, error)

	// LastModifiedTime returns the time that the database was last modified.
	// It can be used by tools that periodically check for vulnerabilities
	// to avoid repeating work.
	LastModifiedTime(context.Context) (time.Time, error)

	unexported() // ensures that adding a method won't break users
}

type source interface {
	Client
	Index(context.Context) (DBIndex, error)
}

type localSource struct {
	dir string
}

func (*localSource) unexported() {}

func (ls *localSource) GetByModule(_ context.Context, modulePath string) (_ []*osv.Entry, err error) {
	defer derrors.Wrap(&err, "localSource.GetByModule(%q)", modulePath)
	epath, err := EscapeModulePath(modulePath)
	if err != nil {
		return nil, err
	}
	content, err := os.ReadFile(filepath.Join(ls.dir, epath+".json"))
	if os.IsNotExist(err) {
		return nil, nil
	} else if err != nil {
		return nil, err
	}
	var e []*osv.Entry
	if err = json.Unmarshal(content, &e); err != nil {
		return nil, err
	}
	return e, nil
}

func (ls *localSource) GetByID(_ context.Context, id string) (_ *osv.Entry, err error) {
	defer derrors.Wrap(&err, "GetByID(%q)", id)
	content, err := os.ReadFile(filepath.Join(ls.dir, internal.IDDirectory, id+".json"))
	if os.IsNotExist(err) {
		return nil, nil
	} else if err != nil {
		return nil, err
	}
	var e osv.Entry
	if err = json.Unmarshal(content, &e); err != nil {
		return nil, err
	}
	return &e, nil
}

func (ls *localSource) GetByAlias(ctx context.Context, alias string) (entries []*osv.Entry, err error) {
	defer derrors.Wrap(&err, "localSource.GetByAlias(%q)", alias)

	aliasToIDs, err := localReadJSON[map[string][]string](ctx, ls, "aliases.json")
	if err != nil {
		return nil, err
	}
	ids := aliasToIDs[alias]
	if len(ids) == 0 {
		return nil, nil
	}
	return getByIDs(ctx, ls, ids)
}

func getByIDs(ctx context.Context, s source, ids []string) ([]*osv.Entry, error) {
	var entries []*osv.Entry
	for _, id := range ids {
		e, err := s.GetByID(ctx, id)
		if err != nil {
			return nil, err
		}
		entries = append(entries, e)
	}
	return entries, nil
}

func (ls *localSource) ListIDs(ctx context.Context) (_ []string, err error) {
	defer derrors.Wrap(&err, "ListIDs()")

	return localReadJSON[[]string](ctx, ls, filepath.Join(internal.IDDirectory, "index.json"))
}

func (ls *localSource) LastModifiedTime(context.Context) (_ time.Time, err error) {
	defer derrors.Wrap(&err, "LastModifiedTime()")

	// Assume that if anything changes, the index does.
	info, err := os.Stat(filepath.Join(ls.dir, "index.json"))
	if err != nil {
		return time.Time{}, err
	}
	return info.ModTime(), nil
}

func (ls *localSource) Index(ctx context.Context) (_ DBIndex, err error) {
	defer derrors.Wrap(&err, "Index()")

	return localReadJSON[DBIndex](ctx, ls, "index.json")
}

func localReadJSON[T any](_ context.Context, ls *localSource, relativePath string) (T, error) {
	var zero T
	content, err := os.ReadFile(filepath.Join(ls.dir, relativePath))
	if err != nil {
		return zero, err
	}
	var t T
	if err := json.Unmarshal(content, &t); err != nil {
		return zero, err
	}
	return t, nil
}

type httpSource struct {
	url    string // the base URI of the source (without trailing "/"). e.g. https://vuln.golang.org
	c      *http.Client
	cache  Cache
	dbName string

	// indexCalls counts the number of times Index()
	// method has been called. httpCalls counts the
	// number of times GetByModule makes an http request
	// to the vuln db for a module path. Used for testing
	// privacy properties of httpSource.
	indexCalls int
	httpCalls  int
}

func (hs *httpSource) Index(ctx context.Context) (_ DBIndex, err error) {
	hs.indexCalls++ // for testing privacy properties
	defer derrors.Wrap(&err, "Index()")

	var cachedIndex DBIndex
	var cachedIndexRetrieved *time.Time

	if hs.cache != nil {
		index, retrieved, err := hs.cache.ReadIndex(hs.dbName)
		if err != nil {
			return nil, err
		}

		cachedIndex = index
		if cachedIndex != nil {
			if time.Since(retrieved) < time.Hour*2 {
				return cachedIndex, nil
			}

			cachedIndexRetrieved = &retrieved
		}
	}

	req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/index.json", hs.url), nil)
	if err != nil {
		return nil, err
	}
	if cachedIndexRetrieved != nil {
		req.Header.Add("If-Modified-Since", cachedIndexRetrieved.Format(http.TimeFormat))
	}
	resp, err := hs.c.Do(req)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()
	if cachedIndexRetrieved != nil && resp.StatusCode == http.StatusNotModified {
		// If status has not been modified, this is equivalent to returning the
		// same index. We update the timestamp so the next cache index read does
		// not require a roundtrip to the server.
		if err = hs.cache.WriteIndex(hs.dbName, cachedIndex, time.Now()); err != nil {
			return nil, err
		}
		return cachedIndex, nil
	}
	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
	}
	b, err := io.ReadAll(resp.Body)
	if err != nil {
		return nil, err
	}
	var index DBIndex
	if err = json.Unmarshal(b, &index); err != nil {
		return nil, err
	}

	if hs.cache != nil {
		if err = hs.cache.WriteIndex(hs.dbName, index, time.Now()); err != nil {
			return nil, err
		}
	}

	return index, nil
}

func (*httpSource) unexported() {}

func (hs *httpSource) GetByModule(ctx context.Context, modulePath string) (_ []*osv.Entry, err error) {
	defer derrors.Wrap(&err, "httpSource.GetByModule(%q)", modulePath)

	index, err := hs.Index(ctx)
	if err != nil {
		return nil, err
	}

	lastModified, present := index[modulePath]
	if !present {
		return nil, nil
	}

	if hs.cache != nil {
		cached, err := hs.cache.ReadEntries(hs.dbName, modulePath)
		if err != nil {
			return nil, err
		}
		if len(cached) > 0 && !latestModifiedTime(cached).Before(lastModified) {
			return cached, nil
		}
	}

	epath, err := EscapeModulePath(modulePath)
	if err != nil {
		return nil, err
	}
	hs.httpCalls++ // for testing privacy properties
	entries, err := httpReadJSON[[]*osv.Entry](ctx, hs, epath+".json")
	if err != nil || entries == nil {
		return nil, err
	}
	// TODO: we may want to check that the returned entries actually match
	// the module we asked about, so that the cache cannot be poisoned
	if hs.cache != nil {
		if err := hs.cache.WriteEntries(hs.dbName, modulePath, entries); err != nil {
			return nil, err
		}
	}
	return entries, nil
}

// Pseudo-module paths used for parts of the Go system.
// These are technically not valid module paths, so we
// mustn't pass them to module.EscapePath.
// Keep in sync with vulndb/internal/database/generate.go.
var specialCaseModulePaths = map[string]bool{
	internal.GoStdModulePath: true,
	internal.GoCmdModulePath: true,
}

// EscapeModulePath should be called by cache implementations or other users of
// this package that want to use module paths as filesystem paths. It is like
// golang.org/x/mod/module, but accounts for special paths used by the
// vulnerability database.
func EscapeModulePath(path string) (string, error) {
	if specialCaseModulePaths[path] {
		return path, nil
	}
	return module.EscapePath(path)
}

func latestModifiedTime(entries []*osv.Entry) time.Time {
	var t time.Time
	for _, e := range entries {
		if e.Modified.After(t) {
			t = e.Modified
		}
	}
	return t
}

func (hs *httpSource) GetByID(ctx context.Context, id string) (_ *osv.Entry, err error) {
	defer derrors.Wrap(&err, "GetByID(%q)", id)

	return httpReadJSON[*osv.Entry](ctx, hs, fmt.Sprintf("%s/%s.json", internal.IDDirectory, id))
}

func (hs *httpSource) GetByAlias(ctx context.Context, alias string) (entries []*osv.Entry, err error) {
	defer derrors.Wrap(&err, "httpSource.GetByAlias(%q)", alias)

	aliasToIDs, err := httpReadJSON[map[string][]string](ctx, hs, "aliases.json")
	if err != nil {
		return nil, err
	}
	ids := aliasToIDs[alias]
	if len(ids) == 0 {
		return nil, nil
	}
	return getByIDs(ctx, hs, ids)
}

func (hs *httpSource) ListIDs(ctx context.Context) (_ []string, err error) {
	defer derrors.Wrap(&err, "ListIDs()")

	return httpReadJSON[[]string](ctx, hs, path.Join(internal.IDDirectory, "index.json"))
}

func httpReadJSON[T any](ctx context.Context, hs *httpSource, relativePath string) (T, error) {
	var zero T
	content, err := hs.readBody(ctx, fmt.Sprintf("%s/%s", hs.url, relativePath))
	if err != nil {
		return zero, err
	}
	if len(content) == 0 {
		return zero, nil
	}
	var t T
	if err := json.Unmarshal(content, &t); err != nil {
		return zero, err
	}
	return t, nil
}

// This is the format for the last-modified header, as described at
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified.
var lastModifiedFormat = "Mon, 2 Jan 2006 15:04:05 GMT"

func (hs *httpSource) LastModifiedTime(ctx context.Context) (_ time.Time, err error) {
	defer derrors.Wrap(&err, "LastModifiedTime()")

	// Assume that if anything changes, the index does.
	url := fmt.Sprintf("%s/index.json", hs.url)
	req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
	if err != nil {
		return time.Time{}, err
	}
	resp, err := hs.c.Do(req)
	if err != nil {
		return time.Time{}, err
	}
	if resp.StatusCode != 200 {
		return time.Time{}, fmt.Errorf("got status code %d", resp.StatusCode)
	}
	h := resp.Header.Get("Last-Modified")
	return time.Parse(lastModifiedFormat, h)
}

func (hs *httpSource) readBody(ctx context.Context, url string) ([]byte, error) {
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
	if err != nil {
		return nil, err
	}
	resp, err := hs.c.Do(req)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()
	if resp.StatusCode == http.StatusNotFound {
		return nil, nil
	}
	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("got HTTP status %s", resp.Status)
	}
	// might want this to be a LimitedReader
	return io.ReadAll(resp.Body)
}

type client struct {
	sources []source
}

type Options struct {
	HTTPClient *http.Client
	HTTPCache  Cache
}

func NewClient(sources []string, opts Options) (_ Client, err error) {
	defer derrors.Wrap(&err, "NewClient(%v, opts)", sources)
	c := &client{}
	for _, source := range sources {
		source = strings.TrimRight(source, "/") // TODO: why?
		uri, err := url.Parse(source)
		if err != nil {
			return nil, err
		}
		switch uri.Scheme {
		case "http", "https":
			hs := &httpSource{url: uri.String()}
			hs.dbName = uri.Hostname()
			if opts.HTTPCache != nil {
				hs.cache = opts.HTTPCache
			}
			if opts.HTTPClient != nil {
				hs.c = opts.HTTPClient
			} else {
				hs.c = new(http.Client)
			}
			c.sources = append(c.sources, hs)
		case "file":
			dir, err := web.URLToFilePath(uri)
			if err != nil {
				return nil, err
			}
			fi, err := os.Stat(dir)
			if err != nil {
				return nil, err
			}
			if !fi.IsDir() {
				return nil, fmt.Errorf("%s is not a directory", dir)
			}
			c.sources = append(c.sources, &localSource{dir: dir})
		default:
			return nil, fmt.Errorf("source %q has unsupported scheme", uri)
		}
	}
	return c, nil
}

func (*client) unexported() {}

func (c *client) GetByModule(ctx context.Context, module string) (_ []*osv.Entry, err error) {
	defer derrors.Wrap(&err, "GetByModule(%q)", module)
	return c.unionEntries(ctx, func(c Client) ([]*osv.Entry, error) {
		return c.GetByModule(ctx, module)
	})
}

func (c *client) GetByAlias(ctx context.Context, alias string) (entries []*osv.Entry, err error) {
	defer derrors.Wrap(&err, "GetByAlias(%q)", alias)
	return c.unionEntries(ctx, func(c Client) ([]*osv.Entry, error) {
		return c.GetByAlias(ctx, alias)
	})
}

// unionEntries returns the union of all entries obtained by calling get on the client's sources.
func (c *client) unionEntries(ctx context.Context, get func(Client) ([]*osv.Entry, error)) ([]*osv.Entry, error) {
	var entries []*osv.Entry
	// probably should be parallelized
	seen := map[string]bool{}
	for _, s := range c.sources {
		es, err := get(s)
		if err != nil {
			return nil, err // be failure tolerant?
		}
		for _, e := range es {
			if !seen[e.ID] {
				entries = append(entries, e)
				seen[e.ID] = true
			}
		}
	}
	return entries, nil
}

func (c *client) GetByID(ctx context.Context, id string) (_ *osv.Entry, err error) {
	defer derrors.Wrap(&err, "GetByID(%q)", id)
	for _, s := range c.sources {
		entry, err := s.GetByID(ctx, id)
		if err != nil {
			return nil, err // be failure tolerant?
		}
		if entry != nil {
			return entry, nil
		}
	}
	return nil, nil
}

// ListIDs returns the union of the IDs from all sources,
// sorted lexically.
func (c *client) ListIDs(ctx context.Context) (_ []string, err error) {
	defer derrors.Wrap(&err, "ListIDs()")
	idSet := map[string]bool{}
	for _, s := range c.sources {
		ids, err := s.ListIDs(ctx)
		if err != nil {
			return nil, err
		}
		for _, id := range ids {
			idSet[id] = true
		}
	}
	var ids []string
	for id := range idSet {
		ids = append(ids, id)
	}
	sort.Strings(ids)
	return ids, nil
}

// LastModifiedTime returns the latest modified time of all the sources.
func (c *client) LastModifiedTime(ctx context.Context) (_ time.Time, err error) {
	defer derrors.Wrap(&err, "LastModifiedTime()")
	var lmt time.Time
	for _, s := range c.sources {
		t, err := s.LastModifiedTime(ctx)
		if err != nil {
			return time.Time{}, err
		}
		if t.After(lmt) {
			lmt = t
		}
	}
	return lmt, nil
}
