cmd/compile: fix mishandling of unsafe-uintptr arguments in go/defer

Currently, the statement:

	go g(uintptr(f()))

gets rewritten into:

	tmp := f()
	newproc(8, g, uintptr(tmp))
	runtime.KeepAlive(tmp)

which doesn't guarantee that tmp is still alive by time the g call is
scheduled to run.

This CL fixes the issue, by wrapping g call in a closure:

	go func(p unsafe.Pointer) {
		g(uintptr(p))
	}(f())

then this will be rewritten into:

	tmp := f()
	go func(p unsafe.Pointer) {
		g(uintptr(p))
		runtime.KeepAlive(p)
	}(tmp)
	runtime.KeepAlive(tmp)  // superfluous, but harmless

So the unsafe.Pointer p will be kept alive at the time g call runs.

Updates #24491

Change-Id: Ic10821251cbb1b0073daec92b82a866c6ebaf567
Reviewed-on: https://go-review.googlesource.com/c/go/+/253457
Run-TryBot: Cuong Manh Le <cuong.manhle.vn@gmail.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/cmd/compile/internal/gc/order.go b/src/cmd/compile/internal/gc/order.go
index aa91160..412f073 100644
--- a/src/cmd/compile/internal/gc/order.go
+++ b/src/cmd/compile/internal/gc/order.go
@@ -502,6 +502,7 @@
 			x := o.copyExpr(arg.Left, arg.Left.Type, false)
 			x.Name.SetKeepalive(true)
 			arg.Left = x
+			n.SetNeedsWrapper(true)
 		}
 	}
 
diff --git a/src/cmd/compile/internal/gc/syntax.go b/src/cmd/compile/internal/gc/syntax.go
index 47e5e59..5580f78 100644
--- a/src/cmd/compile/internal/gc/syntax.go
+++ b/src/cmd/compile/internal/gc/syntax.go
@@ -141,19 +141,20 @@
 	nodeInitorder, _                   // tracks state during init1; two bits
 	_, _                               // second nodeInitorder bit
 	_, nodeHasBreak
-	_, nodeNoInline  // used internally by inliner to indicate that a function call should not be inlined; set for OCALLFUNC and OCALLMETH only
-	_, nodeImplicit  // implicit OADDR or ODEREF; ++/-- statement represented as OASOP; or ANDNOT lowered to OAND
-	_, nodeIsDDD     // is the argument variadic
-	_, nodeDiag      // already printed error about this
-	_, nodeColas     // OAS resulting from :=
-	_, nodeNonNil    // guaranteed to be non-nil
-	_, nodeTransient // storage can be reused immediately after this statement
-	_, nodeBounded   // bounds check unnecessary
-	_, nodeHasCall   // expression contains a function call
-	_, nodeLikely    // if statement condition likely
-	_, nodeHasVal    // node.E contains a Val
-	_, nodeHasOpt    // node.E contains an Opt
-	_, nodeEmbedded  // ODCLFIELD embedded type
+	_, nodeNoInline     // used internally by inliner to indicate that a function call should not be inlined; set for OCALLFUNC and OCALLMETH only
+	_, nodeImplicit     // implicit OADDR or ODEREF; ++/-- statement represented as OASOP; or ANDNOT lowered to OAND
+	_, nodeIsDDD        // is the argument variadic
+	_, nodeDiag         // already printed error about this
+	_, nodeColas        // OAS resulting from :=
+	_, nodeNonNil       // guaranteed to be non-nil
+	_, nodeTransient    // storage can be reused immediately after this statement
+	_, nodeBounded      // bounds check unnecessary
+	_, nodeHasCall      // expression contains a function call
+	_, nodeLikely       // if statement condition likely
+	_, nodeHasVal       // node.E contains a Val
+	_, nodeHasOpt       // node.E contains an Opt
+	_, nodeEmbedded     // ODCLFIELD embedded type
+	_, nodeNeedsWrapper // OCALLxxx node that needs to be wrapped
 )
 
 func (n *Node) Class() Class     { return Class(n.flags.get3(nodeClass)) }
@@ -286,6 +287,20 @@
 	n.Xoffset = x
 }
 
+func (n *Node) NeedsWrapper() bool {
+	return n.flags&nodeNeedsWrapper != 0
+}
+
+// SetNeedsWrapper indicates that OCALLxxx node needs to be wrapped by a closure.
+func (n *Node) SetNeedsWrapper(b bool) {
+	switch n.Op {
+	case OCALLFUNC, OCALLMETH, OCALLINTER:
+	default:
+		Fatalf("Node.SetNeedsWrapper %v", n.Op)
+	}
+	n.flags.set(nodeNeedsWrapper, b)
+}
+
 // mayBeShared reports whether n may occur in multiple places in the AST.
 // Extra care must be taken when mutating such a node.
 func (n *Node) mayBeShared() bool {
diff --git a/src/cmd/compile/internal/gc/walk.go b/src/cmd/compile/internal/gc/walk.go
index 0158af8..ab7f857 100644
--- a/src/cmd/compile/internal/gc/walk.go
+++ b/src/cmd/compile/internal/gc/walk.go
@@ -232,7 +232,11 @@
 			n.Left = copyany(n.Left, &n.Ninit, true)
 
 		default:
-			n.Left = walkexpr(n.Left, &n.Ninit)
+			if n.Left.NeedsWrapper() {
+				n.Left = wrapCall(n.Left, &n.Ninit)
+			} else {
+				n.Left = walkexpr(n.Left, &n.Ninit)
+			}
 		}
 
 	case OFOR, OFORUNTIL:
@@ -3857,6 +3861,14 @@
 //		builtin(a1, a2, a3)
 //	}(x, y, z)
 // for print, println, and delete.
+//
+// Rewrite
+//	go f(x, y, uintptr(unsafe.Pointer(z)))
+// into
+//	go func(a1, a2, a3) {
+//		builtin(a1, a2, uintptr(a3))
+//	}(x, y, unsafe.Pointer(z))
+// for function contains unsafe-uintptr arguments.
 
 var wrapCall_prgen int
 
@@ -3868,9 +3880,17 @@
 		init.AppendNodes(&n.Ninit)
 	}
 
+	isBuiltinCall := n.Op != OCALLFUNC && n.Op != OCALLMETH && n.Op != OCALLINTER
+	// origArgs keeps track of what argument is uintptr-unsafe/unsafe-uintptr conversion.
+	origArgs := make([]*Node, n.List.Len())
 	t := nod(OTFUNC, nil, nil)
 	for i, arg := range n.List.Slice() {
 		s := lookupN("a", i)
+		if !isBuiltinCall && arg.Op == OCONVNOP && arg.Type.Etype == TUINTPTR && arg.Left.Type.Etype == TUNSAFEPTR {
+			origArgs[i] = arg
+			arg = arg.Left
+			n.List.SetIndex(i, arg)
+		}
 		t.List.Append(symfield(s, arg.Type))
 	}
 
@@ -3878,10 +3898,22 @@
 	sym := lookupN("wrap·", wrapCall_prgen)
 	fn := dclfunc(sym, t)
 
-	a := nod(n.Op, nil, nil)
-	a.List.Set(paramNnames(t.Type))
-	a = typecheck(a, ctxStmt)
-	fn.Nbody.Set1(a)
+	args := paramNnames(t.Type)
+	for i, origArg := range origArgs {
+		if origArg == nil {
+			continue
+		}
+		arg := nod(origArg.Op, args[i], nil)
+		arg.Type = origArg.Type
+		args[i] = arg
+	}
+	call := nod(n.Op, nil, nil)
+	if !isBuiltinCall {
+		call.Op = OCALL
+		call.Left = n.Left
+	}
+	call.List.Set(args)
+	fn.Nbody.Set1(call)
 
 	funcbody()
 
@@ -3889,12 +3921,12 @@
 	typecheckslice(fn.Nbody.Slice(), ctxStmt)
 	xtop = append(xtop, fn)
 
-	a = nod(OCALL, nil, nil)
-	a.Left = fn.Func.Nname
-	a.List.Set(n.List.Slice())
-	a = typecheck(a, ctxStmt)
-	a = walkexpr(a, init)
-	return a
+	call = nod(OCALL, nil, nil)
+	call.Left = fn.Func.Nname
+	call.List.Set(n.List.Slice())
+	call = typecheck(call, ctxStmt)
+	call = walkexpr(call, init)
+	return call
 }
 
 // substArgTypes substitutes the given list of types for
diff --git a/test/fixedbugs/issue24491.go b/test/fixedbugs/issue24491.go
new file mode 100644
index 0000000..4703368
--- /dev/null
+++ b/test/fixedbugs/issue24491.go
@@ -0,0 +1,45 @@
+// run
+
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This test makes sure unsafe-uintptr arguments are handled correctly.
+
+package main
+
+import (
+	"runtime"
+	"unsafe"
+)
+
+var done = make(chan bool, 1)
+
+func setup() unsafe.Pointer {
+	s := "ok"
+	runtime.SetFinalizer(&s, func(p *string) { *p = "FAIL" })
+	return unsafe.Pointer(&s)
+}
+
+//go:noinline
+//go:uintptrescapes
+func test(s string, p uintptr) {
+	runtime.GC()
+	if *(*string)(unsafe.Pointer(p)) != "ok" {
+		panic(s + " return unexpected result")
+	}
+	done <- true
+}
+
+func main() {
+	test("normal", uintptr(setup()))
+	<-done
+
+	go test("go", uintptr(setup()))
+	<-done
+
+	func() {
+		defer test("defer", uintptr(setup()))
+	}()
+	<-done
+}