blob: 032215b801282013faa5e426f80da17b5d92d2b9 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build goexperiment.rangefunc
package main
import (
"fmt"
"iter"
"runtime"
)
func init() {
register("CoroLockOSThreadIterLock", func() {
println("expect: OK")
CoroLockOSThread(callerExhaust, iterLock)
})
register("CoroLockOSThreadIterLockYield", func() {
println("expect: OS thread locking must match")
CoroLockOSThread(callerExhaust, iterLockYield)
})
register("CoroLockOSThreadLock", func() {
println("expect: OK")
CoroLockOSThread(callerExhaustLocked, iterSimple)
})
register("CoroLockOSThreadLockIterNested", func() {
println("expect: OK")
CoroLockOSThread(callerExhaustLocked, iterNested)
})
register("CoroLockOSThreadLockIterLock", func() {
println("expect: OK")
CoroLockOSThread(callerExhaustLocked, iterLock)
})
register("CoroLockOSThreadLockIterLockYield", func() {
println("expect: OS thread locking must match")
CoroLockOSThread(callerExhaustLocked, iterLockYield)
})
register("CoroLockOSThreadLockIterYieldNewG", func() {
println("expect: OS thread locking must match")
CoroLockOSThread(callerExhaustLocked, iterYieldNewG)
})
register("CoroLockOSThreadLockAfterPull", func() {
println("expect: OS thread locking must match")
CoroLockOSThread(callerLockAfterPull, iterSimple)
})
register("CoroLockOSThreadStopLocked", func() {
println("expect: OK")
CoroLockOSThread(callerStopLocked, iterSimple)
})
register("CoroLockOSThreadStopLockedIterNested", func() {
println("expect: OK")
CoroLockOSThread(callerStopLocked, iterNested)
})
}
func CoroLockOSThread(driver func(iter.Seq[int]) error, seq iter.Seq[int]) {
if err := driver(seq); err != nil {
println("error:", err.Error())
return
}
println("OK")
}
func callerExhaust(i iter.Seq[int]) error {
next, _ := iter.Pull(i)
for {
v, ok := next()
if !ok {
break
}
if v != 5 {
return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
}
}
return nil
}
func callerExhaustLocked(i iter.Seq[int]) error {
runtime.LockOSThread()
next, _ := iter.Pull(i)
for {
v, ok := next()
if !ok {
break
}
if v != 5 {
return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
}
}
runtime.UnlockOSThread()
return nil
}
func callerLockAfterPull(i iter.Seq[int]) error {
n := 0
next, _ := iter.Pull(i)
for {
runtime.LockOSThread()
n++
v, ok := next()
if !ok {
break
}
if v != 5 {
return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
}
}
for range n {
runtime.UnlockOSThread()
}
return nil
}
func callerStopLocked(i iter.Seq[int]) error {
runtime.LockOSThread()
next, stop := iter.Pull(i)
v, _ := next()
stop()
if v != 5 {
return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
}
runtime.UnlockOSThread()
return nil
}
func iterSimple(yield func(int) bool) {
for range 3 {
if !yield(5) {
return
}
}
}
func iterNested(yield func(int) bool) {
next, stop := iter.Pull(iterSimple)
for {
v, ok := next()
if ok {
if !yield(v) {
stop()
}
} else {
return
}
}
}
func iterLock(yield func(int) bool) {
for range 3 {
runtime.LockOSThread()
runtime.UnlockOSThread()
if !yield(5) {
return
}
}
}
func iterLockYield(yield func(int) bool) {
for range 3 {
runtime.LockOSThread()
ok := yield(5)
runtime.UnlockOSThread()
if !ok {
return
}
}
}
func iterYieldNewG(yield func(int) bool) {
for range 3 {
done := make(chan struct{})
var ok bool
go func() {
ok = yield(5)
done <- struct{}{}
}()
<-done
if !ok {
return
}
}
}