blob: e0cb9451121421ba4158025d4f9b23ee8b2b8a94 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build goexperiment.rangefunc && !windows
package main
/*
#include <stdint.h> // for uintptr_t
void go_callback_coro(uintptr_t handle);
static void call_go(uintptr_t handle) {
go_callback_coro(handle);
}
*/
import "C"
import (
"fmt"
"iter"
"runtime/cgo"
)
func init() {
register("CoroCgoIterCallback", func() {
println("expect: OK")
CoroCgo(callerExhaust, iterCallback)
})
register("CoroCgoIterCallbackYield", func() {
println("expect: OS thread locking must match")
CoroCgo(callerExhaust, iterCallbackYield)
})
register("CoroCgoCallback", func() {
println("expect: OK")
CoroCgo(callerExhaustCallback, iterSimple)
})
register("CoroCgoCallbackIterNested", func() {
println("expect: OK")
CoroCgo(callerExhaustCallback, iterNested)
})
register("CoroCgoCallbackIterCallback", func() {
println("expect: OK")
CoroCgo(callerExhaustCallback, iterCallback)
})
register("CoroCgoCallbackIterCallbackYield", func() {
println("expect: OS thread locking must match")
CoroCgo(callerExhaustCallback, iterCallbackYield)
})
register("CoroCgoCallbackAfterPull", func() {
println("expect: OS thread locking must match")
CoroCgo(callerCallbackAfterPull, iterSimple)
})
register("CoroCgoStopCallback", func() {
println("expect: OK")
CoroCgo(callerStopCallback, iterSimple)
})
register("CoroCgoStopCallbackIterNested", func() {
println("expect: OK")
CoroCgo(callerStopCallback, iterNested)
})
}
var toCall func()
//export go_callback_coro
func go_callback_coro(handle C.uintptr_t) {
h := cgo.Handle(handle)
h.Value().(func())()
h.Delete()
}
func callFromC(f func()) {
C.call_go(C.uintptr_t(cgo.NewHandle(f)))
}
func CoroCgo(driver func(iter.Seq[int]) error, seq iter.Seq[int]) {
if err := driver(seq); err != nil {
println("error:", err.Error())
return
}
println("OK")
}
func callerExhaust(i iter.Seq[int]) error {
next, _ := iter.Pull(i)
for {
v, ok := next()
if !ok {
break
}
if v != 5 {
return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
}
}
return nil
}
func callerExhaustCallback(i iter.Seq[int]) (err error) {
callFromC(func() {
next, _ := iter.Pull(i)
for {
v, ok := next()
if !ok {
break
}
if v != 5 {
err = fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
}
}
})
return err
}
func callerStopCallback(i iter.Seq[int]) (err error) {
callFromC(func() {
next, stop := iter.Pull(i)
v, _ := next()
stop()
if v != 5 {
err = fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
}
})
return err
}
func callerCallbackAfterPull(i iter.Seq[int]) (err error) {
next, _ := iter.Pull(i)
callFromC(func() {
for {
v, ok := next()
if !ok {
break
}
if v != 5 {
err = fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
}
}
})
return err
}
func iterSimple(yield func(int) bool) {
for range 3 {
if !yield(5) {
return
}
}
}
func iterNested(yield func(int) bool) {
next, stop := iter.Pull(iterSimple)
for {
v, ok := next()
if ok {
if !yield(v) {
stop()
}
} else {
return
}
}
}
func iterCallback(yield func(int) bool) {
for range 3 {
callFromC(func() {})
if !yield(5) {
return
}
}
}
func iterCallbackYield(yield func(int) bool) {
for range 3 {
var ok bool
callFromC(func() {
ok = yield(5)
})
if !ok {
return
}
}
}