blob: 77d0c0496ed938da8d733a95cf17859a2a592988 [file] [log] [blame]
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package godoc
import (
"bytes"
"context"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io"
"path/filepath"
"reflect"
"runtime"
"sort"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
var packageToTest string = filepath.Join(runtime.GOROOT(), "src", "net", "http")
func TestEncodeDecodePackage(t *testing.T) {
p, err := packageForDir(packageToTest, true)
if err != nil {
t.Fatal(err)
}
var want, got bytes.Buffer
printPackage(&want, p)
data, err := p.Encode(context.Background())
if err != nil {
t.Fatal(err)
}
p2, err := DecodePackage(data)
if err != nil {
t.Fatal(err)
}
printPackage(&got, p2)
// Diff the textual output of printPackage, because cmp.Diff takes too long
// on the Packages themselves.
if diff := cmp.Diff(want.String(), got.String()); diff != "" {
t.Errorf("package differs after decoding (-want, +got):\n%s", diff)
}
}
func TestObjectIdentity(t *testing.T) {
// Check that encoding and decoding preserves object identity.
ctx := context.Background()
const file = `
package p
var a int
func main() { a = 1 }
`
compareObjs := func(f *ast.File) {
t.Helper()
// We know (from hand-inspecting the output of ast.Fprintf) that these two
// objects are identical in the above program.
o1 := f.Decls[0].(*ast.GenDecl).Specs[0].(*ast.ValueSpec).Names[0].Obj
o2 := f.Decls[1].(*ast.FuncDecl).Body.List[0].(*ast.AssignStmt).Lhs[0].(*ast.Ident).Obj
if o1 != o2 {
t.Fatal("objects not identical")
}
}
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "test.go", file, parser.ParseComments)
if err != nil {
t.Fatal(err)
}
compareObjs(f)
p := NewPackage(fset, nil)
p.AddFile(f, false)
data, err := p.Encode(ctx)
if err != nil {
t.Fatal(err)
}
p, err = DecodePackage(data)
if err != nil {
t.Fatal(err)
}
compareObjs(p.Files[0].AST)
}
func packageForDir(dir string, removeNodes bool) (*Package, error) {
fset := token.NewFileSet()
pkgs, err := parser.ParseDir(fset, dir, nil, parser.ParseComments)
if err != nil {
return nil, err
}
p := NewPackage(fset, nil)
for _, pkg := range pkgs {
for _, f := range pkg.Files {
p.AddFile(f, removeNodes)
}
}
return p, nil
}
// Compare the time to decode AST files with and without
// removing parts of the AST not relevant to documentation.
//
// Run on a cloudtop 9/29/2020:
// - data size is 3.5x smaller
// - decode time is 4.5x faster
func BenchmarkRemovingAST(b *testing.B) {
for _, removeNodes := range []bool{false, true} {
b.Run(fmt.Sprintf("removeNodes=%t", removeNodes), func(b *testing.B) {
p, err := packageForDir(packageToTest, removeNodes)
if err != nil {
b.Fatal(err)
}
data, err := p.Encode(context.Background())
if err != nil {
b.Fatal(err)
}
b.Logf("len(data) = %d", len(data))
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := DecodePackage(data); err != nil {
b.Fatal(err)
}
}
})
}
}
// printPackage outputs a human-readable form of p to w, deterministically. (The
// ast.Fprint function does not print ASTs deterministically: it is subject to
// random-order map iteration.) The output is designed to be diffed.
func printPackage(w io.Writer, p *Package) error {
if err := printFileSet(w, p.Fset); err != nil {
return err
}
var mpps []string
for k := range p.ModulePackagePaths {
mpps = append(mpps, k)
}
sort.Strings(mpps)
if _, err := fmt.Fprintf(w, "ModulePackagePaths: %v\n", mpps); err != nil {
return err
}
for _, pf := range p.Files {
if _, err := fmt.Fprintf(w, "---- %s\n", pf.Name); err != nil {
return err
}
if err := printNode(w, pf.AST); err != nil {
return err
}
}
return nil
}
func printNode(w io.Writer, root ast.Node) error {
var err error
seen := map[interface{}]int{}
pr := func(format string, args ...interface{}) {
if err == nil {
_, err = fmt.Fprintf(w, format, args...)
}
}
indent := func(d int) {
for i := 0; i < d; i++ {
pr(" ")
}
}
var prValue func(interface{}, int)
prValue = func(x interface{}, depth int) {
indent(depth)
if x == nil || reflect.ValueOf(x).IsNil() {
pr("nil\n")
return
}
ts := strings.TrimPrefix(fmt.Sprintf("%T", x), "*ast.")
if idx, ok := seen[x]; ok {
pr("%s@%d\n", ts, idx)
return
}
idx := len(seen)
seen[x] = idx
pr("%s#%d", ts, idx)
if obj, ok := x.(*ast.Object); ok {
pr(" %s %s %v\n", obj.Name, obj.Kind, obj.Data)
prValue(obj.Decl, depth+1)
return
}
n, ok := x.(ast.Node)
if !ok {
pr(" %v\n", x)
return
}
pr(" %d-%d", n.Pos(), n.End())
switch n := n.(type) {
case *ast.Ident:
pr(" %q\n", n.Name)
if n.Obj != nil {
prValue(n.Obj, depth+1)
}
case *ast.BasicLit:
pr(" %s %s %d\n", n.Value, n.Kind, n.ValuePos)
case *ast.UnaryExpr:
pr(" %s\n", n.Op)
case *ast.BinaryExpr:
pr(" %s\n", n.Op)
case *ast.Comment:
pr(" %q\n", n.Text)
case *ast.File:
// Doc, Name and Decls are walked, but not Scope or Unresolved.
if n.Scope != nil {
pr(" Scope.Outer: %p\n", n.Scope.Outer)
var keys []string
for k := range n.Scope.Objects {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
pr(" key %q\n", k)
prValue(n.Scope.Objects[k], depth+1)
}
}
indent(depth)
pr("unresolved:\n")
for _, id := range n.Unresolved {
prValue(id, depth+1)
}
default:
pr("\n")
}
ast.Inspect(n, func(m ast.Node) bool {
if m == n {
return true
}
if m != nil {
prValue(m, depth+1)
}
return false
})
}
prValue(root, 0)
return err
}
func printFileSet(w io.Writer, fset *token.FileSet) error {
var err error
fset.Iterate(func(f *token.File) bool {
_, err = fmt.Fprintf(w, "%s %d %d %d\n", f.Name(), f.Base(), f.Size(), f.LineCount())
return err == nil
})
return err
}