From 3e820ececc845e719ade074b7c4b4d495be8b20c Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Fri, 11 Jul 2025 15:29:10 -0300 Subject: [PATCH] chore(deps): update mcp-go (#155) * chore(deps): update mcp-go Signed-off-by: Carlos Alexandro Becker * fix: vendoring Signed-off-by: Carlos Alexandro Becker --------- Signed-off-by: Carlos Alexandro Becker --- go.mod | 2 +- go.sum | 4 +- .../mark3labs/mcp-go/client/client.go | 102 +++++- .../mark3labs/mcp-go/client/http.go | 7 +- .../mark3labs/mcp-go/client/sampling.go | 20 + .../mark3labs/mcp-go/client/stdio.go | 22 +- .../mcp-go/client/transport/inprocess.go | 4 + .../mcp-go/client/transport/interface.go | 20 +- .../mark3labs/mcp-go/client/transport/sse.go | 6 + .../mcp-go/client/transport/stdio.go | 204 ++++++++++- .../client/transport/streamable_http.go | 343 ++++++++++++------ .../github.com/mark3labs/mcp-go/mcp/tools.go | 106 +++++- .../github.com/mark3labs/mcp-go/mcp/types.go | 21 ++ .../github.com/mark3labs/mcp-go/mcp/utils.go | 21 ++ .../mark3labs/mcp-go/server/sampling.go | 37 ++ .../mark3labs/mcp-go/server/stdio.go | 180 ++++++++- .../mcp-go/server/streamable_http.go | 6 +- vendor/modules.txt | 2 +- 18 files changed, 959 insertions(+), 148 deletions(-) create mode 100644 vendor/github.com/mark3labs/mcp-go/client/sampling.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/sampling.go diff --git a/go.mod b/go.mod index 35907121af5791acc5cfc5f3aa07f10df9eba763..2a9d6d5dfbaa827a5c8a57cadbe716dd956e1401 100644 --- a/go.mod +++ b/go.mod @@ -29,7 +29,7 @@ require ( github.com/fsnotify/fsnotify v1.8.0 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 - github.com/mark3labs/mcp-go v0.32.0 + github.com/mark3labs/mcp-go v0.33.0 github.com/muesli/termenv v0.16.0 github.com/ncruces/go-sqlite3 v0.25.0 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 diff --git a/go.sum b/go.sum index 50e30a46d4a47cb210add9c3fe61f0c9fb8e6c26..1d40961a3dce4180d9a06d17e3843f8d8709567b 100644 --- a/go.sum +++ b/go.sum @@ -165,8 +165,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= -github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.33.0 h1:naxhjnTIs/tyPZmWUZFuG0lDmdA6sUyYGGf3gsHvTCc= +github.com/mark3labs/mcp-go v0.33.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= diff --git a/vendor/github.com/mark3labs/mcp-go/client/client.go b/vendor/github.com/mark3labs/mcp-go/client/client.go index dd0e31a013595ccbb900a10fe413e02d1ed9d0ad..e2c466586050cf69e2015e83056fdaf6eda949f6 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/client.go +++ b/vendor/github.com/mark3labs/mcp-go/client/client.go @@ -22,6 +22,7 @@ type Client struct { requestID atomic.Int64 clientCapabilities mcp.ClientCapabilities serverCapabilities mcp.ServerCapabilities + samplingHandler SamplingHandler } type ClientOption func(*Client) @@ -33,6 +34,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption { } } +// WithSamplingHandler sets the sampling handler for the client. +// When set, the client will declare sampling capability during initialization. +func WithSamplingHandler(handler SamplingHandler) ClientOption { + return func(c *Client) { + c.samplingHandler = handler + } +} + +// WithSession assumes a MCP Session has already been initialized +func WithSession() ClientOption { + return func(c *Client) { + c.initialized = true + } +} + // NewClient creates a new MCP client with the given transport. // Usage: // @@ -71,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error { handler(notification) } }) + + // Set up request handler for bidirectional communication (e.g., sampling) + if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok { + bidirectional.SetRequestHandler(c.handleIncomingRequest) + } + return nil } @@ -127,6 +149,12 @@ func (c *Client) Initialize( ctx context.Context, request mcp.InitializeRequest, ) (*mcp.InitializeResult, error) { + // Merge client capabilities with sampling capability if handler is configured + capabilities := request.Params.Capabilities + if c.samplingHandler != nil { + capabilities.Sampling = &struct{}{} + } + // Ensure we send a params object with all required fields params := struct { ProtocolVersion string `json:"protocolVersion"` @@ -135,7 +163,7 @@ func (c *Client) Initialize( }{ ProtocolVersion: request.Params.ProtocolVersion, ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set + Capabilities: capabilities, } response, err := c.sendRequest(ctx, "initialize", params) @@ -398,6 +426,64 @@ func (c *Client) Complete( return &result, nil } +// handleIncomingRequest processes incoming requests from the server. +// This is the main entry point for server-to-client requests like sampling. +func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + switch request.Method { + case string(mcp.MethodSamplingCreateMessage): + return c.handleSamplingRequestTransport(ctx, request) + default: + return nil, fmt.Errorf("unsupported request method: %s", request.Method) + } +} + +// handleSamplingRequestTransport handles sampling requests at the transport level. +func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if c.samplingHandler == nil { + return nil, fmt.Errorf("no sampling handler configured") + } + + // Parse the request parameters + var params mcp.CreateMessageParams + if request.Params != nil { + paramsBytes, err := json.Marshal(request.Params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + if err := json.Unmarshal(paramsBytes, ¶ms); err != nil { + return nil, fmt.Errorf("failed to unmarshal params: %w", err) + } + } + + // Create the MCP request + mcpRequest := mcp.CreateMessageRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + CreateMessageParams: params, + } + + // Call the sampling handler + result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest) + if err != nil { + return nil, err + } + + // Marshal the result + resultBytes, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal result: %w", err) + } + + // Create the transport response + response := &transport.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Result: json.RawMessage(resultBytes), + } + + return response, nil +} func listByPage[T any]( ctx context.Context, client *Client, @@ -432,3 +518,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities { func (c *Client) GetClientCapabilities() mcp.ClientCapabilities { return c.clientCapabilities } + +// GetSessionId returns the session ID of the transport. +// If the transport does not support sessions, it returns an empty string. +func (c *Client) GetSessionId() string { + if c.transport == nil { + return "" + } + return c.transport.GetSessionId() +} + +// IsInitialized returns true if the client has been initialized. +func (c *Client) IsInitialized() bool { + return c.initialized +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/http.go b/vendor/github.com/mark3labs/mcp-go/client/http.go index cb3be35d64cfc731efe2cef0c268a018c53a9538..d001a1e63d08e42d7457adbf5c497d93d029e203 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/http.go +++ b/vendor/github.com/mark3labs/mcp-go/client/http.go @@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP if err != nil { return nil, fmt.Errorf("failed to create SSE transport: %w", err) } - return NewClient(trans), nil + clientOptions := make([]ClientOption, 0) + sessionID := trans.GetSessionId() + if sessionID != "" { + clientOptions = append(clientOptions, WithSession()) + } + return NewClient(trans, clientOptions...), nil } diff --git a/vendor/github.com/mark3labs/mcp-go/client/sampling.go b/vendor/github.com/mark3labs/mcp-go/client/sampling.go new file mode 100644 index 0000000000000000000000000000000000000000..245e2c1f7f305ddb75658a345eddcaba5e2898e3 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/client/sampling.go @@ -0,0 +1,20 @@ +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SamplingHandler defines the interface for handling sampling requests from servers. +// Clients can implement this interface to provide LLM sampling capabilities to servers. +type SamplingHandler interface { + // CreateMessage handles a sampling request from the server and returns the generated message. + // The implementation should: + // 1. Validate the request parameters + // 2. Optionally prompt the user for approval (human-in-the-loop) + // 3. Select an appropriate model based on preferences + // 4. Generate the response using the selected model + // 5. Return the result with model information and stop reason + CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/stdio.go b/vendor/github.com/mark3labs/mcp-go/client/stdio.go index 100c08a7cc0529ca30ca1386f74f6ea4f9be4654..199ec14c381b57c12d691495179cf0c45029d29e 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/stdio.go +++ b/vendor/github.com/mark3labs/mcp-go/client/stdio.go @@ -19,10 +19,26 @@ func NewStdioMCPClient( env []string, args ...string, ) (*Client, error) { + return NewStdioMCPClientWithOptions(command, env, args) +} + +// NewStdioMCPClientWithOptions creates a new stdio-based MCP client that communicates with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Optional configuration functions can be provided to customize the transport before it starts, +// such as setting a custom command function. +// +// NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport. +// Don't call the Start method manually. +// This is for backward compatibility. +func NewStdioMCPClientWithOptions( + command string, + env []string, + args []string, + opts ...transport.StdioOption, +) (*Client, error) { + stdioTransport := transport.NewStdioWithOptions(command, env, args, opts...) - stdioTransport := transport.NewStdio(command, env, args...) - err := stdioTransport.Start(context.Background()) - if err != nil { + if err := stdioTransport.Start(context.Background()); err != nil { return nil, fmt.Errorf("failed to start stdio transport: %w", err) } diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go b/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go index 90fc2fae1f05ebf635b46d0fc415e0260348d3a0..0e2393f0731bcf361d5544da99304d1f07e08706 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go @@ -68,3 +68,7 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc func (*InProcessTransport) Close() error { return nil } + +func (c *InProcessTransport) GetSessionId() string { + return "" +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go b/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go index c83c7c65a3a8b0c7a301564144516242919fe2a5..5f8ed6180b6404a1a0f4085c5557aeb1789e3485 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go @@ -29,6 +29,22 @@ type Interface interface { // Close the connection. Close() error + + // GetSessionId returns the session ID of the transport. + GetSessionId() string +} + +// RequestHandler defines a function that handles incoming requests from the server. +type RequestHandler func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) + +// BidirectionalInterface extends Interface to support incoming requests from the server. +// This is used for features like sampling where the server can send requests to the client. +type BidirectionalInterface interface { + Interface + + // SetRequestHandler sets the handler for incoming requests from the server. + // The handler should process the request and return a response. + SetRequestHandler(handler RequestHandler) } type JSONRPCRequest struct { @@ -41,10 +57,10 @@ type JSONRPCRequest struct { type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` ID mcp.RequestId `json:"id"` - Result json.RawMessage `json:"result"` + Result json.RawMessage `json:"result,omitempty"` Error *struct { Code int `json:"code"` Message string `json:"message"` Data json.RawMessage `json:"data"` - } `json:"error"` + } `json:"error,omitempty"` } diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go b/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go index b22ff62d40124b765b633d2b1700c7407d92d041..ffe3247f0ecd87a4e9c68df72b726a8cc44a7736 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go @@ -428,6 +428,12 @@ func (c *SSE) Close() error { return nil } +// GetSessionId returns the session ID of the transport. +// Since SSE does not maintain a session ID, it returns an empty string. +func (c *SSE) GetSessionId() string { + return "" +} + // SendNotification sends a JSON-RPC notification to the server without expecting a response. func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { if c.endpoint == nil { diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go b/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go index c300c405f7e3880f0b94e1e09e3ee5ca7def732a..c36dc2d37737d71d9028ba11485932c23bb09f9f 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go @@ -23,6 +23,7 @@ type Stdio struct { env []string cmd *exec.Cmd + cmdFunc CommandFunc stdin io.WriteCloser stdout *bufio.Reader stderr io.ReadCloser @@ -31,6 +32,28 @@ type Stdio struct { done chan struct{} onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex + onRequest RequestHandler + requestMu sync.RWMutex + ctx context.Context + ctxMu sync.RWMutex +} + +// StdioOption defines a function that configures a Stdio transport instance. +// Options can be used to customize the behavior of the transport before it starts, +// such as setting a custom command function. +type StdioOption func(*Stdio) + +// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess. +// It can be used to apply sandboxing, custom environment control, working directories, etc. +type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error) + +// WithCommandFunc sets a custom command factory function for the stdio transport. +// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess, +// allowing control over attributes like environment, working directory, and system-level sandboxing. +func WithCommandFunc(f CommandFunc) StdioOption { + return func(s *Stdio) { + s.cmdFunc = f + } } // NewIO returns a new stdio-based transport using existing input, output, and @@ -44,6 +67,7 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), + ctx: context.Background(), } } @@ -55,20 +79,43 @@ func NewStdio( env []string, args ...string, ) *Stdio { + return NewStdioWithOptions(command, env, args) +} - client := &Stdio{ +// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Returns an error if the subprocess cannot be started or the pipes cannot be created. +// Optional configuration functions can be provided to customize the transport before it starts, +// such as setting a custom command factory. +func NewStdioWithOptions( + command string, + env []string, + args []string, + opts ...StdioOption, +) *Stdio { + s := &Stdio{ command: command, args: args, env: env, responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), + ctx: context.Background(), + } + + for _, opt := range opts { + opt(s) } - return client + return s } func (c *Stdio) Start(ctx context.Context) error { + // Store the context for use in request handling + c.ctxMu.Lock() + c.ctx = ctx + c.ctxMu.Unlock() + if err := c.spawnCommand(ctx); err != nil { return err } @@ -83,18 +130,25 @@ func (c *Stdio) Start(ctx context.Context) error { return nil } -// spawnCommand spawns a new process running c.command. +// spawnCommand spawns a new process running the configured command, args, and env. +// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess; +// otherwise, the default behavior uses exec.CommandContext with the merged environment. +// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication. func (c *Stdio) spawnCommand(ctx context.Context) error { if c.command == "" { return nil } - cmd := exec.CommandContext(ctx, c.command, c.args...) - - mergedEnv := os.Environ() - mergedEnv = append(mergedEnv, c.env...) + var cmd *exec.Cmd + var err error - cmd.Env = mergedEnv + // Standard behavior if no command func present. + if c.cmdFunc == nil { + cmd = exec.CommandContext(ctx, c.command, c.args...) + cmd.Env = append(os.Environ(), c.env...) + } else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil { + return err + } stdin, err := cmd.StdinPipe() if err != nil { @@ -148,6 +202,12 @@ func (c *Stdio) Close() error { return nil } +// GetSessionId returns the session ID of the transport. +// Since stdio does not maintain a session ID, it returns an empty string. +func (c *Stdio) GetSessionId() string { + return "" +} + // SetNotificationHandler sets the handler function to be called when a notification is received. // Only one handler can be set at a time; setting a new one replaces the previous handler. func (c *Stdio) SetNotificationHandler( @@ -158,6 +218,14 @@ func (c *Stdio) SetNotificationHandler( c.onNotification = handler } +// SetRequestHandler sets the handler function to be called when a request is received from the server. +// This enables bidirectional communication for features like sampling. +func (c *Stdio) SetRequestHandler(handler RequestHandler) { + c.requestMu.Lock() + defer c.requestMu.Unlock() + c.onRequest = handler +} + // readResponses continuously reads and processes responses from the server's stdout. // It handles both responses to requests and notifications, routing them appropriately. // Runs until the done channel is closed or an error occurs reading from stdout. @@ -175,13 +243,18 @@ func (c *Stdio) readResponses() { return } - var baseMessage JSONRPCResponse + // First try to parse as a generic message to check for ID field + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + ID *mcp.RequestId `json:"id,omitempty"` + Method string `json:"method,omitempty"` + } if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { continue } - // Handle notification - if baseMessage.ID.IsNil() { + // If it has a method but no ID, it's a notification + if baseMessage.Method != "" && baseMessage.ID == nil { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(line), ¬ification); err != nil { continue @@ -194,15 +267,30 @@ func (c *Stdio) readResponses() { continue } + // If it has a method and an ID, it's an incoming request + if baseMessage.Method != "" && baseMessage.ID != nil { + var request JSONRPCRequest + if err := json.Unmarshal([]byte(line), &request); err == nil { + c.handleIncomingRequest(request) + continue + } + } + + // Otherwise, it's a response to our request + var response JSONRPCResponse + if err := json.Unmarshal([]byte(line), &response); err != nil { + continue + } + // Create string key for map lookup - idKey := baseMessage.ID.String() + idKey := response.ID.String() c.mu.RLock() ch, exists := c.responses[idKey] c.mu.RUnlock() if exists { - ch <- &baseMessage + ch <- &response c.mu.Lock() delete(c.responses, idKey) c.mu.Unlock() @@ -281,6 +369,96 @@ func (c *Stdio) SendNotification( return nil } +// handleIncomingRequest processes incoming requests from the server. +// It calls the registered request handler and sends the response back to the server. +func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) { + c.requestMu.RLock() + handler := c.onRequest + c.requestMu.RUnlock() + + if handler == nil { + // Send error response if no handler is configured + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.METHOD_NOT_FOUND, + Message: "No request handler configured", + }, + } + c.sendResponse(errorResponse) + return + } + + // Handle the request in a goroutine to avoid blocking + go func() { + c.ctxMu.RLock() + ctx := c.ctx + c.ctxMu.RUnlock() + + // Check if context is already cancelled before processing + select { + case <-ctx.Done(): + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.INTERNAL_ERROR, + Message: ctx.Err().Error(), + }, + } + c.sendResponse(errorResponse) + return + default: + } + + response, err := handler(ctx, request) + + if err != nil { + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.INTERNAL_ERROR, + Message: err.Error(), + }, + } + c.sendResponse(errorResponse) + return + } + + if response != nil { + c.sendResponse(*response) + } + }() +} + +// sendResponse sends a response back to the server. +func (c *Stdio) sendResponse(response JSONRPCResponse) { + responseBytes, err := json.Marshal(response) + if err != nil { + fmt.Printf("Error marshaling response: %v\n", err) + return + } + responseBytes = append(responseBytes, '\n') + + if _, err := c.stdin.Write(responseBytes); err != nil { + fmt.Printf("Error writing response: %v\n", err) + } +} + // Stderr returns a reader for the stderr output of the subprocess. // This can be used to capture error messages or logs from the subprocess. func (c *Stdio) Stderr() io.Reader { diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go b/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go index 50bde9c288d39e32e00bc5691cbaf75addd740f5..e358751b3344c3783be539cc5daa3e09ffa81020 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go @@ -17,10 +17,24 @@ import ( "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/util" ) type StreamableHTTPCOption func(*StreamableHTTP) +// WithContinuousListening enables receiving server-to-client notifications when no request is in flight. +// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification), +// you should enable this option. +// +// It will establish a standalone long-live GET HTTP connection to the server. +// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server +// NOTICE: Even enabled, the server may not support this feature. +func WithContinuousListening() StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.getListeningEnabled = true + } +} + // WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport. func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption { return func(sc *StreamableHTTP) { @@ -54,6 +68,19 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption { } } +func WithLogger(logger util.Logger) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.logger = logger + } +} + +// WithSession creates a client with a pre-configured session +func WithSession(sessionID string) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.sessionID.Store(sessionID) + } +} + // StreamableHTTP implements Streamable HTTP transport. // // It transmits JSON-RPC messages over individual HTTP requests. One message per request. @@ -64,19 +91,22 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption { // // The current implementation does not support the following features: // - batching -// - continuously listening for server notifications when no request is in flight -// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) // - resuming stream // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) // - server -> client request type StreamableHTTP struct { - serverURL *url.URL - httpClient *http.Client - headers map[string]string - headerFunc HTTPHeaderFunc + serverURL *url.URL + httpClient *http.Client + headers map[string]string + headerFunc HTTPHeaderFunc + logger util.Logger + getListeningEnabled bool sessionID atomic.Value // string + initialized chan struct{} + initializedOnce sync.Once + notificationHandler func(mcp.JSONRPCNotification) notifyMu sync.RWMutex @@ -95,15 +125,19 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str } smc := &StreamableHTTP{ - serverURL: parsedURL, - httpClient: &http.Client{}, - headers: make(map[string]string), - closed: make(chan struct{}), + serverURL: parsedURL, + httpClient: &http.Client{}, + headers: make(map[string]string), + closed: make(chan struct{}), + logger: util.DefaultLogger(), + initialized: make(chan struct{}), } smc.sessionID.Store("") // set initial value to simplify later usage for _, opt := range options { - opt(smc) + if opt != nil { + opt(smc) + } } // If OAuth is configured, set the base URL for metadata discovery @@ -118,7 +152,20 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str // Start initiates the HTTP connection to the server. func (c *StreamableHTTP) Start(ctx context.Context) error { - // For Streamable HTTP, we don't need to establish a persistent connection + // For Streamable HTTP, we don't need to establish a persistent connection by default + if c.getListeningEnabled { + go func() { + select { + case <-c.initialized: + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() + c.listenForever(ctx) + case <-c.closed: + return + } + }() + } + return nil } @@ -142,13 +189,13 @@ func (c *StreamableHTTP) Close() error { defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil) if err != nil { - fmt.Printf("failed to create close request\n: %v", err) + c.logger.Errorf("failed to create close request: %v", err) return } req.Header.Set(headerKeySessionID, sessionId) res, err := c.httpClient.Do(req) if err != nil { - fmt.Printf("failed to send close request\n: %v", err) + c.logger.Errorf("failed to send close request: %v", err) return } res.Body.Close() @@ -185,77 +232,29 @@ func (c *StreamableHTTP) SendRequest( request JSONRPCRequest, ) (*JSONRPCResponse, error) { - // Create a combined context that could be canceled when the client is closed - newCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-c.closed: - cancel() - case <-newCtx.Done(): - // The original context was canceled, no need to do anything - } - }() - ctx = newCtx - // Marshal request requestBody, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } - // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - sessionID := c.sessionID.Load() - if sessionID != "" { - req.Header.Set(headerKeySessionID, sessionID.(string)) - } - for k, v := range c.headers { - req.Header.Set(k, v) - } - - // Add OAuth authorization if configured - if c.oauthHandler != nil { - authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) - if err != nil { - // If we get an authorization error, return a specific error that can be handled by the client - if err.Error() == "no valid token available, authorization required" { - return nil, &OAuthAuthorizationRequiredError{ - Handler: c.oauthHandler, - } - } - return nil, fmt.Errorf("failed to get authorization header: %w", err) - } - req.Header.Set("Authorization", authHeader) - } - - if c.headerFunc != nil { - for k, v := range c.headerFunc(ctx) { - req.Header.Set(k, v) - } - } + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() - // Send request - resp, err := c.httpClient.Do(req) + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) + if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) { + // If the request is initialize, should not return a SessionTerminated error + // It should be a genuine endpoint-routing issue. + // ( Fall through to return StatusCode checking. ) + } else { + return nil, fmt.Errorf("failed to send request: %w", err) + } } defer resp.Body.Close() // Check if we got an error response if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { - // handle session closed - if resp.StatusCode == http.StatusNotFound { - c.sessionID.CompareAndSwap(sessionID, "") - return nil, fmt.Errorf("session terminated (404). need to re-initialize") - } // Handle OAuth unauthorized error if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { @@ -279,6 +278,10 @@ func (c *StreamableHTTP) SendRequest( if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { c.sessionID.Store(sessionID) } + + c.initializedOnce.Do(func() { + close(c.initialized) + }) } // Handle different response types @@ -300,16 +303,77 @@ func (c *StreamableHTTP) SendRequest( case "text/event-stream": // Server is using SSE for streaming responses - return c.handleSSEResponse(ctx, resp.Body) + return c.handleSSEResponse(ctx, resp.Body, false) default: return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type")) } } +func (c *StreamableHTTP) sendHTTP( + ctx context.Context, + method string, + body io.Reader, + acceptType string, +) (resp *http.Response, err error) { + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", acceptType) + sessionID := c.sessionID.Load().(string) + if sessionID != "" { + req.Header.Set(headerKeySessionID, sessionID) + } + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // Add OAuth authorization if configured + if c.oauthHandler != nil { + authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) + if err != nil { + // If we get an authorization error, return a specific error that can be handled by the client + if err.Error() == "no valid token available, authorization required" { + return nil, &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } + return nil, fmt.Errorf("failed to get authorization header: %w", err) + } + req.Header.Set("Authorization", authHeader) + } + + if c.headerFunc != nil { + for k, v := range c.headerFunc(ctx) { + req.Header.Set(k, v) + } + } + + // Send request + resp, err = c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + // universal handling for session terminated + if resp.StatusCode == http.StatusNotFound { + c.sessionID.CompareAndSwap(sessionID, "") + return nil, ErrSessionTerminated + } + + return resp, nil +} + // handleSSEResponse processes an SSE stream for a specific request. // It returns the final result for the request once received, or an error. -func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) { +// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening. +func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) { // Create a channel for this specific request responseChan := make(chan *JSONRPCResponse, 1) @@ -328,7 +392,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl var message JSONRPCResponse if err := json.Unmarshal([]byte(data), &message); err != nil { - fmt.Printf("failed to unmarshal message: %v\n", err) + c.logger.Errorf("failed to unmarshal message: %v", err) return } @@ -336,7 +400,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl if message.ID.IsNil() { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(data), ¬ification); err != nil { - fmt.Printf("failed to unmarshal notification: %v\n", err) + c.logger.Errorf("failed to unmarshal notification: %v", err) return } c.notifyMu.RLock() @@ -347,7 +411,9 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl return } - responseChan <- &message + if !ignoreResponse { + responseChan <- &message + } }) }() @@ -393,7 +459,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand case <-ctx.Done(): return default: - fmt.Printf("SSE stream error: %v\n", err) + c.logger.Errorf("SSE stream error: %v", err) return } } @@ -432,44 +498,10 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. } // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - if sessionID := c.sessionID.Load(); sessionID != "" { - req.Header.Set(headerKeySessionID, sessionID.(string)) - } - for k, v := range c.headers { - req.Header.Set(k, v) - } - - // Add OAuth authorization if configured - if c.oauthHandler != nil { - authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) - if err != nil { - // If we get an authorization error, return a specific error that can be handled by the client - if errors.Is(err, ErrOAuthAuthorizationRequired) { - return &OAuthAuthorizationRequiredError{ - Handler: c.oauthHandler, - } - } - return fmt.Errorf("failed to get authorization header: %w", err) - } - req.Header.Set("Authorization", authHeader) - } - - if c.headerFunc != nil { - for k, v := range c.headerFunc(ctx) { - req.Header.Set(k, v) - } - } + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() - // Send request - resp, err := c.httpClient.Do(req) + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -513,3 +545,84 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler { func (c *StreamableHTTP) IsOAuthEnabled() bool { return c.oauthHandler != nil } + +func (c *StreamableHTTP) listenForever(ctx context.Context) { + c.logger.Infof("listening to server forever") + for { + err := c.createGETConnectionToServer(ctx) + if errors.Is(err, ErrGetMethodNotAllowed) { + // server does not support listening + c.logger.Errorf("server does not support listening") + return + } + + select { + case <-ctx.Done(): + return + default: + } + + if err != nil { + c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) + } + time.Sleep(retryInterval) + } +} + +var ( + ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize") + ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed") + + retryInterval = 1 * time.Second // a variable is convenient for testing +) + +func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error { + + resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream") + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Check if we got an error response + if resp.StatusCode == http.StatusMethodNotAllowed { + return ErrGetMethodNotAllowed + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body) + } + + // handle SSE response + contentType := resp.Header.Get("Content-Type") + if contentType != "text/event-stream" { + return fmt.Errorf("unexpected content type: %s", contentType) + } + + // When ignoreResponse is true, the function will never return expect context is done. + // NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response + // messages. To be more compatible, we should handle this response, however, as the transport layer is message-based, + // currently, there is no convenient way to handle this response. + // So we ignore the response here. It's not a bug, but may be not compatible with other SDKs. + _, err = c.handleSSEResponse(ctx, resp.Body, true) + if err != nil { + return fmt.Errorf("failed to handle SSE response: %w", err) + } + + return nil +} + +func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) { + newCtx, cancel := context.WithCancel(ctx) + go func() { + select { + case <-c.closed: + cancel() + case <-newCtx.Done(): + // The original context was canceled + cancel() + } + }() + return newCtx, cancel +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go index 5f3524b0212d12a89b14fd1e2f4b2e6ba4dbd806..3e3931b09c9aedfce1f6e58a80be180e107b3116 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go @@ -945,7 +945,20 @@ func PropertyNames(schema map[string]any) PropertyOption { } } -// Items defines the schema for array items +// Items defines the schema for array items. +// Accepts any schema definition for maximum flexibility. +// +// Example: +// +// Items(map[string]any{ +// "type": "object", +// "properties": map[string]any{ +// "name": map[string]any{"type": "string"}, +// "age": map[string]any{"type": "number"}, +// }, +// }) +// +// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead. func Items(schema any) PropertyOption { return func(schemaMap map[string]any) { schemaMap["items"] = schema @@ -972,3 +985,94 @@ func UniqueItems(unique bool) PropertyOption { schema["uniqueItems"] = unique } } + +// WithStringItems configures an array's items to be of type string. +// +// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("tags", mcp.WithStringItems()) +// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue"))) +// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50))) +// +// Limitations: Only supports simple string arrays. Use Items() for complex objects. +func WithStringItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "string", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithStringEnumItems configures an array's items to be of type string with a specified enum. +// Example: +// +// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"})) +// +// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility. +func WithStringEnumItems(values []string) PropertyOption { + return func(schema map[string]any) { + schema["items"] = map[string]any{ + "type": "string", + "enum": values, + } + } +} + +// WithNumberItems configures an array's items to be of type number. +// +// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100))) +// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0))) +// +// Limitations: Only supports simple number arrays. Use Items() for complex objects. +func WithNumberItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "number", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithBooleanItems configures an array's items to be of type boolean. +// +// Supported options: Description(), DefaultBool() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("flags", mcp.WithBooleanItems()) +// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions"))) +// +// Limitations: Only supports simple boolean arrays. Use Items() for complex objects. +func WithBooleanItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "boolean", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/types.go b/vendor/github.com/mark3labs/mcp-go/mcp/types.go index 0091d2e42d380253ee03d0d1b5cde8597775be8f..241b55ce9b549941d764a2ca5b4ba11e551d301d 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/types.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/types.go @@ -763,6 +763,11 @@ const ( /* Sampling */ +const ( + // MethodSamplingCreateMessage allows servers to request LLM completions from clients + MethodSamplingCreateMessage MCPMethod = "sampling/createMessage" +) + // CreateMessageRequest is a request from the server to sample an LLM via the // client. The client has full discretion over which model to select. The client // should also inform the user before beginning sampling, to allow them to inspect @@ -865,6 +870,22 @@ type AudioContent struct { func (AudioContent) isContent() {} +// ResourceLink represents a link to a resource that the client can access. +type ResourceLink struct { + Annotated + Type string `json:"type"` // Must be "resource_link" + // The URI of the resource. + URI string `json:"uri"` + // The name of the resource. + Name string `json:"name"` + // The description of the resource. + Description string `json:"description"` + // The MIME type of the resource. + MIMEType string `json:"mimeType"` +} + +func (ResourceLink) isContent() {} + // EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. // // It is up to the client how best to render embedded resources for the diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go index 55bef7a997e2a406f111b2fb399812ca1941ab96..3e652efd7e842d24bc6ab13fa119d21f272a8ba7 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go @@ -222,6 +222,17 @@ func NewAudioContent(data, mimeType string) AudioContent { } } +// Helper function to create a new ResourceLink +func NewResourceLink(uri, name, description, mimeType string) ResourceLink { + return ResourceLink{ + Type: "resource_link", + URI: uri, + Name: name, + Description: description, + MIMEType: mimeType, + } +} + // Helper function to create a new EmbeddedResource func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { return EmbeddedResource{ @@ -476,6 +487,16 @@ func ParseContent(contentMap map[string]any) (Content, error) { } return NewAudioContent(data, mimeType), nil + case "resource_link": + uri := ExtractString(contentMap, "uri") + name := ExtractString(contentMap, "name") + description := ExtractString(contentMap, "description") + mimeType := ExtractString(contentMap, "mimeType") + if uri == "" || name == "" { + return nil, fmt.Errorf("resource_link uri or name is missing") + } + return NewResourceLink(uri, name, description, mimeType), nil + case "resource": resourceMap := ExtractMap(contentMap, "resource") if resourceMap == nil { diff --git a/vendor/github.com/mark3labs/mcp-go/server/sampling.go b/vendor/github.com/mark3labs/mcp-go/server/sampling.go new file mode 100644 index 0000000000000000000000000000000000000000..b633b24d07ebfeeedd9b49468d7aadb411c87b4c --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/sampling.go @@ -0,0 +1,37 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// EnableSampling enables sampling capabilities for the server. +// This allows the server to send sampling requests to clients that support it. +func (s *MCPServer) EnableSampling() { + s.capabilitiesMu.Lock() + defer s.capabilitiesMu.Unlock() +} + +// RequestSampling sends a sampling request to the client. +// The client must have declared sampling capability during initialization. +func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Check if the session supports sampling requests + if samplingSession, ok := session.(SessionWithSampling); ok { + return samplingSession.RequestSampling(ctx, request) + } + + return nil, fmt.Errorf("session does not support sampling") +} + +// SessionWithSampling extends ClientSession to support sampling requests. +type SessionWithSampling interface { + ClientSession + RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/stdio.go b/vendor/github.com/mark3labs/mcp-go/server/stdio.go index 746a7d96f6c3635ec05c6bc2d7b92820824a8e20..33ac9bb8854527db09ea31a0be4d109521fa0c37 100644 --- a/vendor/github.com/mark3labs/mcp-go/server/stdio.go +++ b/vendor/github.com/mark3labs/mcp-go/server/stdio.go @@ -9,6 +9,7 @@ import ( "log" "os" "os/signal" + "sync" "sync/atomic" "syscall" @@ -51,10 +52,21 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value // stores session-specific client info + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingMu sync.RWMutex // protects pendingRequests +} + +// samplingResponse represents a response to a sampling request +type samplingResponse struct { + result *mcp.CreateMessageResult + err error } func (s *stdioSession) SessionID() string { @@ -100,14 +112,86 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +// RequestSampling sends a sampling request to the client and waits for the response. +func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *samplingResponse, 1) + s.pendingMu.Lock() + s.pendingRequests[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingRequests, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params mcp.CreateMessageParams `json:"params"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodSamplingCreateMessage), + Params: request.CreateMessageParams, + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal sampling request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write sampling request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + +// SetWriter sets the writer for sending requests to the client. +func (s *stdioSession) SetWriter(writer io.Writer) { + s.mu.Lock() + defer s.mu.Unlock() + s.writer = writer +} + var ( _ ClientSession = (*stdioSession)(nil) _ SessionWithLogging = (*stdioSession)(nil) _ SessionWithClientInfo = (*stdioSession)(nil) + _ SessionWithSampling = (*stdioSession)(nil) ) var stdioSessionInstance = stdioSession{ - notifications: make(chan mcp.JSONRPCNotification, 100), + notifications: make(chan mcp.JSONRPCNotification, 100), + pendingRequests: make(map[int64]chan *samplingResponse), } // NewStdioServer creates a new stdio server wrapper around an MCPServer. @@ -224,6 +308,9 @@ func (s *StdioServer) Listen( defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) ctx = s.server.WithContext(ctx, &stdioSessionInstance) + // Set the writer for sending requests to the client + stdioSessionInstance.SetWriter(stdout) + // Add in any custom context. if s.contextFunc != nil { ctx = s.contextFunc(ctx) @@ -256,7 +343,29 @@ func (s *StdioServer) processMessage( return s.writeResponse(response, writer) } - // Handle the message using the wrapped server + // Check if this is a response to a sampling request + if s.handleSamplingResponse(rawMessage) { + return nil + } + + // Check if this is a tool call that might need sampling (and thus should be processed concurrently) + var baseMessage struct { + Method string `json:"method"` + } + if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" { + // Process tool calls concurrently to avoid blocking on sampling requests + go func() { + response := s.server.HandleMessage(ctx, rawMessage) + if response != nil { + if err := s.writeResponse(response, writer); err != nil { + s.errLogger.Printf("Error writing tool response: %v", err) + } + } + }() + return nil + } + + // Handle other messages synchronously response := s.server.HandleMessage(ctx, rawMessage) // Only write response if there is one (not for notifications) @@ -269,6 +378,65 @@ func (s *StdioServer) processMessage( return nil } +// handleSamplingResponse checks if the message is a response to a sampling request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleSamplingResponse(rawMessage) +} + +// handleSamplingResponse handles incoming sampling responses for this session +func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + idInt64, err := response.ID.Int64() + if err != nil || (response.Result == nil && response.Error == nil) { + return false + } + + // Look for a pending request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingRequests[idInt64] + s.pendingMu.RUnlock() + + if !exists { + return false + } // Parse and send the response + samplingResp := &samplingResponse{} + + if response.Error != nil { + samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message) + } else { + var result mcp.CreateMessageResult + if err := json.Unmarshal(response.Result, &result); err != nil { + samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err) + } else { + samplingResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- samplingResp: + default: + // Channel is full or closed, ignore + } + + return true +} + // writeResponse marshals and writes a JSON-RPC response message followed by a newline. // Returns an error if marshaling or writing fails. func (s *StdioServer) writeResponse( diff --git a/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go index e9a011fb1c31b46771e3baeffe666ef9a71ef1a1..1312c9753a5ddc2d37a2d3c9f6266cc80d517e2e 100644 --- a/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go +++ b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go @@ -40,7 +40,9 @@ func WithEndpointPath(endpointPath string) StreamableHTTPOption { // to StatelessSessionIdManager. func WithStateLess(stateLess bool) StreamableHTTPOption { return func(s *StreamableHTTPServer) { - s.sessionIdManager = &StatelessSessionIdManager{} + if stateLess { + s.sessionIdManager = &StatelessSessionIdManager{} + } } } @@ -374,7 +376,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - w.WriteHeader(http.StatusAccepted) + w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) if !ok { diff --git a/vendor/modules.txt b/vendor/modules.txt index 33d95285eebb41a1038aa2d95233bbcc96a87151..8cbc2b93ffb7ce6c044bca3f157defbf2db3d00c 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -403,7 +403,7 @@ github.com/kylelemons/godebug/pretty # github.com/lucasb-eyer/go-colorful v1.2.0 ## explicit; go 1.12 github.com/lucasb-eyer/go-colorful -# github.com/mark3labs/mcp-go v0.32.0 +# github.com/mark3labs/mcp-go v0.33.0 ## explicit; go 1.23 github.com/mark3labs/mcp-go/client github.com/mark3labs/mcp-go/client/transport