// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package search performs nearest neigbors searches over
// vector databases, allowing the caller to specify filters
// for the results.
package search

import (
	"context"
	"fmt"
	"math"
	"net/url"
	"path"
	"regexp"
	"strings"
	"testing"

	"golang.org/x/oscar/internal/docs"
	"golang.org/x/oscar/internal/llm"
	"golang.org/x/oscar/internal/storage"
)

// QueryRequest is a [Query] request.
// It includes the document to search for neighbors of, and
// (optional) result filters.
type QueryRequest struct {
	Options
	llm.EmbedDoc
}

// Options are the results filters that can be passed to the search
// functions as part of a [QueryRequest] or [VectorRequest].
//
// TODO(tatianabradley): Make kinds case insensitive.
type Options struct {
	Threshold float64  // lowest score to keep; default 0. Max is 1.
	Limit     int      // max results (fewer if Threshold is set); 0 means use a fixed default
	AllowKind []string // kinds of documents to keep; empty means keep all
	DenyKind  []string // kinds of documents to remove; empty means remove none
}

// Result is a single result of a search ([Query] or [Vector]).
// It represents a single document in a vector database which is a
// nearest neighbor of the request.
type Result struct {
	Kind  string // kind of document: issue, doc page, etc.
	Title string
	storage.VectorResult
}

// Query performs a nearest neighbors search for the request's document
// over the given vector database, respecting the options set in [QueryRequest].
//
// It embeds the request's document onto the vector space using the given embedder.
//
// It expects that vdb is a vector database containing embeddings of
// the documents in dc, embedded using embed.
func Query(ctx context.Context, vdb storage.VectorDB, dc *docs.Corpus, embed llm.Embedder, req *QueryRequest) ([]Result, error) {
	vecs, err := embed.EmbedDocs(ctx, []llm.EmbedDoc{req.EmbedDoc})
	if err != nil {
		return nil, fmt.Errorf("EmbedDocs: %w", err)
	}
	vec := vecs[0]
	return vector(vdb, dc, vec, &req.Options), nil
}

// VectorRequest is a [Vector] request.
// It includes the vector to search for neighbors of, and
// (optional) result filters.
type VectorRequest struct {
	Options
	llm.Vector
}

// Vector performs a nearest neighbors search for the request's vector
// over the given vector database, respecting the options set in [VectorRequest].
//
// It expects that vdb is a vector database containing embeddings of
// the documents in dc, embedded using the same embedder used to create
// the request's vector.
func Vector(vdb storage.VectorDB, dc *docs.Corpus, req *VectorRequest) []Result {
	return vector(vdb, dc, req.Vector, &req.Options)
}

// Validate returns an error if any of the options is invalid.
func (o *Options) Validate() error {
	if o.Limit < 0 {
		return fmt.Errorf("limit must be >= 0 (got: %d)", o.Limit)
	}
	if o.Threshold < 0 || o.Threshold > 1 {
		return fmt.Errorf("threshold must be >= 0 and <= 1 (got: %.3f)", o.Threshold)
	}
	for _, allow := range o.AllowKind {
		if _, ok := kinds[allow]; !ok {
			return fmt.Errorf("unrecognized allow kind %q (case-sensitive)", allow)
		}
	}
	for _, deny := range o.DenyKind {
		if _, ok := kinds[deny]; !ok {
			return fmt.Errorf("unrecognized deny kind %q (case-sensitive)", deny)
		}
	}
	return nil
}

func vector(vdb storage.VectorDB, dc *docs.Corpus, vec llm.Vector, opts *Options) []Result {
	limit := defaultLimit
	if opts.Limit > 0 {
		limit = opts.Limit
	}
	// Search uses normalized dot product, so higher numbers are better.
	// Max is 1, min is 0.
	threshold := 0.0
	if opts.Threshold > 0 {
		threshold = opts.Threshold
	}
	// By defaut, allow all kinds of documents.
	allowKind := func(string) bool { return true }
	if len(opts.AllowKind) != 0 {
		allowKind = containsFunc(opts.AllowKind)
	}
	// By defaut, deny no kinds of documents.
	denyKind := func(string) bool { return false }
	if len(opts.DenyKind) != 0 {
		denyKind = containsFunc(opts.DenyKind)
	}
	var srs []Result
	for _, r := range vdb.Search(vec, limit) {
		if r.Score < threshold {
			break
		}
		kind := docIDKind(r.ID)
		if !allowKind(kind) || denyKind(kind) {
			continue
		}
		title := ""
		if d, ok := dc.Get(r.ID); ok {
			title = d.Title
		}
		srs = append(srs, Result{
			Kind:         kind,
			Title:        title,
			VectorResult: r,
		})
	}
	return srs
}

func containsFunc(s []string) func(string) bool {
	m := make(map[string]bool)
	for _, k := range s {
		m[k] = true
	}
	return func(s string) bool { return m[s] }
}

// Round rounds r.Score to three decimal places.
func (r *Result) Round() {
	r.Score = math.Round(r.Score*1e3) / 1e3
}

// IDIsURL reports whether the Result's ID is a valid URL.
func (r *Result) IDIsURL() bool {
	_, err := url.Parse(r.ID)
	return err == nil
}

// Maximum number of search results to return by default.
const defaultLimit = 20

// Recognized kinds of documents.
const (
	KindGitHubIssue      = "GitHubIssue"
	KindGitHubDiscussion = "GitHubDiscussion"
	KindGoWiki           = "GoWiki"
	KindGoDocumentation  = "GoDocumentation"
	KindGoReference      = "GoReference"
	KindGoBlog           = "GoBlog"
	KindGoDevPage        = "GoDevPage"
	KindGoGerritChange   = "GoGerritChange"
	// Unknown document.
	KindUnknown = "Unknown"
)

// Set of recognized document kinds.
var kinds = map[string]bool{
	KindGitHubIssue:      true,
	KindGitHubDiscussion: true,
	KindGoWiki:           true,
	KindGoDocumentation:  true,
	KindGoBlog:           true,
	KindGoDevPage:        true,
	KindUnknown:          true,
	KindGoGerritChange:   true,
}

// docIDKind determines the kind of document from its ID.
// It returns the empty string if it cannot do so.
//
// The function assumes that we only care about the Go project.
func docIDKind(id string) string {
	u, err := url.Parse(id)
	if err != nil {
		return KindUnknown
	}
	hp := path.Join(u.Host, u.Path)
	switch {
	case githubRE.MatchString(hp):
		return githubKind(hp, u.Fragment)
	case strings.HasPrefix(hp, "go.dev/wiki/"):
		return KindGoWiki
	case strings.HasPrefix(hp, "go.dev/doc/"):
		return KindGoDocumentation
	case strings.HasPrefix(hp, "go.dev/ref/"):
		return KindGoReference
	case strings.HasPrefix(hp, "go.dev/blog/"):
		return KindGoBlog
	case strings.HasPrefix(hp, "go.dev/"):
		return KindGoDevPage
	case strings.HasPrefix(hp, "go-review.googlesource.com/"):
		return KindGoGerritChange
	}
	return KindUnknown
}

func githubKind(hostPath string, fragment string) string {
	// We don't currently recognize Github URLs with fragments.
	if fragment != "" {
		return KindUnknown
	}

	s := githubRE.FindStringSubmatch(hostPath)
	if len(s) != 3 { // malformed
		return KindUnknown
	}
	project, api := s[1], s[2]

	// Project must be "golang/go", except in tests.
	if project != "golang/go" && !testing.Testing() {
		return KindUnknown
	}

	switch api {
	case "issues":
		return KindGitHubIssue
	case "discussions":
		return KindGitHubDiscussion
	default:
		return KindUnknown
	}
}

// Matches GitHub URLs in any project of the form github.com/owner/repo/api/num.
var githubRE = regexp.MustCompile(`^github\.com/([\w-]+/[\w-]+)/([\w-]+)/\d+$`)
