dns/dnsmessage: reduce Parser size

In the net package the Parser is copied a lot, the
size of the Parser can be reduced easily by not storing the
entire ResourceHeader in the Parser.

It reduces the size from 328B to 80B.

Also it makes sure that the resource header parsing
methods don't return stale headers (from different
sections).

Change-Id: If05b03ba654ca5c03d536e86446c5a2a7dc79ec3
GitHub-Last-Rev: dacd25cc355269ff2a89d855d2094bb8f152c83c
GitHub-Pull-Request: golang/net#186
Reviewed-on: https://go-review.googlesource.com/c/net/+/514855
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Auto-Submit: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Mateusz Poliwczak <mpoliwczak34@gmail.com>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/dns/dnsmessage/message.go b/dns/dnsmessage/message.go
index 69938d5..19ea8f1 100644
--- a/dns/dnsmessage/message.go
+++ b/dns/dnsmessage/message.go
@@ -542,11 +542,13 @@
 	msg    []byte
 	header header
 
-	section        section
-	off            int
-	index          int
-	resHeaderValid bool
-	resHeader      ResourceHeader
+	section         section
+	off             int
+	index           int
+	resHeaderValid  bool
+	resHeaderOffset int
+	resHeaderType   Type
+	resHeaderLength uint16
 }
 
 // Start parses the header and enables the parsing of Questions.
@@ -597,8 +599,9 @@
 
 func (p *Parser) resourceHeader(sec section) (ResourceHeader, error) {
 	if p.resHeaderValid {
-		return p.resHeader, nil
+		p.off = p.resHeaderOffset
 	}
+
 	if err := p.checkAdvance(sec); err != nil {
 		return ResourceHeader{}, err
 	}
@@ -608,14 +611,16 @@
 		return ResourceHeader{}, err
 	}
 	p.resHeaderValid = true
-	p.resHeader = hdr
+	p.resHeaderOffset = p.off
+	p.resHeaderType = hdr.Type
+	p.resHeaderLength = hdr.Length
 	p.off = off
 	return hdr, nil
 }
 
 func (p *Parser) skipResource(sec section) error {
 	if p.resHeaderValid {
-		newOff := p.off + int(p.resHeader.Length)
+		newOff := p.off + int(p.resHeaderLength)
 		if newOff > len(p.msg) {
 			return errResourceLen
 		}
@@ -866,14 +871,14 @@
 // One of the XXXHeader methods must have been called before calling this
 // method.
 func (p *Parser) CNAMEResource() (CNAMEResource, error) {
-	if !p.resHeaderValid || p.resHeader.Type != TypeCNAME {
+	if !p.resHeaderValid || p.resHeaderType != TypeCNAME {
 		return CNAMEResource{}, ErrNotStarted
 	}
 	r, err := unpackCNAMEResource(p.msg, p.off)
 	if err != nil {
 		return CNAMEResource{}, err
 	}
-	p.off += int(p.resHeader.Length)
+	p.off += int(p.resHeaderLength)
 	p.resHeaderValid = false
 	p.index++
 	return r, nil
@@ -884,14 +889,14 @@
 // One of the XXXHeader methods must have been called before calling this
 // method.
 func (p *Parser) MXResource() (MXResource, error) {
-	if !p.resHeaderValid || p.resHeader.Type != TypeMX {
+	if !p.resHeaderValid || p.resHeaderType != TypeMX {
 		return MXResource{}, ErrNotStarted
 	}
 	r, err := unpackMXResource(p.msg, p.off)
 	if err != nil {
 		return MXResource{}, err
 	}
-	p.off += int(p.resHeader.Length)
+	p.off += int(p.resHeaderLength)
 	p.resHeaderValid = false
 	p.index++
 	return r, nil
@@ -902,14 +907,14 @@
 // One of the XXXHeader methods must have been called before calling this
 // method.
 func (p *Parser) NSResource() (NSResource, error) {
-	if !p.resHeaderValid || p.resHeader.Type != TypeNS {
+	if !p.resHeaderValid || p.resHeaderType != TypeNS {
 		return NSResource{}, ErrNotStarted
 	}
 	r, err := unpackNSResource(p.msg, p.off)
 	if err != nil {
 		return NSResource{}, err
 	}
-	p.off += int(p.resHeader.Length)
+	p.off += int(p.resHeaderLength)
 	p.resHeaderValid = false
 	p.index++
 	return r, nil
@@ -920,14 +925,14 @@
 // One of the XXXHeader methods must have been called before calling this
 // method.
 func (p *Parser) PTRResource() (PTRResource, error) {
-	if !p.resHeaderValid || p.resHeader.Type != TypePTR {
+	if !p.resHeaderValid || p.resHeaderType != TypePTR {
 		return PTRResource{}, ErrNotStarted
 	}
 	r, err := unpackPTRResource(p.msg, p.off)
 	if err != nil {
 		return PTRResource{}, err
 	}
-	p.off += int(p.resHeader.Length)
+	p.off += int(p.resHeaderLength)
 	p.resHeaderValid = false
 	p.index++
 	return r, nil
@@ -938,14 +943,14 @@
 // One of the XXXHeader methods must have been called before calling this
 // method.
 func (p *Parser) SOAResource() (SOAResource, error) {
-	if !p.resHeaderValid || p.resHeader.Type != TypeSOA {
+	if !p.resHeaderValid || p.resHeaderType != TypeSOA {
 		return SOAResource{}, ErrNotStarted
 	}
 	r, err := unpackSOAResource(p.msg, p.off)
 	if err != nil {
 		return SOAResource{}, err
 	}
-	p.off += int(p.resHeader.Length)
+	p.off += int(p.resHeaderLength)
 	p.resHeaderValid = false
 	p.index++
 	return r, nil
@@ -956,14 +961,14 @@
 // One of the XXXHeader methods must have been called before calling this
 // method.
 func (p *Parser) TXTResource() (TXTResource, error) {
-	if !p.resHeaderValid || p.resHeader.Type != TypeTXT {
+	if !p.resHeaderValid || p.resHeaderType != TypeTXT {
 		return TXTResource{}, ErrNotStarted
 	}
-	r, err := unpackTXTResource(p.msg, p.off, p.resHeader.Length)
+	r, err := unpackTXTResource(p.msg, p.off, p.resHeaderLength)
 	if err != nil {
 		return TXTResource{}, err
 	}
-	p.off += int(p.resHeader.Length)
+	p.off += int(p.resHeaderLength)
 	p.resHeaderValid = false
 	p.index++
 	return r, nil
@@ -974,14 +979,14 @@
 // One of the XXXHeader methods must have been called before calling this
 // method.
 func (p *Parser) SRVResource() (SRVResource, error) {
-	if !p.resHeaderValid || p.resHeader.Type != TypeSRV {
+	if !p.resHeaderValid || p.resHeaderType != TypeSRV {
 		return SRVResource{}, ErrNotStarted
 	}
 	r, err := unpackSRVResource(p.msg, p.off)
 	if err != nil {
 		return SRVResource{}, err
 	}
-	p.off += int(p.resHeader.Length)
+	p.off += int(p.resHeaderLength)
 	p.resHeaderValid = false
 	p.index++
 	return r, nil
@@ -992,14 +997,14 @@
 // One of the XXXHeader methods must have been called before calling this
 // method.
 func (p *Parser) AResource() (AResource, error) {
-	if !p.resHeaderValid || p.resHeader.Type != TypeA {
+	if !p.resHeaderValid || p.resHeaderType != TypeA {
 		return AResource{}, ErrNotStarted
 	}
 	r, err := unpackAResource(p.msg, p.off)
 	if err != nil {
 		return AResource{}, err
 	}
-	p.off += int(p.resHeader.Length)
+	p.off += int(p.resHeaderLength)
 	p.resHeaderValid = false
 	p.index++
 	return r, nil
@@ -1010,14 +1015,14 @@
 // One of the XXXHeader methods must have been called before calling this
 // method.
 func (p *Parser) AAAAResource() (AAAAResource, error) {
-	if !p.resHeaderValid || p.resHeader.Type != TypeAAAA {
+	if !p.resHeaderValid || p.resHeaderType != TypeAAAA {
 		return AAAAResource{}, ErrNotStarted
 	}
 	r, err := unpackAAAAResource(p.msg, p.off)
 	if err != nil {
 		return AAAAResource{}, err
 	}
-	p.off += int(p.resHeader.Length)
+	p.off += int(p.resHeaderLength)
 	p.resHeaderValid = false
 	p.index++
 	return r, nil
@@ -1028,14 +1033,14 @@
 // One of the XXXHeader methods must have been called before calling this
 // method.
 func (p *Parser) OPTResource() (OPTResource, error) {
-	if !p.resHeaderValid || p.resHeader.Type != TypeOPT {
+	if !p.resHeaderValid || p.resHeaderType != TypeOPT {
 		return OPTResource{}, ErrNotStarted
 	}
-	r, err := unpackOPTResource(p.msg, p.off, p.resHeader.Length)
+	r, err := unpackOPTResource(p.msg, p.off, p.resHeaderLength)
 	if err != nil {
 		return OPTResource{}, err
 	}
-	p.off += int(p.resHeader.Length)
+	p.off += int(p.resHeaderLength)
 	p.resHeaderValid = false
 	p.index++
 	return r, nil
@@ -1049,11 +1054,11 @@
 	if !p.resHeaderValid {
 		return UnknownResource{}, ErrNotStarted
 	}
-	r, err := unpackUnknownResource(p.resHeader.Type, p.msg, p.off, p.resHeader.Length)
+	r, err := unpackUnknownResource(p.resHeaderType, p.msg, p.off, p.resHeaderLength)
 	if err != nil {
 		return UnknownResource{}, err
 	}
-	p.off += int(p.resHeader.Length)
+	p.off += int(p.resHeaderLength)
 	p.resHeaderValid = false
 	p.index++
 	return r, nil
diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go
index 83fac78..ddb062b 100644
--- a/dns/dnsmessage/message_test.go
+++ b/dns/dnsmessage/message_test.go
@@ -1670,3 +1670,117 @@
 		}
 	})
 }
+
+func TestParseResourceHeaderMultipleTimes(t *testing.T) {
+	msg := Message{
+		Header: Header{Response: true, Authoritative: true},
+		Answers: []Resource{
+			{
+				ResourceHeader{
+					Name:  MustNewName("go.dev."),
+					Type:  TypeA,
+					Class: ClassINET,
+				},
+				&AResource{[4]byte{127, 0, 0, 1}},
+			},
+		},
+		Authorities: []Resource{
+			{
+				ResourceHeader{
+					Name:  MustNewName("go.dev."),
+					Type:  TypeA,
+					Class: ClassINET,
+				},
+				&AResource{[4]byte{127, 0, 0, 1}},
+			},
+		},
+	}
+
+	raw, err := msg.Pack()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	var p Parser
+
+	if _, err := p.Start(raw); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := p.SkipAllQuestions(); err != nil {
+		t.Fatal(err)
+	}
+
+	hdr1, err := p.AnswerHeader()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	hdr2, err := p.AnswerHeader()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if hdr1 != hdr2 {
+		t.Fatal("AnswerHeader called multiple times without parsing the RData returned different headers")
+	}
+
+	if _, err := p.AResource(); err != nil {
+		t.Fatal(err)
+	}
+
+	if _, err := p.AnswerHeader(); err != ErrSectionDone {
+		t.Fatalf("unexpected error: %v, want: %v", err, ErrSectionDone)
+	}
+
+	hdr3, err := p.AuthorityHeader()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	hdr4, err := p.AuthorityHeader()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if hdr3 != hdr4 {
+		t.Fatal("AuthorityHeader called multiple times without parsing the RData returned different headers")
+	}
+
+	if _, err := p.AResource(); err != nil {
+		t.Fatal(err)
+	}
+
+	if _, err := p.AuthorityHeader(); err != ErrSectionDone {
+		t.Fatalf("unexpected error: %v, want: %v", err, ErrSectionDone)
+	}
+}
+
+func TestParseDifferentResourceHeadersWithoutParsingRData(t *testing.T) {
+	msg := smallTestMsg()
+	raw, err := msg.Pack()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	var p Parser
+	if _, err := p.Start(raw); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := p.SkipAllQuestions(); err != nil {
+		t.Fatal(err)
+	}
+
+	if _, err := p.AnswerHeader(); err != nil {
+		t.Fatal(err)
+	}
+
+	if _, err := p.AdditionalHeader(); err == nil {
+		t.Errorf("p.AdditionalHeader() unexpected success")
+	}
+
+	if _, err := p.AuthorityHeader(); err == nil {
+		t.Errorf("p.AuthorityHeader() unexpected success")
+	}
+}