|  | // Copyright 2021 The Go Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style | 
|  | // license that can be found in the LICENSE file. | 
|  |  | 
|  | package vta | 
|  |  | 
|  | import ( | 
|  | "fmt" | 
|  | "go/types" | 
|  | "reflect" | 
|  | "sort" | 
|  | "strings" | 
|  | "testing" | 
|  |  | 
|  | "golang.org/x/tools/go/callgraph/cha" | 
|  | "golang.org/x/tools/go/ssa" | 
|  | "golang.org/x/tools/go/ssa/ssautil" | 
|  | ) | 
|  |  | 
|  | func TestNodeInterface(t *testing.T) { | 
|  | // Since ssa package does not allow explicit creation of ssa | 
|  | // values, we use the values from the program testdata/src/simple.go: | 
|  | //   - basic type int | 
|  | //   - struct X with two int fields a and b | 
|  | //   - global variable "gl" | 
|  | //   - "main" function and its | 
|  | //   - first register instruction t0 := *gl | 
|  | prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0)) | 
|  | if err != nil { | 
|  | t.Fatalf("couldn't load testdata/src/simple.go program: %v", err) | 
|  | } | 
|  |  | 
|  | pkg := prog.AllPackages()[0] | 
|  | main := pkg.Func("main") | 
|  | reg := firstRegInstr(main) // t0 := *gl | 
|  | X := pkg.Type("X").Type() | 
|  | gl := pkg.Var("gl") | 
|  | glPtrType, ok := gl.Type().(*types.Pointer) | 
|  | if !ok { | 
|  | t.Fatalf("could not cast gl variable to pointer type") | 
|  | } | 
|  | bint := glPtrType.Elem() | 
|  |  | 
|  | pint := types.NewPointer(bint) | 
|  | i := types.NewInterface(nil, nil) | 
|  |  | 
|  | voidFunc := main.Signature.Underlying() | 
|  |  | 
|  | for _, test := range []struct { | 
|  | n node | 
|  | s string | 
|  | t types.Type | 
|  | }{ | 
|  | {constant{typ: bint}, "Constant(int)", bint}, | 
|  | {pointer{typ: pint}, "Pointer(*int)", pint}, | 
|  | {mapKey{typ: bint}, "MapKey(int)", bint}, | 
|  | {mapValue{typ: pint}, "MapValue(*int)", pint}, | 
|  | {sliceElem{typ: bint}, "Slice([]int)", bint}, | 
|  | {channelElem{typ: pint}, "Channel(chan *int)", pint}, | 
|  | {field{StructType: X, index: 0}, "Field(testdata.X:a)", bint}, | 
|  | {field{StructType: X, index: 1}, "Field(testdata.X:b)", bint}, | 
|  | {global{val: gl}, "Global(gl)", gl.Type()}, | 
|  | {local{val: reg}, "Local(t0)", bint}, | 
|  | {indexedLocal{val: reg, typ: X, index: 0}, "Local(t0[0])", X}, | 
|  | {function{f: main}, "Function(main)", voidFunc}, | 
|  | {nestedPtrInterface{typ: i}, "PtrInterface(interface{})", i}, | 
|  | {nestedPtrFunction{typ: voidFunc}, "PtrFunction(func())", voidFunc}, | 
|  | {panicArg{}, "Panic", nil}, | 
|  | {recoverReturn{}, "Recover", nil}, | 
|  | } { | 
|  | if test.s != test.n.String() { | 
|  | t.Errorf("want %s; got %s", test.s, test.n.String()) | 
|  | } | 
|  | if test.t != test.n.Type() { | 
|  | t.Errorf("want %s; got %s", test.t, test.n.Type()) | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | func TestVtaGraph(t *testing.T) { | 
|  | // Get the basic type int from a real program. | 
|  | prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0)) | 
|  | if err != nil { | 
|  | t.Fatalf("couldn't load testdata/src/simple.go program: %v", err) | 
|  | } | 
|  |  | 
|  | glPtrType, ok := prog.AllPackages()[0].Var("gl").Type().(*types.Pointer) | 
|  | if !ok { | 
|  | t.Fatalf("could not cast gl variable to pointer type") | 
|  | } | 
|  | bint := glPtrType.Elem() | 
|  |  | 
|  | n1 := constant{typ: bint} | 
|  | n2 := pointer{typ: types.NewPointer(bint)} | 
|  | n3 := mapKey{typ: types.NewMap(bint, bint)} | 
|  | n4 := mapValue{typ: types.NewMap(bint, bint)} | 
|  |  | 
|  | // Create graph | 
|  | //   n1   n2 | 
|  | //    \  / / | 
|  | //     n3 / | 
|  | //     | / | 
|  | //     n4 | 
|  | g := make(vtaGraph) | 
|  | g.addEdge(n1, n3) | 
|  | g.addEdge(n2, n3) | 
|  | g.addEdge(n3, n4) | 
|  | g.addEdge(n2, n4) | 
|  | // for checking duplicates | 
|  | g.addEdge(n1, n3) | 
|  |  | 
|  | want := vtaGraph{ | 
|  | n1: map[node]bool{n3: true}, | 
|  | n2: map[node]bool{n3: true, n4: true}, | 
|  | n3: map[node]bool{n4: true}, | 
|  | } | 
|  |  | 
|  | if !reflect.DeepEqual(want, g) { | 
|  | t.Errorf("want %v; got %v", want, g) | 
|  | } | 
|  |  | 
|  | for _, test := range []struct { | 
|  | n node | 
|  | l int | 
|  | }{ | 
|  | {n1, 1}, | 
|  | {n2, 2}, | 
|  | {n3, 1}, | 
|  | {n4, 0}, | 
|  | } { | 
|  | if sl := len(g.successors(test.n)); sl != test.l { | 
|  | t.Errorf("want %d successors; got %d", test.l, sl) | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | // vtaGraphStr stringifies vtaGraph into a list of strings | 
|  | // where each string represents an edge set of the format | 
|  | // node -> succ_1, ..., succ_n. succ_1, ..., succ_n are | 
|  | // sorted in alphabetical order. | 
|  | func vtaGraphStr(g vtaGraph) []string { | 
|  | var vgs []string | 
|  | for n, succ := range g { | 
|  | var succStr []string | 
|  | for s := range succ { | 
|  | succStr = append(succStr, s.String()) | 
|  | } | 
|  | sort.Strings(succStr) | 
|  | entry := fmt.Sprintf("%v -> %v", n.String(), strings.Join(succStr, ", ")) | 
|  | vgs = append(vgs, entry) | 
|  | } | 
|  | return vgs | 
|  | } | 
|  |  | 
|  | // setdiff returns the set difference of `X-Y` or {s | s ∈ X, s ∉ Y }. | 
|  | func setdiff(X, Y []string) []string { | 
|  | y := make(map[string]bool) | 
|  | var delta []string | 
|  | for _, s := range Y { | 
|  | y[s] = true | 
|  | } | 
|  |  | 
|  | for _, s := range X { | 
|  | if _, ok := y[s]; !ok { | 
|  | delta = append(delta, s) | 
|  | } | 
|  | } | 
|  | sort.Strings(delta) | 
|  | return delta | 
|  | } | 
|  |  | 
|  | func TestVTAGraphConstruction(t *testing.T) { | 
|  | for _, file := range []string{ | 
|  | "testdata/src/store.go", | 
|  | "testdata/src/phi.go", | 
|  | "testdata/src/type_conversions.go", | 
|  | "testdata/src/type_assertions.go", | 
|  | "testdata/src/fields.go", | 
|  | "testdata/src/node_uniqueness.go", | 
|  | "testdata/src/store_load_alias.go", | 
|  | "testdata/src/phi_alias.go", | 
|  | "testdata/src/channels.go", | 
|  | "testdata/src/select.go", | 
|  | "testdata/src/stores_arrays.go", | 
|  | "testdata/src/maps.go", | 
|  | "testdata/src/ranges.go", | 
|  | "testdata/src/closures.go", | 
|  | "testdata/src/function_alias.go", | 
|  | "testdata/src/static_calls.go", | 
|  | "testdata/src/dynamic_calls.go", | 
|  | "testdata/src/returns.go", | 
|  | "testdata/src/panic.go", | 
|  | } { | 
|  | t.Run(file, func(t *testing.T) { | 
|  | prog, want, err := testProg(file, ssa.BuilderMode(0)) | 
|  | if err != nil { | 
|  | t.Fatalf("couldn't load test file '%s': %s", file, err) | 
|  | } | 
|  | if len(want) == 0 { | 
|  | t.Fatalf("couldn't find want in `%s`", file) | 
|  | } | 
|  |  | 
|  | g, _ := typePropGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog)) | 
|  | got := vtaGraphStr(g) | 
|  | if diff := setdiff(want, got); len(diff) > 0 { | 
|  | t.Errorf("`%s`: want superset of %v;\n got %v\ndiff: %v", file, want, got, diff) | 
|  | } | 
|  | }) | 
|  | } | 
|  | } |