gopls: fix StmtToInsertVarBefore for switch stmts The function StmtToInsertVarBefore on getting a variable declaration in switch stmt was returning the variable declaration statement instead of the switch statment. Fixes golang/go#67905 Change-Id: Ied1f82061ae4d5bbe6b65e6897e8db44ef43d8c6 GitHub-Last-Rev: 11b8c6d043f0d55169275b026e9b8f76f8aadf95 GitHub-Pull-Request: golang/tools#498 Reviewed-on: https://go-review.googlesource.com/c/tools/+/591496 Reviewed-by: Alan Donovan <adonovan@google.com> Reviewed-by: Robert Findley <rfindley@google.com> Auto-Submit: Alan Donovan <adonovan@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index ddce478..cc82a53 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go
@@ -58,6 +58,15 @@ return nil, nil, fmt.Errorf("cannot extract %T", expr) } + // TODO: There is a bug here: for a variable declared in a labeled + // switch/for statement it returns the for/switch statement itself + // which produces the below code which is a compiler error e.g. + // label: + // switch r1 := r() { ... break label ... } + // On extracting "r()" to a variable + // label: + // x := r() + // switch r1 := x { ... break label ... } // compiler error insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path) if insertBeforeStmt == nil { return nil, nil, fmt.Errorf("cannot find location to insert extraction")
diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt new file mode 100644 index 0000000..c3e75de --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt
@@ -0,0 +1,29 @@ +This test verifies the fix for golang/go#67905: Extract variable from type +switch produces invalid code + +-- go.mod -- +module mod.test/extract + +go 1.18 + +-- extract_switch.go -- +package extract + +import ( + "io" +) + +func f() io.Reader + +func main() { + switch r := f().(type) { //@codeactionedit("f()", "refactor.extract", type_switch_func_call) + default: + _ = r + } +} + +-- @type_switch_func_call/extract_switch.go -- +@@ -10 +10,2 @@ +- switch r := f().(type) { //@codeactionedit("f()", "refactor.extract", type_switch_func_call) ++ x := f() ++ switch r := x.(type) { //@codeactionedit("f()", "refactor.extract", type_switch_func_call)
diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go index 9ba3a8e..db639c1 100644 --- a/internal/analysisinternal/analysis.go +++ b/internal/analysisinternal/analysis.go
@@ -269,6 +269,8 @@ if expr.Init == enclosingStmt || expr.Post == enclosingStmt { return expr } + case *ast.SwitchStmt, *ast.TypeSwitchStmt: + return expr.(ast.Stmt) } return enclosingStmt.(ast.Stmt) }