gopls: fix StmtToInsertVarBefore for switch stmts

The function StmtToInsertVarBefore on getting a
variable declaration in switch stmt was returning
the variable declaration statement instead of the
switch statment.

Fixes golang/go#67905

Change-Id: Ied1f82061ae4d5bbe6b65e6897e8db44ef43d8c6
GitHub-Last-Rev: 11b8c6d043f0d55169275b026e9b8f76f8aadf95
GitHub-Pull-Request: golang/tools#498
Reviewed-on: https://go-review.googlesource.com/c/tools/+/591496
Reviewed-by: Alan Donovan <adonovan@google.com>
Reviewed-by: Robert Findley <rfindley@google.com>
Auto-Submit: Alan Donovan <adonovan@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go
index ddce478..cc82a53 100644
--- a/gopls/internal/golang/extract.go
+++ b/gopls/internal/golang/extract.go
@@ -58,6 +58,15 @@
 		return nil, nil, fmt.Errorf("cannot extract %T", expr)
 	}
 
+	// TODO: There is a bug here: for a variable declared in a labeled
+	// switch/for statement it returns the for/switch statement itself
+	// which produces the below code which is a compiler error e.g.
+	// label:
+	// switch r1 := r() { ... break label ... }
+	// On extracting "r()" to a variable
+	// label:
+	// x := r()
+	// switch r1 := x { ... break label ... } // compiler error
 	insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path)
 	if insertBeforeStmt == nil {
 		return nil, nil, fmt.Errorf("cannot find location to insert extraction")
diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt
new file mode 100644
index 0000000..c3e75de
--- /dev/null
+++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt
@@ -0,0 +1,29 @@
+This test verifies the fix for golang/go#67905: Extract variable from type
+switch produces invalid code
+
+-- go.mod --
+module mod.test/extract
+
+go 1.18
+
+-- extract_switch.go --
+package extract
+
+import (
+	"io"
+)
+
+func f() io.Reader
+
+func main() {
+	switch r := f().(type) { //@codeactionedit("f()", "refactor.extract", type_switch_func_call)
+	default:
+		_ = r
+	}
+}
+
+-- @type_switch_func_call/extract_switch.go --
+@@ -10 +10,2 @@
+-	switch r := f().(type) { //@codeactionedit("f()", "refactor.extract", type_switch_func_call)
++	x := f()
++	switch r := x.(type) { //@codeactionedit("f()", "refactor.extract", type_switch_func_call)
diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go
index 9ba3a8e..db639c1 100644
--- a/internal/analysisinternal/analysis.go
+++ b/internal/analysisinternal/analysis.go
@@ -269,6 +269,8 @@
 		if expr.Init == enclosingStmt || expr.Post == enclosingStmt {
 			return expr
 		}
+	case *ast.SwitchStmt, *ast.TypeSwitchStmt:
+		return expr.(ast.Stmt)
 	}
 	return enclosingStmt.(ast.Stmt)
 }