|  | // Copyright 2023 The Go Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style | 
|  | // license that can be found in the LICENSE file. | 
|  |  | 
|  | package net | 
|  |  | 
|  | import ( | 
|  | "bytes" | 
|  | "context" | 
|  | "errors" | 
|  | "syscall" | 
|  | "testing" | 
|  | ) | 
|  |  | 
|  | func newLocalListenerMPTCP(t *testing.T) Listener { | 
|  | lc := &ListenConfig{} | 
|  | if lc.MultipathTCP() { | 
|  | t.Error("MultipathTCP should be off by default") | 
|  | } | 
|  |  | 
|  | lc.SetMultipathTCP(true) | 
|  | if !lc.MultipathTCP() { | 
|  | t.Fatal("MultipathTCP is not on after having been forced to on") | 
|  | } | 
|  |  | 
|  | ln, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | return ln | 
|  | } | 
|  |  | 
|  | func postAcceptMPTCP(ls *localServer, ch chan<- error) { | 
|  | defer close(ch) | 
|  |  | 
|  | if len(ls.cl) == 0 { | 
|  | ch <- errors.New("no accepted stream") | 
|  | return | 
|  | } | 
|  |  | 
|  | c := ls.cl[0] | 
|  |  | 
|  | tcp, ok := c.(*TCPConn) | 
|  | if !ok { | 
|  | ch <- errors.New("struct is not a TCPConn") | 
|  | return | 
|  | } | 
|  |  | 
|  | mptcp, err := tcp.MultipathTCP() | 
|  | if err != nil { | 
|  | ch <- err | 
|  | return | 
|  | } | 
|  |  | 
|  | if !mptcp { | 
|  | ch <- errors.New("incoming connection is not with MPTCP") | 
|  | return | 
|  | } | 
|  |  | 
|  | // Also check the method for the older kernels if not tested before | 
|  | if hasSOLMPTCP && !isUsingMPTCPProto(tcp.fd) { | 
|  | ch <- errors.New("incoming connection is not an MPTCP proto") | 
|  | return | 
|  | } | 
|  | } | 
|  |  | 
|  | func dialerMPTCP(t *testing.T, addr string) { | 
|  | d := &Dialer{} | 
|  | if d.MultipathTCP() { | 
|  | t.Error("MultipathTCP should be off by default") | 
|  | } | 
|  |  | 
|  | d.SetMultipathTCP(true) | 
|  | if !d.MultipathTCP() { | 
|  | t.Fatal("MultipathTCP is not on after having been forced to on") | 
|  | } | 
|  |  | 
|  | c, err := d.Dial("tcp", addr) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | defer c.Close() | 
|  |  | 
|  | tcp, ok := c.(*TCPConn) | 
|  | if !ok { | 
|  | t.Fatal("struct is not a TCPConn") | 
|  | } | 
|  |  | 
|  | // Transfer a bit of data to make sure everything is still OK | 
|  | snt := []byte("MPTCP TEST") | 
|  | if _, err := c.Write(snt); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | b := make([]byte, len(snt)) | 
|  | if _, err := c.Read(b); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if !bytes.Equal(snt, b) { | 
|  | t.Errorf("sent bytes (%s) are different from received ones (%s)", snt, b) | 
|  | } | 
|  |  | 
|  | mptcp, err := tcp.MultipathTCP() | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  |  | 
|  | t.Logf("outgoing connection from %s with mptcp: %t", addr, mptcp) | 
|  |  | 
|  | if !mptcp { | 
|  | t.Error("outgoing connection is not with MPTCP") | 
|  | } | 
|  |  | 
|  | // Also check the method for the older kernels if not tested before | 
|  | if hasSOLMPTCP && !isUsingMPTCPProto(tcp.fd) { | 
|  | t.Error("outgoing connection is not an MPTCP proto") | 
|  | } | 
|  | } | 
|  |  | 
|  | func canCreateMPTCPSocket() bool { | 
|  | // We want to know if we can create an MPTCP socket, not just if it is | 
|  | // available (mptcpAvailable()): it could be blocked by the admin | 
|  | fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, _IPPROTO_MPTCP) | 
|  | if err != nil { | 
|  | return false | 
|  | } | 
|  |  | 
|  | syscall.Close(fd) | 
|  | return true | 
|  | } | 
|  |  | 
|  | func TestMultiPathTCP(t *testing.T) { | 
|  | if !canCreateMPTCPSocket() { | 
|  | t.Skip("Cannot create MPTCP sockets") | 
|  | } | 
|  |  | 
|  | ln := newLocalListenerMPTCP(t) | 
|  |  | 
|  | // similar to tcpsock_test:TestIPv6LinkLocalUnicastTCP | 
|  | ls := (&streamListener{Listener: ln}).newLocalServer() | 
|  | defer ls.teardown() | 
|  |  | 
|  | if g, w := ls.Listener.Addr().Network(), "tcp"; g != w { | 
|  | t.Fatalf("Network type mismatch: got %q, want %q", g, w) | 
|  | } | 
|  |  | 
|  | genericCh := make(chan error) | 
|  | mptcpCh := make(chan error) | 
|  | handler := func(ls *localServer, ln Listener) { | 
|  | ls.transponder(ln, genericCh) | 
|  | postAcceptMPTCP(ls, mptcpCh) | 
|  | } | 
|  | if err := ls.buildup(handler); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  |  | 
|  | dialerMPTCP(t, ln.Addr().String()) | 
|  |  | 
|  | if err := <-genericCh; err != nil { | 
|  | t.Error(err) | 
|  | } | 
|  | if err := <-mptcpCh; err != nil { | 
|  | t.Error(err) | 
|  | } | 
|  | } |