blob: 793a91d394b6210d4f4f42e24b2ccffbd718f94c [file] [log] [blame]
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netutil
import (
"context"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestLimitListenerOverload(t *testing.T) {
const (
max = 5
attempts = max * 2
msg = "bye\n"
)
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
l = LimitListener(l, max)
var wg sync.WaitGroup
wg.Add(1)
saturated := make(chan struct{})
go func() {
defer wg.Done()
accepted := 0
for {
c, err := l.Accept()
if err != nil {
break
}
accepted++
if accepted == max {
close(saturated)
}
io.WriteString(c, msg)
// Leave c open until the listener is closed.
defer c.Close()
}
t.Logf("with limit %d, accepted %d simultaneous connections", max, accepted)
// The listener accounts open connections based on Listener-side Close
// calls, so even if the client hangs up early (for example, because it
// was a random dial from another process instead of from this test), we
// should not end up accepting more connections than expected.
if accepted != max {
t.Errorf("want exactly %d", max)
}
}()
dialCtx, cancelDial := context.WithCancel(context.Background())
defer cancelDial()
dialer := &net.Dialer{}
var dialed, served int32
var pendingDials sync.WaitGroup
for n := attempts; n > 0; n-- {
wg.Add(1)
pendingDials.Add(1)
go func() {
defer wg.Done()
c, err := dialer.DialContext(dialCtx, l.Addr().Network(), l.Addr().String())
pendingDials.Done()
if err != nil {
t.Log(err)
return
}
atomic.AddInt32(&dialed, 1)
defer c.Close()
// The kernel may queue more than max connections (allowing their dials to
// succeed), but only max of them should actually be accepted by the
// server. We can distinguish the two based on whether the listener writes
// anything to the connection — a connection that was queued but not
// accepted will be closed without transferring any data.
if b, err := io.ReadAll(c); len(b) < len(msg) {
t.Log(err)
return
}
atomic.AddInt32(&served, 1)
}()
}
// Give the server a bit of time after it saturates to make sure it doesn't
// exceed its limit after serving this connection, then cancel the remaining
// dials (if any).
<-saturated
time.Sleep(10 * time.Millisecond)
cancelDial()
// Wait for the dials to complete to ensure that the port isn't reused before
// the dials are actually attempted.
pendingDials.Wait()
l.Close()
wg.Wait()
t.Logf("served %d simultaneous connections (of %d dialed, %d attempted)", served, dialed, attempts)
// If some other process (such as a port scan or another test) happens to dial
// the listener at the same time, the listener could end up burning its quota
// on that, resulting in fewer than max test connections being served.
// But the number served certainly cannot be greater.
if served > max {
t.Errorf("expected at most %d served", max)
}
}
func TestLimitListenerSaturation(t *testing.T) {
const (
max = 5
attemptsPerWave = max * 2
waves = 10
msg = "bye\n"
)
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
l = LimitListener(l, max)
acceptDone := make(chan struct{})
defer func() {
l.Close()
<-acceptDone
}()
go func() {
defer close(acceptDone)
var open, peakOpen int32
var (
saturated = make(chan struct{})
saturatedOnce sync.Once
)
var wg sync.WaitGroup
for {
c, err := l.Accept()
if err != nil {
break
}
if n := atomic.AddInt32(&open, 1); n > peakOpen {
peakOpen = n
if n == max {
saturatedOnce.Do(func() {
// Wait a bit to make sure the listener doesn't exceed its limit
// after accepting this connection, then allow the in-flight
// connections to write out and close.
time.AfterFunc(10*time.Millisecond, func() { close(saturated) })
})
}
}
wg.Add(1)
go func() {
<-saturated
io.WriteString(c, msg)
atomic.AddInt32(&open, -1)
c.Close()
wg.Done()
}()
}
wg.Wait()
t.Logf("with limit %d, accepted a peak of %d simultaneous connections", max, peakOpen)
if peakOpen > max {
t.Errorf("want at most %d", max)
}
}()
for wave := 0; wave < waves; wave++ {
var dialed, served int32
var wg sync.WaitGroup
for n := attemptsPerWave; n > 0; n-- {
wg.Add(1)
go func() {
defer wg.Done()
c, err := net.Dial(l.Addr().Network(), l.Addr().String())
if err != nil {
t.Log(err)
return
}
atomic.AddInt32(&dialed, 1)
defer c.Close()
if b, err := io.ReadAll(c); len(b) < len(msg) {
t.Log(err)
return
}
atomic.AddInt32(&served, 1)
}()
}
wg.Wait()
t.Logf("served %d connections (of %d dialed, %d attempted)", served, dialed, attemptsPerWave)
// We expect that the kernel can queue at least attemptsPerWave
// connections at a time (since it's only a small number), so every
// connection should eventually be served.
if served != attemptsPerWave {
t.Errorf("expected %d served", attemptsPerWave)
}
}
}
type errorListener struct {
net.Listener
}
func (errorListener) Accept() (net.Conn, error) {
return nil, errFake
}
var errFake = errors.New("fake error from errorListener")
// This used to hang.
func TestLimitListenerError(t *testing.T) {
const n = 2
ll := LimitListener(errorListener{}, n)
for i := 0; i < n+1; i++ {
_, err := ll.Accept()
if err != errFake {
t.Fatalf("Accept error = %v; want errFake", err)
}
}
}
func TestLimitListenerClose(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
ln = LimitListener(ln, 1)
errCh := make(chan error)
go func() {
defer close(errCh)
c, err := net.Dial(ln.Addr().Network(), ln.Addr().String())
if err != nil {
errCh <- err
return
}
c.Close()
}()
c, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()
err = <-errCh
if err != nil {
t.Fatalf("Dial: %v", err)
}
// Allow the subsequent Accept to block before closing the listener.
// (Accept should unblock and return.)
timer := time.AfterFunc(10*time.Millisecond, func() { ln.Close() })
c, err = ln.Accept()
if err == nil {
c.Close()
t.Errorf("Unexpected successful Accept()")
}
if timer.Stop() {
t.Errorf("Accept returned before listener closed: %v", err)
}
}