blob: 07ffd63386305a8317d08cab7424c72ca9fd86f2 [file] [log] [blame]
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"fmt"
"io/ioutil"
"os"
"sync"
"time"
)
// testUnixAddr uses ioutil.TempFile to get a name that is unique.
// It also uses /tmp directory in case it is prohibited to create UNIX
// sockets in TMPDIR.
func testUnixAddr() string {
f, err := ioutil.TempFile("", "nettest")
if err != nil {
panic(err)
}
addr := f.Name()
f.Close()
os.Remove(addr)
return addr
}
func newLocalListener(network string) (Listener, error) {
switch network {
case "tcp", "tcp4", "tcp6":
if supportsIPv4 {
return Listen("tcp4", "127.0.0.1:0")
}
if supportsIPv6 {
return Listen("tcp6", "[::1]:0")
}
case "unix", "unixpacket":
return Listen(network, testUnixAddr())
}
return nil, fmt.Errorf("%s is not supported", network)
}
type localServer struct {
lnmu sync.RWMutex
Listener
done chan bool // signal that indicates server stopped
}
func (ls *localServer) buildup(handler func(*localServer, Listener)) error {
go func() {
handler(ls, ls.Listener)
close(ls.done)
}()
return nil
}
func (ls *localServer) teardown() error {
ls.lnmu.Lock()
if ls.Listener != nil {
network := ls.Listener.Addr().Network()
address := ls.Listener.Addr().String()
ls.Listener.Close()
<-ls.done
ls.Listener = nil
switch network {
case "unix", "unixpacket":
os.Remove(address)
}
}
ls.lnmu.Unlock()
return nil
}
func newLocalServer(network string) (*localServer, error) {
ln, err := newLocalListener(network)
if err != nil {
return nil, err
}
return &localServer{Listener: ln, done: make(chan bool)}, nil
}
type streamListener struct {
network, address string
Listener
done chan bool // signal that indicates server stopped
}
func (sl *streamListener) newLocalServer() (*localServer, error) {
return &localServer{Listener: sl.Listener, done: make(chan bool)}, nil
}
type dualStackServer struct {
lnmu sync.RWMutex
lns []streamListener
port string
cmu sync.RWMutex
cs []Conn // established connections at the passive open side
}
func (dss *dualStackServer) buildup(handler func(*dualStackServer, Listener)) error {
for i := range dss.lns {
go func(i int) {
handler(dss, dss.lns[i].Listener)
close(dss.lns[i].done)
}(i)
}
return nil
}
func (dss *dualStackServer) putConn(c Conn) error {
dss.cmu.Lock()
dss.cs = append(dss.cs, c)
dss.cmu.Unlock()
return nil
}
func (dss *dualStackServer) teardownNetwork(network string) error {
dss.lnmu.Lock()
for i := range dss.lns {
if network == dss.lns[i].network && dss.lns[i].Listener != nil {
dss.lns[i].Listener.Close()
<-dss.lns[i].done
dss.lns[i].Listener = nil
}
}
dss.lnmu.Unlock()
return nil
}
func (dss *dualStackServer) teardown() error {
dss.lnmu.Lock()
for i := range dss.lns {
if dss.lns[i].Listener != nil {
dss.lns[i].Listener.Close()
<-dss.lns[i].done
}
}
dss.lns = dss.lns[:0]
dss.lnmu.Unlock()
dss.cmu.Lock()
for _, c := range dss.cs {
c.Close()
}
dss.cs = dss.cs[:0]
dss.cmu.Unlock()
return nil
}
func newDualStackServer(lns []streamListener) (*dualStackServer, error) {
dss := &dualStackServer{lns: lns, port: "0"}
for i := range dss.lns {
ln, err := Listen(dss.lns[i].network, JoinHostPort(dss.lns[i].address, dss.port))
if err != nil {
for _, ln := range dss.lns {
ln.Listener.Close()
}
return nil, err
}
dss.lns[i].Listener = ln
dss.lns[i].done = make(chan bool)
if dss.port == "0" {
if _, dss.port, err = SplitHostPort(ln.Addr().String()); err != nil {
for _, ln := range dss.lns {
ln.Listener.Close()
}
return nil, err
}
}
}
return dss, nil
}
func transponder(ln Listener, ch chan<- error) {
defer close(ch)
switch ln := ln.(type) {
case *TCPListener:
ln.SetDeadline(time.Now().Add(someTimeout))
case *UnixListener:
ln.SetDeadline(time.Now().Add(someTimeout))
}
c, err := ln.Accept()
if err != nil {
if perr := parseAcceptError(err); perr != nil {
ch <- perr
}
ch <- err
return
}
defer c.Close()
network := ln.Addr().Network()
if c.LocalAddr().Network() != network || c.LocalAddr().Network() != network {
ch <- fmt.Errorf("got %v->%v; expected %v->%v", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
return
}
c.SetDeadline(time.Now().Add(someTimeout))
c.SetReadDeadline(time.Now().Add(someTimeout))
c.SetWriteDeadline(time.Now().Add(someTimeout))
b := make([]byte, 256)
n, err := c.Read(b)
if err != nil {
if perr := parseReadError(err); perr != nil {
ch <- perr
}
ch <- err
return
}
if _, err := c.Write(b[:n]); err != nil {
if perr := parseWriteError(err); perr != nil {
ch <- perr
}
ch <- err
return
}
}
func transceiver(c Conn, wb []byte, ch chan<- error) {
defer close(ch)
c.SetDeadline(time.Now().Add(someTimeout))
c.SetReadDeadline(time.Now().Add(someTimeout))
c.SetWriteDeadline(time.Now().Add(someTimeout))
n, err := c.Write(wb)
if err != nil {
if perr := parseWriteError(err); perr != nil {
ch <- perr
}
ch <- err
return
}
if n != len(wb) {
ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
}
rb := make([]byte, len(wb))
n, err = c.Read(rb)
if err != nil {
if perr := parseReadError(err); perr != nil {
ch <- perr
}
ch <- err
return
}
if n != len(wb) {
ch <- fmt.Errorf("read %d; want %d", n, len(wb))
}
}
func newLocalPacketListener(network string) (PacketConn, error) {
switch network {
case "udp", "udp4", "udp6":
if supportsIPv4 {
return ListenPacket("udp4", "127.0.0.1:0")
}
if supportsIPv6 {
return ListenPacket("udp6", "[::1]:0")
}
case "unixgram":
return ListenPacket(network, testUnixAddr())
}
return nil, fmt.Errorf("%s is not supported", network)
}
type localPacketServer struct {
pcmu sync.RWMutex
PacketConn
done chan bool // signal that indicates server stopped
}
func (ls *localPacketServer) buildup(handler func(*localPacketServer, PacketConn)) error {
go func() {
handler(ls, ls.PacketConn)
close(ls.done)
}()
return nil
}
func (ls *localPacketServer) teardown() error {
ls.pcmu.Lock()
if ls.PacketConn != nil {
network := ls.PacketConn.LocalAddr().Network()
address := ls.PacketConn.LocalAddr().String()
ls.PacketConn.Close()
<-ls.done
ls.PacketConn = nil
switch network {
case "unixgram":
os.Remove(address)
}
}
ls.pcmu.Unlock()
return nil
}
func newLocalPacketServer(network string) (*localPacketServer, error) {
c, err := newLocalPacketListener(network)
if err != nil {
return nil, err
}
return &localPacketServer{PacketConn: c, done: make(chan bool)}, nil
}
type packetListener struct {
PacketConn
}
func (pl *packetListener) newLocalServer() (*localPacketServer, error) {
return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)}, nil
}
func packetTransponder(c PacketConn, ch chan<- error) {
defer close(ch)
c.SetDeadline(time.Now().Add(someTimeout))
c.SetReadDeadline(time.Now().Add(someTimeout))
c.SetWriteDeadline(time.Now().Add(someTimeout))
b := make([]byte, 256)
n, peer, err := c.ReadFrom(b)
if err != nil {
if perr := parseReadError(err); perr != nil {
ch <- perr
}
ch <- err
return
}
if peer == nil { // for connected-mode sockets
switch c.LocalAddr().Network() {
case "udp":
peer, err = ResolveUDPAddr("udp", string(b[:n]))
case "unixgram":
peer, err = ResolveUnixAddr("unixgram", string(b[:n]))
}
if err != nil {
ch <- err
return
}
}
if _, err := c.WriteTo(b[:n], peer); err != nil {
if perr := parseWriteError(err); perr != nil {
ch <- perr
}
ch <- err
return
}
}
func packetTransceiver(c PacketConn, wb []byte, dst Addr, ch chan<- error) {
defer close(ch)
c.SetDeadline(time.Now().Add(someTimeout))
c.SetReadDeadline(time.Now().Add(someTimeout))
c.SetWriteDeadline(time.Now().Add(someTimeout))
n, err := c.WriteTo(wb, dst)
if err != nil {
if perr := parseWriteError(err); perr != nil {
ch <- perr
}
ch <- err
return
}
if n != len(wb) {
ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
}
rb := make([]byte, len(wb))
n, _, err = c.ReadFrom(rb)
if err != nil {
if perr := parseReadError(err); perr != nil {
ch <- perr
}
ch <- err
return
}
if n != len(wb) {
ch <- fmt.Errorf("read %d; want %d", n, len(wb))
}
}