icmp: add simple multipart message validation

This change adds simple validation for multipart messages to avoid
generating incorrect messages and introduces RawBody and RawExtension
to control message validation. RawBody and RawExtension are excluded
from normal message processing and can be used to construct crafted
messages for applications such as wire format testing.

Fixes golang/go#28686.

Change-Id: I56f51d6566142f5e1dcc75cfce5250801e583d6d
Reviewed-on: https://go-review.googlesource.com/c/net/+/155859
Run-TryBot: Mikio Hara <mikioh.public.networking@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
diff --git a/icmp/dstunreach.go b/icmp/dstunreach.go
index 7464bf7..8615cf5 100644
--- a/icmp/dstunreach.go
+++ b/icmp/dstunreach.go
@@ -4,6 +4,12 @@
 
 package icmp
 
+import (
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+)
+
 // A DstUnreach represents an ICMP destination unreachable message
 // body.
 type DstUnreach struct {
@@ -17,11 +23,23 @@
 		return 0
 	}
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
-	return 4 + l
+	return l
 }
 
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *DstUnreach) Marshal(proto int) ([]byte, error) {
+	var typ Type
+	switch proto {
+	case iana.ProtocolICMP:
+		typ = ipv4.ICMPTypeDestinationUnreachable
+	case iana.ProtocolIPv6ICMP:
+		typ = ipv6.ICMPTypeDestinationUnreachable
+	default:
+		return nil, errInvalidProtocol
+	}
+	if !validExtensions(typ, p.Extensions) {
+		return nil, errInvalidExtension
+	}
 	return marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
 }
 
diff --git a/icmp/echo.go b/icmp/echo.go
index c611f65..b591864 100644
--- a/icmp/echo.go
+++ b/icmp/echo.go
@@ -4,7 +4,13 @@
 
 package icmp
 
-import "encoding/binary"
+import (
+	"encoding/binary"
+
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+)
 
 // An Echo represents an ICMP echo request or reply message body.
 type Echo struct {
@@ -59,29 +65,39 @@
 		return 0
 	}
 	l, _ := multipartMessageBodyDataLen(proto, false, nil, p.Extensions)
-	return 4 + l
+	return l
 }
 
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *ExtendedEchoRequest) Marshal(proto int) ([]byte, error) {
+	var typ Type
+	switch proto {
+	case iana.ProtocolICMP:
+		typ = ipv4.ICMPTypeExtendedEchoRequest
+	case iana.ProtocolIPv6ICMP:
+		typ = ipv6.ICMPTypeExtendedEchoRequest
+	default:
+		return nil, errInvalidProtocol
+	}
+	if !validExtensions(typ, p.Extensions) {
+		return nil, errInvalidExtension
+	}
 	b, err := marshalMultipartMessageBody(proto, false, nil, p.Extensions)
 	if err != nil {
 		return nil, err
 	}
-	bb := make([]byte, 4)
-	binary.BigEndian.PutUint16(bb[:2], uint16(p.ID))
-	bb[2] = byte(p.Seq)
+	binary.BigEndian.PutUint16(b[:2], uint16(p.ID))
+	b[2] = byte(p.Seq)
 	if p.Local {
-		bb[3] |= 0x01
+		b[3] |= 0x01
 	}
-	bb = append(bb, b...)
-	return bb, nil
+	return b, nil
 }
 
 // parseExtendedEchoRequest parses b as an ICMP extended echo request
 // message body.
 func parseExtendedEchoRequest(proto int, typ Type, b []byte) (MessageBody, error) {
-	if len(b) < 4+4 {
+	if len(b) < 4 {
 		return nil, errMessageTooShort
 	}
 	p := &ExtendedEchoRequest{ID: int(binary.BigEndian.Uint16(b[:2])), Seq: int(b[2])}
@@ -89,7 +105,7 @@
 		p.Local = true
 	}
 	var err error
-	_, p.Extensions, err = parseMultipartMessageBody(proto, typ, b[4:])
+	_, p.Extensions, err = parseMultipartMessageBody(proto, typ, b)
 	if err != nil {
 		return nil, err
 	}
diff --git a/icmp/extension.go b/icmp/extension.go
index d27c462..eeb85c3 100644
--- a/icmp/extension.go
+++ b/icmp/extension.go
@@ -103,8 +103,68 @@
 				return nil, -1, err
 			}
 			exts = append(exts, ext)
+		default:
+			ext := &RawExtension{Data: make([]byte, ol)}
+			copy(ext.Data, b[:ol])
+			exts = append(exts, ext)
 		}
 		b = b[ol:]
 	}
 	return exts, l, nil
 }
+
+func validExtensions(typ Type, exts []Extension) bool {
+	switch typ {
+	case ipv4.ICMPTypeDestinationUnreachable, ipv4.ICMPTypeTimeExceeded, ipv4.ICMPTypeParameterProblem,
+		ipv6.ICMPTypeDestinationUnreachable, ipv6.ICMPTypeTimeExceeded:
+		for i := range exts {
+			switch exts[i].(type) {
+			case *MPLSLabelStack, *InterfaceInfo, *RawExtension:
+			default:
+				return false
+			}
+		}
+		return true
+	case ipv4.ICMPTypeExtendedEchoRequest, ipv6.ICMPTypeExtendedEchoRequest:
+		var n int
+		for i := range exts {
+			switch exts[i].(type) {
+			case *InterfaceIdent:
+				n++
+			case *RawExtension:
+			default:
+				return false
+			}
+		}
+		// Not a single InterfaceIdent object or a combo of
+		// RawExtension and InterfaceIdent objects is not
+		// allowed.
+		if n == 1 && len(exts) > 1 {
+			return false
+		}
+		return true
+	default:
+		return false
+	}
+}
+
+// A RawExtension represents a raw extension.
+//
+// A raw extension is excluded from message processing and can be used
+// to construct applications such as protocol conformance testing.
+type RawExtension struct {
+	Data []byte // data
+}
+
+// Len implements the Len method of Extension interface.
+func (p *RawExtension) Len(proto int) int {
+	if p == nil {
+		return 0
+	}
+	return len(p.Data)
+}
+
+// Marshal implements the Marshal method of Extension interface.
+func (p *RawExtension) Marshal(proto int) ([]byte, error) {
+	return p.Data, nil
+}
diff --git a/icmp/message.go b/icmp/message.go
index 7ccefaa..a9b70df 100644
--- a/icmp/message.go
+++ b/icmp/message.go
@@ -34,6 +34,7 @@
 	errHeaderTooShort   = errors.New("header too short")
 	errBufferTooShort   = errors.New("buffer too short")
 	errOpNoSupport      = errors.New("operation not supported")
+	errInvalidBody      = errors.New("invalid body")
 	errNoExtension      = errors.New("no extension")
 	errInvalidExtension = errors.New("invalid extension")
 )
@@ -150,7 +151,7 @@
 		return nil, errInvalidProtocol
 	}
 	if fn, ok := parseFns[m.Type]; !ok {
-		m.Body, err = parseDefaultMessageBody(proto, b[4:])
+		m.Body, err = parseRawBody(proto, b[4:])
 	} else {
 		m.Body, err = fn(proto, m.Type, b[4:])
 	}
diff --git a/icmp/message_test.go b/icmp/message_test.go
index c278b8b..d04ee8b 100644
--- a/icmp/message_test.go
+++ b/icmp/message_test.go
@@ -5,6 +5,7 @@
 package icmp_test
 
 import (
+	"bytes"
 	"net"
 	"reflect"
 	"testing"
@@ -31,17 +32,19 @@
 			for _, psh := range pshs {
 				b, err := tm.Marshal(psh)
 				if err != nil {
-					t.Fatal(err)
+					t.Fatalf("#%d: %v", i, err)
 				}
 				m, err := icmp.ParseMessage(proto, b)
 				if err != nil {
-					t.Fatal(err)
+					t.Fatalf("#%d: %v", i, err)
 				}
 				if m.Type != tm.Type || m.Code != tm.Code {
 					t.Errorf("#%d: got %#v; want %#v", i, m, &tm)
+					continue
 				}
 				if !reflect.DeepEqual(m.Body, tm.Body) {
 					t.Errorf("#%d: got %#v; want %#v", i, m.Body, tm.Body)
+					continue
 				}
 			}
 		}
@@ -80,6 +83,13 @@
 					Type: ipv4.ICMPTypeExtendedEchoRequest, Code: 0,
 					Body: &icmp.ExtendedEchoRequest{
 						ID: 1, Seq: 2,
+						Extensions: []icmp.Extension{
+							&icmp.InterfaceIdent{
+								Class: 3,
+								Type:  1,
+								Name:  "en101",
+							},
+						},
 					},
 				},
 				{
@@ -88,12 +98,6 @@
 						State: 4 /* Delay */, Active: true, IPv4: true,
 					},
 				},
-				{
-					Type: ipv4.ICMPTypePhoturis,
-					Body: &icmp.DefaultMessageBody{
-						Data: []byte{0x80, 0x40, 0x20, 0x10},
-					},
-				},
 			})
 	})
 	t.Run("IPv6", func(t *testing.T) {
@@ -136,6 +140,13 @@
 					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 0,
 					Body: &icmp.ExtendedEchoRequest{
 						ID: 1, Seq: 2,
+						Extensions: []icmp.Extension{
+							&icmp.InterfaceIdent{
+								Class: 3,
+								Type:  2,
+								Index: 911,
+							},
+						},
 					},
 				},
 				{
@@ -144,12 +155,189 @@
 						State: 5 /* Probe */, Active: true, IPv6: true,
 					},
 				},
-				{
-					Type: ipv6.ICMPTypeDuplicateAddressConfirmation,
-					Body: &icmp.DefaultMessageBody{
-						Data: []byte{0x80, 0x40, 0x20, 0x10},
+			})
+	})
+}
+
+func TestMarshalAndParseRawMessage(t *testing.T) {
+	t.Run("RawBody", func(t *testing.T) {
+		for i, tt := range []struct {
+			m               icmp.Message
+			wire            []byte
+			parseShouldFail bool
+		}{
+			{ // Nil body
+				m: icmp.Message{
+					Type: ipv4.ICMPTypeDestinationUnreachable, Code: 127,
+				},
+				wire: []byte{
+					0x03, 0x7f, 0xfc, 0x80,
+				},
+				parseShouldFail: true,
+			},
+			{ // Empty body
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDestinationUnreachable, Code: 128,
+					Body: &icmp.RawBody{},
+				},
+				wire: []byte{
+					0x01, 0x80, 0x00, 0x00,
+				},
+				parseShouldFail: true,
+			},
+			{ // Crafted body
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDuplicateAddressConfirmation, Code: 129,
+					Body: &icmp.RawBody{
+						Data: []byte{0xca, 0xfe},
 					},
 				},
-			})
+				wire: []byte{
+					0x9e, 0x81, 0x00, 0x00,
+					0xca, 0xfe,
+				},
+				parseShouldFail: false,
+			},
+		} {
+			b, err := tt.m.Marshal(nil)
+			if err != nil {
+				t.Errorf("#%d: %v", i, err)
+				continue
+			}
+			if !bytes.Equal(b, tt.wire) {
+				t.Errorf("#%d: got %#v; want %#v", i, b, tt.wire)
+				continue
+			}
+			m, err := icmp.ParseMessage(tt.m.Type.Protocol(), b)
+			if err != nil != tt.parseShouldFail {
+				t.Errorf("#%d: got %v, %v", i, m, err)
+				continue
+			}
+			if tt.parseShouldFail {
+				continue
+			}
+			if m.Type != tt.m.Type || m.Code != tt.m.Code {
+				t.Errorf("#%d: got %v; want %v", i, m, tt.m)
+				continue
+			}
+			if !bytes.Equal(m.Body.(*icmp.RawBody).Data, tt.m.Body.(*icmp.RawBody).Data) {
+				t.Errorf("#%d: got %#v; want %#v", i, m.Body, tt.m.Body)
+				continue
+			}
+		}
+	})
+	t.Run("RawExtension", func(t *testing.T) {
+		for i, tt := range []struct {
+			m    icmp.Message
+			wire []byte
+		}{
+			{ // Unaligned data and nil extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDestinationUnreachable, Code: 130,
+					Body: &icmp.DstUnreach{
+						Data: []byte("ERROR-INVOKING-PACKET"),
+					},
+				},
+				wire: []byte{
+					0x01, 0x82, 0x00, 0x00,
+					0x00, 0x00, 0x00, 0x00,
+					'E', 'R', 'R', 'O',
+					'R', '-', 'I', 'N',
+					'V', 'O', 'K', 'I',
+					'N', 'G', '-', 'P',
+					'A', 'C', 'K', 'E',
+					'T',
+				},
+			},
+			{ // Unaligned data and empty extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDestinationUnreachable, Code: 131,
+					Body: &icmp.DstUnreach{
+						Data: []byte("ERROR-INVOKING-PACKET"),
+						Extensions: []icmp.Extension{
+							&icmp.RawExtension{},
+						},
+					},
+				},
+				wire: []byte{
+					0x01, 0x83, 0x00, 0x00,
+					0x02, 0x00, 0x00, 0x00,
+					'E', 'R', 'R', 'O',
+					'R', '-', 'I', 'N',
+					'V', 'O', 'K', 'I',
+					'N', 'G', '-', 'P',
+					'A', 'C', 'K', 'E',
+					'T',
+					0x20, 0x00, 0xdf, 0xff,
+				},
+			},
+			{ // Nil extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 132,
+					Body: &icmp.ExtendedEchoRequest{
+						ID: 1, Seq: 2, Local: true,
+					},
+				},
+				wire: []byte{
+					0xa0, 0x84, 0x00, 0x00,
+					0x00, 0x01, 0x02, 0x01,
+				},
+			},
+			{ // Empty extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 133,
+					Body: &icmp.ExtendedEchoRequest{
+						ID: 1, Seq: 2, Local: true,
+						Extensions: []icmp.Extension{
+							&icmp.RawExtension{},
+						},
+					},
+				},
+				wire: []byte{
+					0xa0, 0x85, 0x00, 0x00,
+					0x00, 0x01, 0x02, 0x01,
+					0x20, 0x00, 0xdf, 0xff,
+				},
+			},
+			{ // Crafted extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 134,
+					Body: &icmp.ExtendedEchoRequest{
+						ID: 1, Seq: 2, Local: true,
+						Extensions: []icmp.Extension{
+							&icmp.RawExtension{
+								Data: []byte("CRAFTED"),
+							},
+						},
+					},
+				},
+				wire: []byte{
+					0xa0, 0x86, 0x00, 0x00,
+					0x00, 0x01, 0x02, 0x01,
+					0x20, 0x00, 0xc3, 0x21,
+					'C', 'R', 'A', 'F',
+					'T', 'E', 'D',
+				},
+			},
+		} {
+			b, err := tt.m.Marshal(nil)
+			if err != nil {
+				t.Errorf("#%d: %v", i, err)
+				continue
+			}
+			if !bytes.Equal(b, tt.wire) {
+				t.Errorf("#%d: got %#v; want %#v", i, b, tt.wire)
+				continue
+			}
+			m, err := icmp.ParseMessage(tt.m.Type.Protocol(), b)
+			if err != nil {
+				t.Errorf("#%d: %v", i, err)
+				continue
+			}
+			if m.Type != tt.m.Type || m.Code != tt.m.Code {
+				t.Errorf("#%d: got %v; want %v", i, m, tt.m)
+				continue
+			}
+		}
 	})
 }
diff --git a/icmp/messagebody.go b/icmp/messagebody.go
index f12250c..e2d9bfa 100644
--- a/icmp/messagebody.go
+++ b/icmp/messagebody.go
@@ -17,13 +17,17 @@
 	Marshal(proto int) ([]byte, error)
 }
 
-// A DefaultMessageBody represents the default message body.
-type DefaultMessageBody struct {
+// A RawBody represents a raw message body.
+//
+// A raw message body is excluded from message processing and can be
+// used to construct applications such as protocol conformance
+// testing.
+type RawBody struct {
 	Data []byte // data
 }
 
 // Len implements the Len method of MessageBody interface.
-func (p *DefaultMessageBody) Len(proto int) int {
+func (p *RawBody) Len(proto int) int {
 	if p == nil {
 		return 0
 	}
@@ -31,13 +35,18 @@
 }
 
 // Marshal implements the Marshal method of MessageBody interface.
-func (p *DefaultMessageBody) Marshal(proto int) ([]byte, error) {
+func (p *RawBody) Marshal(proto int) ([]byte, error) {
 	return p.Data, nil
 }
 
-// parseDefaultMessageBody parses b as an ICMP message body.
-func parseDefaultMessageBody(proto int, b []byte) (MessageBody, error) {
-	p := &DefaultMessageBody{Data: make([]byte, len(b))}
+// parseRawBody parses b as an ICMP message body.
+func parseRawBody(proto int, b []byte) (MessageBody, error) {
+	p := &RawBody{Data: make([]byte, len(b))}
 	copy(p.Data, b)
 	return p, nil
 }
+
+// A DefaultMessageBody represents the default message body.
+//
+// Deprecated: Use RawBody instead.
+type DefaultMessageBody = RawBody
diff --git a/icmp/multipart.go b/icmp/multipart.go
index 9ebbbaf..7f88a4d 100644
--- a/icmp/multipart.go
+++ b/icmp/multipart.go
@@ -11,18 +11,24 @@
 // and a required length for a padded original datagram in wire
 // format.
 func multipartMessageBodyDataLen(proto int, withOrigDgram bool, b []byte, exts []Extension) (bodyLen, dataLen int) {
+	bodyLen = 4 // length of leading octets
+	var extLen int
+	var rawExt bool // raw extension may contain an empty object
 	for _, ext := range exts {
-		bodyLen += ext.Len(proto)
-	}
-	if bodyLen > 0 {
-		if withOrigDgram {
-			dataLen = multipartMessageOrigDatagramLen(proto, b)
+		extLen += ext.Len(proto)
+		if _, ok := ext.(*RawExtension); ok {
+			rawExt = true
 		}
-		bodyLen += 4 // length of extension header
+	}
+	if extLen > 0 && withOrigDgram {
+		dataLen = multipartMessageOrigDatagramLen(proto, b)
 	} else {
 		dataLen = len(b)
 	}
-	bodyLen += dataLen
+	if extLen > 0 || rawExt {
+		bodyLen += 4 // length of extension header
+	}
+	bodyLen += dataLen + extLen
 	return bodyLen, dataLen
 }
 
@@ -54,12 +60,11 @@
 // It can be used for non-multipart message bodies when exts is nil.
 func marshalMultipartMessageBody(proto int, withOrigDgram bool, data []byte, exts []Extension) ([]byte, error) {
 	bodyLen, dataLen := multipartMessageBodyDataLen(proto, withOrigDgram, data, exts)
-	b := make([]byte, 4+bodyLen)
+	b := make([]byte, bodyLen)
 	copy(b[4:], data)
-	off := dataLen + 4
 	if len(exts) > 0 {
-		b[dataLen+4] = byte(extensionVersion << 4)
-		off += 4 // length of object header
+		b[4+dataLen] = byte(extensionVersion << 4)
+		off := 4 + dataLen + 4 // leading octets, data, extension header
 		for _, ext := range exts {
 			switch ext := ext.(type) {
 			case *MPLSLabelStack:
@@ -78,11 +83,14 @@
 					return nil, err
 				}
 				off += ext.Len(proto)
+			case *RawExtension:
+				copy(b[off:], ext.Data)
+				off += ext.Len(proto)
 			}
 		}
-		s := checksum(b[dataLen+4:])
-		b[dataLen+4+2] ^= byte(s)
-		b[dataLen+4+3] ^= byte(s >> 8)
+		s := checksum(b[4+dataLen:])
+		b[4+dataLen+2] ^= byte(s)
+		b[4+dataLen+3] ^= byte(s >> 8)
 		if withOrigDgram {
 			switch proto {
 			case iana.ProtocolICMP:
diff --git a/icmp/multipart_test.go b/icmp/multipart_test.go
index 7440882..4601790 100644
--- a/icmp/multipart_test.go
+++ b/icmp/multipart_test.go
@@ -232,11 +232,6 @@
 							Type:  2,
 							Index: 911,
 						},
-						&icmp.InterfaceIdent{
-							Class: 3,
-							Type:  1,
-							Name:  "en101",
-						},
 					},
 				},
 			},
@@ -361,11 +356,6 @@
 					Extensions: []icmp.Extension{
 						&icmp.InterfaceIdent{
 							Class: 3,
-							Type:  1,
-							Name:  "en101",
-						},
-						&icmp.InterfaceIdent{
-							Class: 3,
 							Type:  2,
 							Index: 911,
 						},
@@ -413,10 +403,12 @@
 			if !reflect.DeepEqual(got, want) {
 				s += fmt.Sprintf("#%d: got %#v; want %#v\n", i, got, want)
 			}
+		case *icmp.RawExtension:
+			s += fmt.Sprintf("#%d: raw extension\n", i)
 		}
 	}
 	if len(s) == 0 {
-		return "<nil>"
+		s += "empty extension"
 	}
 	return s[:len(s)-1]
 }
diff --git a/icmp/paramprob.go b/icmp/paramprob.go
index 8587255..f16fd33 100644
--- a/icmp/paramprob.go
+++ b/icmp/paramprob.go
@@ -6,7 +6,9 @@
 
 import (
 	"encoding/binary"
+
 	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
 )
 
 // A ParamProb represents an ICMP parameter problem message body.
@@ -22,23 +24,30 @@
 		return 0
 	}
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
-	return 4 + l
+	return l
 }
 
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *ParamProb) Marshal(proto int) ([]byte, error) {
-	if proto == iana.ProtocolIPv6ICMP {
+	switch proto {
+	case iana.ProtocolICMP:
+		if !validExtensions(ipv4.ICMPTypeParameterProblem, p.Extensions) {
+			return nil, errInvalidExtension
+		}
+		b, err := marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
+		if err != nil {
+			return nil, err
+		}
+		b[0] = byte(p.Pointer)
+		return b, nil
+	case iana.ProtocolIPv6ICMP:
 		b := make([]byte, p.Len(proto))
 		binary.BigEndian.PutUint32(b[:4], uint32(p.Pointer))
 		copy(b[4:], p.Data)
 		return b, nil
+	default:
+		return nil, errInvalidProtocol
 	}
-	b, err := marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
-	if err != nil {
-		return nil, err
-	}
-	b[0] = byte(p.Pointer)
-	return b, nil
 }
 
 // parseParamProb parses b as an ICMP parameter problem message body.
diff --git a/icmp/timeexceeded.go b/icmp/timeexceeded.go
index 14e9e23..ffa986f 100644
--- a/icmp/timeexceeded.go
+++ b/icmp/timeexceeded.go
@@ -4,6 +4,12 @@
 
 package icmp
 
+import (
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+)
+
 // A TimeExceeded represents an ICMP time exceeded message body.
 type TimeExceeded struct {
 	Data       []byte      // data, known as original datagram field
@@ -16,11 +22,23 @@
 		return 0
 	}
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
-	return 4 + l
+	return l
 }
 
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *TimeExceeded) Marshal(proto int) ([]byte, error) {
+	var typ Type
+	switch proto {
+	case iana.ProtocolICMP:
+		typ = ipv4.ICMPTypeTimeExceeded
+	case iana.ProtocolIPv6ICMP:
+		typ = ipv6.ICMPTypeTimeExceeded
+	default:
+		return nil, errInvalidProtocol
+	}
+	if !validExtensions(typ, p.Extensions) {
+		return nil, errInvalidExtension
+	}
 	return marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
 }