internal/lsp: add highlighting for import statement

This change adds highlights for imports when the cursor is over the use of that import. It also adds it for the opposite direction when the cursor is on the import, it will highlight uses of that import.

Fixes golang/go#36590

Change-Id: Ifd04d81ec9b4fdf2be1b763f31b44d0ef7d92f47
Reviewed-on: https://go-review.googlesource.com/c/tools/+/215258
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
Run-TryBot: Rohan Challa <rohan@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/internal/lsp/source/highlight.go b/internal/lsp/source/highlight.go
index b453157..07a2ac6 100644
--- a/internal/lsp/source/highlight.go
+++ b/internal/lsp/source/highlight.go
@@ -9,6 +9,8 @@
 	"fmt"
 	"go/ast"
 	"go/token"
+	"go/types"
+	"strings"
 
 	"golang.org/x/tools/go/ast/astutil"
 	"golang.org/x/tools/internal/lsp/protocol"
@@ -54,18 +56,25 @@
 	}
 
 	switch path[0].(type) {
-	case *ast.ReturnStmt, *ast.FuncDecl, *ast.FuncType, *ast.BasicLit:
-		return highlightFuncControlFlow(ctx, snapshot, pkg, path)
+	case *ast.BasicLit:
+		if len(path) > 1 {
+			if _, ok := path[1].(*ast.ImportSpec); ok {
+				return highlightImportUses(ctx, snapshot.View(), pkg, path)
+			}
+		}
+		return highlightFuncControlFlow(ctx, snapshot.View(), pkg, path)
+	case *ast.ReturnStmt, *ast.FuncDecl, *ast.FuncType:
+		return highlightFuncControlFlow(ctx, snapshot.View(), pkg, path)
 	case *ast.Ident:
-		return highlightIdentifiers(ctx, snapshot, pkg, path)
+		return highlightIdentifiers(ctx, snapshot.View(), pkg, path)
 	case *ast.BranchStmt, *ast.ForStmt, *ast.RangeStmt:
-		return highlightLoopControlFlow(ctx, snapshot, pkg, path)
+		return highlightLoopControlFlow(ctx, snapshot.View(), pkg, path)
 	}
 	// If the cursor is in an unidentified area, return empty results.
 	return nil, nil
 }
 
-func highlightFuncControlFlow(ctx context.Context, snapshot Snapshot, pkg Package, path []ast.Node) ([]protocol.Range, error) {
+func highlightFuncControlFlow(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) {
 	var enclosingFunc ast.Node
 	var returnStmt *ast.ReturnStmt
 	var resultsList *ast.FieldList
@@ -137,7 +146,7 @@
 	result := make(map[protocol.Range]bool)
 	// Highlight the correct argument in the function declaration return types.
 	if resultsList != nil && -1 < index && index < len(resultsList.List) {
-		rng, err := nodeToProtocolRange(snapshot.View(), pkg, resultsList.List[index])
+		rng, err := nodeToProtocolRange(view, pkg, resultsList.List[index])
 		if err != nil {
 			log.Error(ctx, "Error getting range for node", err)
 		} else {
@@ -146,7 +155,7 @@
 	}
 	// Add the "func" part of the func declaration.
 	if highlightAllReturnsAndFunc {
-		funcStmt, err := posToMappedRange(snapshot.View(), pkg, enclosingFunc.Pos(), enclosingFunc.Pos()+token.Pos(len("func")))
+		funcStmt, err := posToMappedRange(view, pkg, enclosingFunc.Pos(), enclosingFunc.Pos()+token.Pos(len("func")))
 		if err != nil {
 			return nil, err
 		}
@@ -174,7 +183,7 @@
 				toAdd = n.Results[index]
 			}
 			if toAdd != nil {
-				rng, err := nodeToProtocolRange(snapshot.View(), pkg, toAdd)
+				rng, err := nodeToProtocolRange(view, pkg, toAdd)
 				if err != nil {
 					log.Error(ctx, "Error getting range for node", err)
 				} else {
@@ -188,7 +197,7 @@
 	return rangeMapToSlice(result), nil
 }
 
-func highlightLoopControlFlow(ctx context.Context, snapshot Snapshot, pkg Package, path []ast.Node) ([]protocol.Range, error) {
+func highlightLoopControlFlow(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) {
 	var loop ast.Node
 Outer:
 	// Reverse walk the path till we get to the for loop.
@@ -205,7 +214,7 @@
 	}
 	result := make(map[protocol.Range]bool)
 	// Add the for statement.
-	forStmt, err := posToMappedRange(snapshot.View(), pkg, loop.Pos(), loop.Pos()+token.Pos(len("for")))
+	forStmt, err := posToMappedRange(view, pkg, loop.Pos(), loop.Pos()+token.Pos(len("for")))
 	if err != nil {
 		return nil, err
 	}
@@ -223,7 +232,7 @@
 		}
 		// Add all branch statements in same scope as the identified one.
 		if n, ok := n.(*ast.BranchStmt); ok {
-			rng, err := nodeToProtocolRange(snapshot.View(), pkg, n)
+			rng, err := nodeToProtocolRange(view, pkg, n)
 			if err != nil {
 				log.Error(ctx, "Error getting range for node", err)
 				return false
@@ -235,14 +244,49 @@
 	return rangeMapToSlice(result), nil
 }
 
-func highlightIdentifiers(ctx context.Context, snapshot Snapshot, pkg Package, path []ast.Node) ([]protocol.Range, error) {
+func highlightImportUses(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) {
+	result := make(map[protocol.Range]bool)
+	basicLit, ok := path[0].(*ast.BasicLit)
+	if !ok {
+		return nil, errors.Errorf("highlightImportUses called with an ast.Node of type %T", basicLit)
+	}
+
+	ast.Inspect(path[len(path)-1], func(node ast.Node) bool {
+		if imp, ok := node.(*ast.ImportSpec); ok && imp.Path == basicLit {
+			if rng, err := nodeToProtocolRange(view, pkg, node); err == nil {
+				result[rng] = true
+				return false
+			}
+		}
+		n, ok := node.(*ast.Ident)
+		if !ok {
+			return true
+		}
+		obj, ok := pkg.GetTypesInfo().ObjectOf(n).(*types.PkgName)
+		if !ok {
+			return true
+		}
+		if !strings.Contains(basicLit.Value, obj.Name()) {
+			return true
+		}
+		if rng, err := nodeToProtocolRange(view, pkg, n); err == nil {
+			result[rng] = true
+		} else {
+			log.Error(ctx, "Error getting range for node", err)
+		}
+		return false
+	})
+	return rangeMapToSlice(result), nil
+}
+
+func highlightIdentifiers(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) {
 	result := make(map[protocol.Range]bool)
 	id, ok := path[0].(*ast.Ident)
 	if !ok {
 		return nil, errors.Errorf("highlightIdentifiers called with an ast.Node of type %T", id)
 	}
 	// Check if ident is inside return or func decl.
-	if toAdd, err := highlightFuncControlFlow(ctx, snapshot, pkg, path); toAdd != nil && err == nil {
+	if toAdd, err := highlightFuncControlFlow(ctx, view, pkg, path); toAdd != nil && err == nil {
 		for _, r := range toAdd {
 			result[r] = true
 		}
@@ -251,7 +295,13 @@
 	// TODO: maybe check if ident is a reserved word, if true then don't continue and return results.
 
 	idObj := pkg.GetTypesInfo().ObjectOf(id)
+	pkgObj, isImported := idObj.(*types.PkgName)
 	ast.Inspect(path[len(path)-1], func(node ast.Node) bool {
+		if imp, ok := node.(*ast.ImportSpec); ok && isImported {
+			if rng, err := highlightImport(view, pkg, pkgObj, imp); rng != nil && err == nil {
+				result[*rng] = true
+			}
+		}
 		n, ok := node.(*ast.Ident)
 		if !ok {
 			return true
@@ -262,7 +312,7 @@
 		if nObj := pkg.GetTypesInfo().ObjectOf(n); nObj != idObj {
 			return false
 		}
-		if rng, err := nodeToProtocolRange(snapshot.View(), pkg, n); err == nil {
+		if rng, err := nodeToProtocolRange(view, pkg, n); err == nil {
 			result[rng] = true
 		} else {
 			log.Error(ctx, "Error getting range for node", err)
@@ -272,6 +322,20 @@
 	return rangeMapToSlice(result), nil
 }
 
+func highlightImport(view View, pkg Package, obj *types.PkgName, imp *ast.ImportSpec) (*protocol.Range, error) {
+	if imp.Name != nil || imp.Path == nil {
+		return nil, nil
+	}
+	if !strings.Contains(imp.Path.Value, obj.Name()) {
+		return nil, nil
+	}
+	rng, err := nodeToProtocolRange(view, pkg, imp.Path)
+	if err != nil {
+		return nil, err
+	}
+	return &rng, nil
+}
+
 func rangeMapToSlice(rangeMap map[protocol.Range]bool) []protocol.Range {
 	var list []protocol.Range
 	for i := range rangeMap {
diff --git a/internal/lsp/testdata/highlights/highlights.go b/internal/lsp/testdata/highlights/highlights.go
index c8c37d2..de67efe 100644
--- a/internal/lsp/testdata/highlights/highlights.go
+++ b/internal/lsp/testdata/highlights/highlights.go
@@ -1,7 +1,8 @@
 package highlights
 
 import (
-	"fmt"
+	"fmt"         //@mark(fmtImp, "\"fmt\""),highlight(fmtImp, fmtImp, fmt1, fmt2, fmt3, fmt4)
+	h2 "net/http" //@mark(hImp, "h2"),highlight(hImp, hImp, hUse)
 	"sort"
 
 	"golang.org/x/tools/internal/lsp/protocol"
@@ -18,8 +19,10 @@
 var foo = F{bar: 52} //@mark(fooDeclaration, "foo"),mark(bar2, "bar"),highlight(fooDeclaration, fooDeclaration, fooUse),highlight(bar2, barDeclaration, bar1, bar2, bar3)
 
 func Print() { //@mark(printFunc, "Print"),highlight(printFunc, printFunc, printTest)
-	fmt.Println(foo) //@mark(fooUse, "foo"),highlight(fooUse, fooDeclaration, fooUse)
-	fmt.Print("yo")  //@mark(printSep, "Print"),highlight(printSep, printSep, print1, print2)
+	_ = h2.Client{} //@mark(hUse, "h2"),highlight(hUse, hImp, hUse)
+
+	fmt.Println(foo) //@mark(fooUse, "foo"),highlight(fooUse, fooDeclaration, fooUse),mark(fmt1, "fmt"),highlight(fmt1, fmtImp, fmt1, fmt2, fmt3, fmt4)
+	fmt.Print("yo")  //@mark(printSep, "Print"),highlight(printSep, printSep, print1, print2),mark(fmt2, "fmt"),highlight(fmt2, fmtImp, fmt1, fmt2, fmt3, fmt4)
 }
 
 func (x *F) Inc() { //@mark(xRightDecl, "x"),mark(xLeftDecl, " *"),highlight(xRightDecl, xRightDecl, xUse),highlight(xLeftDecl, xRightDecl, xUse)
@@ -27,8 +30,8 @@
 }
 
 func testFunctions() {
-	fmt.Print("main start") //@mark(print1, "Print"),highlight(print1, printSep, print1, print2)
-	fmt.Print("ok")         //@mark(print2, "Print"),highlight(print2, printSep, print1, print2)
+	fmt.Print("main start") //@mark(print1, "Print"),highlight(print1, printSep, print1, print2),mark(fmt3, "fmt"),highlight(fmt3, fmtImp, fmt1, fmt2, fmt3, fmt4)
+	fmt.Print("ok")         //@mark(print2, "Print"),highlight(print2, printSep, print1, print2),mark(fmt4, "fmt"),highlight(fmt4, fmtImp, fmt1, fmt2, fmt3, fmt4)
 	Print()                 //@mark(printTest, "Print"),highlight(printTest, printFunc, printTest)
 }
 
diff --git a/internal/lsp/testdata/summary.txt.golden b/internal/lsp/testdata/summary.txt.golden
index ee564e3..51c0290 100644
--- a/internal/lsp/testdata/summary.txt.golden
+++ b/internal/lsp/testdata/summary.txt.golden
@@ -13,7 +13,7 @@
 SuggestedFixCount = 1
 DefinitionsCount = 43
 TypeDefinitionsCount = 2
-HighlightsCount = 45
+HighlightsCount = 52
 ReferencesCount = 8
 RenamesCount = 22
 PrepareRenamesCount = 8