blob: a3494eef997fd2bc36004d9aa07049ab09df9a7d [file] [log] [blame]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unique
import (
"internal/abi"
"internal/goarch"
"runtime"
"sync"
"sync/atomic"
"unsafe"
"weak"
)
// canonMap is a map of T -> *T. The map controls the creation
// of a canonical *T, and elements of the map are automatically
// deleted when the canonical *T is no longer referenced.
type canonMap[T comparable] struct {
root atomic.Pointer[indirect[T]]
hash func(unsafe.Pointer, uintptr) uintptr
seed uintptr
}
func newCanonMap[T comparable]() *canonMap[T] {
cm := new(canonMap[T])
cm.root.Store(newIndirectNode[T](nil))
var m map[T]struct{}
mapType := abi.TypeOf(m).MapType()
cm.hash = mapType.Hasher
cm.seed = uintptr(runtime_rand())
return cm
}
func (m *canonMap[T]) Load(key T) *T {
hash := m.hash(abi.NoEscape(unsafe.Pointer(&key)), m.seed)
i := m.root.Load()
hashShift := 8 * goarch.PtrSize
for hashShift != 0 {
hashShift -= nChildrenLog2
n := i.children[(hash>>hashShift)&nChildrenMask].Load()
if n == nil {
return nil
}
if n.isEntry {
v, _ := n.entry().lookup(key)
return v
}
i = n.indirect()
}
panic("unique.canonMap: ran out of hash bits while iterating")
}
func (m *canonMap[T]) LoadOrStore(key T) *T {
hash := m.hash(abi.NoEscape(unsafe.Pointer(&key)), m.seed)
var i *indirect[T]
var hashShift uint
var slot *atomic.Pointer[node[T]]
var n *node[T]
for {
// Find the key or a candidate location for insertion.
i = m.root.Load()
hashShift = 8 * goarch.PtrSize
haveInsertPoint := false
for hashShift != 0 {
hashShift -= nChildrenLog2
slot = &i.children[(hash>>hashShift)&nChildrenMask]
n = slot.Load()
if n == nil {
// We found a nil slot which is a candidate for insertion.
haveInsertPoint = true
break
}
if n.isEntry {
// We found an existing entry, which is as far as we can go.
// If it stays this way, we'll have to replace it with an
// indirect node.
if v, _ := n.entry().lookup(key); v != nil {
return v
}
haveInsertPoint = true
break
}
i = n.indirect()
}
if !haveInsertPoint {
panic("unique.canonMap: ran out of hash bits while iterating")
}
// Grab the lock and double-check what we saw.
i.mu.Lock()
n = slot.Load()
if (n == nil || n.isEntry) && !i.dead.Load() {
// What we saw is still true, so we can continue with the insert.
break
}
// We have to start over.
i.mu.Unlock()
}
// N.B. This lock is held from when we broke out of the outer loop above.
// We specifically break this out so that we can use defer here safely.
// One option is to break this out into a new function instead, but
// there's so much local iteration state used below that this turns out
// to be cleaner.
defer i.mu.Unlock()
var oldEntry *entry[T]
if n != nil {
oldEntry = n.entry()
if v, _ := oldEntry.lookup(key); v != nil {
// Easy case: by loading again, it turns out exactly what we wanted is here!
return v
}
}
newEntry, canon, wp := newEntryNode(key, hash)
// Prune dead pointers. This is to avoid O(n) lookups when we store the exact same
// value in the set but the cleanup hasn't run yet because it got delayed for some
// reason.
oldEntry = oldEntry.prune()
if oldEntry == nil {
// Easy case: create a new entry and store it.
slot.Store(&newEntry.node)
} else {
// We possibly need to expand the entry already there into one or more new nodes.
//
// Publish the node last, which will make both oldEntry and newEntry visible. We
// don't want readers to be able to observe that oldEntry isn't in the tree.
slot.Store(m.expand(oldEntry, newEntry, hash, hashShift, i))
}
runtime.AddCleanup(canon, func(_ struct{}) {
m.cleanup(hash, wp)
}, struct{}{})
return canon
}
// expand takes oldEntry and newEntry whose hashes conflict from bit 64 down to hashShift and
// produces a subtree of indirect nodes to hold the two new entries. newHash is the hash of
// the value in the new entry.
func (m *canonMap[T]) expand(oldEntry, newEntry *entry[T], newHash uintptr, hashShift uint, parent *indirect[T]) *node[T] {
// Check for a hash collision.
oldHash := oldEntry.hash
if oldHash == newHash {
// Store the old entry in the new entry's overflow list, then store
// the new entry.
newEntry.overflow.Store(oldEntry)
return &newEntry.node
}
// We have to add an indirect node. Worse still, we may need to add more than one.
newIndirect := newIndirectNode(parent)
top := newIndirect
for {
if hashShift == 0 {
panic("unique.canonMap: ran out of hash bits while inserting")
}
hashShift -= nChildrenLog2 // hashShift is for the level parent is at. We need to go deeper.
oi := (oldHash >> hashShift) & nChildrenMask
ni := (newHash >> hashShift) & nChildrenMask
if oi != ni {
newIndirect.children[oi].Store(&oldEntry.node)
newIndirect.children[ni].Store(&newEntry.node)
break
}
nextIndirect := newIndirectNode(newIndirect)
newIndirect.children[oi].Store(&nextIndirect.node)
newIndirect = nextIndirect
}
return &top.node
}
// cleanup deletes the entry corresponding to wp in the canon map, if it's
// still in the map. wp must have a Value method that returns nil by the
// time this function is called. hash must be the hash of the value that
// wp once pointed to (that is, the hash of *wp.Value()).
func (m *canonMap[T]) cleanup(hash uintptr, wp weak.Pointer[T]) {
var i *indirect[T]
var hashShift uint
var slot *atomic.Pointer[node[T]]
var n *node[T]
for {
// Find wp in the map by following hash.
i = m.root.Load()
hashShift = 8 * goarch.PtrSize
haveEntry := false
for hashShift != 0 {
hashShift -= nChildrenLog2
slot = &i.children[(hash>>hashShift)&nChildrenMask]
n = slot.Load()
if n == nil {
// We found a nil slot, already deleted.
return
}
if n.isEntry {
if !n.entry().hasWeakPointer(wp) {
// The weak pointer was already pruned.
return
}
haveEntry = true
break
}
i = n.indirect()
}
if !haveEntry {
panic("unique.canonMap: ran out of hash bits while iterating")
}
// Grab the lock and double-check what we saw.
i.mu.Lock()
n = slot.Load()
if n != nil && n.isEntry {
// Prune the entry node without thinking too hard. If we do
// somebody else's work, such as someone trying to insert an
// entry with the same hash (probably the same value) then
// great, they'll back out without taking the lock.
newEntry := n.entry().prune()
if newEntry == nil {
slot.Store(nil)
} else {
slot.Store(&newEntry.node)
}
// Delete interior nodes that are empty, up the tree.
//
// We'll hand-over-hand lock our way up the tree as we do this,
// since we need to delete each empty node's link in its parent,
// which requires the parents' lock.
for i.parent != nil && i.empty() {
if hashShift == 8*goarch.PtrSize {
panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
}
hashShift += nChildrenLog2
// Delete the current node in the parent.
parent := i.parent
parent.mu.Lock()
i.dead.Store(true) // Could be done outside of parent's lock.
parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
i.mu.Unlock()
i = parent
}
i.mu.Unlock()
return
}
// We have to start over.
i.mu.Unlock()
}
}
// node is the header for a node. It's polymorphic and
// is actually either an entry or an indirect.
type node[T comparable] struct {
isEntry bool
}
func (n *node[T]) entry() *entry[T] {
if !n.isEntry {
panic("called entry on non-entry node")
}
return (*entry[T])(unsafe.Pointer(n))
}
func (n *node[T]) indirect() *indirect[T] {
if n.isEntry {
panic("called indirect on entry node")
}
return (*indirect[T])(unsafe.Pointer(n))
}
const (
// 16 children. This seems to be the sweet spot for
// load performance: any smaller and we lose out on
// 50% or more in CPU performance. Any larger and the
// returns are minuscule (~1% improvement for 32 children).
nChildrenLog2 = 4
nChildren = 1 << nChildrenLog2
nChildrenMask = nChildren - 1
)
// indirect is an internal node in the hash-trie.
type indirect[T comparable] struct {
node[T]
dead atomic.Bool
parent *indirect[T]
mu sync.Mutex // Protects mutation to children and any children that are entry nodes.
children [nChildren]atomic.Pointer[node[T]]
}
func newIndirectNode[T comparable](parent *indirect[T]) *indirect[T] {
return &indirect[T]{node: node[T]{isEntry: false}, parent: parent}
}
func (i *indirect[T]) empty() bool {
for j := range i.children {
if i.children[j].Load() != nil {
return false
}
}
return true
}
// entry is a leaf node in the hash-trie.
type entry[T comparable] struct {
node[T]
overflow atomic.Pointer[entry[T]] // Overflow for hash collisions.
key weak.Pointer[T]
hash uintptr
}
func newEntryNode[T comparable](key T, hash uintptr) (*entry[T], *T, weak.Pointer[T]) {
k := new(T)
*k = key
wp := weak.Make(k)
return &entry[T]{
node: node[T]{isEntry: true},
key: wp,
hash: hash,
}, k, wp
}
// lookup finds the entry in the overflow chain that has the provided key.
//
// Returns the key's canonical pointer and the weak pointer for that canonical pointer.
func (e *entry[T]) lookup(key T) (*T, weak.Pointer[T]) {
for e != nil {
s := e.key.Value()
if s != nil && *s == key {
return s, e.key
}
e = e.overflow.Load()
}
return nil, weak.Pointer[T]{}
}
// hasWeakPointer returns true if the provided weak pointer can be found in the overflow chain.
func (e *entry[T]) hasWeakPointer(wp weak.Pointer[T]) bool {
for e != nil {
if e.key == wp {
return true
}
e = e.overflow.Load()
}
return false
}
// prune removes all entries in the overflow chain whose keys are nil.
//
// The caller must hold the lock on e's parent node.
func (e *entry[T]) prune() *entry[T] {
// Prune the head of the list.
for e != nil {
if e.key.Value() != nil {
break
}
e = e.overflow.Load()
}
if e == nil {
return nil
}
// Prune individual nodes in the list.
newHead := e
i := &e.overflow
e = i.Load()
for e != nil {
if e.key.Value() != nil {
i = &e.overflow
} else {
i.Store(e.overflow.Load())
}
e = e.overflow.Load()
}
return newHead
}
// Pull in runtime.rand so that we don't need to take a dependency
// on math/rand/v2.
//
//go:linkname runtime_rand runtime.rand
func runtime_rand() uint64