| // Copyright 2013 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package netutil |
| |
| import ( |
| "context" |
| "errors" |
| "io" |
| "net" |
| "sync" |
| "sync/atomic" |
| "testing" |
| "time" |
| ) |
| |
| func TestLimitListenerOverload(t *testing.T) { |
| const ( |
| max = 5 |
| attempts = max * 2 |
| msg = "bye\n" |
| ) |
| |
| l, err := net.Listen("tcp", "127.0.0.1:0") |
| if err != nil { |
| t.Fatal(err) |
| } |
| l = LimitListener(l, max) |
| |
| var wg sync.WaitGroup |
| wg.Add(1) |
| saturated := make(chan struct{}) |
| go func() { |
| defer wg.Done() |
| |
| accepted := 0 |
| for { |
| c, err := l.Accept() |
| if err != nil { |
| break |
| } |
| accepted++ |
| if accepted == max { |
| close(saturated) |
| } |
| io.WriteString(c, msg) |
| |
| // Leave c open until the listener is closed. |
| defer c.Close() |
| } |
| t.Logf("with limit %d, accepted %d simultaneous connections", max, accepted) |
| // The listener accounts open connections based on Listener-side Close |
| // calls, so even if the client hangs up early (for example, because it |
| // was a random dial from another process instead of from this test), we |
| // should not end up accepting more connections than expected. |
| if accepted != max { |
| t.Errorf("want exactly %d", max) |
| } |
| }() |
| |
| dialCtx, cancelDial := context.WithCancel(context.Background()) |
| defer cancelDial() |
| dialer := &net.Dialer{} |
| |
| var dialed, served int32 |
| var pendingDials sync.WaitGroup |
| for n := attempts; n > 0; n-- { |
| wg.Add(1) |
| pendingDials.Add(1) |
| go func() { |
| defer wg.Done() |
| |
| c, err := dialer.DialContext(dialCtx, l.Addr().Network(), l.Addr().String()) |
| pendingDials.Done() |
| if err != nil { |
| t.Log(err) |
| return |
| } |
| atomic.AddInt32(&dialed, 1) |
| defer c.Close() |
| |
| // The kernel may queue more than max connections (allowing their dials to |
| // succeed), but only max of them should actually be accepted by the |
| // server. We can distinguish the two based on whether the listener writes |
| // anything to the connection — a connection that was queued but not |
| // accepted will be closed without transferring any data. |
| if b, err := io.ReadAll(c); len(b) < len(msg) { |
| t.Log(err) |
| return |
| } |
| atomic.AddInt32(&served, 1) |
| }() |
| } |
| |
| // Give the server a bit of time after it saturates to make sure it doesn't |
| // exceed its limit after serving this connection, then cancel the remaining |
| // dials (if any). |
| <-saturated |
| time.Sleep(10 * time.Millisecond) |
| cancelDial() |
| // Wait for the dials to complete to ensure that the port isn't reused before |
| // the dials are actually attempted. |
| pendingDials.Wait() |
| l.Close() |
| wg.Wait() |
| |
| t.Logf("served %d simultaneous connections (of %d dialed, %d attempted)", served, dialed, attempts) |
| |
| // If some other process (such as a port scan or another test) happens to dial |
| // the listener at the same time, the listener could end up burning its quota |
| // on that, resulting in fewer than max test connections being served. |
| // But the number served certainly cannot be greater. |
| if served > max { |
| t.Errorf("expected at most %d served", max) |
| } |
| } |
| |
| func TestLimitListenerSaturation(t *testing.T) { |
| const ( |
| max = 5 |
| attemptsPerWave = max * 2 |
| waves = 10 |
| msg = "bye\n" |
| ) |
| |
| l, err := net.Listen("tcp", "127.0.0.1:0") |
| if err != nil { |
| t.Fatal(err) |
| } |
| l = LimitListener(l, max) |
| |
| acceptDone := make(chan struct{}) |
| defer func() { |
| l.Close() |
| <-acceptDone |
| }() |
| go func() { |
| defer close(acceptDone) |
| |
| var open, peakOpen int32 |
| var ( |
| saturated = make(chan struct{}) |
| saturatedOnce sync.Once |
| ) |
| var wg sync.WaitGroup |
| for { |
| c, err := l.Accept() |
| if err != nil { |
| break |
| } |
| if n := atomic.AddInt32(&open, 1); n > peakOpen { |
| peakOpen = n |
| if n == max { |
| saturatedOnce.Do(func() { |
| // Wait a bit to make sure the listener doesn't exceed its limit |
| // after accepting this connection, then allow the in-flight |
| // connections to write out and close. |
| time.AfterFunc(10*time.Millisecond, func() { close(saturated) }) |
| }) |
| } |
| } |
| wg.Add(1) |
| go func() { |
| <-saturated |
| io.WriteString(c, msg) |
| atomic.AddInt32(&open, -1) |
| c.Close() |
| wg.Done() |
| }() |
| } |
| wg.Wait() |
| |
| t.Logf("with limit %d, accepted a peak of %d simultaneous connections", max, peakOpen) |
| if peakOpen > max { |
| t.Errorf("want at most %d", max) |
| } |
| }() |
| |
| for wave := 0; wave < waves; wave++ { |
| var dialed, served int32 |
| var wg sync.WaitGroup |
| for n := attemptsPerWave; n > 0; n-- { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| |
| c, err := net.Dial(l.Addr().Network(), l.Addr().String()) |
| if err != nil { |
| t.Log(err) |
| return |
| } |
| atomic.AddInt32(&dialed, 1) |
| defer c.Close() |
| |
| if b, err := io.ReadAll(c); len(b) < len(msg) { |
| t.Log(err) |
| return |
| } |
| atomic.AddInt32(&served, 1) |
| }() |
| } |
| wg.Wait() |
| |
| t.Logf("served %d connections (of %d dialed, %d attempted)", served, dialed, attemptsPerWave) |
| |
| // Depending on the kernel's queueing behavior, we could get unlucky |
| // and drop one or more connections. However, we should certainly |
| // be able to serve at least max attempts out of each wave. |
| // (In the typical case, the kernel will queue all of the connections |
| // and they will all be served successfully.) |
| if dialed < max { |
| t.Errorf("expected at least %d dialed", max) |
| } |
| if served < dialed { |
| t.Errorf("expected all dialed connections to be served") |
| } |
| } |
| } |
| |
| type errorListener struct { |
| net.Listener |
| } |
| |
| func (errorListener) Accept() (net.Conn, error) { |
| return nil, errFake |
| } |
| |
| var errFake = errors.New("fake error from errorListener") |
| |
| // This used to hang. |
| func TestLimitListenerError(t *testing.T) { |
| const n = 2 |
| ll := LimitListener(errorListener{}, n) |
| for i := 0; i < n+1; i++ { |
| _, err := ll.Accept() |
| if err != errFake { |
| t.Fatalf("Accept error = %v; want errFake", err) |
| } |
| } |
| } |
| |
| func TestLimitListenerClose(t *testing.T) { |
| ln, err := net.Listen("tcp", "127.0.0.1:0") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer ln.Close() |
| ln = LimitListener(ln, 1) |
| |
| errCh := make(chan error) |
| go func() { |
| defer close(errCh) |
| c, err := net.Dial(ln.Addr().Network(), ln.Addr().String()) |
| if err != nil { |
| errCh <- err |
| return |
| } |
| c.Close() |
| }() |
| |
| c, err := ln.Accept() |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer c.Close() |
| |
| err = <-errCh |
| if err != nil { |
| t.Fatalf("Dial: %v", err) |
| } |
| |
| // Allow the subsequent Accept to block before closing the listener. |
| // (Accept should unblock and return.) |
| timer := time.AfterFunc(10*time.Millisecond, func() { ln.Close() }) |
| |
| c, err = ln.Accept() |
| if err == nil { |
| c.Close() |
| t.Errorf("Unexpected successful Accept()") |
| } |
| if timer.Stop() { |
| t.Errorf("Accept returned before listener closed: %v", err) |
| } |
| } |