blob: ab8e599247e45458e604593ff19f4fd271b2eb06 [file] [log] [blame]
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netutil
import (
"context"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestLimitListener(t *testing.T) {
const (
max = 5
attempts = max * 2
msg = "bye\n"
)
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
l = LimitListener(l, max)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
accepted := 0
for {
c, err := l.Accept()
if err != nil {
break
}
accepted++
io.WriteString(c, msg)
defer c.Close() // Leave c open until the listener is closed.
}
if accepted > max {
t.Errorf("accepted %d simultaneous connections; want at most %d", accepted, max)
}
}()
// connc keeps the client end of the dialed connections alive until the
// test completes.
connc := make(chan []net.Conn, 1)
connc <- nil
dialCtx, cancelDial := context.WithCancel(context.Background())
defer cancelDial()
dialer := &net.Dialer{}
var served int32
for n := attempts; n > 0; n-- {
wg.Add(1)
go func() {
defer wg.Done()
c, err := dialer.DialContext(dialCtx, l.Addr().Network(), l.Addr().String())
if err != nil {
t.Log(err)
return
}
defer c.Close()
// Keep this end of the connection alive until after the Listener
// finishes.
conns := append(<-connc, c)
if len(conns) == max {
go func() {
// Give the server a bit of time to make sure it doesn't exceed its
// limit after serving this connection, then cancel the remaining
// Dials (if any).
time.Sleep(10 * time.Millisecond)
cancelDial()
l.Close()
}()
}
connc <- conns
b := make([]byte, len(msg))
if n, err := c.Read(b); n < len(b) {
t.Log(err)
return
}
atomic.AddInt32(&served, 1)
}()
}
wg.Wait()
conns := <-connc
for _, c := range conns {
c.Close()
}
t.Logf("with limit %d, served %d connections (of %d dialed, %d attempted)", max, served, len(conns), attempts)
if served != max {
t.Errorf("expected exactly %d served", max)
}
}
type errorListener struct {
net.Listener
}
func (errorListener) Accept() (net.Conn, error) {
return nil, errFake
}
var errFake = errors.New("fake error from errorListener")
// This used to hang.
func TestLimitListenerError(t *testing.T) {
const n = 2
ll := LimitListener(errorListener{}, n)
for i := 0; i < n+1; i++ {
_, err := ll.Accept()
if err != errFake {
t.Fatalf("Accept error = %v; want errFake", err)
}
}
}
func TestLimitListenerClose(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
ln = LimitListener(ln, 1)
errCh := make(chan error)
go func() {
defer close(errCh)
c, err := net.Dial(ln.Addr().Network(), ln.Addr().String())
if err != nil {
errCh <- err
return
}
c.Close()
}()
c, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()
err = <-errCh
if err != nil {
t.Fatalf("Dial: %v", err)
}
// Allow the subsequent Accept to block before closing the listener.
// (Accept should unblock and return.)
timer := time.AfterFunc(10*time.Millisecond, func() {
ln.Close()
})
c, err = ln.Accept()
if err == nil {
c.Close()
t.Errorf("Unexpected successful Accept()")
}
if timer.Stop() {
t.Errorf("Accept returned before listener closed: %v", err)
}
}