vulncheck/internal/gosym: support inline tree

Provide a way to get the inline tree from a Func.
The inline tree describes all functions that were inlined into
the Func.
Using it, vulncheck will be able to find the names of all inlined
functions in a binary.

Change-Id: I90c46f8b746009c1fed69ef961f6d4f56b8cbbcc
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/398296
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/vulncheck/internal/gosym/pclntab.go b/vulncheck/internal/gosym/pclntab.go
index 2ceea3d..9ef0f71 100644
--- a/vulncheck/internal/gosym/pclntab.go
+++ b/vulncheck/internal/gosym/pclntab.go
@@ -11,6 +11,7 @@
 import (
 	"bytes"
 	"encoding/binary"
+	"io"
 	"sort"
 	"sync"
 )
@@ -144,6 +145,64 @@
 	return pc - oldQuantum
 }
 
+// InlineTree returns the inline tree for Func f as a sequence of InlinedCalls.
+// goFuncValue is the value of the "go.func.*" symbol.
+// baseAddr is the address of the memory region (ELF Prog) containing goFuncValue.
+// progReader is a ReaderAt positioned at the start of that region.
+func (t *LineTable) InlineTree(f *Func, goFuncValue, baseAddr uint64, progReader io.ReaderAt) ([]InlinedCall, error) {
+	if f.inlineTreeCount == 0 {
+		return nil, nil
+	}
+	if f.inlineTreeOffset == ^uint32(0) {
+		return nil, nil
+	}
+	var offset int64
+	if t.version >= ver118 {
+		offset = int64(goFuncValue - baseAddr + uint64(f.inlineTreeOffset))
+	} else {
+		offset = int64(uint64(f.inlineTreeOffset) - baseAddr)
+	}
+
+	r := io.NewSectionReader(progReader, offset, 1<<32) // pick a size larger than we need
+	var ics []InlinedCall
+	for i := 0; i < f.inlineTreeCount; i++ {
+		var ric rawInlinedCall
+		if err := binary.Read(r, t.binary, &ric); err != nil {
+			return nil, err
+		}
+		ics = append(ics, InlinedCall{
+			Parent:   ric.Parent,
+			FuncID:   ric.FuncID,
+			File:     ric.File,
+			Line:     ric.Line,
+			Name:     t.funcName(uint32(ric.Func_)),
+			ParentPC: ric.ParentPC,
+		})
+	}
+	return ics, nil
+}
+
+// An InlinedCall describes a call to an inlined function.
+type InlinedCall struct {
+	Parent   int16  // index of parent in the inltree, or < 0
+	FuncID   uint8  // type of the called function
+	File     int32  // perCU file index for inlined call. See cmd/link:pcln.go
+	Line     int32  // line number of the call site
+	Name     string // name of called function
+	ParentPC int32  // position of an instruction whose source position is the call site (offset from entry)
+}
+
+// rawInlinedCall is the encoding of entries in the FUNCDATA_InlTree table.
+type rawInlinedCall struct {
+	Parent   int16 // index of parent in the inltree, or < 0
+	FuncID   uint8 // type of the called function
+	_        byte
+	File     int32 // perCU file index for inlined call. See cmd/link:pcln.go
+	Line     int32 // line number of the call site
+	Func_    int32 // offset into pclntab for name of called function
+	ParentPC int32 // position of an instruction whose source position is the call site (offset from entry)
+}
+
 // NewLineTable returns a new PC/line table
 // corresponding to the encoded data.
 // Text must be the start address of the
@@ -286,6 +345,12 @@
 	}
 }
 
+// from cmd/internal/objabi/funcdata.go
+const (
+	pcdata_InlTreeIndex = 2
+	funcdata_InlTree    = 3
+)
+
 // go12Funcs returns a slice of Funcs derived from the Go 1.2+ pcln table.
 func (t *LineTable) go12Funcs() []Func {
 	// Assume it is malformed and return nil on error.
@@ -305,6 +370,8 @@
 		info := t.funcData(uint32(i))
 		f.LineTable = t
 		f.FrameSize = int(info.deferreturn())
+		f.inlineTreeOffset = info.funcdataOffset(funcdata_InlTree)
+		f.inlineTreeCount = 1 + t.maxInlineTreeIndexValue(info)
 		syms[i] = Sym{
 			Value:  f.Entry,
 			Type:   'T',
@@ -317,6 +384,27 @@
 	return funcs
 }
 
+// maxInlineTreeIndexValue returns the maximum value of the inline tree index
+// pc-value table in info. This is the only way to determine how many
+// IndexedCalls are in an inline tree, since the data of the tree itself is not
+// delimited in any way.
+func (t *LineTable) maxInlineTreeIndexValue(info funcData) int {
+	if info.npcdata() <= pcdata_InlTreeIndex {
+		return -1
+	}
+	off := info.pcdataOffset(pcdata_InlTreeIndex)
+	p := t.pctab[off:]
+	val := int32(-1)
+	max := int32(-1)
+	var pc uint64
+	for t.step(&p, &pc, &val, pc == 0) {
+		if val > max {
+			max = val
+		}
+	}
+	return int(max)
+}
+
 // findFunc returns the funcData corresponding to the given program counter.
 func (t *LineTable) findFunc(pc uint64) funcData {
 	ft := t.funcTab()
@@ -453,7 +541,19 @@
 func (f funcData) deferreturn() uint32 { return f.field(3) }
 func (f funcData) pcfile() uint32      { return f.field(5) }
 func (f funcData) pcln() uint32        { return f.field(6) }
+func (f funcData) npcdata() uint32     { return f.field(7) }
 func (f funcData) cuOffset() uint32    { return f.field(8) }
+func (f funcData) nfuncdata() uint32   { return f.field(9) }
+
+func (f funcData) fieldOffset(n uint32) uint32 {
+	// In Go 1.18, the first field of _func changed
+	// from a uintptr entry PC to a uint32 entry offset.
+	sz0 := f.t.ptrsize
+	if f.t.version >= ver118 {
+		sz0 = 4
+	}
+	return sz0 + (n-1)*4 // subsequent fields are 4 bytes each
+}
 
 // field returns the nth field of the _func struct.
 // It panics if n == 0 or n > 9; for n == 0, call f.entryPC.
@@ -462,17 +562,38 @@
 	if n == 0 || n > 9 {
 		panic("bad funcdata field")
 	}
-	// In Go 1.18, the first field of _func changed
-	// from a uintptr entry PC to a uint32 entry offset.
-	sz0 := f.t.ptrsize
-	if f.t.version >= ver118 {
-		sz0 = 4
-	}
-	off := sz0 + (n-1)*4 // subsequent fields are 4 bytes each
+	off := f.fieldOffset(n)
 	data := f.data[off:]
 	return f.t.binary.Uint32(data)
 }
 
+func (f funcData) funcdataOffset(i uint8) uint32 {
+	if uint32(i) >= f.nfuncdata() {
+		return ^uint32(0)
+	}
+	var off uint32
+	if f.t.version >= ver118 {
+		off = f.fieldOffset(10) + // skip fixed part of _func
+			f.npcdata()*4 + // skip pcdata
+			uint32(i)*4 // index of i'th FUNCDATA
+		return f.t.binary.Uint32(f.data[off:])
+	} else {
+		off = f.fieldOffset(10) + // skip fixed part of _func
+			f.npcdata()*4
+		off += uint32(i) * 8
+		return f.t.binary.Uint32(f.data[off:])
+	}
+}
+
+func (f funcData) pcdataOffset(i uint8) uint32 {
+	if uint32(i) >= f.npcdata() {
+		return ^uint32(0)
+	}
+	off := f.fieldOffset(10) + // skip fixed part of _func
+		uint32(i)*4 // index of i'th PCDATA
+	return f.t.binary.Uint32(f.data[off:])
+}
+
 // step advances to the next pc, value pair in the encoded table.
 func (t *LineTable) step(p *[]byte, pc *uint64, val *int32, first bool) bool {
 	uvdelta := t.readvarint(p)
diff --git a/vulncheck/internal/gosym/pclntab_test.go b/vulncheck/internal/gosym/pclntab_test.go
index d4cd44b..cab8ad2 100644
--- a/vulncheck/internal/gosym/pclntab_test.go
+++ b/vulncheck/internal/gosym/pclntab_test.go
@@ -15,6 +15,9 @@
 	"runtime"
 	"strings"
 	"testing"
+
+	"github.com/google/go-cmp/cmp"
+	"github.com/google/go-cmp/cmp/cmpopts"
 )
 
 var (
@@ -35,13 +38,23 @@
 	if runtime.GOOS != "linux" && testing.Short() {
 		t.Skipf("skipping in short mode on non-Linux system %s", runtime.GOARCH)
 	}
+
 	var err error
+	var exeSuffix string
+	if runtime.GOOS == "windows" {
+		exeSuffix = ".exe"
+	}
+	goCommandPath := filepath.Join(runtime.GOROOT(), "bin", "go"+exeSuffix)
+	if _, err := os.Stat(goCommandPath); err != nil {
+		t.Fatal(err)
+	}
+
 	pclineTempDir, err = os.MkdirTemp("", "pclinetest")
 	if err != nil {
 		t.Fatal(err)
 	}
 	pclinetestBinary = filepath.Join(pclineTempDir, "pclinetest")
-	cmd := exec.Command("go", "build", "-o", pclinetestBinary)
+	cmd := exec.Command(goCommandPath, "build", "-o", pclinetestBinary)
 	cmd.Dir = "testdata"
 	cmd.Env = append(os.Environ(), "GOOS=linux")
 	cmd.Stdout = os.Stdout
@@ -224,7 +237,6 @@
 			break
 		}
 		wantLine += int(textdat[off])
-		t.Logf("off is %d %#x (max %d)", off, textdat[off], sym.End-pc)
 		file, line, fn := tab.PCToLine(pc)
 		if fn == nil {
 			t.Errorf("failed to get line of PC %#x", pc)
@@ -270,6 +282,77 @@
 	}
 }
 
+func TestInlineTree(t *testing.T) {
+	dotest(t)
+	defer endtest()
+
+	f, err := elf.Open(pclinetestBinary)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer f.Close()
+	pclndat, err := f.Section(".gopclntab").Data()
+	if err != nil {
+		t.Fatalf("reading %s gopclntab: %v", pclinetestBinary, err)
+	}
+	goFunc := lookupSymbol(f, "go.func.*")
+	if goFunc == nil {
+		t.Fatal("couldn't find go.func.*")
+	}
+	prog := progContaining(f, goFunc.Value)
+	if prog == nil {
+		t.Fatal("couldn't find go.func.* Prog")
+	}
+	pcln := NewLineTable(pclndat, f.Section(".text").Addr)
+	s := f.Section(".gosymtab")
+	if s == nil {
+		t.Fatal("no .gosymtab section")
+	}
+	d, err := s.Data()
+	if err != nil {
+		t.Fatal(err)
+	}
+	tab, err := NewTable(d, pcln)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	fun := tab.LookupFunc("main.main")
+	got, err := pcln.InlineTree(fun, goFunc.Value, prog.Vaddr, prog.ReaderAt)
+	if err != nil {
+		t.Fatal(err)
+	}
+	want := []InlinedCall{
+		{Parent: -1, FuncID: 0, File: 1, Name: "main.inline1"},
+		{Parent: 0, FuncID: 0, File: 1, Name: "main.inline2"},
+	}
+	if !cmp.Equal(got, want, cmpopts.IgnoreFields(InlinedCall{}, "Line", "ParentPC")) {
+		t.Errorf("got\n%+v\nwant\n%+v", got, want)
+	}
+}
+
+func progContaining(f *elf.File, addr uint64) *elf.Prog {
+	for _, p := range f.Progs {
+		if addr >= p.Vaddr && addr < p.Vaddr+p.Filesz {
+			return p
+		}
+	}
+	return nil
+}
+
+func lookupSymbol(f *elf.File, name string) *elf.Symbol {
+	syms, err := f.Symbols()
+	if err != nil {
+		return nil
+	}
+	for _, s := range syms {
+		if s.Name == name {
+			return &s
+		}
+	}
+	return nil
+}
+
 // read115Executable returns a hello world executable compiled by Go 1.15.
 //
 // The file was compiled in /tmp/hello.go:
diff --git a/vulncheck/internal/gosym/symtab.go b/vulncheck/internal/gosym/symtab.go
index 72490dc..19bd90b 100644
--- a/vulncheck/internal/gosym/symtab.go
+++ b/vulncheck/internal/gosym/symtab.go
@@ -122,12 +122,14 @@
 type Func struct {
 	Entry uint64
 	*Sym
-	End       uint64
-	Params    []*Sym // nil for Go 1.3 and later binaries
-	Locals    []*Sym // nil for Go 1.3 and later binaries
-	FrameSize int
-	LineTable *LineTable
-	Obj       *Obj
+	End              uint64
+	Params           []*Sym // nil for Go 1.3 and later binaries
+	Locals           []*Sym // nil for Go 1.3 and later binaries
+	FrameSize        int
+	LineTable        *LineTable
+	Obj              *Obj
+	inlineTreeOffset uint32 // offset from go.func.* symbol
+	inlineTreeCount  int    // number of entries in inline tree
 }
 
 // An Obj represents a collection of functions in a symbol table.
@@ -515,6 +517,7 @@
 
 	if t.go12line != nil && nf == 0 {
 		t.Funcs = t.go12line.go12Funcs()
+
 	}
 	if obj != nil {
 		obj.Funcs = t.Funcs[lastf:]
diff --git a/vulncheck/internal/gosym/testdata/main.go b/vulncheck/internal/gosym/testdata/main.go
index b770218..902ab76 100644
--- a/vulncheck/internal/gosym/testdata/main.go
+++ b/vulncheck/internal/gosym/testdata/main.go
@@ -7,4 +7,13 @@
 	// Prevent GC of our test symbols
 	linefrompc()
 	pcfromline()
+	inline1()
+}
+
+func inline1() {
+	inline2()
+}
+
+func inline2() {
+	println(1)
 }