internal/lsp: refactor line folding range code

This CL removes duplicate code in lineFoldingRange function under
lsp/source/folding_range.go and generally improves code quality.

Fixes bug with composite literal folding where gopls was folding literals
with braces on the same line as end token (paranthesis/braces).

Change-Id: I742f285d866d72a243129c0aef0935fe2a1ad0dd
Reviewed-on: https://go-review.googlesource.com/c/tools/+/245757
Reviewed-by: Heschi Kreinick <heschi@google.com>
diff --git a/internal/lsp/source/folding_range.go b/internal/lsp/source/folding_range.go
index 1b7cea9..e374d50 100644
--- a/internal/lsp/source/folding_range.go
+++ b/internal/lsp/source/folding_range.go
@@ -9,6 +9,7 @@
 	"golang.org/x/tools/internal/lsp/protocol"
 )
 
+// FoldingRangeInfo holds range and kind info of folding for an ast.Node
 type FoldingRangeInfo struct {
 	mappedRange
 	Kind protocol.FoldingRangeKind
@@ -27,13 +28,8 @@
 	// Get folding ranges for comments separately as they are not walked by ast.Inspect.
 	ranges = append(ranges, commentsFoldingRange(fset, pgf.Mapper, pgf.File)...)
 
-	foldingFunc := foldingRange
-	if lineFoldingOnly {
-		foldingFunc = lineFoldingRange
-	}
-
 	visit := func(n ast.Node) bool {
-		rng := foldingFunc(fset, pgf.Mapper, n)
+		rng := foldingRangeFunc(fset, pgf.Mapper, n, lineFoldingOnly)
 		if rng != nil {
 			ranges = append(ranges, rng)
 		}
@@ -51,14 +47,20 @@
 	return ranges, nil
 }
 
-// foldingRange calculates the folding range for n.
-func foldingRange(fset *token.FileSet, m *protocol.ColumnMapper, n ast.Node) *FoldingRangeInfo {
+// foldingRangeFunc calculates the line folding range for ast.Node n
+func foldingRangeFunc(fset *token.FileSet, m *protocol.ColumnMapper, n ast.Node, lineFoldingOnly bool) *FoldingRangeInfo {
+	// TODO(suzmue): include trailing empty lines before the closing
+	// parenthesis/brace.
 	var kind protocol.FoldingRangeKind
 	var start, end token.Pos
 	switch n := n.(type) {
 	case *ast.BlockStmt:
-		// Fold from position of "{" to position of "}".
-		start, end = n.Lbrace+1, n.Rbrace
+		// Fold between positions of or lines between "{" and "}".
+		var startList, endList token.Pos
+		if num := len(n.List); num != 0 {
+			startList, endList = n.List[0].Pos(), n.List[num-1].End()
+		}
+		start, end = validLineFoldingRange(fset, n.Lbrace, n.Rbrace, startList, endList, lineFoldingOnly)
 	case *ast.CaseClause:
 		// Fold from position of ":" to end.
 		start, end = n.Colon+1, n.End()
@@ -69,123 +71,66 @@
 		// Fold from position of "(" to position of ")".
 		start, end = n.Lparen+1, n.Rparen
 	case *ast.FieldList:
-		// Fold from position of opening parenthesis/brace, to position of
-		// closing parenthesis/brace.
-		start, end = n.Opening+1, n.Closing
+		// Fold between positions of or lines between opening parenthesis/brace and closing parenthesis/brace.
+		var startList, endList token.Pos
+		if num := len(n.List); num != 0 {
+			startList, endList = n.List[0].Pos(), n.List[num-1].End()
+		}
+		start, end = validLineFoldingRange(fset, n.Opening, n.Closing, startList, endList, lineFoldingOnly)
 	case *ast.GenDecl:
 		// If this is an import declaration, set the kind to be protocol.Imports.
 		if n.Tok == token.IMPORT {
 			kind = protocol.Imports
 		}
-		start, end = n.Lparen+1, n.Rparen
+		// Fold between positions of or lines between "(" and ")".
+		var startSpecs, endSpecs token.Pos
+		if num := len(n.Specs); num != 0 {
+			startSpecs, endSpecs = n.Specs[0].Pos(), n.Specs[num-1].End()
+		}
+		start, end = validLineFoldingRange(fset, n.Lparen, n.Rparen, startSpecs, endSpecs, lineFoldingOnly)
 	case *ast.CompositeLit:
-		// Fold from position of "{" to position of "}".
-		start, end = n.Lbrace+1, n.Rbrace
+		// Fold between positions of or lines between "{" and "}".
+		var startElts, endElts token.Pos
+		if num := len(n.Elts); num != 0 {
+			startElts, endElts = n.Elts[0].Pos(), n.Elts[num-1].End()
+		}
+		start, end = validLineFoldingRange(fset, n.Lbrace, n.Rbrace, startElts, endElts, lineFoldingOnly)
 	}
+
+	// Check that folding positions are valid.
 	if !start.IsValid() || !end.IsValid() {
 		return nil
 	}
+	// in line folding mode, do not fold if the start and end lines are the same.
+	if lineFoldingOnly && fset.Position(start).Line == fset.Position(end).Line {
+		return nil
+	}
 	return &FoldingRangeInfo{
 		mappedRange: newMappedRange(fset, m, start, end),
 		Kind:        kind,
 	}
 }
 
-// lineFoldingRange calculates the line folding range for n.
-func lineFoldingRange(fset *token.FileSet, m *protocol.ColumnMapper, n ast.Node) *FoldingRangeInfo {
+// validLineFoldingRange returns start and end token.Pos for folding range if the range is valid.
+// returns token.NoPos otherwise, which fails token.IsValid check
+func validLineFoldingRange(fset *token.FileSet, open, close, start, end token.Pos, lineFoldingOnly bool) (token.Pos, token.Pos) {
+	if lineFoldingOnly {
+		if !open.IsValid() || !close.IsValid() {
+			return token.NoPos, token.NoPos
+		}
 
-	// TODO(suzmue): include trailing empty lines before the closing
-	// parenthesis/brace.
-	var kind protocol.FoldingRangeKind
-	var start, end token.Pos
-	switch n := n.(type) {
-	case *ast.BlockStmt:
-		// Fold lines between "{" and "}".
-		if !n.Lbrace.IsValid() || !n.Rbrace.IsValid() {
-			break
+		// Don't want to fold if the start/end is on the same line as the open/close
+		// as an example, the example below should *not* fold:
+		// var x = [2]string{"d",
+		// "e" }
+		if fset.Position(open).Line == fset.Position(start).Line ||
+			fset.Position(close).Line == fset.Position(end).Line {
+			return token.NoPos, token.NoPos
 		}
-		nStmts := len(n.List)
-		if nStmts == 0 {
-			break
-		}
-		// Don't want to fold if the start is on the same line as the brace.
-		if fset.Position(n.Lbrace).Line == fset.Position(n.List[0].Pos()).Line {
-			break
-		}
-		// Don't want to fold if the end is on the same line as the brace.
-		if fset.Position(n.Rbrace).Line == fset.Position(n.List[nStmts-1].End()).Line {
-			break
-		}
-		start, end = n.Lbrace+1, n.List[nStmts-1].End()
-	case *ast.CaseClause:
-		// Fold from position of ":" to end.
-		start, end = n.Colon+1, n.End()
-	case *ast.CommClause:
-		// Fold from position of ":" to end.
-		start, end = n.Colon+1, n.End()
-	case *ast.FieldList:
-		// Fold lines between opening parenthesis/brace and closing parenthesis/brace.
-		if !n.Opening.IsValid() || !n.Closing.IsValid() {
-			break
-		}
-		nFields := len(n.List)
-		if nFields == 0 {
-			break
-		}
-		// Don't want to fold if the start is on the same line as the parenthesis/brace.
-		if fset.Position(n.Opening).Line == fset.Position(n.List[nFields-1].End()).Line {
-			break
-		}
-		// Don't want to fold if the end is on the same line as the parenthesis/brace.
-		if fset.Position(n.Closing).Line == fset.Position(n.List[nFields-1].End()).Line {
-			break
-		}
-		start, end = n.Opening+1, n.List[nFields-1].End()
-	case *ast.GenDecl:
-		// If this is an import declaration, set the kind to be protocol.Imports.
-		if n.Tok == token.IMPORT {
-			kind = protocol.Imports
-		}
-		// Fold from position of "(" to position of ")".
-		if !n.Lparen.IsValid() || !n.Rparen.IsValid() {
-			break
-		}
-		nSpecs := len(n.Specs)
-		if nSpecs == 0 {
-			break
-		}
-		// Don't want to fold if the end is on the same line as the parenthesis/brace.
-		if fset.Position(n.Lparen).Line == fset.Position(n.Specs[0].Pos()).Line {
-			break
-		}
-		// Don't want to fold if the end is on the same line as the parenthesis/brace.
-		if fset.Position(n.Rparen).Line == fset.Position(n.Specs[nSpecs-1].End()).Line {
-			break
-		}
-		start, end = n.Lparen+1, n.Specs[nSpecs-1].End()
-	case *ast.CompositeLit:
-		// Fold lines between "{" and "}".
-		if !n.Lbrace.IsValid() || !n.Rbrace.IsValid() {
-			break
-		}
-		if len(n.Elts) == 0 {
-			break
-		}
-		start, end = n.Lbrace+1, n.Elts[len(n.Elts)-1].End()
-	}
 
-	// Check that folding positions are valid.
-	if !start.IsValid() || !end.IsValid() {
-		return nil
+		return open + 1, end
 	}
-	// Do not fold if the start and end lines are the same.
-	if fset.Position(start).Line == fset.Position(end).Line {
-		return nil
-	}
-	return &FoldingRangeInfo{
-		mappedRange: newMappedRange(fset, m, start, end),
-		Kind:        kind,
-	}
+	return open + 1, close
 }
 
 // commentsFoldingRange returns the folding ranges for all comment blocks in file.
diff --git a/internal/lsp/testdata/lsp/primarymod/folding/a.go b/internal/lsp/testdata/lsp/primarymod/folding/a.go
index e472a33..ffdc2a5 100644
--- a/internal/lsp/testdata/lsp/primarymod/folding/a.go
+++ b/internal/lsp/testdata/lsp/primarymod/folding/a.go
@@ -28,7 +28,8 @@
 		3,
 	}
 	_ = [2]string{"d",
-		"e"}
+		"e"
+	}
 	_ = map[string]int{
 		"a": 1,
 		"b": 2,
diff --git a/internal/lsp/testdata/lsp/primarymod/folding/a.go.golden b/internal/lsp/testdata/lsp/primarymod/folding/a.go.golden
index 6f71e1c..f822736 100644
--- a/internal/lsp/testdata/lsp/primarymod/folding/a.go.golden
+++ b/internal/lsp/testdata/lsp/primarymod/folding/a.go.golden
@@ -59,7 +59,8 @@
 		3,
 	}
 	_ = [2]string{"d",
-		"e"}
+		"e"
+	}
 	_ = map[string]int{
 		"a": 1,
 		"b": 2,
@@ -115,7 +116,8 @@
 		3,
 	}
 	_ = [2]string{"d",
-		"e"}
+		"e"
+	}
 	_ = map[string]int{
 		"a": 1,
 		"b": 2,
@@ -178,7 +180,8 @@
 		3,
 	}
 	_ = [2]string{"d",
-		"e"}
+		"e"
+	}
 	_ = map[string]int{
 		"a": 1,
 		"b": 2,
@@ -218,7 +221,7 @@
 3:9-6:0
 10:22-11:32
 12:10-12:9
-12:20-65:0
+12:20-66:0
 13:10-24:1
 14:12-19:3
 15:12-17:2
@@ -230,23 +233,23 @@
 22:10-23:24
 23:15-23:23
 25:12-29:1
-30:16-31:5
-32:21-36:1
-37:17-41:1
-42:8-46:1
-47:15-47:23
-47:32-47:40
-48:10-59:1
-49:18-54:3
-50:11-52:2
-51:16-51:28
-52:11-54:2
-53:16-53:29
-55:11-56:18
-56:15-56:17
-57:10-58:24
-58:15-58:23
-60:32-61:30
+30:16-32:1
+33:21-37:1
+38:17-42:1
+43:8-47:1
+48:15-48:23
+48:32-48:40
+49:10-60:1
+50:18-55:3
+51:11-53:2
+52:16-52:28
+53:11-55:2
+54:16-54:29
+56:11-57:18
+57:15-57:17
+58:10-59:24
+59:15-59:23
+61:32-62:30
 
 -- foldingRange-comment-0 --
 package folding //@fold("package")
@@ -278,7 +281,8 @@
 		3,
 	}
 	_ = [2]string{"d",
-		"e"}
+		"e"
+	}
 	_ = map[string]int{
 		"a": 1,
 		"b": 2,
@@ -341,7 +345,8 @@
 		3,
 	}
 	_ = [2]string{"d",
-		"e"}
+		"e"
+	}
 	_ = map[string]int{
 		"a": 1,
 		"b": 2,
@@ -406,7 +411,9 @@
 	}
 	_ = []int{<>,
 	}
-	_ = [2]string{<>}
+	_ = [2]string{"d",
+		"e"
+	}
 	_ = map[string]int{<>,
 	}
 	type T struct {<>
@@ -446,7 +453,8 @@
 		3,
 	}
 	_ = [2]string{"d",
-		"e"}
+		"e"
+	}
 	_ = map[string]int{
 		"a": 1,
 		"b": 2,
@@ -504,7 +512,8 @@
 		3,
 	}
 	_ = [2]string{"d",
-		"e"}
+		"e"
+	}
 	_ = map[string]int{
 		"a": 1,
 		"b": 2,
@@ -568,7 +577,8 @@
 		3,
 	}
 	_ = [2]string{"d",
-		"e"}
+		"e"
+	}
 	_ = map[string]int{
 		"a": 1,
 		"b": 2,
@@ -632,7 +642,8 @@
 		3,
 	}
 	_ = [2]string{"d",
-		"e"}
+		"e"
+	}
 	_ = map[string]int{
 		"a": 1,
 		"b": 2,