internal/mcp: provide a customizable http.Client to client transports

Allow customizing the http.Client used for HTTP MCP client transports,
by adding client options structs.

Change-Id: I2297acb136f8d0f7fa70d58cd244a6a81cc89751
Reviewed-on: https://go-review.googlesource.com/c/tools/+/682756
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Auto-Submit: Robert Findley <rfindley@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/gopls/internal/cmd/mcp_test.go b/gopls/internal/cmd/mcp_test.go
index fa7e344..9197f3a 100644
--- a/gopls/internal/cmd/mcp_test.go
+++ b/gopls/internal/cmd/mcp_test.go
@@ -132,7 +132,7 @@
 		t.Logf("failed %d, trying again", i)
 		time.Sleep(50 * time.Millisecond << i) // retry with exponential backoff
 	}
-	serverConn, err := client.Connect(ctx, mcp.NewSSEClientTransport("http://"+addr))
+	serverConn, err := client.Connect(ctx, mcp.NewSSEClientTransport("http://"+addr, nil))
 	if err != nil {
 		// This shouldn't happen because we already waited for the http server to start listening.
 		t.Fatalf("connecting to server: %v", err)
diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go
index fa530ca..e12c865 100644
--- a/gopls/internal/test/marker/marker_test.go
+++ b/gopls/internal/test/marker/marker_test.go
@@ -1014,7 +1014,7 @@
 	var mcpSession *mcp.ClientSession
 	if enableMCP {
 		client := mcp.NewClient("test", "v1.0.0", nil)
-		mcpSession, err = client.Connect(ctx, mcp.NewSSEClientTransport(mcpServer.URL))
+		mcpSession, err = client.Connect(ctx, mcp.NewSSEClientTransport(mcpServer.URL, nil))
 		if err != nil {
 			t.Fatalf("fail to connect to mcp server: %v", err)
 		}
diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md
index b08c60c..f2b07e6 100644
--- a/internal/mcp/design/design.md
+++ b/internal/mcp/design/design.md
@@ -187,9 +187,17 @@
 ```go
 type SSEClientTransport struct { /* ... */ }
 
+// SSEClientTransportOptions provides options for the [NewSSEClientTransport]
+// constructor.
+type SSEClientTransportOptions struct {
+	// HTTPClient is the client to use for making HTTP requests. If nil,
+	// http.DefaultClient is used.
+	HTTPClient *http.Client
+}
+
 // NewSSEClientTransport returns a new client transport that connects to the
 // SSE server at the provided URL.
-func NewSSEClientTransport(url string) (*SSEClientTransport, error) {
+func NewSSEClientTransport(url string, opts *SSEClientTransportOptions) (*SSEClientTransport, error)
 
 // Connect connects through the client endpoint.
 func (*SSEClientTransport) Connect(ctx context.Context) (Connection, error)
@@ -217,7 +225,16 @@
 
 // The streamable client handles reconnection transparently to the user.
 type StreamableClientTransport struct { /* ... */ }
-func NewStreamableClientTransport(url string) *StreamableClientTransport {
+
+// StreamableClientTransportOptions provides options for the
+// [NewStreamableClientTransport] constructor.
+type StreamableClientTransportOptions struct {
+	// HTTPClient is the client to use for making HTTP requests. If nil,
+	// http.DefaultClient is used.
+	HTTPClient *http.Client
+}
+
+func NewStreamableClientTransport(url string, opts *StreamableClientTransportOptions) *StreamableClientTransport
 func (*StreamableClientTransport) Connect(context.Context) (Connection, error)
 ```
 
diff --git a/internal/mcp/sse.go b/internal/mcp/sse.go
index 87462b5..d2c8a0a 100644
--- a/internal/mcp/sse.go
+++ b/internal/mcp/sse.go
@@ -322,20 +322,33 @@
 // https://modelcontextprotocol.io/specification/2024-11-05/basic/transports
 type SSEClientTransport struct {
 	sseEndpoint *url.URL
+	opts        SSEClientTransportOptions
+}
+
+// SSEClientTransportOptions provides options for the [NewSSEClientTransport]
+// constructor.
+type SSEClientTransportOptions struct {
+	// HTTPClient is the client to use for making HTTP requests. If nil,
+	// http.DefaultClient is used.
+	HTTPClient *http.Client
 }
 
 // NewSSEClientTransport returns a new client transport that connects to the
 // SSE server at the provided URL.
 //
 // NewSSEClientTransport panics if the given URL is invalid.
-func NewSSEClientTransport(baseURL string) *SSEClientTransport {
+func NewSSEClientTransport(baseURL string, opts *SSEClientTransportOptions) *SSEClientTransport {
 	url, err := url.Parse(baseURL)
 	if err != nil {
 		panic(fmt.Sprintf("invalid base url: %v", err))
 	}
-	return &SSEClientTransport{
+	t := &SSEClientTransport{
 		sseEndpoint: url,
 	}
+	if opts != nil {
+		t.opts = *opts
+	}
+	return t
 }
 
 // Connect connects through the client endpoint.
@@ -344,8 +357,12 @@
 	if err != nil {
 		return nil, err
 	}
+	httpClient := c.opts.HTTPClient
+	if httpClient == nil {
+		httpClient = http.DefaultClient
+	}
 	req.Header.Set("Accept", "text/event-stream")
-	resp, err := http.DefaultClient.Do(req)
+	resp, err := httpClient.Do(req)
 	if err != nil {
 		return nil, err
 	}
diff --git a/internal/mcp/sse_example_test.go b/internal/mcp/sse_example_test.go
index 0ad37c4..947b86d 100644
--- a/internal/mcp/sse_example_test.go
+++ b/internal/mcp/sse_example_test.go
@@ -33,7 +33,7 @@
 	defer httpServer.Close()
 
 	ctx := context.Background()
-	transport := mcp.NewSSEClientTransport(httpServer.URL)
+	transport := mcp.NewSSEClientTransport(httpServer.URL, nil)
 	client := mcp.NewClient("test", "v1.0.0", nil)
 	cs, err := client.Connect(ctx, transport)
 	if err != nil {
diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go
index 48e0d5c..7f67c88 100644
--- a/internal/mcp/sse_test.go
+++ b/internal/mcp/sse_test.go
@@ -34,7 +34,7 @@
 			httpServer := httptest.NewServer(sseHandler)
 			defer httpServer.Close()
 
-			clientTransport := NewSSEClientTransport(httpServer.URL)
+			clientTransport := NewSSEClientTransport(httpServer.URL, nil)
 
 			c := NewClient("testClient", "v1.0.0", nil)
 			cs, err := c.Connect(ctx, clientTransport)
diff --git a/internal/mcp/streamable.go b/internal/mcp/streamable.go
index c56102e..2fbce76 100644
--- a/internal/mcp/streamable.go
+++ b/internal/mcp/streamable.go
@@ -571,13 +571,26 @@
 //
 // TODO(rfindley): support retries and resumption tokens.
 type StreamableClientTransport struct {
-	url string
+	url  string
+	opts StreamableClientTransportOptions
+}
+
+// StreamableClientTransportOptions provides options for the
+// [NewStreamableClientTransport] constructor.
+type StreamableClientTransportOptions struct {
+	// HTTPClient is the client to use for making HTTP requests. If nil,
+	// http.DefaultClient is used.
+	HTTPClient *http.Client
 }
 
 // NewStreamableClientTransport returns a new client transport that connects to
 // the streamable HTTP server at the provided URL.
-func NewStreamableClientTransport(url string) *StreamableClientTransport {
-	return &StreamableClientTransport{url: url}
+func NewStreamableClientTransport(url string, opts *StreamableClientTransportOptions) *StreamableClientTransport {
+	t := &StreamableClientTransport{url: url}
+	if opts != nil {
+		t.opts = *opts
+	}
+	return t
 }
 
 // Connect implements the [Transport] interface.
@@ -589,9 +602,13 @@
 // When closed, the connection issues a DELETE request to terminate the logical
 // session.
 func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, error) {
+	client := t.opts.HTTPClient
+	if client == nil {
+		client = http.DefaultClient
+	}
 	return &streamableClientConn{
 		url:      t.url,
-		client:   http.DefaultClient,
+		client:   client,
 		incoming: make(chan []byte, 100),
 		done:     make(chan struct{}),
 	}, nil
diff --git a/internal/mcp/streamable_test.go b/internal/mcp/streamable_test.go
index ec9e40f..276afbf 100644
--- a/internal/mcp/streamable_test.go
+++ b/internal/mcp/streamable_test.go
@@ -11,7 +11,9 @@
 	"fmt"
 	"io"
 	"net/http"
+	"net/http/cookiejar"
 	"net/http/httptest"
+	"net/url"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -32,13 +34,35 @@
 	server := NewServer("testServer", "v1.0.0", nil)
 	server.AddTools(NewServerTool("greet", "say hi", sayHi))
 
-	// 2. Start an httptest.Server with the StreamableHTTPHandler.
+	// 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
+	// cookie-checking middleware.
 	handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
-	httpServer := httptest.NewServer(handler)
+	httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		cookie, err := r.Cookie("test-cookie")
+		if err != nil {
+			t.Errorf("missing cookie: %v", err)
+		} else if cookie.Value != "test-value" {
+			t.Errorf("got cookie %q, want %q", cookie.Value, "test-value")
+		}
+		handler.ServeHTTP(w, r)
+	}))
 	defer httpServer.Close()
 
 	// 3. Create a client and connect it to the server using our StreamableClientTransport.
-	transport := NewStreamableClientTransport(httpServer.URL)
+	// Check that all requests honor a custom client.
+	jar, err := cookiejar.New(nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	u, err := url.Parse(httpServer.URL)
+	if err != nil {
+		t.Fatal(err)
+	}
+	jar.SetCookies(u, []*http.Cookie{{Name: "test-cookie", Value: "test-value"}})
+	httpClient := &http.Client{Jar: jar}
+	transport := NewStreamableClientTransport(httpServer.URL, &StreamableClientTransportOptions{
+		HTTPClient: httpClient,
+	})
 	client := NewClient("testClient", "v1.0.0", nil)
 	session, err := client.Connect(ctx, transport)
 	if err != nil {