blob: 1171bc3a14ef0139c0e1c1886b94310fdd6a833c [file] [log] [blame]
Han-Wen Nienhuys7cbb17f2013-06-18 12:43:42 -04001// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
Tobias Klauser3ef80562023-04-28 11:00:06 +02005//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
6// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
Han-Wen Nienhuys7cbb17f2013-06-18 12:43:42 -04007
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -04008package test
9
10import (
11 "bytes"
Lars Lehtonen497ca9f2019-11-15 14:31:24 -080012 "fmt"
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040013 "io"
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040014 "math/rand"
15 "net"
16 "testing"
Han-Wen Nienhuys7f7cbbf2013-07-22 21:50:13 +100017 "time"
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040018)
19
Akihiro Suda3cddcd62017-03-24 09:29:21 +000020type closeWriter interface {
21 CloseWrite() error
22}
23
24func testPortForward(t *testing.T, n, listenAddr string) {
Han-Wen Nienhuys7f7cbbf2013-07-22 21:50:13 +100025 server := newServer(t)
Han-Wen Nienhuys7f7cbbf2013-07-22 21:50:13 +100026 conn := server.Dial(clientConfig())
27 defer conn.Close()
28
Akihiro Suda3cddcd62017-03-24 09:29:21 +000029 sshListener, err := conn.Listen(n, listenAddr)
Han-Wen Nienhuys7f7cbbf2013-07-22 21:50:13 +100030 if err != nil {
31 t.Fatal(err)
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040032 }
33
Lars Lehtonen497ca9f2019-11-15 14:31:24 -080034 errCh := make(chan error, 1)
35
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040036 go func() {
Lars Lehtonen497ca9f2019-11-15 14:31:24 -080037 defer close(errCh)
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040038 sshConn, err := sshListener.Accept()
39 if err != nil {
Lars Lehtonen497ca9f2019-11-15 14:31:24 -080040 errCh <- fmt.Errorf("listen.Accept failed: %v", err)
41 return
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040042 }
Lars Lehtonen497ca9f2019-11-15 14:31:24 -080043 defer sshConn.Close()
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040044
45 _, err = io.Copy(sshConn, sshConn)
46 if err != nil && err != io.EOF {
Lars Lehtonen497ca9f2019-11-15 14:31:24 -080047 errCh <- fmt.Errorf("ssh client copy: %v", err)
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040048 }
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040049 }()
50
51 forwardedAddr := sshListener.Addr().String()
Akihiro Suda3cddcd62017-03-24 09:29:21 +000052 netConn, err := net.Dial(n, forwardedAddr)
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040053 if err != nil {
Akihiro Suda3cddcd62017-03-24 09:29:21 +000054 t.Fatalf("net dial failed: %v", err)
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040055 }
56
57 readChan := make(chan []byte)
58 go func() {
cui fliter35f42652022-09-16 09:30:45 +000059 data, _ := io.ReadAll(netConn)
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040060 readChan <- data
61 }()
62
63 // Invent some data.
64 data := make([]byte, 100*1000)
65 for i := range data {
66 data[i] = byte(i % 255)
67 }
68
69 var sent []byte
70 for len(sent) < 1000*1000 {
71 // Send random sized chunks
72 m := rand.Intn(len(data))
Akihiro Suda3cddcd62017-03-24 09:29:21 +000073 n, err := netConn.Write(data[:m])
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040074 if err != nil {
75 break
76 }
77 sent = append(sent, data[:n]...)
78 }
Akihiro Suda3cddcd62017-03-24 09:29:21 +000079 if err := netConn.(closeWriter).CloseWrite(); err != nil {
80 t.Errorf("netConn.CloseWrite: %v", err)
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040081 }
82
Lars Lehtonen497ca9f2019-11-15 14:31:24 -080083 // Check for errors on server goroutine
84 err = <-errCh
85 if err != nil {
86 t.Fatalf("server: %v", err)
87 }
88
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -040089 read := <-readChan
90
91 if len(sent) != len(read) {
92 t.Fatalf("got %d bytes, want %d", len(read), len(sent))
93 }
94 if bytes.Compare(sent, read) != 0 {
95 t.Fatalf("read back data does not match")
96 }
97
98 if err := sshListener.Close(); err != nil {
99 t.Fatalf("sshListener.Close: %v", err)
100 }
101
102 // Check that the forward disappeared.
Akihiro Suda3cddcd62017-03-24 09:29:21 +0000103 netConn, err = net.Dial(n, forwardedAddr)
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -0400104 if err == nil {
Akihiro Suda3cddcd62017-03-24 09:29:21 +0000105 netConn.Close()
Han-Wen Nienhuys0d8dc3c2013-06-11 22:10:15 -0400106 t.Errorf("still listening to %s after closing", forwardedAddr)
107 }
108}
Han-Wen Nienhuys7f7cbbf2013-07-22 21:50:13 +1000109
Akihiro Suda3cddcd62017-03-24 09:29:21 +0000110func TestPortForwardTCP(t *testing.T) {
111 testPortForward(t, "tcp", "localhost:0")
112}
113
114func TestPortForwardUnix(t *testing.T) {
115 addr, cleanup := newTempSocket(t)
116 defer cleanup()
117 testPortForward(t, "unix", addr)
118}
119
120func testAcceptClose(t *testing.T, n, listenAddr string) {
Han-Wen Nienhuys7f7cbbf2013-07-22 21:50:13 +1000121 server := newServer(t)
Han-Wen Nienhuys7f7cbbf2013-07-22 21:50:13 +1000122 conn := server.Dial(clientConfig())
123
Akihiro Suda3cddcd62017-03-24 09:29:21 +0000124 sshListener, err := conn.Listen(n, listenAddr)
Han-Wen Nienhuys7f7cbbf2013-07-22 21:50:13 +1000125 if err != nil {
126 t.Fatal(err)
127 }
128
129 quit := make(chan error, 1)
130 go func() {
131 for {
132 c, err := sshListener.Accept()
133 if err != nil {
134 quit <- err
135 break
136 }
137 c.Close()
138 }
139 }()
140 sshListener.Close()
141
142 select {
143 case <-time.After(1 * time.Second):
144 t.Errorf("timeout: listener did not close.")
145 case err := <-quit:
146 t.Logf("quit as expected (error %v)", err)
147 }
148}
149
Akihiro Suda3cddcd62017-03-24 09:29:21 +0000150func TestAcceptCloseTCP(t *testing.T) {
151 testAcceptClose(t, "tcp", "localhost:0")
152}
153
154func TestAcceptCloseUnix(t *testing.T) {
155 addr, cleanup := newTempSocket(t)
156 defer cleanup()
157 testAcceptClose(t, "unix", addr)
158}
159
Han-Wen Nienhuysa93ee0c2013-08-28 12:41:55 -0400160// Check that listeners exit if the underlying client transport dies.
Akihiro Suda3cddcd62017-03-24 09:29:21 +0000161func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
Han-Wen Nienhuysa93ee0c2013-08-28 12:41:55 -0400162 server := newServer(t)
Bryan C. Mills0ff60052023-05-22 10:51:18 -0400163 client := server.Dial(clientConfig())
Han-Wen Nienhuysa93ee0c2013-08-28 12:41:55 -0400164
Bryan C. Mills0ff60052023-05-22 10:51:18 -0400165 sshListener, err := client.Listen(n, listenAddr)
Han-Wen Nienhuysa93ee0c2013-08-28 12:41:55 -0400166 if err != nil {
167 t.Fatal(err)
168 }
169
170 quit := make(chan error, 1)
171 go func() {
172 for {
173 c, err := sshListener.Accept()
174 if err != nil {
175 quit <- err
176 break
177 }
178 c.Close()
179 }
180 }()
181
182 // It would be even nicer if we closed the server side, but it
183 // is more involved as the fd for that side is dup()ed.
Bryan C. Mills0ff60052023-05-22 10:51:18 -0400184 server.lastDialConn.Close()
Han-Wen Nienhuysa93ee0c2013-08-28 12:41:55 -0400185
Bryan C. Mills0ff60052023-05-22 10:51:18 -0400186 err = <-quit
187 t.Logf("quit as expected (error %v)", err)
Han-Wen Nienhuysa93ee0c2013-08-28 12:41:55 -0400188}
Akihiro Suda3cddcd62017-03-24 09:29:21 +0000189
190func TestPortForwardConnectionCloseTCP(t *testing.T) {
191 testPortForwardConnectionClose(t, "tcp", "localhost:0")
192}
193
194func TestPortForwardConnectionCloseUnix(t *testing.T) {
195 addr, cleanup := newTempSocket(t)
196 defer cleanup()
197 testPortForwardConnectionClose(t, "unix", addr)
198}