| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package firestore |
| |
| import ( |
| "context" |
| "encoding/hex" |
| "fmt" |
| "iter" |
| "math/rand/v2" |
| "net/url" |
| "slices" |
| "time" |
| |
| "cloud.google.com/go/firestore" |
| "golang.org/x/oscar/internal/storage" |
| "google.golang.org/api/option" |
| ) |
| |
| // DB implements [storage.DB]. |
| type DB struct { |
| *fstore |
| uid int64 // unique ID, to identify lock owners |
| } |
| |
| // NewDB constructs a [DB]. |
| func NewDB(ctx context.Context, dbopts *DBOptions, copts ...option.ClientOption) (*DB, error) { |
| fs, err := newFstore(ctx, dbopts, copts) |
| if err != nil { |
| return nil, err |
| } |
| return &DB{fs, rand.Int64()}, nil |
| } |
| |
| const ( |
| lockCollection = "locks" |
| valueCollection = "values" |
| ) |
| |
| // vars for testing |
| var ( |
| // how long to wait before stealing a lock |
| lockTimeout = 2 * time.Minute |
| |
| timeSince = time.Since |
| ) |
| |
| // Lock implements [storage.DB.Lock]. |
| func (db *DB) Lock(name string) { |
| // Wait for the lock in a separate function to avoid defers inside a loop, consuming |
| // memory on each iteration. |
| for !db.waitForLock(name) { |
| } |
| } |
| |
| // waitForLock waits for the lock to become available. |
| // It returns true if it acquires the lock. |
| // It returns false if the snapshot iterator timed out without the lock |
| // being acquired. |
| func (db *DB) waitForLock(name string) bool { |
| ctx, cancel := context.WithTimeout(context.Background(), lockTimeout) |
| defer cancel() |
| // A snapshot iterator iterates over changing states of the document. |
| // It yields its first value immediately, and subsequent values only when |
| // the document changes state. |
| dr := db.client.Collection(lockCollection).Doc(url.PathEscape(name)) |
| iter := dr.Snapshots(ctx) |
| defer iter.Stop() |
| for { |
| ds, err := iter.Next() |
| if err == nil { |
| if !ds.Exists() && db.tryLock(name) { |
| // The lock doesn't exist and we managed to get it. |
| return true |
| } |
| // Wait for a change in the lock document. |
| continue |
| } |
| if isTimeout(err) { |
| return db.tryLock(name) |
| } |
| // unreachable except for bad DB |
| db.Panic("firestore waiting for lock", "name", name, "err", err) |
| } |
| } |
| |
| // tryLock tries to acquire the named lock in a transaction. |
| func (db *DB) tryLock(name string) (res bool) { |
| db.runTransaction(func(ctx context.Context, tx *firestore.Transaction) { |
| uid, createTime := db.getLock(tx, name) |
| if createTime.IsZero() || timeSince(createTime) > lockTimeout { |
| // Lock does not exist or timed out. |
| if !createTime.IsZero() { |
| db.slog.Warn("taking lock", "name", name, "old", uid, "new", db.uid) |
| } |
| db.setLock(tx, name) |
| res = true |
| } else { |
| res = false |
| } |
| }) |
| return res |
| } |
| |
| // Unlock releases the lock. It panics if the lock isn't locked by this DB. |
| func (db *DB) Unlock(name string) { |
| db.runTransaction(func(ctx context.Context, tx *firestore.Transaction) { |
| uid, createTime := db.getLock(tx, name) |
| if createTime.IsZero() { |
| db.Panic("unlock of never locked key", "key", name) |
| } |
| if uid != db.uid { |
| db.Panic("unlocker is not owner", "key", name) |
| } |
| db.deleteLock(tx, name) |
| }) |
| } |
| |
| // A lock describes a lock in firestore. |
| // The value consists of the UID of the DB that acquired the lock. |
| type lock struct { |
| UID int64 |
| } |
| |
| // setLock sets the value of the named lock in the DB, along with its creation time. |
| func (db *DB) setLock(tx *firestore.Transaction, name string) { |
| db.set(tx, lockCollection, url.PathEscape(name), lock{db.uid}) |
| } |
| |
| // getLock returns the value of the named lock in the DB. |
| func (db *DB) getLock(tx *firestore.Transaction, name string) (int64, time.Time) { |
| ds := db.get(tx, lockCollection, url.PathEscape(name)) |
| if ds == nil { |
| return 0, time.Time{} |
| } |
| uid := dataTo[lock](db.fstore, ds).UID |
| return uid, ds.CreateTime |
| } |
| |
| // deleteLock deletes the named lock in the DB. |
| func (db *DB) deleteLock(tx *firestore.Transaction, name string) { |
| db.delete(tx, lockCollection, url.PathEscape(name)) |
| } |
| |
| // Set implements [storage.DB.Set]. |
| func (db *DB) Set(key, val []byte) { |
| db.set(nil, valueCollection, encodeKey(key), encodeValue(val)) |
| } |
| |
| // Get implements [storage.DB.Get]. |
| func (db *DB) Get(key []byte) ([]byte, bool) { |
| ekey := encodeKey(key) |
| ds := db.get(nil, valueCollection, ekey) |
| if ds == nil { |
| return nil, false |
| } |
| return decodeValue(ds.Data()), true |
| } |
| |
| // Delete implements [storage.DB.Delete]. |
| func (db *DB) Delete(key []byte) { |
| db.delete(nil, valueCollection, encodeKey(key)) |
| } |
| |
| // DeleteRange implements [storage.DB.DeleteRange]. |
| func (db *DB) DeleteRange(start, end []byte) { |
| db.deleteRange(valueCollection, encodeKey(start), encodeKey(end)) |
| } |
| |
| // Scan implements [storage.DB.Scan]. |
| func (db *DB) Scan(start, end []byte) iter.Seq2[[]byte, func() []byte] { |
| return func(yield func(key []byte, valf func() []byte) bool) { |
| for ds := range db.scan(nil, valueCollection, encodeKey(start), encodeKey(end)) { |
| if !yield(decodeKey(ds.Ref.ID), func() []byte { return decodeValue(ds.Data()) }) { |
| return |
| } |
| } |
| } |
| } |
| |
| // Batch implements [storage.DB.Batch]. |
| func (db *DB) Batch() storage.Batch { |
| return &dbBatch{db.newBatch(valueCollection)} |
| } |
| |
| type dbBatch struct { |
| b *batch |
| } |
| |
| // Delete implements [storage.Batch.Delete]. |
| func (b *dbBatch) Delete(key []byte) { |
| b.b.delete(encodeKey(key)) |
| } |
| |
| // DeleteRange implements [storage.Batch.DeleteRange]. |
| func (b *dbBatch) DeleteRange(start, end []byte) { |
| b.b.deleteRange(encodeKey(start), encodeKey(end)) |
| } |
| |
| // Set implements [storage.Batch.Set]. |
| func (b *dbBatch) Set(key, val []byte) { |
| // TODO(jba): account for size of encoded map. |
| b.b.set(encodeKey(key), encodeValue(slices.Clone(val)), len(val)) |
| } |
| |
| // MaybeApply implements [storage.Batch.MaybeApply]. |
| func (b *dbBatch) MaybeApply() bool { |
| return b.b.maybeApply() |
| } |
| |
| // Apply implements [storage.Batch.Apply]. |
| func (b *dbBatch) Apply() { |
| b.b.apply() |
| } |
| |
| // encodeKey converts k to a string, preserving ordering. |
| func encodeKey(k []byte) string { |
| return hex.EncodeToString(k) |
| } |
| |
| func decodeKey(s string) []byte { |
| b, err := hex.DecodeString(s) |
| if err != nil { |
| // unreachable except for bad DB |
| panic(fmt.Sprintf("decodeKey(%q) failed: %v", s, err)) |
| } |
| return b |
| } |
| |
| // encodeValue encodes v in a format acceptable to Firestore. |
| // (Firestore values must be maps or structs.) |
| func encodeValue(v []byte) any { |
| return map[string][]byte{"v": v} |
| } |
| |
| func decodeValue(m map[string]any) []byte { |
| return m["v"].([]byte) |
| } |