blob: 7f67c88727256a0e09de2033a90adba19cc42e0f [file] [log] [blame]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mcp
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestSSEServer(t *testing.T) {
for _, closeServerFirst := range []bool{false, true} {
t.Run(fmt.Sprintf("closeServerFirst=%t", closeServerFirst), func(t *testing.T) {
ctx := context.Background()
server := NewServer("testServer", "v1.0.0", nil)
server.AddTools(NewServerTool("greet", "say hi", sayHi))
sseHandler := NewSSEHandler(func(*http.Request) *Server { return server })
conns := make(chan *ServerSession, 1)
sseHandler.onConnection = func(cc *ServerSession) {
select {
case conns <- cc:
default:
}
}
httpServer := httptest.NewServer(sseHandler)
defer httpServer.Close()
clientTransport := NewSSEClientTransport(httpServer.URL, nil)
c := NewClient("testClient", "v1.0.0", nil)
cs, err := c.Connect(ctx, clientTransport)
if err != nil {
t.Fatal(err)
}
if err := cs.Ping(ctx, nil); err != nil {
t.Fatal(err)
}
ss := <-conns
gotHi, err := cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"Name": "user"},
})
if err != nil {
t.Fatal(err)
}
wantHi := &CallToolResult{
Content: []*Content{{Type: "text", Text: "hi user"}},
}
if diff := cmp.Diff(wantHi, gotHi); diff != "" {
t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff)
}
// Test that closing either end of the connection terminates the other
// end.
if closeServerFirst {
cs.Close()
ss.Wait()
} else {
ss.Close()
cs.Wait()
}
})
}
}
func TestScanEvents(t *testing.T) {
tests := []struct {
name string
input string
want []event
wantErr string
}{
{
name: "simple event",
input: "event: message\nid: 1\ndata: hello\n\n",
want: []event{
{name: "message", id: "1", data: []byte("hello")},
},
},
{
name: "multiple data lines",
input: "data: line 1\ndata: line 2\n\n",
want: []event{
{data: []byte("line 1\nline 2")},
},
},
{
name: "multiple events",
input: "data: first\n\nevent: second\ndata: second\n\n",
want: []event{
{data: []byte("first")},
{name: "second", data: []byte("second")},
},
},
{
name: "no trailing newline",
input: "data: hello",
want: []event{
{data: []byte("hello")},
},
},
{
name: "malformed line",
input: "invalid line\n\n",
wantErr: "malformed line",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := strings.NewReader(tt.input)
var got []event
var err error
for e, err2 := range scanEvents(r) {
if err2 != nil {
err = err2
break
}
got = append(got, e)
}
if tt.wantErr != "" {
if err == nil {
t.Fatalf("scanEvents() got nil error, want error containing %q", tt.wantErr)
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("scanEvents() error = %q, want containing %q", err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("scanEvents() returned unexpected error: %v", err)
}
if len(got) != len(tt.want) {
t.Fatalf("scanEvents() got %d events, want %d", len(got), len(tt.want))
}
for i := range got {
if g, w := got[i].name, tt.want[i].name; g != w {
t.Errorf("event %d: name = %q, want %q", i, g, w)
}
if g, w := got[i].id, tt.want[i].id; g != w {
t.Errorf("event %d: id = %q, want %q", i, g, w)
}
if g, w := string(got[i].data), string(tt.want[i].data); g != w {
t.Errorf("event %d: data = %q, want %q", i, g, w)
}
}
})
}
}