blob: 273fedfe1a2cee6f92b4f28b2910755b086bfc4d [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package llmapp
import (
"context"
"crypto/sha256"
"encoding/json"
"golang.org/x/oscar/internal/storage"
"rsc.io/ordered"
)
// generateText returns a (possibly cached) text response for the prompts.
func (c *Client) generateText(ctx context.Context, prompts []string) (_ string, cached bool, _ error) {
model := c.g.Model()
h := hash(prompts)
k := ordered.Encode(generateTextKind, model, h)
c.db.Lock(string(k))
defer c.db.Unlock(string(k))
r := c.load(k)
if r != nil {
// cache hit
return r.Response, true, nil
}
// cache miss
result, err := c.g.GenerateText(ctx, prompts...)
if err != nil {
return "", false, err
}
c.db.Set(k, storage.JSON(response{
Model: model,
PromptHash: h,
Response: result,
}))
return result, false, nil
}
// Cache key context.
const generateTextKind = "llmapp.GenerateText"
// load loads a cached response from the database.
// load returns nil if the response cannot be unmarshaled
// or there is no entry for the key.
func (c *Client) load(key []byte) *response {
if cached, ok := c.db.Get(key); ok {
var r response
// Unmarshal will only fail if a backwards-incompatible change
// is made to the [response] struct.
if err := json.Unmarshal(cached, &r); err != nil {
c.slog.Error("cannot unmarshal cached response", "err", err)
return nil
}
return &r
}
return nil
}
// response is a cached response to a
// [llm.TextGenerator.GenerateText] query.
type response struct {
// The generative model used to generate the response.
Model string
// The SHA-256 hash of the prompts used to generate the response.
PromptHash []byte
// The generated response.
Response string
}
// hash returns the SHA-256 hash of the strings.
func hash(strs []string) []byte {
h := sha256.New()
for _, s := range strs {
h.Write([]byte(s))
}
return h.Sum(nil)
}