blob: a39d74d5b1662244ed2c3fb924830c58aaa24869 [file] [log] [blame]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mcp
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"sync"
jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2"
"golang.org/x/tools/internal/xcontext"
)
// ErrConnectionClosed is returned when sending a message to a connection that
// is closed or in the process of closing.
var ErrConnectionClosed = errors.New("connection closed")
// A Transport is used to create a bidirectional connection between MCP client
// and server.
//
// Transports should be used for at most one call to [Server.Connect] or
// [Client.Start].
type Transport interface {
// Connect returns the logical stream.
//
// It is called exactly once by [Connect].
Connect(ctx context.Context) (Stream, error)
}
// A Stream is a bidirectional jsonrpc2 Stream.
type Stream interface {
jsonrpc2.Reader
jsonrpc2.Writer
io.Closer
}
// ConnectionOptions configures the behavior of an individual client<->server
// connection.
type ConnectionOptions struct {
SessionID string // if set, the session ID
Logger io.Writer // if set, write RPC logs
}
// A StdIOTransport is a [Transport] that communicates over stdin/stdout using
// newline-delimited JSON.
type StdIOTransport struct {
ioTransport
}
// An ioTransport is a [Transport] that communicates using newline-delimited
// JSON over an io.ReadWriteCloser.
type ioTransport struct {
rwc io.ReadWriteCloser
}
func (t *ioTransport) Connect(context.Context) (Stream, error) {
return newIOStream(t.rwc), nil
}
// NewStdIOTransport constructs a transport that communicates over
// stdin/stdout.
func NewStdIOTransport() *StdIOTransport {
return &StdIOTransport{ioTransport{rwc{os.Stdin, os.Stdout}}}
}
// An InMemoryTransport is a [Transport] that communicates over an in-memory
// network connection, using newline-delimited JSON.
type InMemoryTransport struct {
ioTransport
}
// NewInMemoryTransport returns two InMemoryTransports that connect to each
// other.
func NewInMemoryTransport() (*InMemoryTransport, *InMemoryTransport) {
c1, c2 := net.Pipe()
return &InMemoryTransport{ioTransport{c1}}, &InMemoryTransport{ioTransport{c2}}
}
// handler is an unexported version of jsonrpc2.Handler.
type handler interface {
handle(ctx context.Context, req *jsonrpc2.Request) (result any, err error)
}
type binder[T handler] interface {
bind(*jsonrpc2.Connection) T
disconnect(T)
}
func connect[H handler](ctx context.Context, t Transport, opts *ConnectionOptions, b binder[H]) (H, error) {
if opts == nil {
opts = new(ConnectionOptions)
}
var zero H
stream, err := t.Connect(ctx)
if err != nil {
return zero, err
}
// If logging is configured, write message logs.
reader, writer := jsonrpc2.Reader(stream), jsonrpc2.Writer(stream)
if opts.Logger != nil {
reader = loggingReader(opts.Logger, reader)
writer = loggingWriter(opts.Logger, writer)
}
var (
h H
preempter canceller
)
bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler {
h = b.bind(conn)
preempter.conn = conn
return jsonrpc2.HandlerFunc(h.handle)
}
_ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{
Reader: reader,
Writer: writer,
Closer: stream,
Bind: bind,
Preempter: &preempter,
OnDone: func() {
b.disconnect(h)
},
})
assert(preempter.conn != nil, "unbound preempter")
return h, nil
}
// A canceller is a jsonrpc2.Preempter that cancels in-flight requests on MCP
// cancelled notifications.
type canceller struct {
conn *jsonrpc2.Connection
}
// Preempt implements jsonrpc2.Preempter.
func (c *canceller) Preempt(ctx context.Context, req *jsonrpc2.Request) (result any, err error) {
if req.Method == "notifications/cancelled" {
var params CancelledParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, err
}
id, err := jsonrpc2.MakeID(params.RequestID)
if err != nil {
return nil, err
}
go c.conn.Cancel(id)
}
return nil, jsonrpc2.ErrNotHandled
}
// call executes and awaits a jsonrpc2 call on the given connection,
// translating errors into the mcp domain.
func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error {
// TODO: the "%w"s in this function effectively make jsonrpc2.WireError part of the API.
// Consider alternatives.
call := conn.Call(ctx, method, params)
err := call.Await(ctx, result)
switch {
case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing):
return fmt.Errorf("calling %q: %w", method, ErrConnectionClosed)
case ctx.Err() != nil:
// Notify the peer of cancellation.
err := conn.Notify(xcontext.Detach(ctx), "notifications/cancelled", &CancelledParams{
Reason: ctx.Err().Error(),
RequestID: call.ID().Raw(),
})
return errors.Join(ctx.Err(), err)
case err != nil:
return fmt.Errorf("calling %q: %w", method, err)
}
return nil
}
// The helpers below are used to bind transports to jsonrpc2.
// A readerFunc implements jsonrpc2.Reader.Read.
type readerFunc func(context.Context) (jsonrpc2.Message, int64, error)
func (f readerFunc) Read(ctx context.Context) (jsonrpc2.Message, int64, error) {
return f(ctx)
}
// A writerFunc implements jsonrpc2.Writer.Write.
type writerFunc func(context.Context, jsonrpc2.Message) (int64, error)
func (f writerFunc) Write(ctx context.Context, msg jsonrpc2.Message) (int64, error) {
return f(ctx, msg)
}
// loggingReader is a stream middleware that logs incoming messages.
func loggingReader(w io.Writer, delegate jsonrpc2.Reader) jsonrpc2.Reader {
return readerFunc(func(ctx context.Context) (jsonrpc2.Message, int64, error) {
msg, n, err := delegate.Read(ctx)
if err != nil {
fmt.Fprintf(w, "read error: %v", err)
} else {
data, err := jsonrpc2.EncodeMessage(msg)
if err != nil {
fmt.Fprintf(w, "LoggingFramer: failed to marshal: %v", err)
}
fmt.Fprintf(w, "read: %s", string(data))
}
return msg, n, err
})
}
// loggingWriter is a stream middleware that logs outgoing messages.
func loggingWriter(w io.Writer, delegate jsonrpc2.Writer) jsonrpc2.Writer {
return writerFunc(func(ctx context.Context, msg jsonrpc2.Message) (int64, error) {
n, err := delegate.Write(ctx, msg)
if err != nil {
fmt.Fprintf(w, "write error: %v", err)
} else {
data, err := jsonrpc2.EncodeMessage(msg)
if err != nil {
fmt.Fprintf(w, "LoggingFramer: failed to marshal: %v", err)
}
fmt.Fprintf(w, "write: %s", string(data))
}
return n, err
})
}
// A rwc binds an io.ReadCloser and io.WriteCloser together to create an
// io.ReadWriteCloser.
type rwc struct {
rc io.ReadCloser
wc io.WriteCloser
}
func (r rwc) Read(p []byte) (n int, err error) {
return r.rc.Read(p)
}
func (r rwc) Write(p []byte) (n int, err error) {
return r.wc.Write(p)
}
func (r rwc) Close() error {
return errors.Join(r.rc.Close(), r.wc.Close())
}
// An ioStream is a transport that delimits messages with newlines across
// a bidirectional stream, and supports JSONRPC2 message batching.
//
// See https://github.com/ndjson/ndjson-spec for discussion of newline
// delimited JSON.
//
// See [msgBatch] for more discussion of message batching.
type ioStream struct {
rwc io.ReadWriteCloser // the underlying stream
in *json.Decoder // a decoder bound to rwc
// If outgoiBatch has a positive capacity, it will be used to batch requests
// and notifications before sending.
outgoingBatch []jsonrpc2.Message
// Unread messages in the last batch. Since reads are serialized, there is no
// need to guard here.
queue []jsonrpc2.Message
// batches correlate incoming requests to the batch in which they arrived.
// Since writes may be concurrent to reads, we need to guard this with a mutex.
batchMu sync.Mutex
batches map[jsonrpc2.ID]*msgBatch // lazily allocated
}
func newIOStream(rwc io.ReadWriteCloser) *ioStream {
return &ioStream{
rwc: rwc,
in: json.NewDecoder(rwc),
}
}
// Connect returns the receiver, as a streamTransport is a logical stream.
func (t *ioStream) Connect(ctx context.Context) (Stream, error) {
return t, nil
}
// addBatch records a msgBatch for an incoming batch payload.
// It returns an error if batch is malformed, containing previously seen IDs.
//
// See [msgBatch] for more.
func (t *ioStream) addBatch(batch *msgBatch) error {
t.batchMu.Lock()
defer t.batchMu.Unlock()
for id := range batch.unresolved {
if _, ok := t.batches[id]; ok {
return fmt.Errorf("%w: batch contains previously seen request %v", jsonrpc2.ErrInvalidRequest, id.Raw())
}
}
for id := range batch.unresolved {
if t.batches == nil {
t.batches = make(map[jsonrpc2.ID]*msgBatch)
}
t.batches[id] = batch
}
return nil
}
// updateBatch records a response in the message batch tracking the
// corresponding incoming call, if any.
//
// The second result reports whether resp was part of a batch. If this is true,
// the first result is nil if the batch is still incomplete, or the full set of
// batch responses if resp completed the batch.
func (t *ioStream) updateBatch(resp *jsonrpc2.Response) ([]*jsonrpc2.Response, bool) {
t.batchMu.Lock()
defer t.batchMu.Unlock()
if batch, ok := t.batches[resp.ID]; ok {
idx, ok := batch.unresolved[resp.ID]
if !ok {
panic("internal error: inconsistent batches")
}
batch.responses[idx] = resp
delete(batch.unresolved, resp.ID)
delete(t.batches, resp.ID)
if len(batch.unresolved) == 0 {
return batch.responses, true
}
return nil, true
}
return nil, false
}
// A msgBatch records information about an incoming batch of JSONRPC2 calls.
//
// The JSONRPC2 spec (https://www.jsonrpc.org/specification#batch) says:
//
// "The Server should respond with an Array containing the corresponding
// Response objects, after all of the batch Request objects have been
// processed. A Response object SHOULD exist for each Request object, except
// that there SHOULD NOT be any Response objects for notifications. The Server
// MAY process a batch rpc call as a set of concurrent tasks, processing them
// in any order and with any width of parallelism."
//
// Therefore, a msgBatch keeps track of outstanding calls and their responses.
// When there are no unresolved calls, the response payload is sent.
type msgBatch struct {
unresolved map[jsonrpc2.ID]int
responses []*jsonrpc2.Response
}
func (t *ioStream) Read(ctx context.Context) (jsonrpc2.Message, int64, error) {
return t.read(ctx, t.in)
}
func (t *ioStream) read(ctx context.Context, in *json.Decoder) (jsonrpc2.Message, int64, error) {
select {
case <-ctx.Done():
return nil, 0, ctx.Err()
default:
}
if len(t.queue) > 0 {
next := t.queue[0]
t.queue = t.queue[1:]
return next, 0, nil
}
var raw json.RawMessage
if err := in.Decode(&raw); err != nil {
return nil, 0, err
}
var rawBatch []json.RawMessage
if err := json.Unmarshal(raw, &rawBatch); err == nil {
msg, err := t.readBatch(rawBatch)
if err != nil {
return nil, 0, err
}
return msg, int64(len(raw)), nil
}
msg, err := jsonrpc2.DecodeMessage(raw)
return msg, int64(len(raw)), err
}
// readBatch reads a batch of jsonrpc2 messages, and records the batch
// in the framer so that responses can be collected and send back together.
func (t *ioStream) readBatch(rawBatch []json.RawMessage) (jsonrpc2.Message, error) {
if len(rawBatch) == 0 {
return nil, fmt.Errorf("empty batch")
}
// From the spec:
// "If the batch rpc call itself fails to be recognized as an valid JSON or
// as an Array with at least one value, the response from the Server MUST be
// a single Response object. If there are no Response objects contained
// within the Response array as it is to be sent to the client, the server
// MUST NOT return an empty Array and should return nothing at all."
//
// In our case, an error actually breaks the jsonrpc2 connection entirely,
// but defensively we collect batch information before recording it, so that
// we don't leave the framer in an inconsistent state.
var (
first jsonrpc2.Message // first message, to return
queue []jsonrpc2.Message // remaining messages
respBatch *msgBatch // tracks incoming requests in the batch
)
for i, raw := range rawBatch {
msg, err := jsonrpc2.DecodeMessage(raw)
if err != nil {
return nil, err
}
if i == 0 {
first = msg
} else {
queue = append(queue, msg)
}
if req, ok := msg.(*jsonrpc2.Request); ok {
if respBatch == nil {
respBatch = &msgBatch{
unresolved: make(map[jsonrpc2.ID]int),
}
}
respBatch.unresolved[req.ID] = len(respBatch.responses)
respBatch.responses = append(respBatch.responses, nil)
}
}
if respBatch != nil {
// The batch contains one or more incoming requests to track.
if err := t.addBatch(respBatch); err != nil {
return nil, err
}
}
t.queue = append(t.queue, queue...)
return first, nil
}
func (t *ioStream) Write(ctx context.Context, msg jsonrpc2.Message) (int64, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
}
// Batching support: if msg is a Response, it may have completed a batch, so
// check that first. Otherwise, it is a request or notification, and we may
// want to collect it into a batch before sending, if we're configured to use
// outgoing batches.
if resp, ok := msg.(*jsonrpc2.Response); ok {
if batch, ok := t.updateBatch(resp); ok {
if len(batch) > 0 {
data, err := marshalMessages(batch)
if err != nil {
return 0, err
}
data = append(data, '\n')
n, err := t.rwc.Write(data)
return int64(n), err
}
return 0, nil
}
} else if len(t.outgoingBatch) < cap(t.outgoingBatch) {
t.outgoingBatch = append(t.outgoingBatch, msg)
if len(t.outgoingBatch) == cap(t.outgoingBatch) {
data, err := marshalMessages(t.outgoingBatch)
t.outgoingBatch = t.outgoingBatch[:0]
if err != nil {
return 0, err
}
data = append(data, '\n')
n, err := t.rwc.Write(data)
return int64(n), err
}
return 0, nil
}
data, err := jsonrpc2.EncodeMessage(msg)
if err != nil {
return 0, fmt.Errorf("marshaling message: %v", err)
}
data = append(data, '\n') // newline delimited
n, err := t.rwc.Write(data)
return int64(n), err
}
func (t *ioStream) Close() error {
return t.rwc.Close()
}
func marshalMessages[T jsonrpc2.Message](msgs []T) ([]byte, error) {
var rawMsgs []json.RawMessage
for _, msg := range msgs {
raw, err := jsonrpc2.EncodeMessage(msg)
if err != nil {
return nil, fmt.Errorf("encoding batch message: %w", err)
}
rawMsgs = append(rawMsgs, raw)
}
return json.Marshal(rawMsgs)
}