encoding/asn1: reduce allocations in Marshal

Current code uses trees of bytes.Buffer as data representation.
Each bytes.Buffer takes 4k bytes at least, so it's waste of memory.
The change introduces trees of lazy-encoder as
alternative one which reduce allocations.

name       old time/op    new time/op    delta
Marshal-4    64.7µs ± 2%    42.0µs ± 1%  -35.07%   (p=0.000 n=9+10)

name       old alloc/op   new alloc/op   delta
Marshal-4    35.1kB ± 0%     7.6kB ± 0%  -78.27%  (p=0.000 n=10+10)

name       old allocs/op  new allocs/op  delta
Marshal-4       503 ± 0%       293 ± 0%  -41.75%  (p=0.000 n=10+10)

Change-Id: I32b96c20b8df00414b282d69743d71a598a11336
Reviewed-on: https://go-review.googlesource.com/27030
Reviewed-by: Adam Langley <agl@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Adam Langley <agl@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/encoding/asn1/asn1_test.go b/src/encoding/asn1/asn1_test.go
index f8623fa..81f4dba 100644
--- a/src/encoding/asn1/asn1_test.go
+++ b/src/encoding/asn1/asn1_test.go
@@ -132,9 +132,9 @@
 			if ret.String() != test.base10 {
 				t.Errorf("#%d: bad result from %x, got %s want %s", i, test.in, ret.String(), test.base10)
 			}
-			fw := newForkableWriter()
-			marshalBigInt(fw, ret)
-			result := fw.Bytes()
+			e := makeBigInt(ret)
+			result := make([]byte, e.Len())
+			e.Encode(result)
 			if !bytes.Equal(result, test.in) {
 				t.Errorf("#%d: got %x from marshaling %s, want %x", i, result, ret, test.in)
 			}
diff --git a/src/encoding/asn1/marshal.go b/src/encoding/asn1/marshal.go
index 30797ef..f0664d3 100644
--- a/src/encoding/asn1/marshal.go
+++ b/src/encoding/asn1/marshal.go
@@ -5,77 +5,125 @@
 package asn1
 
 import (
-	"bytes"
 	"errors"
 	"fmt"
-	"io"
 	"math/big"
 	"reflect"
 	"time"
 	"unicode/utf8"
 )
 
-// A forkableWriter is an in-memory buffer that can be
-// 'forked' to create new forkableWriters that bracket the
-// original. After
-//    pre, post := w.fork()
-// the overall sequence of bytes represented is logically w+pre+post.
-type forkableWriter struct {
-	*bytes.Buffer
-	pre, post *forkableWriter
+var (
+	byte00Encoder encoder = byteEncoder(0x00)
+	byteFFEncoder encoder = byteEncoder(0xff)
+)
+
+// encoder represents a ASN.1 element that is waiting to be marshaled.
+type encoder interface {
+	// Len returns the number of bytes needed to marshal this element.
+	Len() int
+	// Encode encodes this element by writing Len() bytes to dst.
+	Encode(dst []byte)
 }
 
-func newForkableWriter() *forkableWriter {
-	return &forkableWriter{new(bytes.Buffer), nil, nil}
+type byteEncoder byte
+
+func (c byteEncoder) Len() int {
+	return 1
 }
 
-func (f *forkableWriter) fork() (pre, post *forkableWriter) {
-	if f.pre != nil || f.post != nil {
-		panic("have already forked")
-	}
-	f.pre = newForkableWriter()
-	f.post = newForkableWriter()
-	return f.pre, f.post
+func (c byteEncoder) Encode(dst []byte) {
+	dst[0] = byte(c)
 }
 
-func (f *forkableWriter) Len() (l int) {
-	l += f.Buffer.Len()
-	if f.pre != nil {
-		l += f.pre.Len()
-	}
-	if f.post != nil {
-		l += f.post.Len()
-	}
-	return
+type bytesEncoder []byte
+
+func (b bytesEncoder) Len() int {
+	return len(b)
 }
 
-func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) {
-	n, err = out.Write(f.Bytes())
-	if err != nil {
-		return
+func (b bytesEncoder) Encode(dst []byte) {
+	if copy(dst, b) != len(b) {
+		panic("internal error")
 	}
-
-	var nn int
-
-	if f.pre != nil {
-		nn, err = f.pre.writeTo(out)
-		n += nn
-		if err != nil {
-			return
-		}
-	}
-
-	if f.post != nil {
-		nn, err = f.post.writeTo(out)
-		n += nn
-	}
-	return
 }
 
-func marshalBase128Int(out *forkableWriter, n int64) (err error) {
+type stringEncoder string
+
+func (s stringEncoder) Len() int {
+	return len(s)
+}
+
+func (s stringEncoder) Encode(dst []byte) {
+	if copy(dst, s) != len(s) {
+		panic("internal error")
+	}
+}
+
+type multiEncoder []encoder
+
+func (m multiEncoder) Len() int {
+	var size int
+	for _, e := range m {
+		size += e.Len()
+	}
+	return size
+}
+
+func (m multiEncoder) Encode(dst []byte) {
+	var off int
+	for _, e := range m {
+		e.Encode(dst[off:])
+		off += e.Len()
+	}
+}
+
+type taggedEncoder struct {
+	// scratch contains temporary space for encoding the tag and length of
+	// an element in order to avoid extra allocations.
+	scratch [8]byte
+	tag     encoder
+	body    encoder
+}
+
+func (t *taggedEncoder) Len() int {
+	return t.tag.Len() + t.body.Len()
+}
+
+func (t *taggedEncoder) Encode(dst []byte) {
+	t.tag.Encode(dst)
+	t.body.Encode(dst[t.tag.Len():])
+}
+
+type int64Encoder int64
+
+func (i int64Encoder) Len() int {
+	n := 1
+
+	for i > 127 {
+		n++
+		i >>= 8
+	}
+
+	for i < -128 {
+		n++
+		i >>= 8
+	}
+
+	return n
+}
+
+func (i int64Encoder) Encode(dst []byte) {
+	n := i.Len()
+
+	for j := 0; j < n; j++ {
+		dst[j] = byte(i >> uint((n-1-j)*8))
+	}
+}
+
+func base128IntLength(n int64) int {
 	if n == 0 {
-		err = out.WriteByte(0)
-		return
+		return 1
 	}
 
 	l := 0
@@ -83,54 +131,29 @@
 		l++
 	}
 
+	return l
+}
+
+func appendBase128Int(dst []byte, n int64) []byte {
+	l := base128IntLength(n)
+
 	for i := l - 1; i >= 0; i-- {
 		o := byte(n >> uint(i*7))
 		o &= 0x7f
 		if i != 0 {
 			o |= 0x80
 		}
-		err = out.WriteByte(o)
-		if err != nil {
-			return
-		}
+
+		dst = append(dst, o)
 	}
 
-	return nil
+	return dst
 }
 
-func marshalInt64(out *forkableWriter, i int64) (err error) {
-	n := int64Length(i)
-
-	for ; n > 0; n-- {
-		err = out.WriteByte(byte(i >> uint((n-1)*8)))
-		if err != nil {
-			return
-		}
-	}
-
-	return nil
-}
-
-func int64Length(i int64) (numBytes int) {
-	numBytes = 1
-
-	for i > 127 {
-		numBytes++
-		i >>= 8
-	}
-
-	for i < -128 {
-		numBytes++
-		i >>= 8
-	}
-
-	return
-}
-
-func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
+func makeBigInt(n *big.Int) encoder {
 	if n.Sign() < 0 {
 		// A negative number has to be converted to two's-complement
-		// form. So we'll subtract 1 and invert. If the
+		// form. So we'll invert and subtract 1. If the
 		// most-significant-bit isn't set then we'll need to pad the
 		// beginning with 0xff in order to keep the number negative.
 		nMinus1 := new(big.Int).Neg(n)
@@ -140,41 +163,31 @@
 			bytes[i] ^= 0xff
 		}
 		if len(bytes) == 0 || bytes[0]&0x80 == 0 {
-			err = out.WriteByte(0xff)
-			if err != nil {
-				return
-			}
+			return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)})
 		}
-		_, err = out.Write(bytes)
+		return bytesEncoder(bytes)
 	} else if n.Sign() == 0 {
 		// Zero is written as a single 0 zero rather than no bytes.
-		err = out.WriteByte(0x00)
+		return byte00Encoder
 	} else {
 		bytes := n.Bytes()
 		if len(bytes) > 0 && bytes[0]&0x80 != 0 {
 			// We'll have to pad this with 0x00 in order to stop it
 			// looking like a negative number.
-			err = out.WriteByte(0)
-			if err != nil {
-				return
-			}
+			return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)})
 		}
-		_, err = out.Write(bytes)
+		return bytesEncoder(bytes)
 	}
-	return
 }
 
-func marshalLength(out *forkableWriter, i int) (err error) {
+func appendLength(dst []byte, i int) []byte {
 	n := lengthLength(i)
 
 	for ; n > 0; n-- {
-		err = out.WriteByte(byte(i >> uint((n-1)*8)))
-		if err != nil {
-			return
-		}
+		dst = append(dst, byte(i>>uint((n-1)*8)))
 	}
 
-	return nil
+	return dst
 }
 
 func lengthLength(i int) (numBytes int) {
@@ -186,123 +199,104 @@
 	return
 }
 
-func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) {
+func appendTagAndLength(dst []byte, t tagAndLength) []byte {
 	b := uint8(t.class) << 6
 	if t.isCompound {
 		b |= 0x20
 	}
 	if t.tag >= 31 {
 		b |= 0x1f
-		err = out.WriteByte(b)
-		if err != nil {
-			return
-		}
-		err = marshalBase128Int(out, int64(t.tag))
-		if err != nil {
-			return
-		}
+		dst = append(dst, b)
+		dst = appendBase128Int(dst, int64(t.tag))
 	} else {
 		b |= uint8(t.tag)
-		err = out.WriteByte(b)
-		if err != nil {
-			return
-		}
+		dst = append(dst, b)
 	}
 
 	if t.length >= 128 {
 		l := lengthLength(t.length)
-		err = out.WriteByte(0x80 | byte(l))
-		if err != nil {
-			return
-		}
-		err = marshalLength(out, t.length)
-		if err != nil {
-			return
-		}
+		dst = append(dst, 0x80|byte(l))
+		dst = appendLength(dst, t.length)
 	} else {
-		err = out.WriteByte(byte(t.length))
-		if err != nil {
-			return
-		}
+		dst = append(dst, byte(t.length))
 	}
 
-	return nil
+	return dst
 }
 
-func marshalBitString(out *forkableWriter, b BitString) (err error) {
-	paddingBits := byte((8 - b.BitLength%8) % 8)
-	err = out.WriteByte(paddingBits)
-	if err != nil {
-		return
-	}
-	_, err = out.Write(b.Bytes)
-	return
+type bitStringEncoder BitString
+
+func (b bitStringEncoder) Len() int {
+	return len(b.Bytes) + 1
 }
 
-func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) {
-	if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
-		return StructuralError{"invalid object identifier"}
+func (b bitStringEncoder) Encode(dst []byte) {
+	dst[0] = byte((8 - b.BitLength%8) % 8)
+	if copy(dst[1:], b.Bytes) != len(b.Bytes) {
+		panic("internal error")
 	}
+}
 
-	err = marshalBase128Int(out, int64(oid[0]*40+oid[1]))
-	if err != nil {
-		return
-	}
+type oidEncoder []int
+
+func (oid oidEncoder) Len() int {
+	l := base128IntLength(int64(oid[0]*40 + oid[1]))
 	for i := 2; i < len(oid); i++ {
-		err = marshalBase128Int(out, int64(oid[i]))
-		if err != nil {
-			return
+		l += base128IntLength(int64(oid[i]))
+	}
+	return l
+}
+
+func (oid oidEncoder) Encode(dst []byte) {
+	dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
+	for i := 2; i < len(oid); i++ {
+		dst = appendBase128Int(dst, int64(oid[i]))
+	}
+}
+
+func makeObjectIdentifier(oid []int) (e encoder, err error) {
+	if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
+		return nil, StructuralError{"invalid object identifier"}
+	}
+
+	return oidEncoder(oid), nil
+}
+
+func makePrintableString(s string) (e encoder, err error) {
+	for i := 0; i < len(s); i++ {
+		if !isPrintable(s[i]) {
+			return nil, StructuralError{"PrintableString contains invalid character"}
 		}
 	}
 
-	return
+	return stringEncoder(s), nil
 }
 
-func marshalPrintableString(out *forkableWriter, s string) (err error) {
-	b := []byte(s)
-	for _, c := range b {
-		if !isPrintable(c) {
-			return StructuralError{"PrintableString contains invalid character"}
+func makeIA5String(s string) (e encoder, err error) {
+	for i := 0; i < len(s); i++ {
+		if s[i] > 127 {
+			return nil, StructuralError{"IA5String contains invalid character"}
 		}
 	}
 
-	_, err = out.Write(b)
-	return
+	return stringEncoder(s), nil
 }
 
-func marshalIA5String(out *forkableWriter, s string) (err error) {
-	b := []byte(s)
-	for _, c := range b {
-		if c > 127 {
-			return StructuralError{"IA5String contains invalid character"}
-		}
-	}
-
-	_, err = out.Write(b)
-	return
+func makeUTF8String(s string) encoder {
+	return stringEncoder(s)
 }
 
-func marshalUTF8String(out *forkableWriter, s string) (err error) {
-	_, err = out.Write([]byte(s))
-	return
+func appendTwoDigits(dst []byte, v int) []byte {
+	return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
 }
 
-func marshalTwoDigits(out *forkableWriter, v int) (err error) {
-	err = out.WriteByte(byte('0' + (v/10)%10))
-	if err != nil {
-		return
-	}
-	return out.WriteByte(byte('0' + v%10))
-}
-
-func marshalFourDigits(out *forkableWriter, v int) (err error) {
+func appendFourDigits(dst []byte, v int) []byte {
 	var bytes [4]byte
 	for i := range bytes {
 		bytes[3-i] = '0' + byte(v%10)
 		v /= 10
 	}
-	_, err = out.Write(bytes[:])
-	return
+	return append(dst, bytes[:]...)
 }
 
 func outsideUTCRange(t time.Time) bool {
@@ -310,80 +304,75 @@
 	return year < 1950 || year >= 2050
 }
 
-func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
+func makeUTCTime(t time.Time) (e encoder, err error) {
+	dst := make([]byte, 0, 18)
+
+	dst, err = appendUTCTime(dst, t)
+	if err != nil {
+		return nil, err
+	}
+
+	return bytesEncoder(dst), nil
+}
+
+func makeGeneralizedTime(t time.Time) (e encoder, err error) {
+	dst := make([]byte, 0, 20)
+
+	dst, err = appendGeneralizedTime(dst, t)
+	if err != nil {
+		return nil, err
+	}
+
+	return bytesEncoder(dst), nil
+}
+
+func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
 	year := t.Year()
 
 	switch {
 	case 1950 <= year && year < 2000:
-		err = marshalTwoDigits(out, year-1900)
+		dst = appendTwoDigits(dst, year-1900)
 	case 2000 <= year && year < 2050:
-		err = marshalTwoDigits(out, year-2000)
+		dst = appendTwoDigits(dst, year-2000)
 	default:
-		return StructuralError{"cannot represent time as UTCTime"}
-	}
-	if err != nil {
-		return
+		return nil, StructuralError{"cannot represent time as UTCTime"}
 	}
 
-	return marshalTimeCommon(out, t)
+	return appendTimeCommon(dst, t), nil
 }
 
-func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) {
+func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
 	year := t.Year()
 	if year < 0 || year > 9999 {
-		return StructuralError{"cannot represent time as GeneralizedTime"}
-	}
-	if err = marshalFourDigits(out, year); err != nil {
-		return
+		return nil, StructuralError{"cannot represent time as GeneralizedTime"}
 	}
 
-	return marshalTimeCommon(out, t)
+	dst = appendFourDigits(dst, year)
+
+	return appendTimeCommon(dst, t), nil
 }
 
-func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
+func appendTimeCommon(dst []byte, t time.Time) []byte {
 	_, month, day := t.Date()
 
-	err = marshalTwoDigits(out, int(month))
-	if err != nil {
-		return
-	}
-
-	err = marshalTwoDigits(out, day)
-	if err != nil {
-		return
-	}
+	dst = appendTwoDigits(dst, int(month))
+	dst = appendTwoDigits(dst, day)
 
 	hour, min, sec := t.Clock()
 
-	err = marshalTwoDigits(out, hour)
-	if err != nil {
-		return
-	}
-
-	err = marshalTwoDigits(out, min)
-	if err != nil {
-		return
-	}
-
-	err = marshalTwoDigits(out, sec)
-	if err != nil {
-		return
-	}
+	dst = appendTwoDigits(dst, hour)
+	dst = appendTwoDigits(dst, min)
+	dst = appendTwoDigits(dst, sec)
 
 	_, offset := t.Zone()
 
 	switch {
 	case offset/60 == 0:
-		err = out.WriteByte('Z')
-		return
+		return append(dst, 'Z')
 	case offset > 0:
-		err = out.WriteByte('+')
+		dst = append(dst, '+')
 	case offset < 0:
-		err = out.WriteByte('-')
-	}
-
-	if err != nil {
-		return
+		dst = append(dst, '-')
 	}
 
 	offsetMinutes := offset / 60
@@ -391,13 +380,10 @@
 		offsetMinutes = -offsetMinutes
 	}
 
-	err = marshalTwoDigits(out, offsetMinutes/60)
-	if err != nil {
-		return
-	}
+	dst = appendTwoDigits(dst, offsetMinutes/60)
+	dst = appendTwoDigits(dst, offsetMinutes%60)
 
-	err = marshalTwoDigits(out, offsetMinutes%60)
-	return
+	return dst
 }
 
 func stripTagAndLength(in []byte) []byte {
@@ -408,114 +394,124 @@
 	return in[offset:]
 }
 
-func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
+func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) {
 	switch value.Type() {
 	case flagType:
-		return nil
+		return bytesEncoder(nil), nil
 	case timeType:
 		t := value.Interface().(time.Time)
 		if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
-			return marshalGeneralizedTime(out, t)
-		} else {
-			return marshalUTCTime(out, t)
+			return makeGeneralizedTime(t)
 		}
+		return makeUTCTime(t)
 	case bitStringType:
-		return marshalBitString(out, value.Interface().(BitString))
+		return bitStringEncoder(value.Interface().(BitString)), nil
 	case objectIdentifierType:
-		return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
+		return makeObjectIdentifier(value.Interface().(ObjectIdentifier))
 	case bigIntType:
-		return marshalBigInt(out, value.Interface().(*big.Int))
+		return makeBigInt(value.Interface().(*big.Int)), nil
 	}
 
 	switch v := value; v.Kind() {
 	case reflect.Bool:
 		if v.Bool() {
-			return out.WriteByte(255)
-		} else {
-			return out.WriteByte(0)
+			return byteFFEncoder, nil
 		}
+		return byte00Encoder, nil
 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-		return marshalInt64(out, v.Int())
+		return int64Encoder(v.Int()), nil
 	case reflect.Struct:
 		t := v.Type()
 
 		startingField := 0
 
+		n := t.NumField()
+		if n == 0 {
+			return bytesEncoder(nil), nil
+		}
+
 		// If the first element of the structure is a non-empty
 		// RawContents, then we don't bother serializing the rest.
-		if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
+		if t.Field(0).Type == rawContentsType {
 			s := v.Field(0)
 			if s.Len() > 0 {
-				bytes := make([]byte, s.Len())
-				for i := 0; i < s.Len(); i++ {
-					bytes[i] = uint8(s.Index(i).Uint())
-				}
+				bytes := s.Bytes()
 				/* The RawContents will contain the tag and
 				 * length fields but we'll also be writing
 				 * those ourselves, so we strip them out of
 				 * bytes */
-				_, err = out.Write(stripTagAndLength(bytes))
-				return
-			} else {
-				startingField = 1
+				return bytesEncoder(stripTagAndLength(bytes)), nil
 			}
+
+			startingField = 1
 		}
 
-		for i := startingField; i < t.NumField(); i++ {
-			var pre *forkableWriter
-			pre, out = out.fork()
-			err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
-			if err != nil {
-				return
+		switch n1 := n - startingField; n1 {
+		case 0:
+			return bytesEncoder(nil), nil
+		case 1:
+			return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1")))
+		default:
+			m := make([]encoder, n1)
+			for i := 0; i < n1; i++ {
+				m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1")))
+				if err != nil {
+					return nil, err
+				}
 			}
+
+			return multiEncoder(m), nil
 		}
-		return
 	case reflect.Slice:
 		sliceType := v.Type()
 		if sliceType.Elem().Kind() == reflect.Uint8 {
-			bytes := make([]byte, v.Len())
-			for i := 0; i < v.Len(); i++ {
-				bytes[i] = uint8(v.Index(i).Uint())
-			}
-			_, err = out.Write(bytes)
-			return
+			return bytesEncoder(v.Bytes()), nil
 		}
 
 		var fp fieldParameters
-		for i := 0; i < v.Len(); i++ {
-			var pre *forkableWriter
-			pre, out = out.fork()
-			err = marshalField(pre, v.Index(i), fp)
-			if err != nil {
-				return
+
+		switch l := v.Len(); l {
+		case 0:
+			return bytesEncoder(nil), nil
+		case 1:
+			return makeField(v.Index(0), fp)
+		default:
+			m := make([]encoder, l)
+
+			for i := 0; i < l; i++ {
+				m[i], err = makeField(v.Index(i), fp)
+				if err != nil {
+					return nil, err
+				}
 			}
+
+			return multiEncoder(m), nil
 		}
-		return
 	case reflect.String:
 		switch params.stringType {
 		case TagIA5String:
-			return marshalIA5String(out, v.String())
+			return makeIA5String(v.String())
 		case TagPrintableString:
-			return marshalPrintableString(out, v.String())
+			return makePrintableString(v.String())
 		default:
-			return marshalUTF8String(out, v.String())
+			return makeUTF8String(v.String()), nil
 		}
 	}
 
-	return StructuralError{"unknown Go type"}
+	return nil, StructuralError{"unknown Go type"}
 }
 
-func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) {
+func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) {
 	if !v.IsValid() {
-		return fmt.Errorf("asn1: cannot marshal nil value")
+		return nil, fmt.Errorf("asn1: cannot marshal nil value")
 	}
 	// If the field is an interface{} then recurse into it.
 	if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
-		return marshalField(out, v.Elem(), params)
+		return makeField(v.Elem(), params)
 	}
 
 	if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
-		return
+		return bytesEncoder(nil), nil
 	}
 
 	if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
@@ -523,7 +519,7 @@
 		defaultValue.SetInt(*params.defaultValue)
 
 		if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
-			return
+			return bytesEncoder(nil), nil
 		}
 	}
 
@@ -532,37 +528,36 @@
 	// behaviour, but it's what Go has traditionally done.
 	if params.optional && params.defaultValue == nil {
 		if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
-			return
+			return bytesEncoder(nil), nil
 		}
 	}
 
 	if v.Type() == rawValueType {
 		rv := v.Interface().(RawValue)
 		if len(rv.FullBytes) != 0 {
-			_, err = out.Write(rv.FullBytes)
-		} else {
-			err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
-			if err != nil {
-				return
-			}
-			_, err = out.Write(rv.Bytes)
+			return bytesEncoder(rv.FullBytes), nil
 		}
-		return
+
+		t := new(taggedEncoder)
+
+		t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
+		t.body = bytesEncoder(rv.Bytes)
+
+		return t, nil
 	}
 
 	tag, isCompound, ok := getUniversalType(v.Type())
 	if !ok {
-		err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
-		return
+		return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
 	}
 	class := ClassUniversal
 
 	if params.timeType != 0 && tag != TagUTCTime {
-		return StructuralError{"explicit time type given to non-time member"}
+		return nil, StructuralError{"explicit time type given to non-time member"}
 	}
 
 	if params.stringType != 0 && tag != TagPrintableString {
-		return StructuralError{"explicit string type given to non-string member"}
+		return nil, StructuralError{"explicit string type given to non-string member"}
 	}
 
 	switch tag {
@@ -574,7 +569,7 @@
 			for _, r := range v.String() {
 				if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
 					if !utf8.ValidString(v.String()) {
-						return errors.New("asn1: string not valid UTF-8")
+						return nil, errors.New("asn1: string not valid UTF-8")
 					}
 					tag = TagUTF8String
 					break
@@ -591,46 +586,46 @@
 
 	if params.set {
 		if tag != TagSequence {
-			return StructuralError{"non sequence tagged as set"}
+			return nil, StructuralError{"non sequence tagged as set"}
 		}
 		tag = TagSet
 	}
 
-	tags, body := out.fork()
+	t := new(taggedEncoder)
 
-	err = marshalBody(body, v, params)
+	t.body, err = makeBody(v, params)
 	if err != nil {
-		return
+		return nil, err
 	}
 
-	bodyLen := body.Len()
+	bodyLen := t.body.Len()
 
-	var explicitTag *forkableWriter
 	if params.explicit {
-		explicitTag, tags = tags.fork()
+		t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
+
+		tt := new(taggedEncoder)
+
+		tt.body = t
+
+		tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
+			class:      ClassContextSpecific,
+			tag:        *params.tag,
+			length:     bodyLen + t.tag.Len(),
+			isCompound: true,
+		}))
+
+		return tt, nil
 	}
 
-	if !params.explicit && params.tag != nil {
+	if params.tag != nil {
 		// implicit tag.
 		tag = *params.tag
 		class = ClassContextSpecific
 	}
 
-	err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
-	if err != nil {
-		return
-	}
+	t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
 
-	if params.explicit {
-		err = marshalTagAndLength(explicitTag, tagAndLength{
-			class:      ClassContextSpecific,
-			tag:        *params.tag,
-			length:     bodyLen + tags.Len(),
-			isCompound: true,
-		})
-	}
-
-	return err
+	return t, nil
 }
 
 // Marshal returns the ASN.1 encoding of val.
@@ -643,13 +638,11 @@
 //	printable:	causes strings to be marshaled as ASN.1, PrintableString strings.
 //	utf8:		causes strings to be marshaled as ASN.1, UTF8 strings
 func Marshal(val interface{}) ([]byte, error) {
-	var out bytes.Buffer
-	v := reflect.ValueOf(val)
-	f := newForkableWriter()
-	err := marshalField(f, v, fieldParameters{})
+	e, err := makeField(reflect.ValueOf(val), fieldParameters{})
 	if err != nil {
 		return nil, err
 	}
-	_, err = f.writeTo(&out)
-	return out.Bytes(), err
+	b := make([]byte, e.Len())
+	e.Encode(b)
+	return b, nil
 }
diff --git a/src/encoding/asn1/marshal_test.go b/src/encoding/asn1/marshal_test.go
index cdca8aa..6af770f 100644
--- a/src/encoding/asn1/marshal_test.go
+++ b/src/encoding/asn1/marshal_test.go
@@ -173,3 +173,13 @@
 		t.Errorf("invalid UTF8 string was accepted")
 	}
 }
+
+func BenchmarkMarshal(b *testing.B) {
+	b.ReportAllocs()
+
+	for i := 0; i < b.N; i++ {
+		for _, test := range marshalTests {
+			Marshal(test.in)
+		}
+	}
+}