net: separate DNS transport from DNS query-response interaction

Before fixing issue 6579 this CL separates DNS transport from
DNS message interaction to make it easier to add builtin DNS
resolver control logic.

Update #6579

LGTM=alex, kevlar
R=golang-codereviews, alex, gobot, iant, minux, kevlar
CC=golang-codereviews
https://golang.org/cl/101220044
diff --git a/src/pkg/net/dnsclient_unix.go b/src/pkg/net/dnsclient_unix.go
index 3713efd..eb4b590 100644
--- a/src/pkg/net/dnsclient_unix.go
+++ b/src/pkg/net/dnsclient_unix.go
@@ -16,6 +16,7 @@
 package net
 
 import (
+	"errors"
 	"io"
 	"math/rand"
 	"os"
@@ -23,118 +24,187 @@
 	"time"
 )
 
-// Send a request on the connection and hope for a reply.
-// Up to cfg.attempts attempts.
-func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, error) {
-	_, useTCP := c.(*TCPConn)
-	if len(name) >= 256 {
-		return nil, &DNSError{Err: "name too long", Name: name}
-	}
-	out := new(dnsMsg)
-	out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
-	out.question = []dnsQuestion{
-		{name, qtype, dnsClassINET},
-	}
-	out.recursion_desired = true
-	msg, ok := out.Pack()
-	if !ok {
-		return nil, &DNSError{Err: "internal error - cannot pack message", Name: name}
-	}
-	if useTCP {
-		mlen := uint16(len(msg))
-		msg = append([]byte{byte(mlen >> 8), byte(mlen)}, msg...)
-	}
-	for attempt := 0; attempt < cfg.attempts; attempt++ {
-		n, err := c.Write(msg)
-		if err != nil {
-			return nil, err
-		}
+// A dnsConn represents a DNS transport endpoint.
+type dnsConn interface {
+	Conn
 
-		if cfg.timeout == 0 {
-			c.SetReadDeadline(noDeadline)
-		} else {
-			c.SetReadDeadline(time.Now().Add(time.Duration(cfg.timeout) * time.Second))
-		}
-		buf := make([]byte, 2000)
-		if useTCP {
-			n, err = io.ReadFull(c, buf[:2])
-			if err != nil {
-				if e, ok := err.(Error); ok && e.Timeout() {
-					continue
-				}
-			}
-			mlen := int(buf[0])<<8 | int(buf[1])
-			if mlen > len(buf) {
-				buf = make([]byte, mlen)
-			}
-			n, err = io.ReadFull(c, buf[:mlen])
-		} else {
-			n, err = c.Read(buf)
-		}
+	// readDNSResponse reads a DNS response message from the DNS
+	// transport endpoint and returns the received DNS response
+	// message.
+	readDNSResponse() (*dnsMsg, error)
+
+	// writeDNSQuery writes a DNS query message to the DNS
+	// connection endpoint.
+	writeDNSQuery(*dnsMsg) error
+}
+
+func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
+	b := make([]byte, 512) // see RFC 1035
+	n, err := c.Read(b)
+	if err != nil {
+		return nil, err
+	}
+	msg := &dnsMsg{}
+	if !msg.Unpack(b[:n]) {
+		return nil, errors.New("cannot unmarshal DNS message")
+	}
+	return msg, nil
+}
+
+func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
+	b, ok := msg.Pack()
+	if !ok {
+		return errors.New("cannot marshal DNS message")
+	}
+	if _, err := c.Write(b); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
+	b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
+	if _, err := io.ReadFull(c, b[:2]); err != nil {
+		return nil, err
+	}
+	l := int(b[0])<<8 | int(b[1])
+	if l > len(b) {
+		b = make([]byte, l)
+	}
+	n, err := io.ReadFull(c, b[:l])
+	if err != nil {
+		return nil, err
+	}
+	msg := &dnsMsg{}
+	if !msg.Unpack(b[:n]) {
+		return nil, errors.New("cannot unmarshal DNS message")
+	}
+	return msg, nil
+}
+
+func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
+	b, ok := msg.Pack()
+	if !ok {
+		return errors.New("cannot marshal DNS message")
+	}
+	l := uint16(len(b))
+	b = append([]byte{byte(l >> 8), byte(l)}, b...)
+	if _, err := c.Write(b); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
+	switch network {
+	case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
+	default:
+		return nil, UnknownNetworkError(network)
+	}
+	// Calling Dial here is scary -- we have to be sure not to
+	// dial a name that will require a DNS lookup, or Dial will
+	// call back here to translate it. The DNS config parser has
+	// already checked that all the cfg.servers[i] are IP
+	// addresses, which Dial will use without a DNS lookup.
+	c, err := d.Dial(network, server)
+	if err != nil {
+		return nil, err
+	}
+	switch network {
+	case "tcp", "tcp4", "tcp6":
+		return c.(*TCPConn), nil
+	case "udp", "udp4", "udp6":
+		return c.(*UDPConn), nil
+	}
+	panic("unreachable")
+}
+
+// exchange sends a query on the connection and hopes for a response.
+func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
+	d := Dialer{Timeout: timeout}
+	out := dnsMsg{
+		dnsMsgHdr: dnsMsgHdr{
+			recursion_desired: true,
+		},
+		question: []dnsQuestion{
+			{name, qtype, dnsClassINET},
+		},
+	}
+	for _, network := range []string{"udp", "tcp"} {
+		c, err := d.dialDNS(network, server)
 		if err != nil {
-			if e, ok := err.(Error); ok && e.Timeout() {
-				continue
-			}
 			return nil, err
 		}
-		buf = buf[:n]
-		in := new(dnsMsg)
-		if !in.Unpack(buf) || in.id != out.id {
+		defer c.Close()
+		if timeout > 0 {
+			c.SetDeadline(time.Now().Add(timeout))
+		}
+		out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
+		if err := c.writeDNSQuery(&out); err != nil {
+			return nil, err
+		}
+		in, err := c.readDNSResponse()
+		if err != nil {
+			return nil, err
+		}
+		if in.id != out.id {
+			return nil, errors.New("DNS message ID mismatch")
+		}
+		if in.truncated { // see RFC 5966
 			continue
 		}
 		return in, nil
 	}
-	var server string
-	if a := c.RemoteAddr(); a != nil {
-		server = a.String()
-	}
-	return nil, &DNSError{Err: "no answer from server", Name: name, Server: server, IsTimeout: true}
+	return nil, errors.New("no answer from DNS server")
 }
 
 // Do a lookup for a single name, which must be rooted
 // (otherwise answer will not find the answers).
-func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs []dnsRR, err error) {
+func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
 	if len(cfg.servers) == 0 {
 		return "", nil, &DNSError{Err: "no DNS servers", Name: name}
 	}
-	for i := 0; i < len(cfg.servers); i++ {
-		// Calling Dial here is scary -- we have to be sure
-		// not to dial a name that will require a DNS lookup,
-		// or Dial will call back here to translate it.
-		// The DNS config parser has already checked that
-		// all the cfg.servers[i] are IP addresses, which
-		// Dial will use without a DNS lookup.
-		server := cfg.servers[i] + ":53"
-		c, cerr := Dial("udp", server)
-		if cerr != nil {
-			err = cerr
-			continue
+	if len(name) >= 256 {
+		return "", nil, &DNSError{Err: "DNS name too long", Name: name}
+	}
+	timeout := time.Duration(cfg.timeout) * time.Second
+	var lastErr error
+	for _, server := range cfg.servers {
+		server += ":53"
+		lastErr = &DNSError{
+			Err:       "no answer from DNS server",
+			Name:      name,
+			Server:    server,
+			IsTimeout: true,
 		}
-		msg, merr := exchange(cfg, c, name, qtype)
-		c.Close()
-		if merr != nil {
-			err = merr
-			continue
-		}
-		if msg.truncated { // see RFC 5966
-			c, cerr = Dial("tcp", server)
-			if cerr != nil {
-				err = cerr
-				continue
+		for i := 0; i < cfg.attempts; i++ {
+			msg, err := exchange(server, name, qtype, timeout)
+			if err != nil {
+				if nerr, ok := err.(Error); ok && nerr.Timeout() {
+					lastErr = &DNSError{
+						Err:       nerr.Error(),
+						Name:      name,
+						Server:    server,
+						IsTimeout: true,
+					}
+					continue
+
+				}
+				lastErr = &DNSError{
+					Err:    err.Error(),
+					Name:   name,
+					Server: server,
+				}
+				break
 			}
-			msg, merr = exchange(cfg, c, name, qtype)
-			c.Close()
-			if merr != nil {
-				err = merr
-				continue
+			cname, addrs, err := answer(name, server, msg, qtype)
+			if err == nil || err.(*DNSError).Err == noSuchHost {
+				return cname, addrs, err
 			}
-		}
-		cname, addrs, err = answer(name, server, msg, qtype)
-		if err == nil || err.(*DNSError).Err == noSuchHost {
-			break
+			lastErr = err
 		}
 	}
-	return
+	return "", nil, lastErr
 }
 
 func convertRR_A(records []dnsRR) []IP {