cmd/internal/obj/riscv: fix trampoline calls from large functions

On riscv64, the JAL instruction is only capable of reaching +/-1MB. In the case where
a single function and its trampolines exceeds this size, it is possible that the JAL
is unable to reach the trampoline, which is laid down after the function text. In the
case of large functions, switch back to using a AUIPC+JALR pairs rather than using
trampolines.

Fixes #48791

Change-Id: I119cf3bc20ce4933a9b7ab41a8e514437c6addb9
Reviewed-on: https://go-review.googlesource.com/c/go/+/356250
Trust: Joel Sing <joel@sing.id.au>
Reviewed-by: Cherry Mui <cherryyz@google.com>
diff --git a/src/cmd/internal/obj/riscv/asm_test.go b/src/cmd/internal/obj/riscv/asm_test.go
index 684c6b6..b23142d 100644
--- a/src/cmd/internal/obj/riscv/asm_test.go
+++ b/src/cmd/internal/obj/riscv/asm_test.go
@@ -16,32 +16,30 @@
 	"testing"
 )
 
-// TestLarge generates a very large file to verify that large
-// program builds successfully, in particular, too-far
-// conditional branches are fixed.
-func TestLarge(t *testing.T) {
+// TestLargeBranch generates a large function with a very far conditional
+// branch, in order to ensure that it assembles successfully.
+func TestLargeBranch(t *testing.T) {
 	if testing.Short() {
-		t.Skip("Skip in short mode")
+		t.Skip("Skipping test in short mode")
 	}
 	testenv.MustHaveGoBuild(t)
 
-	dir, err := ioutil.TempDir("", "testlarge")
+	dir, err := ioutil.TempDir("", "testlargebranch")
 	if err != nil {
-		t.Fatalf("could not create directory: %v", err)
+		t.Fatalf("Could not create directory: %v", err)
 	}
 	defer os.RemoveAll(dir)
 
 	// Generate a very large function.
 	buf := bytes.NewBuffer(make([]byte, 0, 7000000))
-	gen(buf)
+	genLargeBranch(buf)
 
 	tmpfile := filepath.Join(dir, "x.s")
-	err = ioutil.WriteFile(tmpfile, buf.Bytes(), 0644)
-	if err != nil {
-		t.Fatalf("can't write output: %v\n", err)
+	if err := ioutil.WriteFile(tmpfile, buf.Bytes(), 0644); err != nil {
+		t.Fatalf("Failed to write file: %v", err)
 	}
 
-	// Build generated file.
+	// Assemble generated file.
 	cmd := exec.Command(testenv.GoToolPath(t), "tool", "asm", "-o", filepath.Join(dir, "x.o"), tmpfile)
 	cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux")
 	out, err := cmd.CombinedOutput()
@@ -50,8 +48,7 @@
 	}
 }
 
-// gen generates a very large program, with a very far conditional branch.
-func gen(buf *bytes.Buffer) {
+func genLargeBranch(buf *bytes.Buffer) {
 	fmt.Fprintln(buf, "TEXT f(SB),0,$0-0")
 	fmt.Fprintln(buf, "BEQ X0, X0, label")
 	for i := 0; i < 1<<19; i++ {
@@ -61,6 +58,76 @@
 	fmt.Fprintln(buf, "ADD $0, X0, X0")
 }
 
+// TestLargeCall generates a large function (>1MB of text) with a call to
+// a following function, in order to ensure that it assembles and links
+// correctly.
+func TestLargeCall(t *testing.T) {
+	if testing.Short() {
+		t.Skip("Skipping test in short mode")
+	}
+	testenv.MustHaveGoBuild(t)
+
+	dir, err := ioutil.TempDir("", "testlargecall")
+	if err != nil {
+		t.Fatalf("could not create directory: %v", err)
+	}
+	defer os.RemoveAll(dir)
+
+	if err := ioutil.WriteFile(filepath.Join(dir, "go.mod"), []byte("module largecall"), 0644); err != nil {
+		t.Fatalf("Failed to write file: %v\n", err)
+	}
+	main := `package main
+func main() {
+        x()
+}
+
+func x()
+func y()
+`
+	if err := ioutil.WriteFile(filepath.Join(dir, "x.go"), []byte(main), 0644); err != nil {
+		t.Fatalf("failed to write main: %v\n", err)
+	}
+
+	// Generate a very large function with call.
+	buf := bytes.NewBuffer(make([]byte, 0, 7000000))
+	genLargeCall(buf)
+
+	if err := ioutil.WriteFile(filepath.Join(dir, "x.s"), buf.Bytes(), 0644); err != nil {
+		t.Fatalf("Failed to write file: %v\n", err)
+	}
+
+	// Build generated files.
+	cmd := exec.Command(testenv.GoToolPath(t), "build", "-ldflags=-linkmode=internal")
+	cmd.Dir = dir
+	cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux")
+	out, err := cmd.CombinedOutput()
+	if err != nil {
+		t.Errorf("Build failed: %v, output: %s", err, out)
+	}
+
+	if runtime.GOARCH == "riscv64" && testenv.HasCGO() {
+		cmd := exec.Command(testenv.GoToolPath(t), "build", "-ldflags=-linkmode=external")
+		cmd.Dir = dir
+		cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux")
+		out, err := cmd.CombinedOutput()
+		if err != nil {
+			t.Errorf("Build failed: %v, output: %s", err, out)
+		}
+	}
+}
+
+func genLargeCall(buf *bytes.Buffer) {
+	fmt.Fprintln(buf, "TEXT ·x(SB),0,$0-0")
+	fmt.Fprintln(buf, "CALL ·y(SB)")
+	for i := 0; i < 1<<19; i++ {
+		fmt.Fprintln(buf, "ADD $0, X0, X0")
+	}
+	fmt.Fprintln(buf, "RET")
+	fmt.Fprintln(buf, "TEXT ·y(SB),0,$0-0")
+	fmt.Fprintln(buf, "ADD $0, X0, X0")
+	fmt.Fprintln(buf, "RET")
+}
+
 // Issue 20348.
 func TestNoRet(t *testing.T) {
 	dir, err := ioutil.TempDir("", "testnoret")
diff --git a/src/cmd/internal/obj/riscv/obj.go b/src/cmd/internal/obj/riscv/obj.go
index b346b13..d98806e 100644
--- a/src/cmd/internal/obj/riscv/obj.go
+++ b/src/cmd/internal/obj/riscv/obj.go
@@ -280,14 +280,15 @@
 }
 
 // setPCs sets the Pc field in all instructions reachable from p.
-// It uses pc as the initial value.
-func setPCs(p *obj.Prog, pc int64) {
+// It uses pc as the initial value and returns the next available pc.
+func setPCs(p *obj.Prog, pc int64) int64 {
 	for ; p != nil; p = p.Link {
 		p.Pc = pc
 		for _, ins := range instructionsForProg(p) {
 			pc += int64(ins.length())
 		}
 	}
+	return pc
 }
 
 // stackOffset updates Addr offsets based on the current stack size.
@@ -582,17 +583,26 @@
 		}
 	}
 
+	var callCount int
 	for p := cursym.Func().Text; p != nil; p = p.Link {
 		markRelocs(p)
+		if p.Mark&NEED_CALL_RELOC == NEED_CALL_RELOC {
+			callCount++
+		}
 	}
+	const callTrampSize = 8 // 2 machine instructions.
+	maxTrampSize := int64(callCount * callTrampSize)
 
 	// Compute instruction addresses.  Once we do that, we need to check for
 	// overextended jumps and branches.  Within each iteration, Pc differences
 	// are always lower bounds (since the program gets monotonically longer,
 	// a fixed point will be reached).  No attempt to handle functions > 2GiB.
 	for {
-		rescan := false
-		setPCs(cursym.Func().Text, 0)
+		big, rescan := false, false
+		maxPC := setPCs(cursym.Func().Text, 0)
+		if maxPC+maxTrampSize > (1 << 20) {
+			big = true
+		}
 
 		for p := cursym.Func().Text; p != nil; p = p.Link {
 			switch p.As {
@@ -619,6 +629,24 @@
 			case AJAL:
 				// Linker will handle the intersymbol case and trampolines.
 				if p.To.Target() == nil {
+					if !big {
+						break
+					}
+					// This function is going to be too large for JALs
+					// to reach trampolines. Replace with AUIPC+JALR.
+					jmp := obj.Appendp(p, newprog)
+					jmp.As = AJALR
+					jmp.From = p.From
+					jmp.To = obj.Addr{Type: obj.TYPE_REG, Reg: REG_TMP}
+
+					p.As = AAUIPC
+					p.Mark = (p.Mark &^ NEED_CALL_RELOC) | NEED_PCREL_ITYPE_RELOC
+					p.SetFrom3(obj.Addr{Type: obj.TYPE_CONST, Offset: p.To.Offset, Sym: p.To.Sym})
+					p.From = obj.Addr{Type: obj.TYPE_CONST, Offset: 0}
+					p.Reg = obj.REG_NONE
+					p.To = obj.Addr{Type: obj.TYPE_REG, Reg: REG_TMP}
+
+					rescan = true
 					break
 				}
 				offset := p.To.Target().Pc - p.Pc