internal/lsp: fix deadlock in type-checking

There was a situation where we were trying to re-acquire a lock that was
already held. This change solves this issue.

Change-Id: I97cf6bad7e7c219a267e3ca5d174a2573f70ebe2
Reviewed-on: https://go-review.googlesource.com/c/tools/+/184217
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/internal/lsp/cache/check.go b/internal/lsp/cache/check.go
index b56be5f..1ea134f 100644
--- a/internal/lsp/cache/check.go
+++ b/internal/lsp/cache/check.go
@@ -119,7 +119,7 @@
 	)
 	for _, filename := range meta.files {
 		uri := span.FileURI(filename)
-		f, err := imp.view.getFile(uri)
+		f, err := imp.view.getFile(imp.ctx, uri)
 		if err != nil {
 			continue
 		}
@@ -187,25 +187,25 @@
 	check.Files(pkg.GetSyntax())
 
 	// Add every file in this package to our cache.
-	imp.cachePackage(imp.ctx, pkg, meta, mode)
+	if err := imp.cachePackage(imp.ctx, pkg, meta, mode); err != nil {
+		return nil, err
+	}
 
 	return pkg, nil
 }
 
-func (imp *importer) cachePackage(ctx context.Context, pkg *pkg, meta *metadata, mode source.ParseMode) {
+func (imp *importer) cachePackage(ctx context.Context, pkg *pkg, meta *metadata, mode source.ParseMode) error {
 	for _, file := range pkg.files {
-		f, err := imp.view.getFile(file.uri)
+		f, err := imp.view.getFile(ctx, file.uri)
 		if err != nil {
-			imp.view.session.log.Errorf(ctx, "no file: %v", err)
-			continue
+			return fmt.Errorf("no such file %s: %v", file.uri, err)
 		}
 		gof, ok := f.(*goFile)
 		if !ok {
-			imp.view.session.log.Errorf(ctx, "%v is not a Go file", file.uri)
-			continue
+			return fmt.Errorf("non Go file %s", file.uri)
 		}
 		if err := imp.cachePerFile(gof, file, pkg); err != nil {
-			imp.view.session.log.Errorf(ctx, "failed to cache file %s: %v", gof.URI(), err)
+			return fmt.Errorf("failed to cache file %s: %v", gof.URI(), err)
 		}
 	}
 
@@ -219,6 +219,8 @@
 		}
 		pkg.imports[importPkg.pkgPath] = importPkg
 	}
+
+	return nil
 }
 
 func (imp *importer) cachePerFile(gof *goFile, file *astFile, p *pkg) error {
diff --git a/internal/lsp/cache/gofile.go b/internal/lsp/cache/gofile.go
index 0351aa6..23ecac6 100644
--- a/internal/lsp/cache/gofile.go
+++ b/internal/lsp/cache/gofile.go
@@ -226,7 +226,7 @@
 	}
 	for _, filename := range m.files {
 		uri := span.FileURI(filename)
-		if f, err := v.getFile(uri); err == nil && v.session.IsOpen(uri) {
+		if f, err := v.getFile(ctx, uri); err == nil && v.session.IsOpen(uri) {
 			results[f.(*goFile)] = struct{}{}
 		}
 	}
diff --git a/internal/lsp/cache/load.go b/internal/lsp/cache/load.go
index 2d707ed..73aeff6 100644
--- a/internal/lsp/cache/load.go
+++ b/internal/lsp/cache/load.go
@@ -21,7 +21,7 @@
 	// don't ignore function bodies.
 	if f.astIsTrimmed() {
 		v.pcache.mu.Lock()
-		f.invalidateAST()
+		f.invalidateAST(ctx)
 		v.pcache.mu.Unlock()
 	}
 
@@ -75,25 +75,20 @@
 // checkMetadata determines if we should run go/packages.Load for this file.
 // If yes, update the metadata for the file and its package.
 func (v *view) checkMetadata(ctx context.Context, f *goFile) (map[packageID]*metadata, []packages.Error, error) {
-	f.mu.Lock()
-	defer f.mu.Unlock()
-
-	if !v.parseImports(ctx, f) {
+	filename, ok := v.runGopackages(ctx, f)
+	if !ok {
 		return f.meta, nil, nil
 	}
 
-	// Reset the file's metadata and type information if we are re-running `go list`.
-	for k := range f.meta {
-		delete(f.meta, k)
-	}
-	for k := range f.pkgs {
-		delete(f.pkgs, k)
+	// Check if the context has been canceled before calling packages.Load.
+	if ctx.Err() != nil {
+		return nil, nil, ctx.Err()
 	}
 
-	pkgs, err := packages.Load(v.buildConfig(), fmt.Sprintf("file=%s", f.filename()))
+	pkgs, err := packages.Load(v.buildConfig(), fmt.Sprintf("file=%s", filename))
 	if len(pkgs) == 0 {
 		if err == nil {
-			err = fmt.Errorf("go/packages.Load: no packages found for %s", f.filename())
+			err = fmt.Errorf("go/packages.Load: no packages found for %s", filename)
 		}
 		// Return this error as a diagnostic to the user.
 		return nil, []packages.Error{
@@ -103,7 +98,6 @@
 			},
 		}, err
 	}
-
 	// Track missing imports as we look at the package's errors.
 	missingImports := make(map[packagePath]struct{})
 	for _, pkg := range pkgs {
@@ -119,52 +113,87 @@
 			}
 		}
 		// Build the import graph for this package.
-		v.link(ctx, packagePath(pkg.PkgPath), pkg, nil)
+		if err := v.link(ctx, packagePath(pkg.PkgPath), pkg, nil); err != nil {
+			return nil, nil, err
+		}
 	}
+	m, err := validateMetadata(ctx, missingImports, f)
+	if err != nil {
+		return nil, nil, err
+	}
+	return m, nil, nil
+}
+
+func validateMetadata(ctx context.Context, missingImports map[packagePath]struct{}, f *goFile) (map[packageID]*metadata, error) {
+	f.mu.Lock()
+	defer f.mu.Unlock()
 
 	// If `go list` failed to get data for the file in question (this should never happen).
 	if len(f.meta) == 0 {
-		return nil, nil, fmt.Errorf("loadParseTypecheck: no metadata found for %v", f.filename())
+		return nil, fmt.Errorf("loadParseTypecheck: no metadata found for %v", f.filename())
 	}
 
 	// If we have already seen these missing imports before, and we have type information,
 	// there is no need to continue.
 	if sameSet(missingImports, f.missingImports) && len(f.pkgs) != 0 {
-		return nil, nil, nil
+		return nil, nil
 	}
 	// Otherwise, update the missing imports map.
 	f.missingImports = missingImports
-
-	return f.meta, nil, nil
+	return f.meta, nil
 }
 
 // reparseImports reparses a file's package and import declarations to
 // determine if they have changed.
-func (v *view) parseImports(ctx context.Context, f *goFile) bool {
+func (v *view) runGopackages(ctx context.Context, f *goFile) (filename string, result bool) {
+	f.mu.Lock()
+	defer func() {
+		// Clear metadata if we are intending to re-run go/packages.
+		if result {
+			// Reset the file's metadata and type information if we are re-running `go list`.
+			for k := range f.meta {
+				delete(f.meta, k)
+			}
+			for k := range f.pkgs {
+				delete(f.pkgs, k)
+			}
+		}
+
+		defer f.mu.Unlock()
+	}()
+
 	if len(f.meta) == 0 || len(f.missingImports) > 0 {
-		return true
+		return f.filename(), true
 	}
 	// Get file content in case we don't already have it.
 	parsed, _ := v.session.cache.ParseGoHandle(f.Handle(ctx), source.ParseHeader).Parse(ctx)
 	if parsed == nil {
-		return true
+		return f.filename(), true
 	}
-	// TODO: Add support for re-running `go list` when the package name changes.
-
-	// If the package's imports have changed, re-run `go list`.
-	if len(f.imports) != len(parsed.Imports) {
-		return true
-	}
-
-	for i, importSpec := range f.imports {
-		if importSpec.Path.Value != parsed.Imports[i].Path.Value {
-			return true
+	// Check if the package's name has changed, by checking if this is a filename
+	// we already know about, and if so, check if its package name has changed.
+	for _, m := range f.meta {
+		for _, filename := range m.files {
+			if filename == f.URI().Filename() {
+				if m.name != parsed.Name.Name {
+					return f.filename(), true
+				}
+			}
 		}
 	}
-	return false
+	// If the package's imports have changed, re-run `go list`.
+	if len(f.imports) != len(parsed.Imports) {
+		return f.filename(), true
+	}
+	for i, importSpec := range f.imports {
+		if importSpec.Path.Value != parsed.Imports[i].Path.Value {
+			return f.filename(), true
+		}
+	}
+	return f.filename(), false
 }
 
-func (v *view) link(ctx context.Context, pkgPath packagePath, pkg *packages.Package, parent *metadata) *metadata {
+func (v *view) link(ctx context.Context, pkgPath packagePath, pkg *packages.Package, parent *metadata) error {
 	id := packageID(pkg.ID)
 	m, ok := v.mcache.packages[id]
 
@@ -172,7 +201,7 @@
 	// so relevant packages get parsed and type-checked again.
 	if ok && !filenamesIdentical(m.files, pkg.CompiledGoFiles) {
 		v.pcache.mu.Lock()
-		v.remove(id, make(map[packageID]struct{}))
+		v.remove(ctx, id, make(map[packageID]struct{}))
 		v.pcache.mu.Unlock()
 	}
 
@@ -192,16 +221,18 @@
 	m.name = pkg.Name
 	m.files = pkg.CompiledGoFiles
 	for _, filename := range m.files {
-		if f, _ := v.getFile(span.FileURI(filename)); f != nil {
-			if gof, ok := f.(*goFile); ok {
-				if gof.meta == nil {
-					gof.meta = make(map[packageID]*metadata)
-				}
-				gof.meta[m.id] = m
-			} else {
-				v.Session().Logger().Errorf(ctx, "not a Go file: %s", f.URI())
-			}
+		f, err := v.getFile(ctx, span.FileURI(filename))
+		if err != nil {
+			return err
 		}
+		gof, ok := f.(*goFile)
+		if !ok {
+			return fmt.Errorf("not a Go file: %s", f.URI())
+		}
+		if gof.meta == nil {
+			gof.meta = make(map[packageID]*metadata)
+		}
+		gof.meta[m.id] = m
 	}
 	// Connect the import graph.
 	if parent != nil {
@@ -209,8 +240,15 @@
 		parent.children[id] = true
 	}
 	for importPath, importPkg := range pkg.Imports {
+		importPkgPath := packagePath(importPath)
+		if importPkgPath == pkgPath {
+			v.session.log.Errorf(ctx, "cycle detected in %s", importPath)
+			return nil
+		}
 		if _, ok := m.children[packageID(importPkg.ID)]; !ok {
-			v.link(ctx, packagePath(importPath), importPkg, m)
+			if err := v.link(ctx, importPkgPath, importPkg, m); err != nil {
+				return err
+			}
 		}
 	}
 	// Clear out any imports that have been removed.
@@ -226,7 +264,7 @@
 		delete(m.children, importID)
 		delete(child.parents, id)
 	}
-	return m
+	return nil
 }
 
 // filenamesIdentical reports whether two sets of file names are identical.
diff --git a/internal/lsp/cache/view.go b/internal/lsp/cache/view.go
index ec848a0..a56dcbb 100644
--- a/internal/lsp/cache/view.go
+++ b/internal/lsp/cache/view.go
@@ -224,7 +224,7 @@
 
 // invalidateContent invalidates the content of a Go file,
 // including any position and type information that depends on it.
-func (f *goFile) invalidateContent() {
+func (f *goFile) invalidateContent(ctx context.Context) {
 	f.handleMu.Lock()
 	defer f.handleMu.Unlock()
 
@@ -234,13 +234,13 @@
 	f.view.pcache.mu.Lock()
 	defer f.view.pcache.mu.Unlock()
 
-	f.invalidateAST()
+	f.invalidateAST(ctx)
 	f.handle = nil
 }
 
 // invalidateAST invalidates the AST of a Go file,
 // including any position and type information that depends on it.
-func (f *goFile) invalidateAST() {
+func (f *goFile) invalidateAST(ctx context.Context) {
 	f.mu.Lock()
 	f.ast = nil
 	f.token = nil
@@ -250,7 +250,7 @@
 	// Remove the package and all of its reverse dependencies from the cache.
 	for id, pkg := range pkgs {
 		if pkg != nil {
-			f.view.remove(id, map[packageID]struct{}{})
+			f.view.remove(ctx, id, map[packageID]struct{}{})
 		}
 	}
 }
@@ -258,7 +258,7 @@
 // remove invalidates a package and its reverse dependencies in the view's
 // package cache. It is assumed that the caller has locked both the mutexes
 // of both the mcache and the pcache.
-func (v *view) remove(id packageID, seen map[packageID]struct{}) {
+func (v *view) remove(ctx context.Context, id packageID, seen map[packageID]struct{}) {
 	if _, ok := seen[id]; ok {
 		return
 	}
@@ -268,20 +268,27 @@
 	}
 	seen[id] = struct{}{}
 	for parentID := range m.parents {
-		v.remove(parentID, seen)
+		v.remove(ctx, parentID, seen)
 	}
 	// All of the files in the package may also be holding a pointer to the
 	// invalidated package.
 	for _, filename := range m.files {
-		if f, _ := v.findFile(span.FileURI(filename)); f != nil {
-			if gof, ok := f.(*goFile); ok {
-				gof.mu.Lock()
-				delete(gof.pkgs, id)
-				gof.mu.Unlock()
-			}
+		f, err := v.findFile(span.FileURI(filename))
+		if err != nil {
+			v.session.log.Errorf(ctx, "cannot find file %s: %v", f.URI(), err)
+			continue
 		}
+		gof, ok := f.(*goFile)
+		if !ok {
+			v.session.log.Errorf(ctx, "non-Go file %v", f.URI())
+			continue
+		}
+		gof.mu.Lock()
+		delete(gof.pkgs, id)
+		gof.mu.Unlock()
 	}
 	delete(v.pcache.packages, id)
+	return
 }
 
 // FindFile returns the file if the given URI is already a part of the view.
@@ -301,11 +308,11 @@
 	v.mu.Lock()
 	defer v.mu.Unlock()
 
-	return v.getFile(uri)
+	return v.getFile(ctx, uri)
 }
 
 // getFile is the unlocked internal implementation of GetFile.
-func (v *view) getFile(uri span.URI) (viewFile, error) {
+func (v *view) getFile(ctx context.Context, uri span.URI) (viewFile, error) {
 	if f, err := v.findFile(uri); err != nil {
 		return nil, err
 	} else if f != nil {
@@ -344,7 +351,7 @@
 			if !ok {
 				return
 			}
-			gof.invalidateContent()
+			gof.invalidateContent(ctx)
 		})
 	}
 	v.mapFile(uri, f)
diff --git a/internal/lsp/source/format.go b/internal/lsp/source/format.go
index 56be5e0..3a308a7 100644
--- a/internal/lsp/source/format.go
+++ b/internal/lsp/source/format.go
@@ -25,7 +25,7 @@
 		return nil, fmt.Errorf("no AST for %s", f.URI())
 	}
 	pkg := f.GetPackage(ctx)
-	if hasParseErrors(pkg.GetErrors()) {
+	if hasListErrors(pkg.GetErrors()) || hasParseErrors(pkg.GetErrors()) {
 		return nil, fmt.Errorf("%s has parse errors, not formatting", f.URI())
 	}
 	path, exact := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
@@ -45,6 +45,23 @@
 	return computeTextEdits(ctx, f, buf.String()), nil
 }
 
+// Imports formats a file using the goimports tool.
+func Imports(ctx context.Context, f GoFile, rng span.Range) ([]TextEdit, error) {
+	data, _, err := f.Handle(ctx).Read(ctx)
+	if err != nil {
+		return nil, err
+	}
+	pkg := f.GetPackage(ctx)
+	if hasListErrors(pkg.GetErrors()) {
+		return nil, fmt.Errorf("%s has list errors, not running goimports", f.URI())
+	}
+	formatted, err := imports.Process(f.URI().Filename(), data, nil)
+	if err != nil {
+		return nil, err
+	}
+	return computeTextEdits(ctx, f, string(formatted)), nil
+}
+
 func hasParseErrors(errors []packages.Error) bool {
 	for _, err := range errors {
 		if err.Kind == packages.ParseError {
@@ -54,21 +71,13 @@
 	return false
 }
 
-// Imports formats a file using the goimports tool.
-func Imports(ctx context.Context, f GoFile, rng span.Range) ([]TextEdit, error) {
-	data, _, err := f.Handle(ctx).Read(ctx)
-	if err != nil {
-		return nil, err
+func hasListErrors(errors []packages.Error) bool {
+	for _, err := range errors {
+		if err.Kind == packages.ListError {
+			return true
+		}
 	}
-	tok := f.GetToken(ctx)
-	if tok == nil {
-		return nil, fmt.Errorf("no token file for %s", f.URI())
-	}
-	formatted, err := imports.Process(tok.Name(), data, nil)
-	if err != nil {
-		return nil, err
-	}
-	return computeTextEdits(ctx, f, string(formatted)), nil
+	return false
 }
 
 func computeTextEdits(ctx context.Context, file File, formatted string) (edits []TextEdit) {