blob: 25d3805bc662c63fc0514f6791629e32ba7df71f [file] [log] [blame]
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trie
// Collision functions combine a left and right hand side (lhs and rhs) values
// the two values are associated with the same key and produces the value that
// will be stored for the key.
//
// Collision functions must be idempotent:
// collision(x, x) == x for all x.
// Collisions functions may be applied whenever a value is inserted
// or two maps are merged, or intersected.
type Collision func(lhs interface{}, rhs interface{}) interface{}
// TakeLhs always returns the left value in a collision.
func TakeLhs(lhs, rhs interface{}) interface{} { return lhs }
// TakeRhs always returns the right hand side in a collision.
func TakeRhs(lhs, rhs interface{}) interface{} { return rhs }
// Builder creates new Map. Each Builder has a unique Scope.
//
// IMPORTANT: Nodes are hash-consed internally to reduce memory consumption. To
// support hash-consing Builders keep an internal Map of all of the Maps that they
// have created. To GC any of the Maps created by the Builder, all references to
// the Builder must be dropped. This includes MutMaps.
type Builder struct {
scope Scope
// hash-consing maps for each node type.
empty *empty
leaves map[leaf]*leaf
branches map[branch]*branch
// It may be possible to support more types of patricia tries
// (e.g. non-hash-consed) by making Builder an interface and abstracting
// the mkLeaf and mkBranch functions.
}
// NewBuilder creates a new Builder with a unique Scope.
func NewBuilder() *Builder {
s := newScope()
return &Builder{
scope: s,
empty: &empty{s},
leaves: make(map[leaf]*leaf),
branches: make(map[branch]*branch),
}
}
func (b *Builder) Scope() Scope { return b.scope }
// Rescope changes the builder's scope to a new unique Scope.
//
// Any Maps created using the previous scope need to be Cloned
// before any operation.
//
// This makes the old internals of the Builder eligible to be GC'ed.
func (b *Builder) Rescope() {
s := newScope()
b.scope = s
b.empty = &empty{s}
b.leaves = make(map[leaf]*leaf)
b.branches = make(map[branch]*branch)
}
// Empty is the empty map.
func (b *Builder) Empty() Map { return Map{b.Scope(), b.empty} }
// InsertWith inserts a new association from k to v into the Map m to create a new map
// in the current scope and handle collisions using the collision function c.
//
// This is roughly corresponds to updating a map[uint64]interface{} by:
// if _, ok := m[k]; ok { m[k] = c(m[k], v} else { m[k] = v}
//
// An insertion or update happened whenever Insert(m, ...) != m .
func (b *Builder) InsertWith(c Collision, m Map, k uint64, v interface{}) Map {
m = b.Clone(m)
return Map{b.Scope(), b.insert(c, m.n, b.mkLeaf(key(k), v), false)}
}
// Inserts a new association from key to value into the Map m to create
// a new map in the current scope.
//
// If there was a previous value mapped by key, keep the previously mapped value.
// This is roughly corresponds to updating a map[uint64]interface{} by:
// if _, ok := m[k]; ok { m[k] = val }
//
// This is equivalent to b.Merge(m, b.Create({k: v})).
func (b *Builder) Insert(m Map, k uint64, v interface{}) Map {
return b.InsertWith(TakeLhs, m, k, v)
}
// Updates a (key, value) in the map. This is roughly corresponds to
// updating a map[uint64]interface{} by:
// m[key] = val
func (b *Builder) Update(m Map, key uint64, val interface{}) Map {
return b.InsertWith(TakeRhs, m, key, val)
}
// Merge two maps lhs and rhs to create a new map in the current scope.
//
// Whenever there is a key in both maps (a collision), the resulting value mapped by
// the key will be `c(lhs[key], rhs[key])`.
func (b *Builder) MergeWith(c Collision, lhs, rhs Map) Map {
lhs, rhs = b.Clone(lhs), b.Clone(rhs)
return Map{b.Scope(), b.merge(c, lhs.n, rhs.n)}
}
// Merge two maps lhs and rhs to create a new map in the current scope.
//
// Whenever there is a key in both maps (a collision), the resulting value mapped by
// the key will be the value in lhs `b.Collision(lhs[key], rhs[key])`.
func (b *Builder) Merge(lhs, rhs Map) Map {
return b.MergeWith(TakeLhs, lhs, rhs)
}
// Clone returns a Map that contains the same (key, value) elements
// within b.Scope(), i.e. return m if m.Scope() == b.Scope() or return
// a deep copy of m within b.Scope() otherwise.
func (b *Builder) Clone(m Map) Map {
if m.Scope() == b.Scope() {
return m
} else if m.n == nil {
return Map{b.Scope(), b.empty}
}
return Map{b.Scope(), b.clone(m.n)}
}
func (b *Builder) clone(n node) node {
switch n := n.(type) {
case *empty:
return b.empty
case *leaf:
return b.mkLeaf(n.k, n.v)
case *branch:
return b.mkBranch(n.prefix, n.branching, b.clone(n.left), b.clone(n.right))
default:
panic("unreachable")
}
}
// Remove a key from a Map m and return the resulting Map.
func (b *Builder) Remove(m Map, k uint64) Map {
m = b.Clone(m)
return Map{b.Scope(), b.remove(m.n, key(k))}
}
// Intersect Maps lhs and rhs and returns a map with all of the keys in
// both lhs and rhs and the value comes from lhs, i.e.
// {(k, lhs[k]) | k in lhs, k in rhs}.
func (b *Builder) Intersect(lhs, rhs Map) Map {
return b.IntersectWith(TakeLhs, lhs, rhs)
}
// IntersectWith take lhs and rhs and returns the intersection
// with the value coming from the collision function, i.e.
// {(k, c(lhs[k], rhs[k]) ) | k in lhs, k in rhs}.
// The elements of the resulting map are always { <k, c(lhs[k], rhs[k]) > }
// for each key k that a key in both lhs and rhs.
func (b *Builder) IntersectWith(c Collision, lhs, rhs Map) Map {
l, r := b.Clone(lhs), b.Clone(rhs)
return Map{b.Scope(), b.intersect(c, l.n, r.n)}
}
// MutMap is a convenient wrapper for a Map and a *Builder that will be used to create
// new Maps from it.
type MutMap struct {
B *Builder
M Map
}
// MutEmpty is an empty MutMap for a builder.
func (b *Builder) MutEmpty() MutMap {
return MutMap{b, b.Empty()}
}
// Insert an element into the map using the collision function for the builder.
// Returns true if the element was inserted.
func (mm *MutMap) Insert(k uint64, v interface{}) bool {
old := mm.M
mm.M = mm.B.Insert(old, k, v)
return old != mm.M
}
// Updates an element in the map. Returns true if the map was updated.
func (mm *MutMap) Update(k uint64, v interface{}) bool {
old := mm.M
mm.M = mm.B.Update(old, k, v)
return old != mm.M
}
// Removes a key from the map. Returns true if the element was removed.
func (mm *MutMap) Remove(k uint64) bool {
old := mm.M
mm.M = mm.B.Remove(old, k)
return old != mm.M
}
// Merge another map into the current one using the collision function
// for the builder. Returns true if the map changed.
func (mm *MutMap) Merge(other Map) bool {
old := mm.M
mm.M = mm.B.Merge(old, other)
return old != mm.M
}
// Intersect another map into the current one using the collision function
// for the builder. Returns true if the map changed.
func (mm *MutMap) Intersect(other Map) bool {
old := mm.M
mm.M = mm.B.Intersect(old, other)
return old != mm.M
}
func (b *Builder) Create(m map[uint64]interface{}) Map {
var leaves []*leaf
for k, v := range m {
leaves = append(leaves, b.mkLeaf(key(k), v))
}
return Map{b.Scope(), b.create(leaves)}
}
// Merge another map into the current one using the collision function
// for the builder. Returns true if the map changed.
func (mm *MutMap) MergeWith(c Collision, other Map) bool {
old := mm.M
mm.M = mm.B.MergeWith(c, old, other)
return old != mm.M
}
// creates a map for a collection of leaf nodes.
func (b *Builder) create(leaves []*leaf) node {
n := len(leaves)
if n == 0 {
return b.empty
} else if n == 1 {
return leaves[0]
}
// Note: we can do a more sophisicated algorithm by:
// - sorting the leaves ahead of time,
// - taking the prefix and branching bit of the min and max key,
// - binary searching for the branching bit,
// - splitting exactly where the branch will be, and
// - making the branch node for this prefix + branching bit.
// Skipping until this is a performance bottleneck.
m := n / 2 // (n >= 2) ==> 1 <= m < n
l, r := leaves[:m], leaves[m:]
return b.merge(nil, b.create(l), b.create(r))
}
// mkLeaf returns the hash-consed representative of (k, v) in the current scope.
func (b *Builder) mkLeaf(k key, v interface{}) *leaf {
l := &leaf{k: k, v: v}
if rep, ok := b.leaves[*l]; ok {
return rep
}
b.leaves[*l] = l
return l
}
// mkBranch returns the hash-consed representative of the tuple
// (prefix, branch, left, right)
// in the current scope.
func (b *Builder) mkBranch(p prefix, bp bitpos, left node, right node) *branch {
br := &branch{
sz: left.size() + right.size(),
prefix: p,
branching: bp,
left: left,
right: right,
}
if rep, ok := b.branches[*br]; ok {
return rep
}
b.branches[*br] = br
return br
}
// join two maps with prefixes p0 and p1 that are *known* to disagree.
func (b *Builder) join(p0 prefix, t0 node, p1 prefix, t1 node) *branch {
m := branchingBit(p0, p1)
var left, right node
if zeroBit(p0, m) {
left, right = t0, t1
} else {
left, right = t1, t0
}
prefix := mask(p0, m)
return b.mkBranch(prefix, m, left, right)
}
// collide two leaves with the same key to create a leaf
// with the collided value.
func (b *Builder) collide(c Collision, left, right *leaf) *leaf {
if left == right {
return left // c is idempotent: c(x, x) == x
}
val := left.v // keep the left value by default if c is nil
if c != nil {
val = c(left.v, right.v)
}
switch val {
case left.v:
return left
case right.v:
return right
default:
return b.mkLeaf(left.k, val)
}
}
// inserts a leaf l into a map m and returns the resulting map.
// When lhs is true, l is the left hand side in a collision.
// Both l and m are in the current scope.
func (b *Builder) insert(c Collision, m node, l *leaf, lhs bool) node {
switch m := m.(type) {
case *empty:
return l
case *leaf:
if m.k == l.k {
left, right := l, m
if !lhs {
left, right = right, left
}
return b.collide(c, left, right)
}
return b.join(prefix(l.k), l, prefix(m.k), m)
case *branch:
// fallthrough
}
// m is a branch
br := m.(*branch)
if !matchPrefix(prefix(l.k), br.prefix, br.branching) {
return b.join(prefix(l.k), l, br.prefix, br)
}
var left, right node
if zeroBit(prefix(l.k), br.branching) {
left, right = b.insert(c, br.left, l, lhs), br.right
} else {
left, right = br.left, b.insert(c, br.right, l, lhs)
}
if left == br.left && right == br.right {
return m
}
return b.mkBranch(br.prefix, br.branching, left, right)
}
// merge two maps in the current scope.
func (b *Builder) merge(c Collision, lhs, rhs node) node {
if lhs == rhs {
return lhs
}
switch lhs := lhs.(type) {
case *empty:
return rhs
case *leaf:
return b.insert(c, rhs, lhs, true)
case *branch:
switch rhs := rhs.(type) {
case *empty:
return lhs
case *leaf:
return b.insert(c, lhs, rhs, false)
case *branch:
// fallthrough
}
}
// Last remaining case is branch branch merging.
// For brevity, we adopt the Okasaki and Gill naming conventions
// for branching and prefixes.
s, t := lhs.(*branch), rhs.(*branch)
p, m := s.prefix, s.branching
q, n := t.prefix, t.branching
if m == n && p == q { // prefixes are identical.
left, right := b.merge(c, s.left, t.left), b.merge(c, s.right, t.right)
return b.mkBranch(p, m, left, right)
}
if !prefixesOverlap(p, m, q, n) {
return b.join(p, s, q, t) // prefixes are disjoint.
}
// prefixesOverlap(p, m, q, n) && !(m ==n && p == q)
// By prefixesOverlap(...), either:
// higher(m, n) && matchPrefix(q, p, m), or
// higher(n, m) && matchPrefix(p, q, n)
// So either s or t may can be merged with one branch or the other.
switch {
case ord(m, n) && zeroBit(q, m):
return b.mkBranch(p, m, b.merge(c, s.left, t), s.right)
case ord(m, n) && !zeroBit(q, m):
return b.mkBranch(p, m, s.left, b.merge(c, s.right, t))
case ord(n, m) && zeroBit(p, n):
return b.mkBranch(q, n, b.merge(c, s, t.left), t.right)
default:
return b.mkBranch(q, n, t.left, b.merge(c, s, t.right))
}
}
func (b *Builder) remove(m node, k key) node {
switch m := m.(type) {
case *empty:
return m
case *leaf:
if m.k == k {
return b.empty
}
return m
case *branch:
// fallthrough
}
br := m.(*branch)
kp := prefix(k)
if !matchPrefix(kp, br.prefix, br.branching) {
// The prefix does not match. kp is not in br.
return br
}
// the prefix matches. try to remove from the left or right branch.
left, right := br.left, br.right
if zeroBit(kp, br.branching) {
left = b.remove(left, k) // k may be in the left branch.
} else {
right = b.remove(right, k) // k may be in the right branch.
}
if left == br.left && right == br.right {
return br // no update
} else if _, ok := left.(*empty); ok {
return right // left updated and is empty.
} else if _, ok := right.(*empty); ok {
return left // right updated and is empty.
}
// Either left or right updated. Both left and right are not empty.
// The left and right branches still share the same prefix and disagree
// on the same branching bit. It is safe to directly create the branch.
return b.mkBranch(br.prefix, br.branching, left, right)
}
func (b *Builder) intersect(c Collision, l, r node) node {
if l == r {
return l
}
switch l := l.(type) {
case *empty:
return b.empty
case *leaf:
if rleaf := r.find(l.k); rleaf != nil {
return b.collide(c, l, rleaf)
}
return b.empty
case *branch:
switch r := r.(type) {
case *empty:
return b.empty
case *leaf:
if lleaf := l.find(r.k); lleaf != nil {
return b.collide(c, lleaf, r)
}
return b.empty
case *branch:
// fallthrough
}
}
// Last remaining case is branch branch intersection.
s, t := l.(*branch), r.(*branch)
p, m := s.prefix, s.branching
q, n := t.prefix, t.branching
if m == n && p == q {
// prefixes are identical.
left, right := b.intersect(c, s.left, t.left), b.intersect(c, s.right, t.right)
if _, ok := left.(*empty); ok {
return right
} else if _, ok := right.(*empty); ok {
return left
}
// The left and right branches are both non-empty.
// They still share the same prefix and disagree on the same branching bit.
// It is safe to directly create the branch.
return b.mkBranch(p, m, left, right)
}
if !prefixesOverlap(p, m, q, n) {
return b.empty // The prefixes share no keys.
}
// prefixesOverlap(p, m, q, n) && !(m ==n && p == q)
// By prefixesOverlap(...), either:
// ord(m, n) && matchPrefix(q, p, m), or
// ord(n, m) && matchPrefix(p, q, n)
// So either s or t may be a strict subtree of the other.
var lhs, rhs node
switch {
case ord(m, n) && zeroBit(q, m):
lhs, rhs = s.left, t
case ord(m, n) && !zeroBit(q, m):
lhs, rhs = s.right, t
case ord(n, m) && zeroBit(p, n):
lhs, rhs = s, t.left
default:
lhs, rhs = s, t.right
}
return b.intersect(c, lhs, rhs)
}