internal/storage: generalize overlay DB

The overlay part of an overlay DB can now be any DB, not just an
in-memory one.

Change-Id: I42b63fef57c6f0dda9c6b53f91b77c1a25047c60
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/629975
Reviewed-by: Hyang-Ah Hana Kim <hyangah@gmail.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/gaby/main.go b/internal/gaby/main.go
index 276783d..1493c9b 100644
--- a/internal/gaby/main.go
+++ b/internal/gaby/main.go
@@ -268,7 +268,7 @@
 	}
 	g.db = db
 	if flags.overlay {
-		g.db = storage.NewOverlayDB(g.db)
+		g.db = storage.NewOverlayDB(storage.MemDB(), g.db)
 	}
 
 	vdb, err := firestore.NewVectorDB(g.ctx, g.slog, flags.project, flags.firestoredb, "gaby")
diff --git a/internal/storage/overlay.go b/internal/storage/overlay.go
index bfee888..6766916 100644
--- a/internal/storage/overlay.go
+++ b/internal/storage/overlay.go
@@ -8,29 +8,41 @@
 	"bytes"
 	"iter"
 	"sync"
+
+	"rsc.io/ordered"
 )
 
 type overlayDB struct {
 	MemLocker
-	mu            sync.RWMutex
-	base          DB
-	overlay       *memDB
-	deletedKeys   map[string]bool
-	deletedRanges []keyRange
+	mu      sync.RWMutex
+	overlay DB
+	base    DB
 }
 
 type keyRange struct {
 	start, end []byte
 }
 
-// NewOverlayDB returns a DB that overlays a MemDB over a base DB.
+// Start of keys used for the overlay implementation.
+const overlayPrefix = "__overlay"
+
+// NewOverlayDB returns a DB that combines overlay with base.
 // Reads happen from the overlay first, then the base.
 // All writes go to the overlay.
-func NewOverlayDB(base DB) DB {
+//
+// An overlay DB should only be used for testing. It can violate the
+// specification of [DB] when a process is writing to the base concurrently.
+// Locks held in the overlay will not be observed by the base, so changes
+// from other processes can occur even while the process with the overlay
+// has the lock.
+//
+// The overlay DB assumes that all keys are encoded with [rsc.io/ordered].
+// The part of the key space beginning with ordered.Encode(overlayPrefix) in the overlay
+// DB is reserved for use by the implementation.
+func NewOverlayDB(overlay, base DB) DB {
 	return &overlayDB{
-		base:        base,
-		overlay:     MemDB().(*memDB),
-		deletedKeys: map[string]bool{},
+		overlay: overlay,
+		base:    base,
 	}
 }
 
@@ -58,7 +70,7 @@
 
 func (db *overlayDB) setLocked(key, val []byte) {
 	db.overlay.Set(key, val)
-	delete(db.deletedKeys, string(key)) // save some space; not strictly necessary
+	db.unmarkDeleted(key)
 }
 
 // Delete deletes any entry with the given key.
@@ -70,7 +82,7 @@
 
 func (db *overlayDB) deleteLocked(key []byte) {
 	db.overlay.Delete(key)
-	db.deletedKeys[string(key)] = true
+	db.markDeleted(key)
 }
 
 // DeleteRange deletes all entries with start ≤ key ≤ end.
@@ -83,7 +95,7 @@
 func (db *overlayDB) deleteRangeLocked(start, end []byte) {
 	// TODO(maybe): consolidate ranges
 	db.overlay.DeleteRange(start, end)
-	db.deletedRanges = append(db.deletedRanges, keyRange{start, end})
+	db.markRangeDeleted(start, end)
 }
 
 // Scan returns an iterator over all key-value pairs
@@ -162,16 +174,47 @@
 	}
 }
 
-// deleted reports whether key is a deleted key.
+// markDeleted marks key as deleted.
+// It isn't sufficient to simply delete the key in the overlay, because
+// the key may exist in the base as well.
+func (db *overlayDB) markDeleted(key []byte) {
+	tombstone := ordered.Encode(overlayPrefix, ordered.Raw(key))
+	db.overlay.Set(tombstone, nil)
+}
+
+// unmarkDeleted removes from the database the marker that key is deleted.
+// It is not strictly necessary to do this when a key is set, but it
+// saves some space.
+func (db *overlayDB) unmarkDeleted(key []byte) {
+	tombstone := ordered.Encode(overlayPrefix, ordered.Raw(key))
+	db.overlay.Delete(tombstone)
+}
+
+// markRangeDeleted marks a range of keys as deleted.
+// It isn't sufficient to delete each key in the range that appears in base,
+// because a key in the range might be added to base but not overlay, and then
+// it would be visible.
+func (db *overlayDB) markRangeDeleted(start, end []byte) {
+	// The key for a deleted range is the start of the range.
+	key := ordered.Encode(overlayPrefix, "ranges", ordered.Raw(start))
+	db.overlay.Set(key, end)
+}
+
+// deleted reports whether key is deleted.
 // The result is only valid for db if key is not present in db.overlay;
 // keys in db.overlay are not deleted, regardless of what this function reports.
 // deleted must be called with the lock held.
 func (db *overlayDB) deleted(key []byte) bool {
-	if db.deletedKeys[string(key)] {
+	tombstone := ordered.Encode(overlayPrefix, ordered.Raw(key))
+	if _, ok := db.overlay.Get(tombstone); ok {
 		return true
 	}
-	for _, r := range db.deletedRanges {
-		if bytes.Compare(key, r.start) >= 0 && bytes.Compare(key, r.end) <= 0 {
+	// Scan deleted ranges up to where the start of the range is equal to key.
+	prefix := ordered.Encode(overlayPrefix, "ranges")
+	for _, vf := range db.overlay.Scan(prefix, append(prefix, ordered.Encode(ordered.Raw(key))...)) {
+		// We know start <= key, so just compare end >= key.
+		end := vf()
+		if bytes.Compare(end, key) >= 0 {
 			return true
 		}
 	}
diff --git a/internal/storage/overlay_test.go b/internal/storage/overlay_test.go
index 48888a2..8784c6b 100644
--- a/internal/storage/overlay_test.go
+++ b/internal/storage/overlay_test.go
@@ -8,15 +8,17 @@
 	"bytes"
 	"slices"
 	"testing"
+
+	"rsc.io/ordered"
 )
 
 func TestOverlayDB(t *testing.T) {
-	bs := func(n byte) []byte { return []byte{n} }
+	o := func(n int) []byte { return ordered.Encode(n) }
 
 	newbase := func() DB {
 		db := MemDB()
-		db.Set(bs(0), bs(0))
-		db.Set(bs(9), bs(9))
+		db.Set(o(0), o(0))
+		db.Set(o(9), o(9))
 		return db
 	}
 
@@ -26,25 +28,25 @@
 	}{
 		{
 			"set3",
-			func(db DB) { db.Set(bs(3), bs(3)) },
+			func(db DB) { db.Set(o(3), o(3)) },
 		},
 		{
 			"del9",
-			func(db DB) { db.Delete(bs(9)) },
+			func(db DB) { db.Delete(o(9)) },
 		},
 		{
 			"del9set9",
-			func(db DB) { db.Delete(bs(9)); db.Set(bs(9), bs(4)) },
+			func(db DB) { db.Delete(o(9)); db.Set(o(9), o(4)) },
 		},
 		{
 			"del5set3",
-			func(db DB) { db.Delete(bs(5)); db.Set(bs(3), bs(3)) },
+			func(db DB) { db.Delete(o(5)); db.Set(o(3), o(3)) },
 		},
 		{
 			"get9",
 			func(db DB) {
-				v, ok := db.Get(bs(9))
-				if !ok || !bytes.Equal(v, bs(9)) {
+				v, ok := db.Get(o(9))
+				if !ok || !bytes.Equal(v, o(9)) {
 					t.Fatal("bad Get")
 				}
 			},
@@ -52,9 +54,9 @@
 		{
 			"set9get9",
 			func(db DB) {
-				db.Set(bs(9), bs(1))
-				v, ok := db.Get(bs(9))
-				if !ok || !bytes.Equal(v, bs(1)) {
+				db.Set(o(9), o(1))
+				v, ok := db.Get(o(9))
+				if !ok || !bytes.Equal(v, o(1)) {
 					t.Fatal("bad Get")
 				}
 			},
@@ -62,8 +64,8 @@
 		{
 			"del9get9",
 			func(db DB) {
-				db.Delete(bs(9))
-				if _, ok := db.Get(bs(9)); ok {
+				db.Delete(o(9))
+				if _, ok := db.Get(o(9)); ok {
 					t.Fatal("bad Get")
 				}
 			},
@@ -72,32 +74,52 @@
 			"batch",
 			func(db DB) {
 				b := db.Batch()
-				b.Set(bs(1), bs(1))
-				b.Set(bs(2), bs(2))
-				b.Delete(bs(9))
-				b.Set(bs(9), bs(9))
-				b.DeleteRange(bs(2), bs(6))
+				b.Set(o(1), o(1))
+				b.Set(o(2), o(2))
+				b.Delete(o(9))
+				b.Set(o(9), o(9))
+				b.DeleteRange(o(2), o(6))
 				b.Apply()
 			},
 		},
 		{
 			"delrange",
 			func(db DB) {
-				for i := range byte(9) {
-					db.Set(bs(i), bs(i))
+				for i := range 9 {
+					db.Set(o(i), o(i))
 				}
-				db.DeleteRange(bs(3), bs(9))
-				db.Set(bs(3), bs(3))
+				db.DeleteRange(o(3), o(9))
+				db.Set(o(3), o(3))
+			},
+		},
+		{
+			"delrange2",
+			func(db DB) {
+				for i := range 20 {
+					db.Set(o(i), o(i))
+				}
+				db.DeleteRange(o(3), o(9))
+				db.DeleteRange(o(8), o(12))
+				db.DeleteRange(o(18), o(22))
+				for _, k := range []int{4, 12, 15, 18, 23} {
+					db.Set(o(k), o(k))
+				}
+				db.Delete(o(4))
+				db.Delete(o(15))
 			},
 		},
 	} {
 		t.Run(test.name, func(t *testing.T) {
-			// The overlay DB should behave exactly like an ordinary DB.
+			// Run the operations on an overlay DB.
 			base := newbase()
-			gdb := NewOverlayDB(base)
+			over := MemDB()
+			gdb := NewOverlayDB(over, base)
 			test.ops(gdb)
+
+			// Run the operations directly on the base DB.
 			wdb := newbase()
 			test.ops(wdb)
+			// The overlay DB should behave exactly like an ordinary DB.
 			got := items(gdb)
 			want := items(wdb)
 			if !slices.EqualFunc(got, want, item.Equal) {
@@ -123,7 +145,11 @@
 
 func items(db DB) []item {
 	var items []item
-	for k, vf := range db.Scan([]byte{0}, []byte{255}) {
+	for k, vf := range db.Scan(nil, ordered.Encode(ordered.Inf)) {
+		var prefix string
+		if _, err := ordered.DecodePrefix(k, &prefix); err == nil && prefix == overlayPrefix {
+			continue
+		}
 		items = append(items, item{k, vf()})
 	}
 	return items