blob: a9f596d03b648b9e9ce9b80955a3a0790bc5946e [file] [log] [blame]
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netutil
import (
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"sync"
"sync/atomic"
"testing"
"time"
)
const defaultMaxOpenFiles = 256
const timeout = 5 * time.Second
func TestLimitListener(t *testing.T) {
const max = 5
attempts := (maxOpenFiles() - max) / 2
if attempts > 256 { // maximum length of accept queue is 128 by default
attempts = 256
}
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer l.Close()
l = LimitListener(l, max)
var open int32
go http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if n := atomic.AddInt32(&open, 1); n > max {
t.Errorf("%d open connections, want <= %d", n, max)
}
defer atomic.AddInt32(&open, -1)
time.Sleep(10 * time.Millisecond)
fmt.Fprint(w, "some body")
}))
var wg sync.WaitGroup
var failed int32
for i := 0; i < attempts; i++ {
wg.Add(1)
go func() {
defer wg.Done()
c := http.Client{Timeout: 3 * time.Second}
r, err := c.Get("http://" + l.Addr().String())
if err != nil {
t.Log(err)
atomic.AddInt32(&failed, 1)
return
}
defer r.Body.Close()
io.Copy(ioutil.Discard, r.Body)
}()
}
wg.Wait()
// We expect some Gets to fail as the kernel's accept queue is filled,
// but most should succeed.
if int(failed) >= attempts/2 {
t.Errorf("%d requests failed within %d attempts", failed, attempts)
}
}
type errorListener struct {
net.Listener
}
func (errorListener) Accept() (net.Conn, error) {
return nil, errFake
}
var errFake = errors.New("fake error from errorListener")
// This used to hang.
func TestLimitListenerError(t *testing.T) {
errCh := make(chan error, 1)
go func() {
defer close(errCh)
const n = 2
ll := LimitListener(errorListener{}, n)
for i := 0; i < n+1; i++ {
_, err := ll.Accept()
if err != errFake {
errCh <- fmt.Errorf("Accept error = %v; want errFake", err)
return
}
}
}()
select {
case err := <-errCh:
if err != nil {
t.Fatalf("server: %v", err)
}
case <-time.After(timeout):
t.Fatal("timeout. deadlock?")
}
}
func TestLimitListenerClose(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
ln = LimitListener(ln, 1)
errCh := make(chan error)
go func() {
defer close(errCh)
c, err := net.DialTimeout("tcp", ln.Addr().String(), timeout)
if err != nil {
errCh <- err
return
}
c.Close()
}()
c, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()
err = <-errCh
if err != nil {
t.Fatalf("DialTimeout: %v", err)
}
acceptDone := make(chan struct{})
go func() {
c, err := ln.Accept()
if err == nil {
c.Close()
t.Errorf("Unexpected successful Accept()")
}
close(acceptDone)
}()
// Wait a tiny bit to ensure the Accept() is blocking.
time.Sleep(10 * time.Millisecond)
ln.Close()
select {
case <-acceptDone:
case <-time.After(timeout):
t.Fatalf("Accept() still blocking")
}
}