internal/lsp: handle completion after defer, go statements

This change adds support for completion of incomplete selectors after a
defer or go statement. We modify the AST before type-checking it with a
fake *ast.CallExpr.

Updates golang/go#29313

Change-Id: Ic9e8c9c49aa569cd7874791692c70a28c3146251
Reviewed-on: https://go-review.googlesource.com/c/tools/+/172974
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/lsp/cache/check.go b/internal/lsp/cache/check.go
index 1b5e5ca..f7b1b3e 100644
--- a/internal/lsp/cache/check.go
+++ b/internal/lsp/cache/check.go
@@ -7,11 +7,6 @@
 	"go/parser"
 	"go/scanner"
 	"go/types"
-	"io/ioutil"
-	"os"
-	"path/filepath"
-	"strings"
-	"sync"
 
 	"golang.org/x/tools/go/analysis"
 	"golang.org/x/tools/go/packages"
@@ -41,9 +36,9 @@
 		return nil, fmt.Errorf("no metadata found for %v", f.filename)
 	}
 	imp := &importer{
-		view:     v,
-		circular: make(map[string]struct{}),
-		ctx:      ctx,
+		view: v,
+		seen: make(map[string]struct{}),
+		ctx:  ctx,
 	}
 	// Start prefetching direct imports.
 	for importPath := range f.meta.children {
@@ -178,15 +173,15 @@
 type importer struct {
 	view *View
 
-	// circular maintains the set of previously imported packages.
+	// seen maintains the set of previously imported packages.
 	// If we have seen a package that is already in this map, we have a circular import.
-	circular map[string]struct{}
+	seen map[string]struct{}
 
 	ctx context.Context
 }
 
 func (imp *importer) Import(pkgPath string) (*types.Package, error) {
-	if _, ok := imp.circular[pkgPath]; ok {
+	if _, ok := imp.seen[pkgPath]; ok {
 		return nil, fmt.Errorf("circular import detected")
 	}
 	imp.view.pcache.mu.Lock()
@@ -245,22 +240,25 @@
 	appendError := func(err error) {
 		imp.view.appendPkgError(pkg, err)
 	}
-	files, errs := imp.view.parseFiles(meta.files)
+	files, errs := imp.parseFiles(meta.files)
 	for _, err := range errs {
 		appendError(err)
 	}
 	pkg.syntax = files
 
 	// Handle circular imports by copying previously seen imports.
-	newCircular := copySet(imp.circular)
-	newCircular[pkgPath] = struct{}{}
+	seen := make(map[string]struct{})
+	for k, v := range imp.seen {
+		seen[k] = v
+	}
+	seen[pkgPath] = struct{}{}
 
 	cfg := &types.Config{
 		Error: appendError,
 		Importer: &importer{
-			view:     imp.view,
-			circular: newCircular,
-			ctx:      imp.ctx,
+			view: imp.view,
+			seen: seen,
+			ctx:  imp.ctx,
 		},
 	}
 	check := types.NewChecker(cfg, imp.view.Config.Fset, pkg.types, pkg.typesInfo)
@@ -284,14 +282,6 @@
 	return pkg, nil
 }
 
-func copySet(m map[string]struct{}) map[string]struct{} {
-	result := make(map[string]struct{})
-	for k, v := range m {
-		result[k] = v
-	}
-	return result
-}
-
 func (v *View) appendPkgError(pkg *Package, err error) {
 	if err == nil {
 		return
@@ -322,115 +312,3 @@
 	}
 	pkg.errors = append(pkg.errors, errs...)
 }
-
-// We use a counting semaphore to limit
-// the number of parallel I/O calls per process.
-var ioLimit = make(chan bool, 20)
-
-// parseFiles reads and parses the Go source files and returns the ASTs
-// of the ones that could be at least partially parsed, along with a
-// list of I/O and parse errors encountered.
-//
-// Because files are scanned in parallel, the token.Pos
-// positions of the resulting ast.Files are not ordered.
-//
-func (v *View) parseFiles(filenames []string) ([]*ast.File, []error) {
-	var wg sync.WaitGroup
-	n := len(filenames)
-	parsed := make([]*ast.File, n)
-	errors := make([]error, n)
-	for i, filename := range filenames {
-		if v.Config.Context.Err() != nil {
-			parsed[i] = nil
-			errors[i] = v.Config.Context.Err()
-			continue
-		}
-
-		// First, check if we have already cached an AST for this file.
-		f, err := v.findFile(span.FileURI(filename))
-		if err != nil {
-			parsed[i], errors[i] = nil, err
-		}
-		var fAST *ast.File
-		if f != nil {
-			fAST = f.ast
-		}
-
-		wg.Add(1)
-		go func(i int, filename string) {
-			ioLimit <- true // wait
-
-			if fAST != nil {
-				parsed[i], errors[i] = fAST, nil
-			} else {
-				// We don't have a cached AST for this file.
-				var src []byte
-				// Check for an available overlay.
-				for f, contents := range v.Config.Overlay {
-					if sameFile(f, filename) {
-						src = contents
-					}
-				}
-				var err error
-				// We don't have an overlay, so we must read the file's contents.
-				if src == nil {
-					src, err = ioutil.ReadFile(filename)
-				}
-				if err != nil {
-					parsed[i], errors[i] = nil, err
-				} else {
-					// ParseFile may return both an AST and an error.
-					parsed[i], errors[i] = v.Config.ParseFile(v.Config.Fset, filename, src)
-				}
-			}
-
-			<-ioLimit // signal
-			wg.Done()
-		}(i, filename)
-	}
-	wg.Wait()
-
-	// Eliminate nils, preserving order.
-	var o int
-	for _, f := range parsed {
-		if f != nil {
-			parsed[o] = f
-			o++
-		}
-	}
-	parsed = parsed[:o]
-
-	o = 0
-	for _, err := range errors {
-		if err != nil {
-			errors[o] = err
-			o++
-		}
-	}
-	errors = errors[:o]
-
-	return parsed, errors
-}
-
-// sameFile returns true if x and y have the same basename and denote
-// the same file.
-//
-func sameFile(x, y string) bool {
-	if x == y {
-		// It could be the case that y doesn't exist.
-		// For instance, it may be an overlay file that
-		// hasn't been written to disk. To handle that case
-		// let x == y through. (We added the exact absolute path
-		// string to the CompiledGoFiles list, so the unwritten
-		// overlay case implies x==y.)
-		return true
-	}
-	if strings.EqualFold(filepath.Base(x), filepath.Base(y)) { // (optimisation)
-		if xi, err := os.Stat(x); err == nil {
-			if yi, err := os.Stat(y); err == nil {
-				return os.SameFile(xi, yi)
-			}
-		}
-	}
-	return false
-}
diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go
new file mode 100644
index 0000000..d42b7aa
--- /dev/null
+++ b/internal/lsp/cache/parse.go
@@ -0,0 +1,269 @@
+package cache
+
+import (
+	"context"
+	"fmt"
+	"go/ast"
+	"go/parser"
+	"go/scanner"
+	"go/token"
+	"io/ioutil"
+	"os"
+	"path/filepath"
+	"strings"
+	"sync"
+
+	"golang.org/x/tools/internal/span"
+)
+
+// We use a counting semaphore to limit
+// the number of parallel I/O calls per process.
+var ioLimit = make(chan bool, 20)
+
+// parseFiles reads and parses the Go source files and returns the ASTs
+// of the ones that could be at least partially parsed, along with a
+// list of I/O and parse errors encountered.
+//
+// Because files are scanned in parallel, the token.Pos
+// positions of the resulting ast.Files are not ordered.
+//
+func (imp *importer) parseFiles(filenames []string) ([]*ast.File, []error) {
+	var wg sync.WaitGroup
+	n := len(filenames)
+	parsed := make([]*ast.File, n)
+	errors := make([]error, n)
+	for i, filename := range filenames {
+		if imp.view.Config.Context.Err() != nil {
+			parsed[i] = nil
+			errors[i] = imp.view.Config.Context.Err()
+			continue
+		}
+
+		// First, check if we have already cached an AST for this file.
+		f, err := imp.view.findFile(span.FileURI(filename))
+		if err != nil {
+			parsed[i], errors[i] = nil, err
+		}
+		var fAST *ast.File
+		if f != nil {
+			fAST = f.ast
+		}
+
+		wg.Add(1)
+		go func(i int, filename string) {
+			ioLimit <- true // wait
+
+			if fAST != nil {
+				parsed[i], errors[i] = fAST, nil
+			} else {
+				// We don't have a cached AST for this file.
+				var src []byte
+				// Check for an available overlay.
+				for f, contents := range imp.view.Config.Overlay {
+					if sameFile(f, filename) {
+						src = contents
+					}
+				}
+				var err error
+				// We don't have an overlay, so we must read the file's contents.
+				if src == nil {
+					src, err = ioutil.ReadFile(filename)
+				}
+				if err != nil {
+					parsed[i], errors[i] = nil, err
+				} else {
+					// ParseFile may return both an AST and an error.
+					parsed[i], errors[i] = imp.view.Config.ParseFile(imp.view.Config.Fset, filename, src)
+
+					// Fix any badly parsed parts of the AST.
+					if file := parsed[i]; file != nil {
+						tok := imp.view.Config.Fset.File(file.Pos())
+						imp.view.fix(imp.ctx, parsed[i], tok, src)
+					}
+				}
+			}
+
+			<-ioLimit // signal
+			wg.Done()
+		}(i, filename)
+	}
+	wg.Wait()
+
+	// Eliminate nils, preserving order.
+	var o int
+	for _, f := range parsed {
+		if f != nil {
+			parsed[o] = f
+			o++
+		}
+	}
+	parsed = parsed[:o]
+
+	o = 0
+	for _, err := range errors {
+		if err != nil {
+			errors[o] = err
+			o++
+		}
+	}
+	errors = errors[:o]
+
+	return parsed, errors
+}
+
+// sameFile returns true if x and y have the same basename and denote
+// the same file.
+//
+func sameFile(x, y string) bool {
+	if x == y {
+		// It could be the case that y doesn't exist.
+		// For instance, it may be an overlay file that
+		// hasn't been written to disk. To handle that case
+		// let x == y through. (We added the exact absolute path
+		// string to the CompiledGoFiles list, so the unwritten
+		// overlay case implies x==y.)
+		return true
+	}
+	if strings.EqualFold(filepath.Base(x), filepath.Base(y)) { // (optimisation)
+		if xi, err := os.Stat(x); err == nil {
+			if yi, err := os.Stat(y); err == nil {
+				return os.SameFile(xi, yi)
+			}
+		}
+	}
+	return false
+}
+
+// fix inspects and potentially modifies any *ast.BadStmts or *ast.BadExprs in the AST.
+
+// We attempt to modify the AST such that we can type-check it more effectively.
+func (v *View) fix(ctx context.Context, file *ast.File, tok *token.File, src []byte) {
+	var parent ast.Node
+	ast.Inspect(file, func(n ast.Node) bool {
+		if n == nil {
+			return false
+		}
+		switch n := n.(type) {
+		case *ast.BadStmt:
+			if err := v.parseDeferOrGoStmt(n, parent, tok, src); err != nil {
+				v.log.Debugf(ctx, "unable to parse defer or go from *ast.BadStmt: %v", err)
+			}
+			return false
+		default:
+			parent = n
+			return true
+		}
+	})
+}
+
+// parseDeferOrGoStmt tries to parse an *ast.BadStmt into a defer or a go statement.
+//
+// go/parser packages a statement of the form "defer x." as an *ast.BadStmt because
+// it does not include a call expression. This means that go/types skips type-checking
+// this statement entirely, and we can't use the type information when completing.
+// Here, we try to generate a fake *ast.DeferStmt or *ast.GoStmt to put into the AST,
+// instead of the *ast.BadStmt.
+func (v *View) parseDeferOrGoStmt(bad *ast.BadStmt, parent ast.Node, tok *token.File, src []byte) error {
+	// Check if we have a bad statement containing either a "go" or "defer".
+	s := &scanner.Scanner{}
+	s.Init(tok, src, nil, 0)
+
+	var pos token.Pos
+	var tkn token.Token
+	var lit string
+	for {
+		if tkn == token.EOF {
+			return fmt.Errorf("reached the end of the file")
+		}
+		if pos >= bad.From {
+			break
+		}
+		pos, tkn, lit = s.Scan()
+	}
+	var stmt ast.Stmt
+	switch lit {
+	case "defer":
+		stmt = &ast.DeferStmt{
+			Defer: pos,
+		}
+	case "go":
+		stmt = &ast.GoStmt{
+			Go: pos,
+		}
+	default:
+		return fmt.Errorf("no defer or go statement found")
+	}
+
+	// The expression after the "defer" or "go" starts at this position.
+	from, _, _ := s.Scan()
+	var to, curr token.Pos
+FindTo:
+	for {
+		curr, tkn, lit = s.Scan()
+		// TODO(rstambler): This still needs more handling to work correctly.
+		// We encounter a specific issue with code that looks like this:
+		//
+		//      defer x.<>
+		//      y := 1
+		//
+		// In this scenario, we parse it as "defer x.y", which then fails to
+		// type-check, and we don't get completions as expected.
+		switch tkn {
+		case token.COMMENT, token.EOF, token.SEMICOLON, token.DEFINE:
+			break FindTo
+		}
+		// to is the end of expression that should become the Fun part of the call.
+		to = curr
+	}
+	if !from.IsValid() || tok.Offset(from) >= len(src) {
+		return fmt.Errorf("invalid from position")
+	}
+	if !to.IsValid() || tok.Offset(to)+1 >= len(src) {
+		return fmt.Errorf("invalid to position")
+	}
+	exprstr := string(src[tok.Offset(from) : tok.Offset(to)+1])
+	expr, err := parser.ParseExpr(exprstr)
+	if expr == nil {
+		return fmt.Errorf("no expr in %s: %v", exprstr, err)
+	}
+	// parser.ParseExpr returns undefined positions.
+	// Adjust them for the current file.
+	v.offsetPositions(expr, from-1)
+
+	// Package the expression into a fake *ast.CallExpr and re-insert into the function.
+	call := &ast.CallExpr{
+		Fun:    expr,
+		Lparen: to,
+		Rparen: to,
+	}
+	switch stmt := stmt.(type) {
+	case *ast.DeferStmt:
+		stmt.Call = call
+	case *ast.GoStmt:
+		stmt.Call = call
+	}
+	switch parent := parent.(type) {
+	case *ast.BlockStmt:
+		for i, s := range parent.List {
+			if s == bad {
+				parent.List[i] = stmt
+				break
+			}
+		}
+	}
+	return nil
+}
+
+// offsetPositions applies an offset to the positions in an ast.Node.
+// TODO(rstambler): Add more cases here as they become necessary.
+func (v *View) offsetPositions(expr ast.Expr, offset token.Pos) {
+	ast.Inspect(expr, func(n ast.Node) bool {
+		switch n := n.(type) {
+		case *ast.Ident:
+			n.NamePos += offset
+			return false
+		default:
+			return true
+		}
+	})
+}
diff --git a/internal/lsp/testdata/badstmt/badstmt.go b/internal/lsp/testdata/badstmt/badstmt.go
new file mode 100644
index 0000000..2ff6e09
--- /dev/null
+++ b/internal/lsp/testdata/badstmt/badstmt.go
@@ -0,0 +1,10 @@
+package badstmt
+
+import (
+	"golang.org/x/tools/internal/lsp/foo"
+)
+
+func _() {
+	defer foo.F //@complete("F", Foo, IntFoo, StructFoo),diag(" //", "LSP", "function must be invoked in defer statement")
+	go foo.F //@complete("F", Foo, IntFoo, StructFoo)
+}
\ No newline at end of file
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 046c9ee..7948b42 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -27,8 +27,8 @@
 // We hardcode the expected number of test cases to ensure that all tests
 // are being executed. If a test is added, this number must be changed.
 const (
-	ExpectedCompletionsCount     = 82
-	ExpectedDiagnosticsCount     = 16
+	ExpectedCompletionsCount     = 84
+	ExpectedDiagnosticsCount     = 17
 	ExpectedFormatCount          = 4
 	ExpectedDefinitionsCount     = 21
 	ExpectedTypeDefinitionsCount = 2