cmd/compile: major refactoring of switch walking

There are a lot of complexities to handling switches efficiently:

1. Order matters for expression switches with non-constant cases and
for type expressions with interface types. We have to respect
side-effects, and we also can't allow later cases to accidentally take
precedence over earlier cases.

2. For runs of integers, floats, and string constants in expression
switches or runs of concrete types in type switches, we want to emit
efficient binary searches.

3. For runs of consecutive integers in expression switches, we want to
collapse them into range comparisons.

4. For binary searches of strings, we want to compare by length first,
because that's more efficient and we don't need to respect any
particular ordering.

5. For "switch true { ... }" and "switch false { ... }", we want to
optimize "case x:" as simply "if x" or "if !x", respectively, unless x
is interface-typed.

The current swt.go code reflects how these constraints have been
incrementally added over time, with each of them being handled ad
hocly in different parts of the code. Also, the existing code tries
very hard to reuse logic between expression and type switches, even
though the similarities are very superficial.

This CL rewrites switch handling to better abstract away the logic
involved in constructing the binary searches. In particular, it's
intended to make further optimizations to switch dispatch much easier.

It also eliminates the need for both OXCASE and OCASE ops, and a
subsequent CL can collapse the two.

Passes toolstash-check.

Change-Id: Ifcd1e56f81f858117a412971d82e98abe7c4481f
Reviewed-on: https://go-review.googlesource.com/c/go/+/194660
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Keith Randall <khr@golang.org>
diff --git a/src/cmd/compile/internal/gc/swt.go b/src/cmd/compile/internal/gc/swt.go
index 5089efe..4a8e9bc 100644
--- a/src/cmd/compile/internal/gc/swt.go
+++ b/src/cmd/compile/internal/gc/swt.go
@@ -10,50 +10,6 @@
 	"sort"
 )
 
-const (
-	// expression switch
-	switchKindExpr  = iota // switch a {...} or switch 5 {...}
-	switchKindTrue         // switch true {...} or switch {...}
-	switchKindFalse        // switch false {...}
-)
-
-const (
-	binarySearchMin = 4 // minimum number of cases for binary search
-	integerRangeMin = 2 // minimum size of integer ranges
-)
-
-// An exprSwitch walks an expression switch.
-type exprSwitch struct {
-	exprname *Node // node for the expression being switched on
-	kind     int   // kind of switch statement (switchKind*)
-}
-
-// A typeSwitch walks a type switch.
-type typeSwitch struct {
-	hashname *Node // node for the hash of the type of the variable being switched on
-	facename *Node // node for the concrete type of the variable being switched on
-	okname   *Node // boolean node used for comma-ok type assertions
-}
-
-// A caseClause is a single case clause in a switch statement.
-type caseClause struct {
-	node    *Node  // points at case statement
-	ordinal int    // position in switch
-	hash    uint32 // hash of a type switch
-	// isconst indicates whether this case clause is a constant,
-	// for the purposes of the switch code generation.
-	// For expression switches, that's generally literals (case 5:, not case x:).
-	// For type switches, that's concrete types (case time.Time:), not interfaces (case io.Reader:).
-	isconst bool
-}
-
-// caseClauses are all the case clauses in a switch statement.
-type caseClauses struct {
-	list   []caseClause // general cases
-	defjmp *Node        // OGOTO for default case or OBREAK if no default case present
-	niljmp *Node        // OGOTO for nil type case in a type switch
-}
-
 // typecheckswitch typechecks a switch statement.
 func typecheckswitch(n *Node) {
 	typecheckslice(n.Ninit.Slice(), ctxStmt)
@@ -71,7 +27,6 @@
 		yyerrorl(n.Pos, "cannot type switch on non-interface value %L", n.Left.Right)
 		t = nil
 	}
-	n.Type = t // TODO(mdempsky): Remove; statements aren't typed.
 
 	// We don't actually declare the type switch's guarded
 	// declaration itself. So if there are no cases, we won't
@@ -212,7 +167,6 @@
 			t = nil
 		}
 	}
-	n.Type = t // TODO(mdempsky): Remove; statements aren't typed.
 
 	var defCase *Node
 	var cs constSet
@@ -265,422 +219,267 @@
 
 // walkswitch walks a switch statement.
 func walkswitch(sw *Node) {
-	// convert switch {...} to switch true {...}
-	if sw.Left == nil {
-		sw.Left = nodbool(true)
-		sw.Left = typecheck(sw.Left, ctxExpr)
-		sw.Left = defaultlit(sw.Left, nil)
-	}
-
-	if sw.Left.Op == OTYPESW {
-		var s typeSwitch
-		s.walk(sw)
-	} else {
-		var s exprSwitch
-		s.walk(sw)
-	}
-}
-
-// walk generates an AST implementing sw.
-// sw is an expression switch.
-// The AST is generally of the form of a linear
-// search using if..goto, although binary search
-// is used with long runs of constants.
-func (s *exprSwitch) walk(sw *Node) {
 	// Guard against double walk, see #25776.
 	if sw.List.Len() == 0 && sw.Nbody.Len() > 0 {
 		return // Was fatal, but eliminating every possible source of double-walking is hard
 	}
 
-	casebody(sw, nil)
+	if sw.Left != nil && sw.Left.Op == OTYPESW {
+		walkTypeSwitch(sw)
+	} else {
+		walkExprSwitch(sw)
+	}
+}
+
+// walkExprSwitch generates an AST implementing sw.  sw is an
+// expression switch.
+func walkExprSwitch(sw *Node) {
+	lno := setlineno(sw)
 
 	cond := sw.Left
 	sw.Left = nil
 
-	s.kind = switchKindExpr
-	if Isconst(cond, CTBOOL) {
-		s.kind = switchKindTrue
-		if !cond.Val().U.(bool) {
-			s.kind = switchKindFalse
-		}
+	// convert switch {...} to switch true {...}
+	if cond == nil {
+		cond = nodbool(true)
+		cond = typecheck(cond, ctxExpr)
+		cond = defaultlit(cond, nil)
 	}
 
 	// Given "switch string(byteslice)",
-	// with all cases being constants (or the default case),
+	// with all cases being side-effect free,
 	// use a zero-cost alias of the byte slice.
-	// In theory, we could be more aggressive,
-	// allowing any side-effect-free expressions in cases,
-	// but it's a bit tricky because some of that information
-	// is unavailable due to the introduction of temporaries during order.
-	// Restricting to constants is simple and probably powerful enough.
 	// Do this before calling walkexpr on cond,
 	// because walkexpr will lower the string
 	// conversion into a runtime call.
 	// See issue 24937 for more discussion.
-	if cond.Op == OBYTES2STR {
-		ok := true
-		for _, cas := range sw.List.Slice() {
-			if cas.Op != OCASE {
-				Fatalf("switch string(byteslice) bad op: %v", cas.Op)
-			}
-			if cas.Left != nil && !Isconst(cas.Left, CTSTR) {
-				ok = false
-				break
-			}
-		}
-		if ok {
-			cond.Op = OBYTES2STRTMP
-		}
+	if cond.Op == OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
+		cond.Op = OBYTES2STRTMP
 	}
 
 	cond = walkexpr(cond, &sw.Ninit)
-	t := sw.Type
-	if t == nil {
-		return
+	if cond.Op != OLITERAL {
+		cond = copyexpr(cond, cond.Type, &sw.Nbody)
 	}
 
-	// convert the switch into OIF statements
-	var cas []*Node
-	if s.kind == switchKindTrue || s.kind == switchKindFalse {
-		s.exprname = nodbool(s.kind == switchKindTrue)
-	} else if consttype(cond) > 0 {
-		// leave constants to enable dead code elimination (issue 9608)
-		s.exprname = cond
-	} else {
-		s.exprname = temp(cond.Type)
-		cas = []*Node{nod(OAS, s.exprname, cond)} // This gets walk()ed again in walkstmtlist just before end of this function.  See #29562.
-		typecheckslice(cas, ctxStmt)
-	}
-
-	// Enumerate the cases and prepare the default case.
-	clauses := s.genCaseClauses(sw.List.Slice())
-	sw.List.Set(nil)
-	cc := clauses.list
-
-	// handle the cases in order
-	for len(cc) > 0 {
-		run := 1
-		if okforcmp[t.Etype] && cc[0].isconst {
-			// do binary search on runs of constants
-			for ; run < len(cc) && cc[run].isconst; run++ {
-			}
-			// sort and compile constants
-			sort.Sort(caseClauseByConstVal(cc[:run]))
-		}
-
-		a := s.walkCases(cc[:run])
-		cas = append(cas, a)
-		cc = cc[run:]
-	}
-
-	// handle default case
-	if nerrors == 0 {
-		cas = append(cas, clauses.defjmp)
-		sw.Nbody.Prepend(cas...)
-		walkstmtlist(sw.Nbody.Slice())
-	}
-}
-
-// walkCases generates an AST implementing the cases in cc.
-func (s *exprSwitch) walkCases(cc []caseClause) *Node {
-	if len(cc) < binarySearchMin {
-		// linear search
-		var cas []*Node
-		for _, c := range cc {
-			n := c.node
-			lno := setlineno(n)
-
-			a := nod(OIF, nil, nil)
-			if rng := n.List.Slice(); rng != nil {
-				// Integer range.
-				// exprname is a temp or a constant,
-				// so it is safe to evaluate twice.
-				// In most cases, this conjunction will be
-				// rewritten by walkinrange into a single comparison.
-				low := nod(OGE, s.exprname, rng[0])
-				high := nod(OLE, s.exprname, rng[1])
-				a.Left = nod(OANDAND, low, high)
-			} else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE {
-				a.Left = nod(OEQ, s.exprname, n.Left) // if name == val
-			} else if s.kind == switchKindTrue {
-				a.Left = n.Left // if val
-			} else {
-				// s.kind == switchKindFalse
-				a.Left = nod(ONOT, n.Left, nil) // if !val
-			}
-			a.Left = typecheck(a.Left, ctxExpr)
-			a.Left = defaultlit(a.Left, nil)
-			a.Nbody.Set1(n.Right) // goto l
-
-			cas = append(cas, a)
-			lineno = lno
-		}
-		return liststmt(cas)
-	}
-
-	// find the middle and recur
-	half := len(cc) / 2
-	a := nod(OIF, nil, nil)
-	n := cc[half-1].node
-	var mid *Node
-	if rng := n.List.Slice(); rng != nil {
-		mid = rng[1] // high end of range
-	} else {
-		mid = n.Left
-	}
-	le := nod(OLE, s.exprname, mid)
-	if Isconst(mid, CTSTR) {
-		// Search by length and then by value; see caseClauseByConstVal.
-		lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
-		leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
-		a.Left = nod(OOROR, lenlt, nod(OANDAND, leneq, le))
-	} else {
-		a.Left = le
-	}
-	a.Left = typecheck(a.Left, ctxExpr)
-	a.Left = defaultlit(a.Left, nil)
-	a.Nbody.Set1(s.walkCases(cc[:half]))
-	a.Rlist.Set1(s.walkCases(cc[half:]))
-	return a
-}
-
-// casebody builds separate lists of statements and cases.
-// It makes labels between cases and statements
-// and deals with fallthrough, break, and unreachable statements.
-func casebody(sw *Node, typeswvar *Node) {
-	if sw.List.Len() == 0 {
-		return
-	}
-
-	lno := setlineno(sw)
-
-	var cas []*Node  // cases
-	var stat []*Node // statements
-	var def *Node    // defaults
-	br := nod(OBREAK, nil, nil)
-
-	for _, n := range sw.List.Slice() {
-		setlineno(n)
-		if n.Op != OXCASE {
-			Fatalf("casebody %v", n.Op)
-		}
-		n.Op = OCASE
-		needvar := n.List.Len() != 1 || n.List.First().Op == OLITERAL
-
-		lbl := autolabel(".s")
-		jmp := nodSym(OGOTO, nil, lbl)
-		switch n.List.Len() {
-		case 0:
-			// default
-			if def != nil {
-				yyerrorl(n.Pos, "more than one default case")
-			}
-			// reuse original default case
-			n.Right = jmp
-			def = n
-		case 1:
-			// one case -- reuse OCASE node
-			n.Left = n.List.First()
-			n.Right = jmp
-			n.List.Set(nil)
-			cas = append(cas, n)
-		default:
-			// Expand multi-valued cases and detect ranges of integer cases.
-			if typeswvar != nil || sw.Left.Type.IsInterface() || !n.List.First().Type.IsInteger() || n.List.Len() < integerRangeMin {
-				// Can't use integer ranges. Expand each case into a separate node.
-				for _, n1 := range n.List.Slice() {
-					cas = append(cas, nod(OCASE, n1, jmp))
-				}
-				break
-			}
-			// Find integer ranges within runs of constants.
-			s := n.List.Slice()
-			j := 0
-			for j < len(s) {
-				// Find a run of constants.
-				var run int
-				for run = j; run < len(s) && Isconst(s[run], CTINT); run++ {
-				}
-				if run-j >= integerRangeMin {
-					// Search for integer ranges in s[j:run].
-					// Typechecking is done, so all values are already in an appropriate range.
-					search := s[j:run]
-					sort.Sort(constIntNodesByVal(search))
-					for beg, end := 0, 1; end <= len(search); end++ {
-						if end < len(search) && search[end].Int64() == search[end-1].Int64()+1 {
-							continue
-						}
-						if end-beg >= integerRangeMin {
-							// Record range in List.
-							c := nod(OCASE, nil, jmp)
-							c.List.Set2(search[beg], search[end-1])
-							cas = append(cas, c)
-						} else {
-							// Not large enough for range; record separately.
-							for _, n := range search[beg:end] {
-								cas = append(cas, nod(OCASE, n, jmp))
-							}
-						}
-						beg = end
-					}
-					j = run
-				}
-				// Advance to next constant, adding individual non-constant
-				// or as-yet-unhandled constant cases as we go.
-				for ; j < len(s) && (j < run || !Isconst(s[j], CTINT)); j++ {
-					cas = append(cas, nod(OCASE, s[j], jmp))
-				}
-			}
-		}
-
-		stat = append(stat, nodSym(OLABEL, nil, lbl))
-		if typeswvar != nil && needvar && n.Rlist.Len() != 0 {
-			l := []*Node{
-				nod(ODCL, n.Rlist.First(), nil),
-				nod(OAS, n.Rlist.First(), typeswvar),
-			}
-			typecheckslice(l, ctxStmt)
-			stat = append(stat, l...)
-		}
-		stat = append(stat, n.Nbody.Slice()...)
-
-		// Search backwards for the index of the fallthrough
-		// statement. Do not assume it'll be in the last
-		// position, since in some cases (e.g. when the statement
-		// list contains autotmp_ variables), one or more OVARKILL
-		// nodes will be at the end of the list.
-		fallIndex := len(stat) - 1
-		for stat[fallIndex].Op == OVARKILL {
-			fallIndex--
-		}
-		last := stat[fallIndex]
-		if last.Op != OFALL {
-			stat = append(stat, br)
-		}
-	}
-
-	stat = append(stat, br)
-	if def != nil {
-		cas = append(cas, def)
-	}
-
-	sw.List.Set(cas)
-	sw.Nbody.Set(stat)
 	lineno = lno
-}
 
-// genCaseClauses generates the caseClauses value for clauses.
-func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses {
-	var cc caseClauses
-	for _, n := range clauses {
-		if n.Left == nil && n.List.Len() == 0 {
-			// default case
-			if cc.defjmp != nil {
+	s := exprSwitch{
+		exprname: cond,
+	}
+
+	br := nod(OBREAK, nil, nil)
+	var defaultGoto *Node
+	var body Nodes
+	for _, ncase := range sw.List.Slice() {
+		label := autolabel(".s")
+		jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))
+
+		// Process case dispatch.
+		if ncase.List.Len() == 0 {
+			if defaultGoto != nil {
 				Fatalf("duplicate default case not detected during typechecking")
 			}
-			cc.defjmp = n.Right
-			continue
+			defaultGoto = jmp
 		}
-		c := caseClause{node: n, ordinal: len(cc.list)}
-		if n.List.Len() > 0 {
-			c.isconst = true
+
+		for _, n1 := range ncase.List.Slice() {
+			s.Add(ncase.Pos, n1, jmp)
 		}
-		switch consttype(n.Left) {
-		case CTFLT, CTINT, CTRUNE, CTSTR:
-			c.isconst = true
+
+		// Process body.
+		body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
+		body.Append(ncase.Nbody.Slice()...)
+		if !hasFall(ncase.Nbody.Slice()) {
+			body.Append(br)
 		}
-		cc.list = append(cc.list, c)
+	}
+	sw.List.Set(nil)
+
+	if defaultGoto == nil {
+		defaultGoto = br
 	}
 
-	if cc.defjmp == nil {
-		cc.defjmp = nod(OBREAK, nil, nil)
-	}
-	return cc
+	s.Emit(&sw.Nbody)
+	sw.Nbody.Append(defaultGoto)
+	sw.Nbody.AppendNodes(&body)
+	walkstmtlist(sw.Nbody.Slice())
 }
 
-// genCaseClauses generates the caseClauses value for clauses.
-func (s *typeSwitch) genCaseClauses(clauses []*Node) caseClauses {
-	var cc caseClauses
-	for _, n := range clauses {
-		switch {
-		case n.Left == nil:
-			// default case
-			if cc.defjmp != nil {
-				Fatalf("duplicate default case not detected during typechecking")
-			}
-			cc.defjmp = n.Right
-			continue
-		case n.Left.Op == OLITERAL:
-			// nil case in type switch
-			if cc.niljmp != nil {
-				Fatalf("duplicate nil case not detected during typechecking")
-			}
-			cc.niljmp = n.Right
-			continue
-		}
+// An exprSwitch walks an expression switch.
+type exprSwitch struct {
+	exprname *Node // value being switched on
 
-		// general case
-		c := caseClause{
-			node:    n,
-			ordinal: len(cc.list),
-			isconst: !n.Left.Type.IsInterface(),
-			hash:    typehash(n.Left.Type),
-		}
-		cc.list = append(cc.list, c)
-	}
-
-	if cc.defjmp == nil {
-		cc.defjmp = nod(OBREAK, nil, nil)
-	}
-
-	return cc
+	done    Nodes
+	clauses []exprClause
 }
 
-// walk generates an AST that implements sw,
-// where sw is a type switch.
-// The AST is generally of the form of a linear
-// search using if..goto, although binary search
-// is used with long runs of concrete types.
-func (s *typeSwitch) walk(sw *Node) {
-	cond := sw.Left
+type exprClause struct {
+	pos    src.XPos
+	lo, hi *Node
+	jmp    *Node
+}
+
+func (s *exprSwitch) Add(pos src.XPos, expr, jmp *Node) {
+	c := exprClause{pos: pos, lo: expr, hi: expr, jmp: jmp}
+	if okforcmp[s.exprname.Type.Etype] && expr.Op == OLITERAL {
+		s.clauses = append(s.clauses, c)
+		return
+	}
+
+	s.flush()
+	s.clauses = append(s.clauses, c)
+	s.flush()
+}
+
+func (s *exprSwitch) Emit(out *Nodes) {
+	s.flush()
+	out.AppendNodes(&s.done)
+}
+
+func (s *exprSwitch) flush() {
+	cc := s.clauses
+	s.clauses = nil
+	if len(cc) == 0 {
+		return
+	}
+
+	// Caution: If len(cc) == 1, then cc[0] might not an OLITERAL.
+	// The code below is structured to implicitly handle this case
+	// (e.g., sort.Slice doesn't need to invoke the less function
+	// when there's only a single slice element).
+
+	// Sort strings by length and then by value.
+	// It is much cheaper to compare lengths than values,
+	// and all we need here is consistency.
+	// We respect this sorting below.
+	sort.Slice(cc, func(i, j int) bool {
+		vi := cc[i].lo.Val()
+		vj := cc[j].lo.Val()
+
+		if s.exprname.Type.IsString() {
+			si := vi.U.(string)
+			sj := vj.U.(string)
+			if len(si) != len(sj) {
+				return len(si) < len(sj)
+			}
+			return si < sj
+		}
+
+		return compareOp(vi, OLT, vj)
+	})
+
+	// Merge consecutive integer cases.
+	if s.exprname.Type.IsInteger() {
+		merged := cc[:1]
+		for _, c := range cc[1:] {
+			last := &merged[len(merged)-1]
+			if last.jmp == c.jmp && last.hi.Int64()+1 == c.lo.Int64() {
+				last.hi = c.lo
+			} else {
+				merged = append(merged, c)
+			}
+		}
+		cc = merged
+	}
+
+	binarySearch(len(cc), &s.done,
+		func(i int) *Node {
+			mid := cc[i-1].hi
+
+			le := nod(OLE, s.exprname, mid)
+			if s.exprname.Type.IsString() {
+				// Compare strings by length and then
+				// by value; see sort.Slice above.
+				lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
+				leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
+				le = nod(OOROR, lenlt, nod(OANDAND, leneq, le))
+			}
+			return le
+		},
+		func(i int, out *Nodes) {
+			c := &cc[i]
+
+			nif := nodl(c.pos, OIF, c.test(s.exprname), nil)
+			nif.Left = typecheck(nif.Left, ctxExpr)
+			nif.Left = defaultlit(nif.Left, nil)
+			nif.Nbody.Set1(c.jmp)
+			out.Append(nif)
+		},
+	)
+}
+
+func (c *exprClause) test(exprname *Node) *Node {
+	// Integer range.
+	if c.hi != c.lo {
+		low := nodl(c.pos, OGE, exprname, c.lo)
+		high := nodl(c.pos, OLE, exprname, c.hi)
+		return nodl(c.pos, OANDAND, low, high)
+	}
+
+	// Optimize "switch true { ...}" and "switch false { ... }".
+	if Isconst(exprname, CTBOOL) && !c.lo.Type.IsInterface() {
+		if exprname.Val().U.(bool) {
+			return c.lo
+		} else {
+			return nodl(c.pos, ONOT, c.lo, nil)
+		}
+	}
+
+	return nodl(c.pos, OEQ, exprname, c.lo)
+}
+
+func allCaseExprsAreSideEffectFree(sw *Node) bool {
+	// In theory, we could be more aggressive, allowing any
+	// side-effect-free expressions in cases, but it's a bit
+	// tricky because some of that information is unavailable due
+	// to the introduction of temporaries during order.
+	// Restricting to constants is simple and probably powerful
+	// enough.
+
+	for _, ncase := range sw.List.Slice() {
+		if ncase.Op != OXCASE {
+			Fatalf("switch string(byteslice) bad op: %v", ncase.Op)
+		}
+		for _, v := range ncase.List.Slice() {
+			if v.Op != OLITERAL {
+				return false
+			}
+		}
+	}
+	return true
+}
+
+// hasFall reports whether stmts ends with a "fallthrough" statement.
+func hasFall(stmts []*Node) bool {
+	// Search backwards for the index of the fallthrough
+	// statement. Do not assume it'll be in the last
+	// position, since in some cases (e.g. when the statement
+	// list contains autotmp_ variables), one or more OVARKILL
+	// nodes will be at the end of the list.
+
+	i := len(stmts) - 1
+	for i >= 0 && stmts[i].Op == OVARKILL {
+		i--
+	}
+	return i >= 0 && stmts[i].Op == OFALL
+}
+
+// walkTypeSwitch generates an AST that implements sw, where sw is a
+// type switch.
+func walkTypeSwitch(sw *Node) {
+	var s typeSwitch
+	s.facename = sw.Left.Right
 	sw.Left = nil
 
-	if cond == nil {
-		sw.List.Set(nil)
-		return
-	}
-	if cond.Right == nil {
-		yyerrorl(sw.Pos, "type switch must have an assignment")
-		return
-	}
-
-	cond.Right = walkexpr(cond.Right, &sw.Ninit)
-	if !cond.Right.Type.IsInterface() {
-		yyerrorl(sw.Pos, "type switch must be on an interface")
-		return
-	}
-
-	var cas []*Node
-
-	// predeclare temporary variables and the boolean var
-	s.facename = temp(cond.Right.Type)
-
-	a := nod(OAS, s.facename, cond.Right)
-	a = typecheck(a, ctxStmt)
-	cas = append(cas, a)
-
+	s.facename = walkexpr(s.facename, &sw.Ninit)
+	s.facename = copyexpr(s.facename, s.facename.Type, &sw.Nbody)
 	s.okname = temp(types.Types[TBOOL])
-	s.okname = typecheck(s.okname, ctxExpr)
 
-	s.hashname = temp(types.Types[TUINT32])
-	s.hashname = typecheck(s.hashname, ctxExpr)
-
-	// set up labels and jumps
-	casebody(sw, s.facename)
-
-	clauses := s.genCaseClauses(sw.List.Slice())
-	sw.List.Set(nil)
-	def := clauses.defjmp
+	// Get interface descriptor word.
+	// For empty interfaces this will be the type.
+	// For non-empty interfaces this will be the itab.
+	itab := nod(OITAB, s.facename, nil)
 
 	// For empty interfaces, do:
 	//     if e._type == nil {
@@ -688,230 +487,235 @@
 	//     }
 	//     h := e._type.hash
 	// Use a similar strategy for non-empty interfaces.
-
-	// Get interface descriptor word.
-	// For empty interfaces this will be the type.
-	// For non-empty interfaces this will be the itab.
-	itab := nod(OITAB, s.facename, nil)
-
-	// Check for nil first.
-	i := nod(OIF, nil, nil)
-	i.Left = nod(OEQ, itab, nodnil())
-	if clauses.niljmp != nil {
-		// Do explicit nil case right here.
-		i.Nbody.Set1(clauses.niljmp)
-	} else {
-		// Jump to default case.
-		lbl := autolabel(".s")
-		i.Nbody.Set1(nodSym(OGOTO, nil, lbl))
-		// Wrap default case with label.
-		blk := nod(OBLOCK, nil, nil)
-		blk.List.Set2(nodSym(OLABEL, nil, lbl), def)
-		def = blk
-	}
-	i.Left = typecheck(i.Left, ctxExpr)
-	i.Left = defaultlit(i.Left, nil)
-	cas = append(cas, i)
+	ifNil := nod(OIF, nil, nil)
+	ifNil.Left = nod(OEQ, itab, nodnil())
+	ifNil.Left = typecheck(ifNil.Left, ctxExpr)
+	ifNil.Left = defaultlit(ifNil.Left, nil)
+	// ifNil.Nbody assigned at end.
+	sw.Nbody.Append(ifNil)
 
 	// Load hash from type or itab.
-	h := nodSym(ODOTPTR, itab, nil)
-	h.Type = types.Types[TUINT32]
-	h.SetTypecheck(1)
-	if cond.Right.Type.IsEmptyInterface() {
-		h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type
+	dotHash := nodSym(ODOTPTR, itab, nil)
+	dotHash.Type = types.Types[TUINT32]
+	dotHash.SetTypecheck(1)
+	if s.facename.Type.IsEmptyInterface() {
+		dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type
 	} else {
-		h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab
+		dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab
 	}
-	h.SetBounded(true) // guaranteed not to fault
-	a = nod(OAS, s.hashname, h)
-	a = typecheck(a, ctxStmt)
-	cas = append(cas, a)
+	dotHash.SetBounded(true) // guaranteed not to fault
+	s.hashname = copyexpr(dotHash, dotHash.Type, &sw.Nbody)
 
-	cc := clauses.list
-
-	// insert type equality check into each case block
-	for _, c := range cc {
-		c.node.Right = s.typeone(c.node)
-	}
-
-	// generate list of if statements, binary search for constant sequences
-	for len(cc) > 0 {
-		if !cc[0].isconst {
-			n := cc[0].node
-			cas = append(cas, n.Right)
-			cc = cc[1:]
-			continue
+	br := nod(OBREAK, nil, nil)
+	var defaultGoto, nilGoto *Node
+	var body Nodes
+	for _, ncase := range sw.List.Slice() {
+		var caseVar *Node
+		if ncase.Rlist.Len() != 0 {
+			caseVar = ncase.Rlist.First()
 		}
 
-		// identify run of constants
-		var run int
-		for run = 1; run < len(cc) && cc[run].isconst; run++ {
-		}
+		// For single-type cases, we initialize the case
+		// variable as part of the type assertion; but in
+		// other cases, we initialize it in the body.
+		singleType := ncase.List.Len() == 1 && ncase.List.First().Op == OTYPE
 
-		// sort by hash
-		sort.Sort(caseClauseByType(cc[:run]))
+		label := autolabel(".s")
 
-		// for debugging: linear search
-		if false {
-			for i := 0; i < run; i++ {
-				n := cc[i].node
-				cas = append(cas, n.Right)
+		jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))
+		if ncase.List.Len() == 0 { // default:
+			if defaultGoto != nil {
+				Fatalf("duplicate default case not detected during typechecking")
 			}
-			continue
+			defaultGoto = jmp
 		}
 
-		// combine adjacent cases with the same hash
-		var batch []caseClause
-		for i, j := 0, 0; i < run; i = j {
-			hash := []*Node{cc[i].node.Right}
-			for j = i + 1; j < run && cc[i].hash == cc[j].hash; j++ {
-				hash = append(hash, cc[j].node.Right)
+		for _, n1 := range ncase.List.Slice() {
+			if n1.isNil() { // case nil:
+				if nilGoto != nil {
+					Fatalf("duplicate nil case not detected during typechecking")
+				}
+				nilGoto = jmp
+				continue
 			}
-			cc[i].node.Right = liststmt(hash)
-			batch = append(batch, cc[i])
+
+			if singleType {
+				s.Add(n1.Type, caseVar, jmp)
+			} else {
+				s.Add(n1.Type, nil, jmp)
+			}
 		}
 
-		// binary search among cases to narrow by hash
-		cas = append(cas, s.walkCases(batch))
-		cc = cc[run:]
+		body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
+		if caseVar != nil && !singleType {
+			l := []*Node{
+				nodl(ncase.Pos, ODCL, caseVar, nil),
+				nodl(ncase.Pos, OAS, caseVar, s.facename),
+			}
+			typecheckslice(l, ctxStmt)
+			body.Append(l...)
+		}
+		body.Append(ncase.Nbody.Slice()...)
+		body.Append(br)
+	}
+	sw.List.Set(nil)
+
+	if defaultGoto == nil {
+		defaultGoto = br
 	}
 
-	// handle default case
-	if nerrors == 0 {
-		cas = append(cas, def)
-		sw.Nbody.Prepend(cas...)
-		sw.List.Set(nil)
-		walkstmtlist(sw.Nbody.Slice())
+	if nilGoto != nil {
+		ifNil.Nbody.Set1(nilGoto)
+	} else {
+		// TODO(mdempsky): Just use defaultGoto directly.
+
+		// Jump to default case.
+		label := autolabel(".s")
+		ifNil.Nbody.Set1(nodSym(OGOTO, nil, label))
+		// Wrap default case with label.
+		blk := nod(OBLOCK, nil, nil)
+		blk.List.Set2(nodSym(OLABEL, nil, label), defaultGoto)
+		defaultGoto = blk
 	}
+
+	s.Emit(&sw.Nbody)
+	sw.Nbody.Append(defaultGoto)
+	sw.Nbody.AppendNodes(&body)
+
+	walkstmtlist(sw.Nbody.Slice())
 }
 
-// typeone generates an AST that jumps to the
-// case body if the variable is of type t.
-func (s *typeSwitch) typeone(t *Node) *Node {
-	var name *Node
-	var init Nodes
-	if t.Rlist.Len() == 0 {
-		name = nblank
-		nblank = typecheck(nblank, ctxExpr|ctxAssign)
-	} else {
-		name = t.Rlist.First()
-		init.Append(nod(ODCL, name, nil))
-		a := nod(OAS, name, nil)
-		a = typecheck(a, ctxStmt)
-		init.Append(a)
-	}
+// A typeSwitch walks a type switch.
+type typeSwitch struct {
+	// Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
+	facename *Node // value being type-switched on
+	hashname *Node // type hash of the value being type-switched on
+	okname   *Node // boolean used for comma-ok type assertions
 
-	a := nod(OAS2, nil, nil)
-	a.List.Set2(name, s.okname) // name, ok =
-	b := nod(ODOTTYPE, s.facename, nil)
-	b.Type = t.Left.Type // interface.(type)
-	a.Rlist.Set1(b)
-	a = typecheck(a, ctxStmt)
-	a = walkexpr(a, &init)
-	init.Append(a)
-
-	c := nod(OIF, nil, nil)
-	c.Left = s.okname
-	c.Nbody.Set1(t.Right) // if ok { goto l }
-
-	init.Append(c)
-	return init.asblock()
+	done    Nodes
+	clauses []typeClause
 }
 
-// walkCases generates an AST implementing the cases in cc.
-func (s *typeSwitch) walkCases(cc []caseClause) *Node {
-	if len(cc) < binarySearchMin {
-		var cas []*Node
-		for _, c := range cc {
-			n := c.node
-			if !c.isconst {
-				Fatalf("typeSwitch walkCases")
-			}
+type typeClause struct {
+	hash uint32
+	body Nodes
+}
+
+func (s *typeSwitch) Add(typ *types.Type, caseVar *Node, jmp *Node) {
+	var body Nodes
+	if caseVar != nil {
+		l := []*Node{
+			nod(ODCL, caseVar, nil),
+			nod(OAS, caseVar, nil),
+		}
+		typecheckslice(l, ctxStmt)
+		body.Append(l...)
+	} else {
+		caseVar = nblank
+	}
+
+	// cv, ok = iface.(type)
+	as := nod(OAS2, nil, nil)
+	as.List.Set2(caseVar, s.okname) // cv, ok =
+	dot := nod(ODOTTYPE, s.facename, nil)
+	dot.Type = typ // iface.(type)
+	as.Rlist.Set1(dot)
+	as = typecheck(as, ctxStmt)
+	as = walkexpr(as, &body)
+	body.Append(as)
+
+	// if ok { goto label }
+	nif := nod(OIF, nil, nil)
+	nif.Left = s.okname
+	nif.Nbody.Set1(jmp)
+	body.Append(nif)
+
+	if !typ.IsInterface() {
+		s.clauses = append(s.clauses, typeClause{
+			hash: typehash(typ),
+			body: body,
+		})
+		return
+	}
+
+	s.flush()
+	s.done.AppendNodes(&body)
+}
+
+func (s *typeSwitch) Emit(out *Nodes) {
+	s.flush()
+	out.AppendNodes(&s.done)
+}
+
+func (s *typeSwitch) flush() {
+	cc := s.clauses
+	s.clauses = nil
+	if len(cc) == 0 {
+		return
+	}
+
+	sort.Slice(cc, func(i, j int) bool { return cc[i].hash < cc[j].hash })
+
+	// Combine adjacent cases with the same hash.
+	merged := cc[:1]
+	for _, c := range cc[1:] {
+		last := &merged[len(merged)-1]
+		if last.hash == c.hash {
+			last.body.AppendNodes(&c.body)
+		} else {
+			merged = append(merged, c)
+		}
+	}
+	cc = merged
+
+	binarySearch(len(cc), &s.done,
+		func(i int) *Node {
+			return nod(OLE, s.hashname, nodintconst(int64(cc[i-1].hash)))
+		},
+		func(i int, out *Nodes) {
+			// TODO(mdempsky): Omit hash equality check if
+			// there's only one type.
+			c := cc[i]
 			a := nod(OIF, nil, nil)
 			a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash)))
 			a.Left = typecheck(a.Left, ctxExpr)
 			a.Left = defaultlit(a.Left, nil)
-			a.Nbody.Set1(n.Right)
-			cas = append(cas, a)
+			a.Nbody.AppendNodes(&c.body)
+			out.Append(a)
+		},
+	)
+}
+
+// binarySearch constructs a binary search tree for handling n cases,
+// and appends it to out. It's used for efficiently implementing
+// switch statements.
+//
+// less(i) should return a boolean expression. If it evaluates true,
+// then cases [0, i) will be tested; otherwise, cases [i, n).
+//
+// base(i, out) should append statements to out to test the i'th case.
+func binarySearch(n int, out *Nodes, less func(i int) *Node, base func(i int, out *Nodes)) {
+	const binarySearchMin = 4 // minimum number of cases for binary search
+
+	var do func(lo, hi int, out *Nodes)
+	do = func(lo, hi int, out *Nodes) {
+		n := hi - lo
+		if n < binarySearchMin {
+			for i := lo; i < hi; i++ {
+				base(i, out)
+			}
+			return
 		}
-		return liststmt(cas)
+
+		half := lo + n/2
+		nif := nod(OIF, nil, nil)
+		nif.Left = less(half)
+		nif.Left = typecheck(nif.Left, ctxExpr)
+		nif.Left = defaultlit(nif.Left, nil)
+		do(lo, half, &nif.Nbody)
+		do(half, hi, &nif.Rlist)
+		out.Append(nif)
 	}
 
-	// find the middle and recur
-	half := len(cc) / 2
-	a := nod(OIF, nil, nil)
-	a.Left = nod(OLE, s.hashname, nodintconst(int64(cc[half-1].hash)))
-	a.Left = typecheck(a.Left, ctxExpr)
-	a.Left = defaultlit(a.Left, nil)
-	a.Nbody.Set1(s.walkCases(cc[:half]))
-	a.Rlist.Set1(s.walkCases(cc[half:]))
-	return a
-}
-
-// caseClauseByConstVal sorts clauses by constant value to enable binary search.
-type caseClauseByConstVal []caseClause
-
-func (x caseClauseByConstVal) Len() int      { return len(x) }
-func (x caseClauseByConstVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
-func (x caseClauseByConstVal) Less(i, j int) bool {
-	// n1 and n2 might be individual constants or integer ranges.
-	// We have checked for duplicates already,
-	// so ranges can be safely represented by any value in the range.
-	n1 := x[i].node
-	var v1 interface{}
-	if s := n1.List.Slice(); s != nil {
-		v1 = s[0].Val().U
-	} else {
-		v1 = n1.Left.Val().U
-	}
-
-	n2 := x[j].node
-	var v2 interface{}
-	if s := n2.List.Slice(); s != nil {
-		v2 = s[0].Val().U
-	} else {
-		v2 = n2.Left.Val().U
-	}
-
-	switch v1 := v1.(type) {
-	case *Mpflt:
-		return v1.Cmp(v2.(*Mpflt)) < 0
-	case *Mpint:
-		return v1.Cmp(v2.(*Mpint)) < 0
-	case string:
-		// Sort strings by length and then by value.
-		// It is much cheaper to compare lengths than values,
-		// and all we need here is consistency.
-		// We respect this sorting in exprSwitch.walkCases.
-		a := v1
-		b := v2.(string)
-		if len(a) != len(b) {
-			return len(a) < len(b)
-		}
-		return a < b
-	}
-
-	Fatalf("caseClauseByConstVal passed bad clauses %v < %v", x[i].node.Left, x[j].node.Left)
-	return false
-}
-
-type caseClauseByType []caseClause
-
-func (x caseClauseByType) Len() int      { return len(x) }
-func (x caseClauseByType) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
-func (x caseClauseByType) Less(i, j int) bool {
-	c1, c2 := x[i], x[j]
-	// sort by hash code, then ordinal (for the rare case of hash collisions)
-	if c1.hash != c2.hash {
-		return c1.hash < c2.hash
-	}
-	return c1.ordinal < c2.ordinal
-}
-
-type constIntNodesByVal []*Node
-
-func (x constIntNodesByVal) Len() int      { return len(x) }
-func (x constIntNodesByVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
-func (x constIntNodesByVal) Less(i, j int) bool {
-	return x[i].Val().U.(*Mpint).Cmp(x[j].Val().U.(*Mpint)) < 0
+	do(0, n, out)
 }
diff --git a/src/cmd/compile/internal/gc/swt_test.go b/src/cmd/compile/internal/gc/swt_test.go
deleted file mode 100644
index 2f73ef7..0000000
--- a/src/cmd/compile/internal/gc/swt_test.go
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright 2015 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package gc
-
-import (
-	"testing"
-)
-
-func nodrune(r rune) *Node {
-	v := new(Mpint)
-	v.SetInt64(int64(r))
-	v.Rune = true
-	return nodlit(Val{v})
-}
-
-func nodflt(f float64) *Node {
-	v := newMpflt()
-	v.SetFloat64(f)
-	return nodlit(Val{v})
-}
-
-func TestCaseClauseByConstVal(t *testing.T) {
-	tests := []struct {
-		a, b *Node
-	}{
-		// CTFLT
-		{nodflt(0.1), nodflt(0.2)},
-		// CTINT
-		{nodintconst(0), nodintconst(1)},
-		// CTRUNE
-		{nodrune('a'), nodrune('b')},
-		// CTSTR
-		{nodlit(Val{"ab"}), nodlit(Val{"abc"})},
-		{nodlit(Val{"ab"}), nodlit(Val{"xyz"})},
-		{nodlit(Val{"abc"}), nodlit(Val{"xyz"})},
-	}
-	for i, test := range tests {
-		a := caseClause{node: nod(OXXX, test.a, nil)}
-		b := caseClause{node: nod(OXXX, test.b, nil)}
-		s := caseClauseByConstVal{a, b}
-		if less := s.Less(0, 1); !less {
-			t.Errorf("%d: caseClauseByConstVal(%v, %v) = false", i, test.a, test.b)
-		}
-		if less := s.Less(1, 0); less {
-			t.Errorf("%d: caseClauseByConstVal(%v, %v) = true", i, test.a, test.b)
-		}
-	}
-}