blob: 2521db1723396df2ad62d67fb58963d78d3c8839 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package llm
import (
"context"
"math"
"strings"
)
const quoteLen = 123
// QuoteEmbedder returns an implementation
// of Embedder that can be useful for testing but
// is completely pointless for real use.
// It encodes up to the first 122 bytes of each document
// directly into the first 122 elements of a 123-element unit vector.
func QuoteEmbedder() Embedder {
return quoter{}
}
// quote quotes text into a vector.
// The text ends at the first negative entry in the vector.
// The final entry of the vector is hard-coded to -1
// before normalization, so that the final entry of a
// normalized vector lets us know scaling to reverse
// to obtain the original bytes.
func quote(text string) Vector {
v := make(Vector, quoteLen)
var d float64
for i := range len(text) {
if i >= len(v)-1 {
break
}
v[i] = float32(byte(text[i])) / 256
d += float64(v[i]) * float64(v[i])
}
if len(text)+1 < len(v) {
v[len(text)] = -1
d += 1
}
v[len(v)-1] = -1
d += 1
d = 1 / math.Sqrt(d)
for i := range v {
v[i] *= float32(d)
}
return v
}
// quoter is a quoting Embedder, returned by QuoteEmbedder
type quoter struct{}
// EmbedDocs implements Embedder by quoting.
func (quoter) EmbedDocs(ctx context.Context, docs []EmbedDoc) ([]Vector, error) {
var vecs []Vector
for _, d := range docs {
vecs = append(vecs, quote(d.Text))
}
return vecs, nil
}
// UnquoteVector recovers the original text prefix
// passed to a [QuoteEmbedder]'s EmbedDocs method.
// Like QuoteEmbedder, UnquoteVector is only useful in tests.
func UnquoteVector(v Vector) string {
if len(v) != quoteLen {
panic("UnquoteVector of non-quotation vector")
}
d := -1 / v[len(v)-1]
var b []byte
for _, f := range v {
if f < 0 {
break
}
b = append(b, byte(256*f*d+0.5))
}
return string(b)
}
// EchoTextGenerator returns an implementation
// of [TextGenerator] that responds with the prompt itself.
//
// For testing.
func EchoTextGenerator() TextGenerator {
return echo{}
}
type echo struct{}
// Implements [TextGenerator.GenerativeModel].
func (echo) Model() string { return "echo" }
// GenerateText echoes the prompts (for testing).
// Implements [TextGenerator].
func (echo) GenerateText(ctx context.Context, promptParts ...string) (string, error) {
return EchoResponse(promptParts...), nil
}
// EchoResponse returns the concatenation of the prompt parts.
// For testing.
func EchoResponse(promptParts ...string) string {
return strings.Join(promptParts, "")
}