internal/storage: overlay DB
Implement a DB that overlays an in-memory DB on top of another
DB (the "base").
Reads happen first on the overlay, then the base.
Writes only happen to the overlay.
For golang/oscar#52.
Change-Id: I60bf9c8ec34319ca52b73145711620103a22e7e6
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/626002
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
diff --git a/internal/storage/mem.go b/internal/storage/mem.go
index 13a08d7..9dddcf5 100644
--- a/internal/storage/mem.go
+++ b/internal/storage/mem.go
@@ -87,7 +87,7 @@
return v, ok
}
-// Scan returns an iterator overall key-value pairs
+// Scan returns an iterator over all key-value pairs
// in the range start ≤ key ≤ end.
func (db *memDB) Scan(start, end []byte) iter.Seq2[[]byte, func() []byte] {
lo := string(start)
diff --git a/internal/storage/overlay.go b/internal/storage/overlay.go
new file mode 100644
index 0000000..8d44dca
--- /dev/null
+++ b/internal/storage/overlay.go
@@ -0,0 +1,237 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+ "bytes"
+ "iter"
+ "sync"
+)
+
+type overlayDB struct {
+ MemLocker
+ mu sync.RWMutex
+ base DB
+ overlay *memDB
+ deletedKeys map[string]bool
+ deletedRanges []keyRange
+}
+
+type keyRange struct {
+ start, end []byte
+}
+
+// NewOverlayDB is a DB that overlays a MemDB over a base DB.
+// Reads happen from the overlay first, then the base.
+// All writes go to the overlay.
+func NewOverlayDB(base DB) DB {
+ return &overlayDB{
+ base: base,
+ overlay: MemDB().(*memDB),
+ deletedKeys: map[string]bool{},
+ }
+}
+
+// Get returns the value associated with the key.
+func (db *overlayDB) Get(key []byte) (val []byte, ok bool) {
+ db.mu.RLock()
+ defer db.mu.RUnlock()
+
+ if oval, ok := db.overlay.Get(key); ok {
+ return oval, true
+ }
+
+ if db.deleted(key) {
+ return nil, false
+ }
+ return db.base.Get(key)
+}
+
+// Set sets the value associated with key to val.
+func (db *overlayDB) Set(key, val []byte) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ db.setLocked(key, val)
+}
+
+func (db *overlayDB) setLocked(key, val []byte) {
+ db.overlay.Set(key, val)
+ delete(db.deletedKeys, string(key)) // save some space; not strictly necessary
+}
+
+// Delete deletes any entry with the given key.
+func (db *overlayDB) Delete(key []byte) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ db.deleteLocked(key)
+}
+
+func (db *overlayDB) deleteLocked(key []byte) {
+ db.overlay.Delete(key)
+ db.deletedKeys[string(key)] = true
+}
+
+// DeleteRange deletes all entries with start ≤ key ≤ end.
+func (db *overlayDB) DeleteRange(start, end []byte) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ db.deleteRangeLocked(start, end)
+}
+
+func (db *overlayDB) deleteRangeLocked(start, end []byte) {
+ // TODO(maybe): consolidate ranges
+ db.overlay.DeleteRange(start, end)
+ db.deletedRanges = append(db.deletedRanges, keyRange{start, end})
+}
+
+// Scan returns an iterator over all key-value pairs
+// in the range start ≤ key ≤ end.
+// It does not guaranteed a consistent view of the DB (snapshot); keys and values
+// may change during the iteration.
+func (db *overlayDB) Scan(start, end []byte) iter.Seq2[[]byte, func() []byte] {
+ return func(yield func([]byte, func() []byte) bool) {
+ db.mu.RLock()
+ locked := true
+ defer func() {
+ if locked {
+ db.mu.RUnlock()
+ }
+ }()
+
+ // Filter out all keys in base that have been deleted.
+ fbase := filter2(db.base.Scan(start, end), func(k []byte, v func() []byte) bool { return !db.deleted(k) })
+ // Merge all the keys in overlay with the undeleted ones in base.
+ for k, vf := range unionFunc2(db.overlay.Scan(start, end), fbase, bytes.Compare) {
+ // Release the lock so yield can call methods on db.
+ db.mu.RUnlock()
+ locked = false
+ if !yield(k, vf) {
+ return
+ }
+ db.mu.RLock()
+ locked = true
+ }
+ }
+}
+
+// filter2 returns a sequence that consists of all the elements of s for which f returns true.
+func filter2[K, V any](s iter.Seq2[K, V], f func(K, V) bool) iter.Seq2[K, V] {
+ return func(yield func(K, V) bool) {
+ for k, v := range s {
+ if !f(k, v) {
+ continue
+ }
+ if !yield(k, v) {
+ return
+ }
+ }
+ }
+}
+
+// unionFunc2 returns an iterator over all elements of s1 and s2, with keys in the same order.
+// The keys of s1 and s2 must both be sorted according to cmp.
+// If s1 and s2 have the same key, it is yielded once, with s1's value.
+func unionFunc2[K, V any](s1, s2 iter.Seq2[K, V], cmp func(K, K) int) iter.Seq2[K, V] {
+ return func(yield func(K, V) bool) {
+ next, stop := iter.Pull2(s2)
+ defer stop()
+
+ k2, v2, ok := next()
+ for k1, v1 := range s1 {
+ for ok && cmp(k2, k1) < 0 {
+ if !yield(k2, v2) {
+ return
+ }
+ k2, v2, ok = next()
+ }
+ if !yield(k1, v1) {
+ return
+ }
+ if cmp(k1, k2) == 0 {
+ k2, v2, ok = next()
+ }
+ }
+ for ; ok; k2, v2, ok = next() {
+ if !yield(k2, v2) {
+ return
+ }
+ }
+ }
+}
+
+// deleted reports whether key is a deleted key.
+// The result is only valid for db if key is not present in db.overlay;
+// keys in db.overlay are not deleted, regardless of what this function reports.
+// deleted must be called with the lock held.
+func (db *overlayDB) deleted(key []byte) bool {
+ if db.deletedKeys[string(key)] {
+ return true
+ }
+ for _, r := range db.deletedRanges {
+ if bytes.Compare(key, r.start) >= 0 && bytes.Compare(key, r.end) <= 0 {
+ return true
+ }
+ }
+ return false
+}
+
+// Batch returns a new batch.
+func (db *overlayDB) Batch() Batch {
+ return &overlayBatch{db: db}
+}
+
+// Flush flushes everything to persistent storage.
+func (db *overlayDB) Flush() {
+ // overlay is a memDB and base is never written; nothing to flush.
+}
+
+func (db *overlayDB) Close() {
+ db.base.Close()
+ db.overlay.Close()
+}
+
+func (db *overlayDB) Panic(msg string, args ...any) {
+ Panic(msg, args...)
+}
+
+// An overlayBatch is a Batch for an overlayDB.
+type overlayBatch struct {
+ db *overlayDB // underlying database
+ ops []func() // operations to apply
+}
+
+func (b *overlayBatch) Set(key, val []byte) {
+ if len(key) == 0 {
+ b.db.Panic("overlaydb batch set: empty key")
+ }
+ k := bytes.Clone(key)
+ v := bytes.Clone(val)
+ b.ops = append(b.ops, func() { b.db.setLocked(k, v) })
+}
+
+func (b *overlayBatch) Delete(key []byte) {
+ k := bytes.Clone(key)
+ b.ops = append(b.ops, func() { b.db.deleteLocked(k) })
+}
+
+func (b *overlayBatch) DeleteRange(start, end []byte) {
+ s := bytes.Clone(start)
+ e := bytes.Clone(end)
+ b.ops = append(b.ops, func() { b.db.deleteRangeLocked(s, e) })
+}
+
+func (b *overlayBatch) MaybeApply() bool {
+ return false
+}
+
+func (b *overlayBatch) Apply() {
+ b.db.mu.Lock()
+ defer b.db.mu.Unlock()
+
+ for _, op := range b.ops {
+ op()
+ }
+ b.ops = nil
+}
diff --git a/internal/storage/overlay_test.go b/internal/storage/overlay_test.go
new file mode 100644
index 0000000..48888a2
--- /dev/null
+++ b/internal/storage/overlay_test.go
@@ -0,0 +1,130 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package storage
+
+import (
+ "bytes"
+ "slices"
+ "testing"
+)
+
+func TestOverlayDB(t *testing.T) {
+ bs := func(n byte) []byte { return []byte{n} }
+
+ newbase := func() DB {
+ db := MemDB()
+ db.Set(bs(0), bs(0))
+ db.Set(bs(9), bs(9))
+ return db
+ }
+
+ for _, test := range []struct {
+ name string
+ ops func(DB)
+ }{
+ {
+ "set3",
+ func(db DB) { db.Set(bs(3), bs(3)) },
+ },
+ {
+ "del9",
+ func(db DB) { db.Delete(bs(9)) },
+ },
+ {
+ "del9set9",
+ func(db DB) { db.Delete(bs(9)); db.Set(bs(9), bs(4)) },
+ },
+ {
+ "del5set3",
+ func(db DB) { db.Delete(bs(5)); db.Set(bs(3), bs(3)) },
+ },
+ {
+ "get9",
+ func(db DB) {
+ v, ok := db.Get(bs(9))
+ if !ok || !bytes.Equal(v, bs(9)) {
+ t.Fatal("bad Get")
+ }
+ },
+ },
+ {
+ "set9get9",
+ func(db DB) {
+ db.Set(bs(9), bs(1))
+ v, ok := db.Get(bs(9))
+ if !ok || !bytes.Equal(v, bs(1)) {
+ t.Fatal("bad Get")
+ }
+ },
+ },
+ {
+ "del9get9",
+ func(db DB) {
+ db.Delete(bs(9))
+ if _, ok := db.Get(bs(9)); ok {
+ t.Fatal("bad Get")
+ }
+ },
+ },
+ {
+ "batch",
+ func(db DB) {
+ b := db.Batch()
+ b.Set(bs(1), bs(1))
+ b.Set(bs(2), bs(2))
+ b.Delete(bs(9))
+ b.Set(bs(9), bs(9))
+ b.DeleteRange(bs(2), bs(6))
+ b.Apply()
+ },
+ },
+ {
+ "delrange",
+ func(db DB) {
+ for i := range byte(9) {
+ db.Set(bs(i), bs(i))
+ }
+ db.DeleteRange(bs(3), bs(9))
+ db.Set(bs(3), bs(3))
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ // The overlay DB should behave exactly like an ordinary DB.
+ base := newbase()
+ gdb := NewOverlayDB(base)
+ test.ops(gdb)
+ wdb := newbase()
+ test.ops(wdb)
+ got := items(gdb)
+ want := items(wdb)
+ if !slices.EqualFunc(got, want, item.Equal) {
+ t.Errorf("\ngot %v\nwant %v", got, want)
+ }
+
+ // The overlay DB should not change its base.
+ bgot := items(base)
+ if !slices.EqualFunc(bgot, items(newbase()), item.Equal) {
+ t.Errorf("base changed: %v", bgot)
+ }
+ })
+ }
+}
+
+type item struct {
+ key, val []byte
+}
+
+func (i1 item) Equal(i2 item) bool {
+ return bytes.Equal(i1.key, i2.key) && bytes.Equal(i1.val, i2.val)
+}
+
+func items(db DB) []item {
+ var items []item
+ for k, vf := range db.Scan([]byte{0}, []byte{255}) {
+ items = append(items, item{k, vf()})
+ }
+ return items
+}