internal/weak: add package implementing weak pointers

This change adds the internal/weak package, which exposes GC-supported
weak pointers to the standard library. This is for the upcoming weak
package, but may be useful for other future constructs.

For #62483.

Change-Id: I4aa8fa9400110ad5ea022a43c094051699ccab9d
Reviewed-on: https://go-review.googlesource.com/c/go/+/576297
Auto-Submit: Michael Knyszek <mknyszek@google.com>
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go
index c1034e5..5954669 100644
--- a/src/go/build/deps_test.go
+++ b/src/go/build/deps_test.go
@@ -77,6 +77,7 @@
 	< internal/race
 	< internal/msan
 	< internal/asan
+	< internal/weak
 	< sync
 	< internal/bisect
 	< internal/godebug
diff --git a/src/internal/abi/escape.go b/src/internal/abi/escape.go
index 8f37563..8cdae14 100644
--- a/src/internal/abi/escape.go
+++ b/src/internal/abi/escape.go
@@ -20,3 +20,14 @@
 	x := uintptr(p)
 	return unsafe.Pointer(x ^ 0)
 }
+
+var alwaysFalse bool
+var escapeSink any
+
+// Escape forces any pointers in x to escape to the heap.
+func Escape[T any](x T) T {
+	if alwaysFalse {
+		escapeSink = x
+	}
+	return x
+}
diff --git a/src/internal/weak/pointer.go b/src/internal/weak/pointer.go
new file mode 100644
index 0000000..44d2673
--- /dev/null
+++ b/src/internal/weak/pointer.go
@@ -0,0 +1,83 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+The weak package is a package for managing weak pointers.
+
+Weak pointers are pointers that explicitly do not keep a value live and
+must be queried for a regular Go pointer.
+The result of such a query may be observed as nil at any point after a
+weakly-pointed-to object becomes eligible for reclamation by the garbage
+collector.
+More specifically, weak pointers become nil as soon as the garbage collector
+identifies that the object is unreachable, before it is made reachable
+again by a finalizer.
+In terms of the C# language, these semantics are roughly equivalent to the
+the semantics of "short" weak references.
+In terms of the Java language, these semantics are roughly equivalent to the
+semantics of the WeakReference type.
+
+Using go:linkname to access this package and the functions it references
+is explicitly forbidden by the toolchain because the semantics of this
+package have not gone through the proposal process. By exposing this
+functionality, we risk locking in the existing semantics due to Hyrum's Law.
+
+If you believe you have a good use-case for weak references not already
+covered by the standard library, file a proposal issue at
+https://github.com/golang/go/issues instead of relying on this package.
+*/
+package weak
+
+import (
+	"internal/abi"
+	"runtime"
+	"unsafe"
+)
+
+// Pointer is a weak pointer to a value of type T.
+//
+// This value is comparable is guaranteed to compare equal if the pointers
+// that they were created from compare equal. This property is retained even
+// after the object referenced by the pointer used to create a weak reference
+// is reclaimed.
+//
+// If multiple weak pointers are made to different offsets within same object
+// (for example, pointers to different fields of the same struct), those pointers
+// will not compare equal.
+// If a weak pointer is created from an object that becomes reachable again due
+// to a finalizer, that weak pointer will not compare equal with weak pointers
+// created before it became unreachable.
+type Pointer[T any] struct {
+	u unsafe.Pointer
+}
+
+// Make creates a weak pointer from a strong pointer to some value of type T.
+func Make[T any](ptr *T) Pointer[T] {
+	// Explicitly force ptr to escape to the heap.
+	ptr = abi.Escape(ptr)
+
+	var u unsafe.Pointer
+	if ptr != nil {
+		u = runtime_registerWeakPointer(unsafe.Pointer(ptr))
+	}
+	runtime.KeepAlive(ptr)
+	return Pointer[T]{u}
+}
+
+// Strong creates a strong pointer from the weak pointer.
+// Returns nil if the original value for the weak pointer was reclaimed by
+// the garbage collector.
+// If a weak pointer points to an object with a finalizer, thhen Strong will
+// return nil as soon as the object's finalizer is queued for execution.
+func (p Pointer[T]) Strong() *T {
+	return (*T)(runtime_makeStrongFromWeak(unsafe.Pointer(p.u)))
+}
+
+// Implemented in runtime.
+
+//go:linkname runtime_registerWeakPointer
+func runtime_registerWeakPointer(unsafe.Pointer) unsafe.Pointer
+
+//go:linkname runtime_makeStrongFromWeak
+func runtime_makeStrongFromWeak(unsafe.Pointer) unsafe.Pointer
diff --git a/src/internal/weak/pointer_test.go b/src/internal/weak/pointer_test.go
new file mode 100644
index 0000000..e143749
--- /dev/null
+++ b/src/internal/weak/pointer_test.go
@@ -0,0 +1,130 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package weak_test
+
+import (
+	"internal/weak"
+	"runtime"
+	"testing"
+)
+
+type T struct {
+	// N.B. This must contain a pointer, otherwise the weak handle might get placed
+	// in a tiny block making the tests in this package flaky.
+	t *T
+	a int
+}
+
+func TestPointer(t *testing.T) {
+	bt := new(T)
+	wt := weak.Make(bt)
+	if st := wt.Strong(); st != bt {
+		t.Fatalf("weak pointer is not the same as strong pointer: %p vs. %p", st, bt)
+	}
+	// bt is still referenced.
+	runtime.GC()
+
+	if st := wt.Strong(); st != bt {
+		t.Fatalf("weak pointer is not the same as strong pointer after GC: %p vs. %p", st, bt)
+	}
+	// bt is no longer referenced.
+	runtime.GC()
+
+	if st := wt.Strong(); st != nil {
+		t.Fatalf("expected weak pointer to be nil, got %p", st)
+	}
+}
+
+func TestPointerEquality(t *testing.T) {
+	bt := make([]*T, 10)
+	wt := make([]weak.Pointer[T], 10)
+	for i := range bt {
+		bt[i] = new(T)
+		wt[i] = weak.Make(bt[i])
+	}
+	for i := range bt {
+		st := wt[i].Strong()
+		if st != bt[i] {
+			t.Fatalf("weak pointer is not the same as strong pointer: %p vs. %p", st, bt[i])
+		}
+		if wp := weak.Make(st); wp != wt[i] {
+			t.Fatalf("new weak pointer not equal to existing weak pointer: %v vs. %v", wp, wt[i])
+		}
+		if i == 0 {
+			continue
+		}
+		if wt[i] == wt[i-1] {
+			t.Fatalf("expected weak pointers to not be equal to each other, but got %v", wt[i])
+		}
+	}
+	// bt is still referenced.
+	runtime.GC()
+	for i := range bt {
+		st := wt[i].Strong()
+		if st != bt[i] {
+			t.Fatalf("weak pointer is not the same as strong pointer: %p vs. %p", st, bt[i])
+		}
+		if wp := weak.Make(st); wp != wt[i] {
+			t.Fatalf("new weak pointer not equal to existing weak pointer: %v vs. %v", wp, wt[i])
+		}
+		if i == 0 {
+			continue
+		}
+		if wt[i] == wt[i-1] {
+			t.Fatalf("expected weak pointers to not be equal to each other, but got %v", wt[i])
+		}
+	}
+	bt = nil
+	// bt is no longer referenced.
+	runtime.GC()
+	for i := range bt {
+		st := wt[i].Strong()
+		if st != nil {
+			t.Fatalf("expected weak pointer to be nil, got %p", st)
+		}
+		if i == 0 {
+			continue
+		}
+		if wt[i] == wt[i-1] {
+			t.Fatalf("expected weak pointers to not be equal to each other, but got %v", wt[i])
+		}
+	}
+}
+
+func TestPointerFinalizer(t *testing.T) {
+	bt := new(T)
+	wt := weak.Make(bt)
+	done := make(chan struct{}, 1)
+	runtime.SetFinalizer(bt, func(bt *T) {
+		if wt.Strong() != nil {
+			t.Errorf("weak pointer did not go nil before finalizer ran")
+		}
+		done <- struct{}{}
+	})
+
+	// Make sure the weak pointer stays around while bt is live.
+	runtime.GC()
+	if wt.Strong() == nil {
+		t.Errorf("weak pointer went nil too soon")
+	}
+	runtime.KeepAlive(bt)
+
+	// bt is no longer referenced.
+	//
+	// Run one cycle to queue the finalizer.
+	runtime.GC()
+	if wt.Strong() != nil {
+		t.Errorf("weak pointer did not go nil when finalizer was enqueued")
+	}
+
+	// Wait for the finalizer to run.
+	<-done
+
+	// The weak pointer should still be nil after the finalizer runs.
+	runtime.GC()
+	if wt.Strong() != nil {
+		t.Errorf("weak pointer is non-nil even after finalization: %v", wt)
+	}
+}
diff --git a/src/runtime/mgcmark.go b/src/runtime/mgcmark.go
index a42912e..61e917d 100644
--- a/src/runtime/mgcmark.go
+++ b/src/runtime/mgcmark.go
@@ -328,6 +328,13 @@
 	// 2) Finalizer specials (which are not in the garbage
 	// collected heap) are roots. In practice, this means the fn
 	// field must be scanned.
+	//
+	// Objects with weak handles have only one invariant related
+	// to this function: weak handle specials (which are not in the
+	// garbage collected heap) are roots. In practice, this means
+	// the handle field must be scanned. Note that the value the
+	// handle pointer referenced does *not* need to be scanned. See
+	// the definition of specialWeakHandle for details.
 	sg := mheap_.sweepgen
 
 	// Find the arena and page index into that arena for this shard.
@@ -373,24 +380,28 @@
 			// removed from the list while we're traversing it.
 			lock(&s.speciallock)
 			for sp := s.specials; sp != nil; sp = sp.next {
-				if sp.kind != _KindSpecialFinalizer {
-					continue
-				}
-				// don't mark finalized object, but scan it so we
-				// retain everything it points to.
-				spf := (*specialfinalizer)(unsafe.Pointer(sp))
-				// A finalizer can be set for an inner byte of an object, find object beginning.
-				p := s.base() + uintptr(spf.special.offset)/s.elemsize*s.elemsize
+				switch sp.kind {
+				case _KindSpecialFinalizer:
+					// don't mark finalized object, but scan it so we
+					// retain everything it points to.
+					spf := (*specialfinalizer)(unsafe.Pointer(sp))
+					// A finalizer can be set for an inner byte of an object, find object beginning.
+					p := s.base() + uintptr(spf.special.offset)/s.elemsize*s.elemsize
 
-				// Mark everything that can be reached from
-				// the object (but *not* the object itself or
-				// we'll never collect it).
-				if !s.spanclass.noscan() {
-					scanobject(p, gcw)
-				}
+					// Mark everything that can be reached from
+					// the object (but *not* the object itself or
+					// we'll never collect it).
+					if !s.spanclass.noscan() {
+						scanobject(p, gcw)
+					}
 
-				// The special itself is a root.
-				scanblock(uintptr(unsafe.Pointer(&spf.fn)), goarch.PtrSize, &oneptrmask[0], gcw, nil)
+					// The special itself is a root.
+					scanblock(uintptr(unsafe.Pointer(&spf.fn)), goarch.PtrSize, &oneptrmask[0], gcw, nil)
+				case _KindSpecialWeakHandle:
+					// The special itself is a root.
+					spw := (*specialWeakHandle)(unsafe.Pointer(sp))
+					scanblock(uintptr(unsafe.Pointer(&spw.handle)), goarch.PtrSize, &oneptrmask[0], gcw, nil)
+				}
 			}
 			unlock(&s.speciallock)
 		}
diff --git a/src/runtime/mgcsweep.go b/src/runtime/mgcsweep.go
index 701e0b8..5670b1b 100644
--- a/src/runtime/mgcsweep.go
+++ b/src/runtime/mgcsweep.go
@@ -552,31 +552,44 @@
 		mbits := s.markBitsForIndex(objIndex)
 		if !mbits.isMarked() {
 			// This object is not marked and has at least one special record.
-			// Pass 1: see if it has at least one finalizer.
-			hasFin := false
+			// Pass 1: see if it has a finalizer.
+			hasFinAndRevived := false
 			endOffset := p - s.base() + size
 			for tmp := siter.s; tmp != nil && uintptr(tmp.offset) < endOffset; tmp = tmp.next {
 				if tmp.kind == _KindSpecialFinalizer {
 					// Stop freeing of object if it has a finalizer.
 					mbits.setMarkedNonAtomic()
-					hasFin = true
+					hasFinAndRevived = true
 					break
 				}
 			}
-			// Pass 2: queue all finalizers _or_ handle profile record.
-			for siter.valid() && uintptr(siter.s.offset) < endOffset {
-				// Find the exact byte for which the special was setup
-				// (as opposed to object beginning).
-				special := siter.s
-				p := s.base() + uintptr(special.offset)
-				if special.kind == _KindSpecialFinalizer || !hasFin {
+			if hasFinAndRevived {
+				// Pass 2: queue all finalizers and clear any weak handles. Weak handles are cleared
+				// before finalization as specified by the internal/weak package. See the documentation
+				// for that package for more details.
+				for siter.valid() && uintptr(siter.s.offset) < endOffset {
+					// Find the exact byte for which the special was setup
+					// (as opposed to object beginning).
+					special := siter.s
+					p := s.base() + uintptr(special.offset)
+					if special.kind == _KindSpecialFinalizer || special.kind == _KindSpecialWeakHandle {
+						siter.unlinkAndNext()
+						freeSpecial(special, unsafe.Pointer(p), size)
+					} else {
+						// All other specials only apply when an object is freed,
+						// so just keep the special record.
+						siter.next()
+					}
+				}
+			} else {
+				// Pass 2: the object is truly dead, free (and handle) all specials.
+				for siter.valid() && uintptr(siter.s.offset) < endOffset {
+					// Find the exact byte for which the special was setup
+					// (as opposed to object beginning).
+					special := siter.s
+					p := s.base() + uintptr(special.offset)
 					siter.unlinkAndNext()
 					freeSpecial(special, unsafe.Pointer(p), size)
-				} else {
-					// The object has finalizers, so we're keeping it alive.
-					// All other specials only apply when an object is freed,
-					// so just keep the special record.
-					siter.next()
 				}
 			}
 		} else {
diff --git a/src/runtime/mheap.go b/src/runtime/mheap.go
index 1241f6e..a68f855 100644
--- a/src/runtime/mheap.go
+++ b/src/runtime/mheap.go
@@ -207,6 +207,7 @@
 	specialprofilealloc    fixalloc // allocator for specialprofile*
 	specialReachableAlloc  fixalloc // allocator for specialReachable
 	specialPinCounterAlloc fixalloc // allocator for specialPinCounter
+	specialWeakHandleAlloc fixalloc // allocator for specialWeakHandle
 	speciallock            mutex    // lock for special record allocators.
 	arenaHintAlloc         fixalloc // allocator for arenaHints
 
@@ -745,6 +746,7 @@
 	h.specialprofilealloc.init(unsafe.Sizeof(specialprofile{}), nil, nil, &memstats.other_sys)
 	h.specialReachableAlloc.init(unsafe.Sizeof(specialReachable{}), nil, nil, &memstats.other_sys)
 	h.specialPinCounterAlloc.init(unsafe.Sizeof(specialPinCounter{}), nil, nil, &memstats.other_sys)
+	h.specialWeakHandleAlloc.init(unsafe.Sizeof(specialWeakHandle{}), nil, nil, &memstats.gcMiscSys)
 	h.arenaHintAlloc.init(unsafe.Sizeof(arenaHint{}), nil, nil, &memstats.other_sys)
 
 	// Don't zero mspan allocations. Background sweeping can
@@ -1789,18 +1791,18 @@
 }
 
 const (
+	// _KindSpecialFinalizer is for tracking finalizers.
 	_KindSpecialFinalizer = 1
-	_KindSpecialProfile   = 2
+	// _KindSpecialWeakHandle is used for creating weak pointers.
+	_KindSpecialWeakHandle = 2
+	// _KindSpecialProfile is for memory profiling.
+	_KindSpecialProfile = 3
 	// _KindSpecialReachable is a special used for tracking
 	// reachability during testing.
-	_KindSpecialReachable = 3
+	_KindSpecialReachable = 4
 	// _KindSpecialPinCounter is a special used for objects that are pinned
 	// multiple times
-	_KindSpecialPinCounter = 4
-	// Note: The finalizer special must be first because if we're freeing
-	// an object, a finalizer special will cause the freeing operation
-	// to abort, and we want to keep the other special records around
-	// if that happens.
+	_KindSpecialPinCounter = 5
 )
 
 type special struct {
@@ -1985,6 +1987,155 @@
 	unlock(&mheap_.speciallock)
 }
 
+// The described object has a weak pointer.
+//
+// Weak pointers in the GC have the following invariants:
+//
+//   - Strong-to-weak conversions must ensure the strong pointer
+//     remains live until the weak handle is installed. This ensures
+//     that creating a weak pointer cannot fail.
+//
+//   - Weak-to-strong conversions require the weakly-referenced
+//     object to be swept before the conversion may proceed. This
+//     ensures that weak-to-strong conversions cannot resurrect
+//     dead objects by sweeping them before that happens.
+//
+//   - Weak handles are unique and canonical for each byte offset into
+//     an object that a strong pointer may point to, until an object
+//     becomes unreachable.
+//
+//   - Weak handles contain nil as soon as an object becomes unreachable
+//     the first time, before a finalizer makes it reachable again. New
+//     weak handles created after resurrection are newly unique.
+//
+// specialWeakHandle is allocated from non-GC'd memory, so any heap
+// pointers must be specially handled.
+type specialWeakHandle struct {
+	_       sys.NotInHeap
+	special special
+	// handle is a reference to the actual weak pointer.
+	// It is always heap-allocated and must be explicitly kept
+	// live so long as this special exists.
+	handle *atomic.Uintptr
+}
+
+//go:linkname internal_weak_runtime_registerWeakPointer internal/weak.runtime_registerWeakPointer
+func internal_weak_runtime_registerWeakPointer(p unsafe.Pointer) unsafe.Pointer {
+	return unsafe.Pointer(getOrAddWeakHandle(unsafe.Pointer(p)))
+}
+
+//go:linkname internal_weak_runtime_makeStrongFromWeak internal/weak.runtime_makeStrongFromWeak
+func internal_weak_runtime_makeStrongFromWeak(u unsafe.Pointer) unsafe.Pointer {
+	handle := (*atomic.Uintptr)(u)
+
+	// Prevent preemption. We want to make sure that another GC cycle can't start.
+	mp := acquirem()
+	p := handle.Load()
+	if p == 0 {
+		releasem(mp)
+		return nil
+	}
+	// Be careful. p may or may not refer to valid memory anymore, as it could've been
+	// swept and released already. It's always safe to ensure a span is swept, though,
+	// even if it's just some random span.
+	span := spanOfHeap(p)
+	if span == nil {
+		// The span probably got swept and released.
+		releasem(mp)
+		return nil
+	}
+	// Ensure the span is swept.
+	span.ensureSwept()
+
+	// Now we can trust whatever we get from handle, so make a strong pointer.
+	//
+	// Even if we just swept some random span that doesn't contain this object, because
+	// this object is long dead and its memory has since been reused, we'll just observe nil.
+	ptr := unsafe.Pointer(handle.Load())
+	releasem(mp)
+	return ptr
+}
+
+// Retrieves or creates a weak pointer handle for the object p.
+func getOrAddWeakHandle(p unsafe.Pointer) *atomic.Uintptr {
+	// First try to retrieve without allocating.
+	if handle := getWeakHandle(p); handle != nil {
+		return handle
+	}
+
+	lock(&mheap_.speciallock)
+	s := (*specialWeakHandle)(mheap_.specialWeakHandleAlloc.alloc())
+	unlock(&mheap_.speciallock)
+
+	handle := new(atomic.Uintptr)
+	s.special.kind = _KindSpecialWeakHandle
+	s.handle = handle
+	handle.Store(uintptr(p))
+	if addspecial(p, &s.special) {
+		// This is responsible for maintaining the same
+		// GC-related invariants as markrootSpans in any
+		// situation where it's possible that markrootSpans
+		// has already run but mark termination hasn't yet.
+		if gcphase != _GCoff {
+			mp := acquirem()
+			gcw := &mp.p.ptr().gcw
+			// Mark the weak handle itself, since the
+			// special isn't part of the GC'd heap.
+			scanblock(uintptr(unsafe.Pointer(&s.handle)), goarch.PtrSize, &oneptrmask[0], gcw, nil)
+			releasem(mp)
+		}
+		return s.handle
+	}
+
+	// There was an existing handle. Free the special
+	// and try again. We must succeed because we're explicitly
+	// keeping p live until the end of this function. Either
+	// we, or someone else, must have succeeded, because we can
+	// only fail in the event of a race, and p will still be
+	// be valid no matter how much time we spend here.
+	lock(&mheap_.speciallock)
+	mheap_.specialWeakHandleAlloc.free(unsafe.Pointer(s))
+	unlock(&mheap_.speciallock)
+
+	handle = getWeakHandle(p)
+	if handle == nil {
+		throw("failed to get or create weak handle")
+	}
+
+	// Keep p alive for the duration of the function to ensure
+	// that it cannot die while we're trying to this.
+	KeepAlive(p)
+	return handle
+}
+
+func getWeakHandle(p unsafe.Pointer) *atomic.Uintptr {
+	span := spanOfHeap(uintptr(p))
+	if span == nil {
+		throw("getWeakHandle on invalid pointer")
+	}
+
+	// Ensure that the span is swept.
+	// Sweeping accesses the specials list w/o locks, so we have
+	// to synchronize with it. And it's just much safer.
+	mp := acquirem()
+	span.ensureSwept()
+
+	offset := uintptr(p) - span.base()
+
+	lock(&span.speciallock)
+
+	// Find the existing record and return the handle if one exists.
+	var handle *atomic.Uintptr
+	iter, exists := span.specialFindSplicePoint(offset, _KindSpecialWeakHandle)
+	if exists {
+		handle = ((*specialWeakHandle)(unsafe.Pointer(*iter))).handle
+	}
+	unlock(&span.speciallock)
+	releasem(mp)
+
+	return handle
+}
+
 // The described object is being heap profiled.
 type specialprofile struct {
 	_       sys.NotInHeap
@@ -2056,6 +2207,12 @@
 		lock(&mheap_.speciallock)
 		mheap_.specialfinalizeralloc.free(unsafe.Pointer(sf))
 		unlock(&mheap_.speciallock)
+	case _KindSpecialWeakHandle:
+		sw := (*specialWeakHandle)(unsafe.Pointer(s))
+		sw.handle.Store(0)
+		lock(&mheap_.speciallock)
+		mheap_.specialWeakHandleAlloc.free(unsafe.Pointer(s))
+		unlock(&mheap_.speciallock)
 	case _KindSpecialProfile:
 		sp := (*specialprofile)(unsafe.Pointer(s))
 		mProf_Free(sp.b, size)
diff --git a/src/runtime/proc.go b/src/runtime/proc.go
index a029a23..8f5787d 100644
--- a/src/runtime/proc.go
+++ b/src/runtime/proc.go
@@ -6937,6 +6937,18 @@
 	procUnpin()
 }
 
+//go:linkname internal_weak_runtime_procPin internal/weak.runtime_procPin
+//go:nosplit
+func internal_weak_runtime_procPin() int {
+	return procPin()
+}
+
+//go:linkname internal_weak_runtime_procUnpin internal/weak.runtime_procUnpin
+//go:nosplit
+func internal_weak_runtime_procUnpin() {
+	procUnpin()
+}
+
 // Active spinning for sync.Mutex.
 //
 //go:linkname sync_runtime_canSpin sync.runtime_canSpin