internal/mcp: provide a customizable http.Client to client transports Allow customizing the http.Client used for HTTP MCP client transports, by adding client options structs. Change-Id: I2297acb136f8d0f7fa70d58cd244a6a81cc89751 Reviewed-on: https://go-review.googlesource.com/c/tools/+/682756 Reviewed-by: Jonathan Amsterdam <jba@google.com> Auto-Submit: Robert Findley <rfindley@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/gopls/internal/cmd/mcp_test.go b/gopls/internal/cmd/mcp_test.go index fa7e344..9197f3a 100644 --- a/gopls/internal/cmd/mcp_test.go +++ b/gopls/internal/cmd/mcp_test.go
@@ -132,7 +132,7 @@ t.Logf("failed %d, trying again", i) time.Sleep(50 * time.Millisecond << i) // retry with exponential backoff } - serverConn, err := client.Connect(ctx, mcp.NewSSEClientTransport("http://"+addr)) + serverConn, err := client.Connect(ctx, mcp.NewSSEClientTransport("http://"+addr, nil)) if err != nil { // This shouldn't happen because we already waited for the http server to start listening. t.Fatalf("connecting to server: %v", err)
diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index fa530ca..e12c865 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go
@@ -1014,7 +1014,7 @@ var mcpSession *mcp.ClientSession if enableMCP { client := mcp.NewClient("test", "v1.0.0", nil) - mcpSession, err = client.Connect(ctx, mcp.NewSSEClientTransport(mcpServer.URL)) + mcpSession, err = client.Connect(ctx, mcp.NewSSEClientTransport(mcpServer.URL, nil)) if err != nil { t.Fatalf("fail to connect to mcp server: %v", err) }
diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index b08c60c..f2b07e6 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md
@@ -187,9 +187,17 @@ ```go type SSEClientTransport struct { /* ... */ } +// SSEClientTransportOptions provides options for the [NewSSEClientTransport] +// constructor. +type SSEClientTransportOptions struct { + // HTTPClient is the client to use for making HTTP requests. If nil, + // http.DefaultClient is used. + HTTPClient *http.Client +} + // NewSSEClientTransport returns a new client transport that connects to the // SSE server at the provided URL. -func NewSSEClientTransport(url string) (*SSEClientTransport, error) { +func NewSSEClientTransport(url string, opts *SSEClientTransportOptions) (*SSEClientTransport, error) // Connect connects through the client endpoint. func (*SSEClientTransport) Connect(ctx context.Context) (Connection, error) @@ -217,7 +225,16 @@ // The streamable client handles reconnection transparently to the user. type StreamableClientTransport struct { /* ... */ } -func NewStreamableClientTransport(url string) *StreamableClientTransport { + +// StreamableClientTransportOptions provides options for the +// [NewStreamableClientTransport] constructor. +type StreamableClientTransportOptions struct { + // HTTPClient is the client to use for making HTTP requests. If nil, + // http.DefaultClient is used. + HTTPClient *http.Client +} + +func NewStreamableClientTransport(url string, opts *StreamableClientTransportOptions) *StreamableClientTransport func (*StreamableClientTransport) Connect(context.Context) (Connection, error) ```
diff --git a/internal/mcp/sse.go b/internal/mcp/sse.go index 87462b5..d2c8a0a 100644 --- a/internal/mcp/sse.go +++ b/internal/mcp/sse.go
@@ -322,20 +322,33 @@ // https://modelcontextprotocol.io/specification/2024-11-05/basic/transports type SSEClientTransport struct { sseEndpoint *url.URL + opts SSEClientTransportOptions +} + +// SSEClientTransportOptions provides options for the [NewSSEClientTransport] +// constructor. +type SSEClientTransportOptions struct { + // HTTPClient is the client to use for making HTTP requests. If nil, + // http.DefaultClient is used. + HTTPClient *http.Client } // NewSSEClientTransport returns a new client transport that connects to the // SSE server at the provided URL. // // NewSSEClientTransport panics if the given URL is invalid. -func NewSSEClientTransport(baseURL string) *SSEClientTransport { +func NewSSEClientTransport(baseURL string, opts *SSEClientTransportOptions) *SSEClientTransport { url, err := url.Parse(baseURL) if err != nil { panic(fmt.Sprintf("invalid base url: %v", err)) } - return &SSEClientTransport{ + t := &SSEClientTransport{ sseEndpoint: url, } + if opts != nil { + t.opts = *opts + } + return t } // Connect connects through the client endpoint. @@ -344,8 +357,12 @@ if err != nil { return nil, err } + httpClient := c.opts.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } req.Header.Set("Accept", "text/event-stream") - resp, err := http.DefaultClient.Do(req) + resp, err := httpClient.Do(req) if err != nil { return nil, err }
diff --git a/internal/mcp/sse_example_test.go b/internal/mcp/sse_example_test.go index 0ad37c4..947b86d 100644 --- a/internal/mcp/sse_example_test.go +++ b/internal/mcp/sse_example_test.go
@@ -33,7 +33,7 @@ defer httpServer.Close() ctx := context.Background() - transport := mcp.NewSSEClientTransport(httpServer.URL) + transport := mcp.NewSSEClientTransport(httpServer.URL, nil) client := mcp.NewClient("test", "v1.0.0", nil) cs, err := client.Connect(ctx, transport) if err != nil {
diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index 48e0d5c..7f67c88 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go
@@ -34,7 +34,7 @@ httpServer := httptest.NewServer(sseHandler) defer httpServer.Close() - clientTransport := NewSSEClientTransport(httpServer.URL) + clientTransport := NewSSEClientTransport(httpServer.URL, nil) c := NewClient("testClient", "v1.0.0", nil) cs, err := c.Connect(ctx, clientTransport)
diff --git a/internal/mcp/streamable.go b/internal/mcp/streamable.go index c56102e..2fbce76 100644 --- a/internal/mcp/streamable.go +++ b/internal/mcp/streamable.go
@@ -571,13 +571,26 @@ // // TODO(rfindley): support retries and resumption tokens. type StreamableClientTransport struct { - url string + url string + opts StreamableClientTransportOptions +} + +// StreamableClientTransportOptions provides options for the +// [NewStreamableClientTransport] constructor. +type StreamableClientTransportOptions struct { + // HTTPClient is the client to use for making HTTP requests. If nil, + // http.DefaultClient is used. + HTTPClient *http.Client } // NewStreamableClientTransport returns a new client transport that connects to // the streamable HTTP server at the provided URL. -func NewStreamableClientTransport(url string) *StreamableClientTransport { - return &StreamableClientTransport{url: url} +func NewStreamableClientTransport(url string, opts *StreamableClientTransportOptions) *StreamableClientTransport { + t := &StreamableClientTransport{url: url} + if opts != nil { + t.opts = *opts + } + return t } // Connect implements the [Transport] interface. @@ -589,9 +602,13 @@ // When closed, the connection issues a DELETE request to terminate the logical // session. func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, error) { + client := t.opts.HTTPClient + if client == nil { + client = http.DefaultClient + } return &streamableClientConn{ url: t.url, - client: http.DefaultClient, + client: client, incoming: make(chan []byte, 100), done: make(chan struct{}), }, nil
diff --git a/internal/mcp/streamable_test.go b/internal/mcp/streamable_test.go index ec9e40f..276afbf 100644 --- a/internal/mcp/streamable_test.go +++ b/internal/mcp/streamable_test.go
@@ -11,7 +11,9 @@ "fmt" "io" "net/http" + "net/http/cookiejar" "net/http/httptest" + "net/url" "strings" "sync" "sync/atomic" @@ -32,13 +34,35 @@ server := NewServer("testServer", "v1.0.0", nil) server.AddTools(NewServerTool("greet", "say hi", sayHi)) - // 2. Start an httptest.Server with the StreamableHTTPHandler. + // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a + // cookie-checking middleware. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) - httpServer := httptest.NewServer(handler) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("test-cookie") + if err != nil { + t.Errorf("missing cookie: %v", err) + } else if cookie.Value != "test-value" { + t.Errorf("got cookie %q, want %q", cookie.Value, "test-value") + } + handler.ServeHTTP(w, r) + })) defer httpServer.Close() // 3. Create a client and connect it to the server using our StreamableClientTransport. - transport := NewStreamableClientTransport(httpServer.URL) + // Check that all requests honor a custom client. + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatal(err) + } + u, err := url.Parse(httpServer.URL) + if err != nil { + t.Fatal(err) + } + jar.SetCookies(u, []*http.Cookie{{Name: "test-cookie", Value: "test-value"}}) + httpClient := &http.Client{Jar: jar} + transport := NewStreamableClientTransport(httpServer.URL, &StreamableClientTransportOptions{ + HTTPClient: httpClient, + }) client := NewClient("testClient", "v1.0.0", nil) session, err := client.Connect(ctx, transport) if err != nil {