encoding/json: add Encoder.DisableHTMLEscaping

This provides a way to disable the escaping of <, >, and & in JSON
strings.

Fixes #14749.

Change-Id: I1afeb0244455fc8b06c6cce920444532f229555b
Reviewed-on: https://go-review.googlesource.com/21796
Run-TryBot: Caleb Spare <cespare@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/src/encoding/json/encode.go b/src/encoding/json/encode.go
index 0088f25..d8c7798 100644
--- a/src/encoding/json/encode.go
+++ b/src/encoding/json/encode.go
@@ -49,6 +49,7 @@
 // The angle brackets "<" and ">" are escaped to "\u003c" and "\u003e"
 // to keep some browsers from misinterpreting JSON output as HTML.
 // Ampersand "&" is also escaped to "\u0026" for the same reason.
+// This escaping can be disabled using an Encoder with DisableHTMLEscaping.
 //
 // Array and slice values encode as JSON arrays, except that
 // []byte encodes as a base64-encoded string, and a nil slice
@@ -136,7 +137,7 @@
 //
 func Marshal(v interface{}) ([]byte, error) {
 	e := &encodeState{}
-	err := e.marshal(v)
+	err := e.marshal(v, encOpts{escapeHTML: true})
 	if err != nil {
 		return nil, err
 	}
@@ -259,7 +260,7 @@
 	return new(encodeState)
 }
 
-func (e *encodeState) marshal(v interface{}) (err error) {
+func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) {
 	defer func() {
 		if r := recover(); r != nil {
 			if _, ok := r.(runtime.Error); ok {
@@ -271,7 +272,7 @@
 			err = r.(error)
 		}
 	}()
-	e.reflectValue(reflect.ValueOf(v))
+	e.reflectValue(reflect.ValueOf(v), opts)
 	return nil
 }
 
@@ -297,11 +298,18 @@
 	return false
 }
 
-func (e *encodeState) reflectValue(v reflect.Value) {
-	valueEncoder(v)(e, v, false)
+func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
+	valueEncoder(v)(e, v, opts)
 }
 
-type encoderFunc func(e *encodeState, v reflect.Value, quoted bool)
+type encOpts struct {
+	// quoted causes primitive fields to be encoded inside JSON strings.
+	quoted bool
+	// escapeHTML causes '<', '>', and '&' to be escaped in JSON strings.
+	escapeHTML bool
+}
+
+type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
 
 var encoderCache struct {
 	sync.RWMutex
@@ -333,9 +341,9 @@
 	}
 	var wg sync.WaitGroup
 	wg.Add(1)
-	encoderCache.m[t] = func(e *encodeState, v reflect.Value, quoted bool) {
+	encoderCache.m[t] = func(e *encodeState, v reflect.Value, opts encOpts) {
 		wg.Wait()
-		f(e, v, quoted)
+		f(e, v, opts)
 	}
 	encoderCache.Unlock()
 
@@ -405,11 +413,11 @@
 	}
 }
 
-func invalidValueEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func invalidValueEncoder(e *encodeState, v reflect.Value, _ encOpts) {
 	e.WriteString("null")
 }
 
-func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
 	if v.Kind() == reflect.Ptr && v.IsNil() {
 		e.WriteString("null")
 		return
@@ -418,14 +426,14 @@
 	b, err := m.MarshalJSON()
 	if err == nil {
 		// copy JSON into buffer, checking validity.
-		err = compact(&e.Buffer, b, true)
+		err = compact(&e.Buffer, b, opts.escapeHTML)
 	}
 	if err != nil {
 		e.error(&MarshalerError{v.Type(), err})
 	}
 }
 
-func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func addrMarshalerEncoder(e *encodeState, v reflect.Value, _ encOpts) {
 	va := v.Addr()
 	if va.IsNil() {
 		e.WriteString("null")
@@ -442,7 +450,7 @@
 	}
 }
 
-func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
 	if v.Kind() == reflect.Ptr && v.IsNil() {
 		e.WriteString("null")
 		return
@@ -452,10 +460,10 @@
 	if err != nil {
 		e.error(&MarshalerError{v.Type(), err})
 	}
-	e.stringBytes(b)
+	e.stringBytes(b, opts.escapeHTML)
 }
 
-func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
 	va := v.Addr()
 	if va.IsNil() {
 		e.WriteString("null")
@@ -466,11 +474,11 @@
 	if err != nil {
 		e.error(&MarshalerError{v.Type(), err})
 	}
-	e.stringBytes(b)
+	e.stringBytes(b, opts.escapeHTML)
 }
 
-func boolEncoder(e *encodeState, v reflect.Value, quoted bool) {
-	if quoted {
+func boolEncoder(e *encodeState, v reflect.Value, opts encOpts) {
+	if opts.quoted {
 		e.WriteByte('"')
 	}
 	if v.Bool() {
@@ -478,46 +486,46 @@
 	} else {
 		e.WriteString("false")
 	}
-	if quoted {
+	if opts.quoted {
 		e.WriteByte('"')
 	}
 }
 
-func intEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func intEncoder(e *encodeState, v reflect.Value, opts encOpts) {
 	b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
-	if quoted {
+	if opts.quoted {
 		e.WriteByte('"')
 	}
 	e.Write(b)
-	if quoted {
+	if opts.quoted {
 		e.WriteByte('"')
 	}
 }
 
-func uintEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
 	b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
-	if quoted {
+	if opts.quoted {
 		e.WriteByte('"')
 	}
 	e.Write(b)
-	if quoted {
+	if opts.quoted {
 		e.WriteByte('"')
 	}
 }
 
 type floatEncoder int // number of bits
 
-func (bits floatEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
+func (bits floatEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
 	f := v.Float()
 	if math.IsInf(f, 0) || math.IsNaN(f) {
 		e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, int(bits))})
 	}
 	b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, int(bits))
-	if quoted {
+	if opts.quoted {
 		e.WriteByte('"')
 	}
 	e.Write(b)
-	if quoted {
+	if opts.quoted {
 		e.WriteByte('"')
 	}
 }
@@ -527,7 +535,7 @@
 	float64Encoder = (floatEncoder(64)).encode
 )
 
-func stringEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func stringEncoder(e *encodeState, v reflect.Value, opts encOpts) {
 	if v.Type() == numberType {
 		numStr := v.String()
 		// In Go1.5 the empty string encodes to "0", while this is not a valid number literal
@@ -541,26 +549,26 @@
 		e.WriteString(numStr)
 		return
 	}
-	if quoted {
+	if opts.quoted {
 		sb, err := Marshal(v.String())
 		if err != nil {
 			e.error(err)
 		}
-		e.string(string(sb))
+		e.string(string(sb), opts.escapeHTML)
 	} else {
-		e.string(v.String())
+		e.string(v.String(), opts.escapeHTML)
 	}
 }
 
-func interfaceEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func interfaceEncoder(e *encodeState, v reflect.Value, opts encOpts) {
 	if v.IsNil() {
 		e.WriteString("null")
 		return
 	}
-	e.reflectValue(v.Elem())
+	e.reflectValue(v.Elem(), opts)
 }
 
-func unsupportedTypeEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func unsupportedTypeEncoder(e *encodeState, v reflect.Value, _ encOpts) {
 	e.error(&UnsupportedTypeError{v.Type()})
 }
 
@@ -569,7 +577,7 @@
 	fieldEncs []encoderFunc
 }
 
-func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
+func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
 	e.WriteByte('{')
 	first := true
 	for i, f := range se.fields {
@@ -582,9 +590,10 @@
 		} else {
 			e.WriteByte(',')
 		}
-		e.string(f.name)
+		e.string(f.name, opts.escapeHTML)
 		e.WriteByte(':')
-		se.fieldEncs[i](e, fv, f.quoted)
+		opts.quoted = f.quoted
+		se.fieldEncs[i](e, fv, opts)
 	}
 	e.WriteByte('}')
 }
@@ -605,7 +614,7 @@
 	elemEnc encoderFunc
 }
 
-func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
+func (me *mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
 	if v.IsNil() {
 		e.WriteString("null")
 		return
@@ -627,9 +636,9 @@
 		if i > 0 {
 			e.WriteByte(',')
 		}
-		e.string(kv.s)
+		e.string(kv.s, opts.escapeHTML)
 		e.WriteByte(':')
-		me.elemEnc(e, v.MapIndex(kv.v), false)
+		me.elemEnc(e, v.MapIndex(kv.v), opts)
 	}
 	e.WriteByte('}')
 }
@@ -642,7 +651,7 @@
 	return me.encode
 }
 
-func encodeByteSlice(e *encodeState, v reflect.Value, _ bool) {
+func encodeByteSlice(e *encodeState, v reflect.Value, _ encOpts) {
 	if v.IsNil() {
 		e.WriteString("null")
 		return
@@ -669,12 +678,12 @@
 	arrayEnc encoderFunc
 }
 
-func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
+func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
 	if v.IsNil() {
 		e.WriteString("null")
 		return
 	}
-	se.arrayEnc(e, v, false)
+	se.arrayEnc(e, v, opts)
 }
 
 func newSliceEncoder(t reflect.Type) encoderFunc {
@@ -692,14 +701,14 @@
 	elemEnc encoderFunc
 }
 
-func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
+func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
 	e.WriteByte('[')
 	n := v.Len()
 	for i := 0; i < n; i++ {
 		if i > 0 {
 			e.WriteByte(',')
 		}
-		ae.elemEnc(e, v.Index(i), false)
+		ae.elemEnc(e, v.Index(i), opts)
 	}
 	e.WriteByte(']')
 }
@@ -713,12 +722,12 @@
 	elemEnc encoderFunc
 }
 
-func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
+func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
 	if v.IsNil() {
 		e.WriteString("null")
 		return
 	}
-	pe.elemEnc(e, v.Elem(), quoted)
+	pe.elemEnc(e, v.Elem(), opts)
 }
 
 func newPtrEncoder(t reflect.Type) encoderFunc {
@@ -730,11 +739,11 @@
 	canAddrEnc, elseEnc encoderFunc
 }
 
-func (ce *condAddrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
+func (ce *condAddrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
 	if v.CanAddr() {
-		ce.canAddrEnc(e, v, quoted)
+		ce.canAddrEnc(e, v, opts)
 	} else {
-		ce.elseEnc(e, v, quoted)
+		ce.elseEnc(e, v, opts)
 	}
 }
 
@@ -812,13 +821,14 @@
 func (sv byString) Less(i, j int) bool { return sv[i].s < sv[j].s }
 
 // NOTE: keep in sync with stringBytes below.
-func (e *encodeState) string(s string) int {
+func (e *encodeState) string(s string, escapeHTML bool) int {
 	len0 := e.Len()
 	e.WriteByte('"')
 	start := 0
 	for i := 0; i < len(s); {
 		if b := s[i]; b < utf8.RuneSelf {
-			if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
+			if 0x20 <= b && b != '\\' && b != '"' &&
+				(!escapeHTML || b != '<' && b != '>' && b != '&') {
 				i++
 				continue
 			}
@@ -839,10 +849,11 @@
 				e.WriteByte('\\')
 				e.WriteByte('t')
 			default:
-				// This encodes bytes < 0x20 except for \n and \r,
-				// as well as <, > and &. The latter are escaped because they
-				// can lead to security holes when user-controlled strings
-				// are rendered into JSON and served to some browsers.
+				// This encodes bytes < 0x20 except for \t, \n and \r.
+				// If escapeHTML is set, it also escapes <, >, and &
+				// because they can lead to security holes when
+				// user-controlled strings are rendered into JSON
+				// and served to some browsers.
 				e.WriteString(`\u00`)
 				e.WriteByte(hex[b>>4])
 				e.WriteByte(hex[b&0xF])
@@ -888,13 +899,14 @@
 }
 
 // NOTE: keep in sync with string above.
-func (e *encodeState) stringBytes(s []byte) int {
+func (e *encodeState) stringBytes(s []byte, escapeHTML bool) int {
 	len0 := e.Len()
 	e.WriteByte('"')
 	start := 0
 	for i := 0; i < len(s); {
 		if b := s[i]; b < utf8.RuneSelf {
-			if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
+			if 0x20 <= b && b != '\\' && b != '"' &&
+				(!escapeHTML || b != '<' && b != '>' && b != '&') {
 				i++
 				continue
 			}
@@ -915,10 +927,11 @@
 				e.WriteByte('\\')
 				e.WriteByte('t')
 			default:
-				// This encodes bytes < 0x20 except for \n and \r,
-				// as well as <, >, and &. The latter are escaped because they
-				// can lead to security holes when user-controlled strings
-				// are rendered into JSON and served to some browsers.
+				// This encodes bytes < 0x20 except for \t, \n and \r.
+				// If escapeHTML is set, it also escapes <, >, and &
+				// because they can lead to security holes when
+				// user-controlled strings are rendered into JSON
+				// and served to some browsers.
 				e.WriteString(`\u00`)
 				e.WriteByte(hex[b>>4])
 				e.WriteByte(hex[b&0xF])
diff --git a/src/encoding/json/encode_test.go b/src/encoding/json/encode_test.go
index eee59cc..b484022 100644
--- a/src/encoding/json/encode_test.go
+++ b/src/encoding/json/encode_test.go
@@ -376,41 +376,45 @@
 
 func TestStringBytes(t *testing.T) {
 	// Test that encodeState.stringBytes and encodeState.string use the same encoding.
-	es := &encodeState{}
 	var r []rune
 	for i := '\u0000'; i <= unicode.MaxRune; i++ {
 		r = append(r, i)
 	}
 	s := string(r) + "\xff\xff\xffhello" // some invalid UTF-8 too
-	es.string(s)
 
-	esBytes := &encodeState{}
-	esBytes.stringBytes([]byte(s))
+	for _, escapeHTML := range []bool{true, false} {
+		es := &encodeState{}
+		es.string(s, escapeHTML)
 
-	enc := es.Buffer.String()
-	encBytes := esBytes.Buffer.String()
-	if enc != encBytes {
-		i := 0
-		for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] {
-			i++
-		}
-		enc = enc[i:]
-		encBytes = encBytes[i:]
-		i = 0
-		for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] {
-			i++
-		}
-		enc = enc[:len(enc)-i]
-		encBytes = encBytes[:len(encBytes)-i]
+		esBytes := &encodeState{}
+		esBytes.stringBytes([]byte(s), escapeHTML)
 
-		if len(enc) > 20 {
-			enc = enc[:20] + "..."
-		}
-		if len(encBytes) > 20 {
-			encBytes = encBytes[:20] + "..."
-		}
+		enc := es.Buffer.String()
+		encBytes := esBytes.Buffer.String()
+		if enc != encBytes {
+			i := 0
+			for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] {
+				i++
+			}
+			enc = enc[i:]
+			encBytes = encBytes[i:]
+			i = 0
+			for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] {
+				i++
+			}
+			enc = enc[:len(enc)-i]
+			encBytes = encBytes[:len(encBytes)-i]
 
-		t.Errorf("encodings differ at %#q vs %#q", enc, encBytes)
+			if len(enc) > 20 {
+				enc = enc[:20] + "..."
+			}
+			if len(encBytes) > 20 {
+				encBytes = encBytes[:20] + "..."
+			}
+
+			t.Errorf("with escapeHTML=%t, encodings differ at %#q vs %#q",
+				escapeHTML, enc, encBytes)
+		}
 	}
 }
 
diff --git a/src/encoding/json/stream.go b/src/encoding/json/stream.go
index 422837b..d6b2992 100644
--- a/src/encoding/json/stream.go
+++ b/src/encoding/json/stream.go
@@ -166,8 +166,9 @@
 
 // An Encoder writes JSON values to an output stream.
 type Encoder struct {
-	w   io.Writer
-	err error
+	w          io.Writer
+	err        error
+	escapeHTML bool
 
 	indentBuf    *bytes.Buffer
 	indentPrefix string
@@ -176,7 +177,7 @@
 
 // NewEncoder returns a new encoder that writes to w.
 func NewEncoder(w io.Writer) *Encoder {
-	return &Encoder{w: w}
+	return &Encoder{w: w, escapeHTML: true}
 }
 
 // Encode writes the JSON encoding of v to the stream,
@@ -189,7 +190,7 @@
 		return enc.err
 	}
 	e := newEncodeState()
-	err := e.marshal(v)
+	err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
 	if err != nil {
 		return err
 	}
@@ -225,6 +226,12 @@
 	enc.indentValue = indent
 }
 
+// DisableHTMLEscaping causes the encoder not to escape angle brackets
+// ("<" and ">") or ampersands ("&") in JSON strings.
+func (enc *Encoder) DisableHTMLEscaping() {
+	enc.escapeHTML = false
+}
+
 // RawMessage is a raw encoded JSON value.
 // It implements Marshaler and Unmarshaler and can
 // be used to delay JSON decoding or precompute a JSON encoding.
diff --git a/src/encoding/json/stream_test.go b/src/encoding/json/stream_test.go
index db25708..3516ac3 100644
--- a/src/encoding/json/stream_test.go
+++ b/src/encoding/json/stream_test.go
@@ -87,6 +87,39 @@
 	}
 }
 
+func TestEncoderDisableHTMLEscaping(t *testing.T) {
+	var c C
+	var ct CText
+	for _, tt := range []struct {
+		name       string
+		v          interface{}
+		wantEscape string
+		want       string
+	}{
+		{"c", c, `"\u003c\u0026\u003e"`, `"<&>"`},
+		{"ct", ct, `"\"\u003c\u0026\u003e\""`, `"\"<&>\""`},
+		{`"<&>"`, "<&>", `"\u003c\u0026\u003e"`, `"<&>"`},
+	} {
+		var buf bytes.Buffer
+		enc := NewEncoder(&buf)
+		if err := enc.Encode(tt.v); err != nil {
+			t.Fatalf("Encode(%s): %s", tt.name, err)
+		}
+		if got := strings.TrimSpace(buf.String()); got != tt.wantEscape {
+			t.Errorf("Encode(%s) = %#q, want %#q", tt.name, got, tt.wantEscape)
+		}
+		buf.Reset()
+		enc.DisableHTMLEscaping()
+		if err := enc.Encode(tt.v); err != nil {
+			t.Fatalf("DisableHTMLEscaping Encode(%s): %s", tt.name, err)
+		}
+		if got := strings.TrimSpace(buf.String()); got != tt.want {
+			t.Errorf("DisableHTMLEscaping Encode(%s) = %#q, want %#q",
+				tt.name, got, tt.want)
+		}
+	}
+}
+
 func TestDecoder(t *testing.T) {
 	for i := 0; i <= len(streamTest); i++ {
 		// Use stream without newlines as input,