// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package firestore

import (
	"context"
	"encoding/hex"
	"iter"
	"log/slog"
	"path"

	"cloud.google.com/go/firestore"
	"golang.org/x/oscar/internal/gcp/grpcerrors"
	"golang.org/x/oscar/internal/llm"
	"golang.org/x/oscar/internal/storage"
	"google.golang.org/api/iterator"
	"google.golang.org/api/option"
)

// A VectorDB is a [storage.VectorDB] using Firestore.
type VectorDB struct {
	fs        *fstore
	namespace string
	coll      *firestore.CollectionRef
}

// NewVectorDB creates a [VectorDB] with the given logger, GCP project ID,
// Firestore database, namespace and client options.
// The projectID must not be empty.
// If the database is empty, the default database will be used.
// The namespace must be a [valid Firestore collection ID].
// Namespaces allow multiple vector DBs to be stored in the same Firestore DB.
//
// Vectors in a VectorDB with namespace NS are stored in the Firestore collection
// "vectorDBs/NS/vectors".
//
// [valid Firestore collection ID]: https://firebase.google.com/docs/firestore/quotas#collections_documents_and_fields
func NewVectorDB(ctx context.Context, lg *slog.Logger, projectID, database, namespace string, opts ...option.ClientOption) (*VectorDB, error) {
	fs, err := newFstore(ctx, lg, projectID, database, opts)
	if err != nil {
		return nil, err
	}
	coll := fs.client.Collection(path.Join("vectorDBs", namespace, "vectors"))
	return &VectorDB{fs, namespace, coll}, nil
}

// Close closes the [VectorDB], releasing its resources.
func (db *VectorDB) Close() {
	db.fs.Close()
}

// A vectorDoc holds an embedding as the Firestore type for a
// vector of float32s.
// All Firestore documents must be sets of key-value pairs (structs or maps in Go),
// so we cannot represent an embedding as a lone firestore.Vector32.
// Firestore's vector DB search requires that the type of the vector is either
// firestore.Vector32 or firestore.Vector64.
type vectorDoc struct {
	Embedding firestore.Vector32
}

// Set implements [storage.VectorDB.Set].
func (db *VectorDB) Set(id string, vec llm.Vector) {
	if id == "" {
		db.fs.Panic("firestore VectorDB Set: empty ID")
	}
	doc := vectorDoc{firestore.Vector32(vec)}
	if _, err := db.docref(id).Set(context.TODO(), doc); err != nil {
		db.fs.Panic("firestore VectorDB Set", "id", id, "err", err)
	}
}

// Get implements [storage.VectorDB.Get].
func (db *VectorDB) Get(id string) (llm.Vector, bool) {
	docsnap, err := db.docref(id).Get(context.TODO())
	if err != nil {
		if grpcerrors.IsNotFound(err) {
			return nil, false
		}
		db.fs.Panic("firestore VectorDB Get", "id", id, "err", err)
	}
	var doc vectorDoc
	if err := docsnap.DataTo(&doc); err != nil {
		db.fs.Panic("firestore VectorDB Get", "id", id, "err", err)
	}
	return llm.Vector(doc.Embedding), true
}

// Delete implements [storage.VectorDB.Delete].
func (db *VectorDB) Delete(id string) {
	db.fs.delete(nil, db.coll, encodeVectorID(id))
}

// All implements [storage.VectorDB.All].
func (db *VectorDB) All() iter.Seq2[string, func() llm.Vector] {
	return func(yield func(string, func() llm.Vector) bool) {
		next := func(start string) *firestore.DocumentIterator {
			// OrderBy is required for StartAt to work.
			query := db.coll.OrderBy(firestore.DocumentID, firestore.Asc).Limit(db.fs.docQueryLimit)
			if start != "" {
				query = query.StartAt(start)
			}
			return query.Documents(context.Background())
		}
		var start, last string
		for {
			n := 0
			it := next(start)
			for { // should iterate up to db.fs.docQueryLimit number of times
				ds, err := it.Next()
				if err == iterator.Done {
					break
				}
				if err != nil {
					// Unreachable except for bad DB or potential 60 seconds timeout
					// (see longer comment in [fstore.scan]).
					// The timeout should not happen now with Query.Limit(db.fs.docQueryLimit).
					db.fs.Panic("firestore VectorDB All", "err", err)
				}
				n++
				last = ds.Ref.ID

				id := db.decodeVectorID(ds.Ref.ID)
				var doc vectorDoc
				if err := ds.DataTo(&doc); err != nil {
					db.fs.Panic("firestore VectorDB All", "id", id, "err", err)
				}
				if !yield(id, func() llm.Vector { return llm.Vector(doc.Embedding) }) {
					return
				}
			}

			start = keyAfter(last)
			if n < db.fs.docQueryLimit { // no more things to fetch
				return
			}
		}
	}
}

// Search implements [storage.VectorDB.Search].
func (db *VectorDB) Search(vec llm.Vector, n int) []storage.VectorResult {
	q := db.coll.FindNearest("Embedding", firestore.Vector32(vec), n, firestore.DistanceMeasureDotProduct, nil)
	iter := q.Documents(context.TODO())
	defer iter.Stop()
	var res []storage.VectorResult
	for {
		docsnap, err := iter.Next()
		if err == iterator.Done {
			break
		}
		if err != nil {
			db.fs.Panic("firestore VectorDB Search", "err", err)
		}
		var doc vectorDoc
		if err := docsnap.DataTo(&doc); err != nil {
			db.fs.Panic("firestore VectorDB Search", "err", err)
		}
		id := db.decodeVectorID(docsnap.Ref.ID)
		res = append(res, storage.VectorResult{
			ID:    id,
			Score: vec.Dot(llm.Vector(doc.Embedding)),
		})
	}
	return res
}

// docref returns a DocumentReference for the document with the given ID.
func (db *VectorDB) docref(id string) *firestore.DocumentRef {
	// To avoid running into the restrictions on Firestore document IDs, escape the id.
	return db.coll.Doc(encodeVectorID(id))
}

// Flush implements [storage.VectorDB.Flush]. It is a no-op.
func (db *VectorDB) Flush() {
	// Firestore operations do not require flushing.
}

// A vBatch is a [storage.VectorBatch] for a [VectorDB].
type vBatch struct {
	b *batch // underlying DB operations
}

// Batch implements [storage.VectorDB.Batch].
func (db *VectorDB) Batch() storage.VectorBatch {
	return &vBatch{db.fs.newBatch(db.coll)}
}

// Approximate size of a float64 encoded as a Firestore value.
// (Firestore encodes a float32 as a float64.)
const perFloatSize = 11

// Set implements [storage.VectorBatch.Set].
func (b *vBatch) Set(id string, vec llm.Vector) {
	if id == "" {
		b.b.f.Panic("firestore VectorDB Set: empty ID")
	}
	b.b.set(encodeVectorID(id), vectorDoc{firestore.Vector32(vec)}, len(vec)*perFloatSize)
}

// Delete implements [storage.VectorBatch.Delete].
func (b *vBatch) Delete(id string) {
	b.b.delete(encodeVectorID(id))
}

// MaybeApply implements [storage.VectorBatch.MaybeApply].
func (b *vBatch) MaybeApply() bool { return b.b.maybeApply() }

// Apply implements [storage.VectorBatch.Apply].
func (b *vBatch) Apply() { b.b.apply() }

func encodeVectorID(id string) string {
	return hex.EncodeToString([]byte(id))
}

func (db *VectorDB) decodeVectorID(id string) string {
	bid, err := hex.DecodeString(id)
	if err != nil {
		db.fs.Panic("firestore VectorDB ID decode", "id", id, "err", err)
	}
	return string(bid)
}
