net: filter destination addresses when source address is specified

This change filters out destination addresses by address family when
source address is specified to avoid running Dial operation with wrong
addressing scopes.

Fixes #11837.

Change-Id: I10b7a1fa325add2cd8ed58f105d527700a10d342
Reviewed-on: https://go-review.googlesource.com/20586
Reviewed-by: Paul Marks <pmarks@google.com>
diff --git a/src/net/dial.go b/src/net/dial.go
index e4e44d2..22992d5 100644
--- a/src/net/dial.go
+++ b/src/net/dial.go
@@ -5,7 +5,6 @@
 package net
 
 import (
-	"errors"
 	"runtime"
 	"time"
 )
@@ -140,8 +139,11 @@
 	return "", 0, UnknownNetworkError(net)
 }
 
-func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error) {
-	afnet, _, err := parseNetwork(net)
+// resolverAddrList resolves addr using hint and returns a list of
+// addresses. The result contains at least one address when error is
+// nil.
+func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (addrList, error) {
+	afnet, _, err := parseNetwork(network)
 	if err != nil {
 		return nil, err
 	}
@@ -154,9 +156,59 @@
 		if err != nil {
 			return nil, err
 		}
+		if op == "dial" && hint != nil && addr.Network() != hint.Network() {
+			return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
+		}
 		return addrList{addr}, nil
 	}
-	return internetAddrList(afnet, addr, deadline)
+	addrs, err := internetAddrList(afnet, addr, deadline)
+	if err != nil || op != "dial" || hint == nil {
+		return addrs, err
+	}
+	var (
+		tcp      *TCPAddr
+		udp      *UDPAddr
+		ip       *IPAddr
+		wildcard bool
+	)
+	switch hint := hint.(type) {
+	case *TCPAddr:
+		tcp = hint
+		wildcard = tcp.isWildcard()
+	case *UDPAddr:
+		udp = hint
+		wildcard = udp.isWildcard()
+	case *IPAddr:
+		ip = hint
+		wildcard = ip.isWildcard()
+	}
+	naddrs := addrs[:0]
+	for _, addr := range addrs {
+		if addr.Network() != hint.Network() {
+			return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
+		}
+		switch addr := addr.(type) {
+		case *TCPAddr:
+			if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
+				continue
+			}
+			naddrs = append(naddrs, addr)
+		case *UDPAddr:
+			if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
+				continue
+			}
+			naddrs = append(naddrs, addr)
+		case *IPAddr:
+			if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
+				continue
+			}
+			naddrs = append(naddrs, addr)
+		}
+	}
+	if len(naddrs) == 0 {
+		return nil, errNoSuitableAddress
+	}
+	return naddrs, nil
 }
 
 // Dial connects to the address on the named network.
@@ -214,7 +266,7 @@
 // parameters.
 func (d *Dialer) Dial(network, address string) (Conn, error) {
 	finalDeadline := d.deadline(time.Now())
-	addrs, err := resolveAddrList("dial", network, address, finalDeadline)
+	addrs, err := resolveAddrList("dial", network, address, d.LocalAddr, finalDeadline)
 	if err != nil {
 		return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
 	}
@@ -387,9 +439,6 @@
 // dial function, because some OSes don't implement the deadline feature.
 func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan struct{}) (c Conn, err error) {
 	la := ctx.LocalAddr
-	if la != nil && la.Network() != ra.Network() {
-		return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: errors.New("mismatched local address type " + la.Network())}
-	}
 	switch ra := ra.(type) {
 	case *TCPAddr:
 		la, _ := la.(*TCPAddr)
@@ -420,7 +469,7 @@
 // instead of just the interface with the given host address.
 // See Dial for more details about address syntax.
 func Listen(net, laddr string) (Listener, error) {
-	addrs, err := resolveAddrList("listen", net, laddr, noDeadline)
+	addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
 	if err != nil {
 		return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
 	}
@@ -447,7 +496,7 @@
 // instead of just the interface with the given host address.
 // See Dial for the syntax of laddr.
 func ListenPacket(net, laddr string) (PacketConn, error) {
-	addrs, err := resolveAddrList("listen", net, laddr, noDeadline)
+	addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
 	if err != nil {
 		return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
 	}
diff --git a/src/net/dial_test.go b/src/net/dial_test.go
index 5fe3e85..3335df5 100644
--- a/src/net/dial_test.go
+++ b/src/net/dial_test.go
@@ -646,41 +646,118 @@
 	}
 }
 
+type dialerLocalAddrTest struct {
+	network, raddr string
+	laddr          Addr
+	error
+}
+
+var dialerLocalAddrTests = []dialerLocalAddrTest{
+	{"tcp4", "127.0.0.1", nil, nil},
+	{"tcp4", "127.0.0.1", &TCPAddr{}, nil},
+	{"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+	{"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+	{"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, &AddrError{Err: "some error"}},
+	{"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, nil},
+	{"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, nil},
+	{"tcp4", "127.0.0.1", &TCPAddr{IP: IPv6loopback}, errNoSuitableAddress},
+	{"tcp4", "127.0.0.1", &UDPAddr{}, &AddrError{Err: "some error"}},
+	{"tcp4", "127.0.0.1", &UnixAddr{}, &AddrError{Err: "some error"}},
+
+	{"tcp6", "::1", nil, nil},
+	{"tcp6", "::1", &TCPAddr{}, nil},
+	{"tcp6", "::1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+	{"tcp6", "::1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+	{"tcp6", "::1", &TCPAddr{IP: ParseIP("::")}, nil},
+	{"tcp6", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, errNoSuitableAddress},
+	{"tcp6", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, errNoSuitableAddress},
+	{"tcp6", "::1", &TCPAddr{IP: IPv6loopback}, nil},
+	{"tcp6", "::1", &UDPAddr{}, &AddrError{Err: "some error"}},
+	{"tcp6", "::1", &UnixAddr{}, &AddrError{Err: "some error"}},
+
+	{"tcp", "127.0.0.1", nil, nil},
+	{"tcp", "127.0.0.1", &TCPAddr{}, nil},
+	{"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+	{"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+	{"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, nil},
+	{"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, nil},
+	{"tcp", "127.0.0.1", &TCPAddr{IP: IPv6loopback}, errNoSuitableAddress},
+	{"tcp", "127.0.0.1", &UDPAddr{}, &AddrError{Err: "some error"}},
+	{"tcp", "127.0.0.1", &UnixAddr{}, &AddrError{Err: "some error"}},
+
+	{"tcp", "::1", nil, nil},
+	{"tcp", "::1", &TCPAddr{}, nil},
+	{"tcp", "::1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+	{"tcp", "::1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+	{"tcp", "::1", &TCPAddr{IP: ParseIP("::")}, nil},
+	{"tcp", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, errNoSuitableAddress},
+	{"tcp", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, errNoSuitableAddress},
+	{"tcp", "::1", &TCPAddr{IP: IPv6loopback}, nil},
+	{"tcp", "::1", &UDPAddr{}, &AddrError{Err: "some error"}},
+	{"tcp", "::1", &UnixAddr{}, &AddrError{Err: "some error"}},
+}
+
 func TestDialerLocalAddr(t *testing.T) {
-	ch := make(chan error, 1)
-	handler := func(ls *localServer, ln Listener) {
-		c, err := ln.Accept()
-		if err != nil {
-			ch <- err
-			return
-		}
-		defer c.Close()
-		ch <- nil
-	}
-	ls, err := newLocalServer("tcp")
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer ls.teardown()
-	if err := ls.buildup(handler); err != nil {
-		t.Fatal(err)
+	if !supportsIPv4 || !supportsIPv6 {
+		t.Skip("both IPv4 and IPv6 are required")
 	}
 
-	laddr, err := ResolveTCPAddr(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
-	if err != nil {
-		t.Fatal(err)
+	if supportsIPv4map {
+		dialerLocalAddrTests = append(dialerLocalAddrTests, dialerLocalAddrTest{
+			"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, nil,
+		})
+	} else {
+		dialerLocalAddrTests = append(dialerLocalAddrTests, dialerLocalAddrTest{
+			"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, &AddrError{Err: "some error"},
+		})
 	}
-	laddr.Port = 0
-	d := &Dialer{LocalAddr: laddr}
-	c, err := d.Dial(ls.Listener.Addr().Network(), ls.Addr().String())
-	if err != nil {
-		t.Fatal(err)
+
+	origTestHookLookupIP := testHookLookupIP
+	defer func() { testHookLookupIP = origTestHookLookupIP }()
+	testHookLookupIP = lookupLocalhost
+	handler := func(ls *localServer, ln Listener) {
+		for {
+			c, err := ln.Accept()
+			if err != nil {
+				return
+			}
+			c.Close()
+		}
 	}
-	defer c.Close()
-	c.Read(make([]byte, 1))
-	err = <-ch
-	if err != nil {
-		t.Error(err)
+	var err error
+	var lss [2]*localServer
+	for i, network := range []string{"tcp4", "tcp6"} {
+		lss[i], err = newLocalServer(network)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer lss[i].teardown()
+		if err := lss[i].buildup(handler); err != nil {
+			t.Fatal(err)
+		}
+	}
+
+	for _, tt := range dialerLocalAddrTests {
+		d := &Dialer{LocalAddr: tt.laddr}
+		var addr string
+		ip := ParseIP(tt.raddr)
+		if ip.To4() != nil {
+			addr = lss[0].Listener.Addr().String()
+		}
+		if ip.To16() != nil && ip.To4() == nil {
+			addr = lss[1].Listener.Addr().String()
+		}
+		c, err := d.Dial(tt.network, addr)
+		if err == nil && tt.error != nil || err != nil && tt.error == nil {
+			t.Errorf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error)
+		}
+		if err != nil {
+			if perr := parseDialError(err); perr != nil {
+				t.Error(perr)
+			}
+			continue
+		}
+		c.Close()
 	}
 }
 
diff --git a/src/net/error_test.go b/src/net/error_test.go
index ee0979c..c3a4d32 100644
--- a/src/net/error_test.go
+++ b/src/net/error_test.go
@@ -96,7 +96,7 @@
 		goto third
 	}
 	switch nestedErr {
-	case errCanceled, errClosing, errMissingAddress:
+	case errCanceled, errClosing, errMissingAddress, errNoSuitableAddress:
 		return nil
 	}
 	return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
@@ -416,7 +416,7 @@
 		goto third
 	}
 	switch nestedErr {
-	case errCanceled, errClosing, errTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF:
+	case errCanceled, errClosing, errMissingAddress, errTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF:
 		return nil
 	}
 	return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
diff --git a/src/net/ip.go b/src/net/ip.go
index a25729c..0501f5a 100644
--- a/src/net/ip.go
+++ b/src/net/ip.go
@@ -377,6 +377,10 @@
 	return true
 }
 
+func (ip IP) matchAddrFamily(x IP) bool {
+	return ip.To4() != nil && x.To4() != nil || ip.To16() != nil && ip.To4() == nil && x.To16() != nil && x.To4() == nil
+}
+
 // If mask is a sequence of 1 bits followed by 0 bits,
 // return the number of 1 bits.
 func simpleMaskLength(mask IPMask) int {
diff --git a/src/net/ipsock.go b/src/net/ipsock.go
index f3ac00d..f093b49 100644
--- a/src/net/ipsock.go
+++ b/src/net/ipsock.go
@@ -6,10 +6,7 @@
 
 package net
 
-import (
-	"errors"
-	"time"
-)
+import "time"
 
 var (
 	// supportsIPv4 reports whether the platform supports IPv4
@@ -73,8 +70,6 @@
 	return
 }
 
-var errNoSuitableAddress = errors.New("no suitable address found")
-
 // filterAddrList applies a filter to a list of IP addresses,
 // yielding a list of Addr objects. Known filters are nil, ipv4only,
 // and ipv6only. It returns every address when the filter is nil.
diff --git a/src/net/net.go b/src/net/net.go
index 2ff1a34..3b37b33 100644
--- a/src/net/net.go
+++ b/src/net/net.go
@@ -364,6 +364,9 @@
 
 // Various errors contained in OpError.
 var (
+	// For connection setup operations.
+	errNoSuitableAddress = errors.New("no suitable address found")
+
 	// For connection setup and write operations.
 	errMissingAddress = errors.New("missing address")