blob: e6fbcd2949ab2e71e05c9e7c875cb67b79d42450 [file] [log] [blame]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mcp
import (
"context"
"errors"
"fmt"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/tools/internal/mcp/internal/protocol"
"golang.org/x/tools/internal/mcp/jsonschema"
)
type hiParams struct {
Name string
}
func sayHi(ctx context.Context, cc *ClientConnection, v hiParams) ([]Content, error) {
if err := cc.Ping(ctx); err != nil {
return nil, fmt.Errorf("ping failed: %v", err)
}
return []Content{TextContent{Text: "hi " + v.Name}}, nil
}
func TestEndToEnd(t *testing.T) {
ctx := context.Background()
ct, st := NewLocalTransport()
s := NewServer("testServer", "v1.0.0", nil)
// The 'greet' tool says hi.
s.AddTools(MakeTool("greet", "say hi", sayHi))
// The 'fail' tool returns this error.
failure := errors.New("mcp failure")
s.AddTools(
MakeTool("fail", "just fail", func(context.Context, *ClientConnection, struct{}) ([]Content, error) {
return nil, failure
}),
)
s.AddPrompts(
MakePrompt("code_review", "do a code review", func(_ context.Context, _ *ClientConnection, params struct{ Code string }) (*protocol.GetPromptResult, error) {
return &protocol.GetPromptResult{
Description: "Code review prompt",
Messages: []protocol.PromptMessage{
{Role: "user", Content: TextContent{Text: "Please review the following code: " + params.Code}.ToWire()},
},
}, nil
}),
MakePrompt("fail", "", func(_ context.Context, _ *ClientConnection, params struct{}) (*protocol.GetPromptResult, error) {
return nil, failure
}),
)
// Connect the server.
cc, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
if got := slices.Collect(s.Clients()); len(got) != 1 {
t.Errorf("after connection, Clients() has length %d, want 1", len(got))
}
// Wait for the server to exit after the client closes its connection.
var clientWG sync.WaitGroup
clientWG.Add(1)
go func() {
if err := cc.Wait(); err != nil {
t.Errorf("server failed: %v", err)
}
clientWG.Done()
}()
c := NewClient("testClient", "v1.0.0", nil)
// Connect the client.
if err := c.Connect(ctx, ct, nil); err != nil {
t.Fatal(err)
}
if err := c.Ping(ctx); err != nil {
t.Fatalf("ping failed: %v", err)
}
gotPrompts, err := c.ListPrompts(ctx)
if err != nil {
t.Errorf("prompts/list failed: %v", err)
}
wantPrompts := []protocol.Prompt{
{
Name: "code_review",
Description: "do a code review",
Arguments: []protocol.PromptArgument{{Name: "Code", Required: true}},
},
{Name: "fail"},
}
if diff := cmp.Diff(wantPrompts, gotPrompts); diff != "" {
t.Fatalf("prompts/list mismatch (-want +got):\n%s", diff)
}
gotReview, err := c.GetPrompt(ctx, "code_review", map[string]string{"Code": "1+1"})
if err != nil {
t.Fatal(err)
}
wantReview := &protocol.GetPromptResult{
Description: "Code review prompt",
Messages: []protocol.PromptMessage{{
Content: TextContent{Text: "Please review the following code: 1+1"}.ToWire(),
Role: "user",
}},
}
if diff := cmp.Diff(wantReview, gotReview); diff != "" {
t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff)
}
if _, err := c.GetPrompt(ctx, "fail", map[string]string{}); err == nil || !strings.Contains(err.Error(), failure.Error()) {
t.Errorf("fail returned unexpected error: got %v, want containing %v", err, failure)
}
gotTools, err := c.ListTools(ctx)
if err != nil {
t.Errorf("tools/list failed: %v", err)
}
wantTools := []protocol.Tool{{
Name: "greet",
Description: "say hi",
InputSchema: &jsonschema.Schema{
Type: "object",
Required: []string{"Name"},
Properties: map[string]*jsonschema.Schema{
"Name": {Type: "string"},
},
AdditionalProperties: falseSchema,
},
}, {
Name: "fail",
Description: "just fail",
InputSchema: &jsonschema.Schema{
Type: "object",
AdditionalProperties: falseSchema,
},
}}
if diff := cmp.Diff(wantTools, gotTools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
t.Fatalf("tools/list mismatch (-want +got):\n%s", diff)
}
gotHi, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"})
if err != nil {
t.Fatal(err)
}
wantHi := &protocol.CallToolResult{
Content: []protocol.Content{{Type: "text", Text: "hi user"}},
}
if diff := cmp.Diff(wantHi, gotHi); diff != "" {
t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff)
}
gotFail, err := c.CallTool(ctx, "fail", map[string]any{})
// Counter-intuitively, when a tool fails, we don't expect an RPC error for
// call tool: instead, the failure is embedded in the result.
if err != nil {
t.Fatal(err)
}
wantFail := &protocol.CallToolResult{
IsError: true,
Content: []protocol.Content{{Type: "text", Text: failure.Error()}},
}
if diff := cmp.Diff(wantFail, gotFail); diff != "" {
t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff)
}
// Disconnect.
c.Close()
clientWG.Wait()
// After disconnecting, neither client nor server should have any
// connections.
for range s.Clients() {
t.Errorf("unexpected client after disconnection")
}
}
// basicConnection returns a new basic client-server connection configured with
// the provided tools.
//
// The caller should cancel either the client connection or server connection
// when the connections are no longer needed.
func basicConnection(t *testing.T, tools ...*Tool) (*ClientConnection, *Client) {
t.Helper()
ctx := context.Background()
ct, st := NewLocalTransport()
s := NewServer("testServer", "v1.0.0", nil)
// The 'greet' tool says hi.
s.AddTools(tools...)
cc, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
c := NewClient("testClient", "v1.0.0", nil)
if err := c.Connect(ctx, ct, nil); err != nil {
t.Fatal(err)
}
return cc, c
}
func TestServerClosing(t *testing.T) {
cc, c := basicConnection(t, MakeTool("greet", "say hi", sayHi))
defer c.Close()
ctx := context.Background()
var wg sync.WaitGroup
wg.Add(1)
go func() {
if err := c.Wait(); err != nil {
t.Errorf("server connection failed: %v", err)
}
wg.Done()
}()
if _, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}); err != nil {
t.Fatalf("after connecting: %v", err)
}
cc.Close()
wg.Wait()
if _, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}); !errors.Is(err, ErrConnectionClosed) {
t.Errorf("after disconnection, got error %v, want EOF", err)
}
}
func TestBatching(t *testing.T) {
ctx := context.Background()
ct, st := NewLocalTransport()
s := NewServer("testServer", "v1.0.0", nil)
_, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
c := NewClient("testClient", "v1.0.0", nil)
opts := new(ConnectionOptions)
// TODO: this test is broken, because increasing the batch size here causes
// 'initialize' to block. Therefore, we can only test with a size of 1.
const batchSize = 1
BatchSize(ct, batchSize)
if err := c.Connect(ctx, ct, opts); err != nil {
t.Fatal(err)
}
defer c.Close()
errs := make(chan error, batchSize)
for i := range batchSize {
go func() {
_, err := c.ListTools(ctx)
errs <- err
}()
time.Sleep(2 * time.Millisecond)
if i < batchSize-1 {
select {
case <-errs:
t.Errorf("ListTools: unexpected result for incomplete batch: %v", err)
default:
}
}
}
}
func TestCancellation(t *testing.T) {
var (
start = make(chan struct{})
cancelled = make(chan struct{}, 1) // don't block the request
)
slowRequest := func(ctx context.Context, cc *ClientConnection, v struct{}) ([]Content, error) {
start <- struct{}{}
select {
case <-ctx.Done():
cancelled <- struct{}{}
case <-time.After(5 * time.Second):
return nil, nil
}
return nil, nil
}
_, sc := basicConnection(t, MakeTool("slow", "a slow request", slowRequest))
defer sc.Close()
ctx, cancel := context.WithCancel(context.Background())
go sc.CallTool(ctx, "slow", map[string]any{})
<-start
cancel()
select {
case <-cancelled:
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for cancellation")
}
}
var falseSchema = &jsonschema.Schema{Not: &jsonschema.Schema{}}