blob: ffd9f76af2deae6dbb80a79cead18f044a96d2e9 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package embeddocs implements embedding text docs into a vector database.
package embeddocs
import (
"context"
"fmt"
"log/slog"
"golang.org/x/oscar/internal/docs"
"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/storage/timed"
)
// Sync reads new documents from dc, embeds them using embed,
// and then writes the (docid, vector) pairs to vdb.
//
// Sync uses [docs.DocWatcher] with the the name “embeddocs” to
// save its position across multiple calls.
//
// Sync logs status and unexpected problems to lg.
func Sync(ctx context.Context, lg *slog.Logger, vdb storage.VectorDB, embed llm.Embedder, dc *docs.Corpus) error {
lg.Info("embeddocs sync")
const batchSize = 100
var (
batch []llm.EmbedDoc
ids []string
batchLast timed.DBTime
)
w := dc.DocWatcher("embeddocs")
flush := func() error {
vecs, err := embed.EmbedDocs(ctx, batch)
if len(vecs) > len(ids) {
return fmt.Errorf("embeddocs length mismatch: batch=%d vecs=%d ids=%d", len(batch), len(vecs), len(ids))
}
vbatch := vdb.Batch()
for i, v := range vecs {
vbatch.Set(ids[i], v)
}
vbatch.Apply()
if err != nil {
return fmt.Errorf("embeddocs EmbedDocs error: %w", err)
}
if len(vecs) != len(ids) {
return fmt.Errorf("embeddocs length mismatch: batch=%d vecs=%d ids=%d", len(batch), len(vecs), len(ids))
}
vdb.Flush()
w.MarkOld(batchLast)
w.Flush()
batch = nil
ids = nil
return nil
}
for d := range w.Recent() {
lg.Debug("embeddocs sync start", "doc", d.ID)
batch = append(batch, llm.EmbedDoc{Title: d.Title, Text: d.Text})
ids = append(ids, d.ID)
batchLast = d.DBTime
if len(batch) >= batchSize {
if err := flush(); err != nil {
return err
}
}
}
if len(batch) > 0 {
// More to flush, but flush uses w.MarkOld,
// which has to be called during an iteration over w.Recent.
// Start a new iteration just to call flush and then break out.
for _ = range w.Recent() {
if err := flush(); err != nil {
return err
}
break
}
}
return nil
}
// Latest returns the latest known DBTime marked old by the corpus's Watcher.
func Latest(dc *docs.Corpus) timed.DBTime {
return dc.DocWatcher("embeddocs").Latest()
}