Detailed changes
@@ -29,7 +29,7 @@ require (
github.com/fsnotify/fsnotify v1.8.0
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
- github.com/mark3labs/mcp-go v0.32.0
+ github.com/mark3labs/mcp-go v0.33.0
github.com/muesli/termenv v0.16.0
github.com/ncruces/go-sqlite3 v0.25.0
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
@@ -165,8 +165,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
-github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8=
-github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
+github.com/mark3labs/mcp-go v0.33.0 h1:naxhjnTIs/tyPZmWUZFuG0lDmdA6sUyYGGf3gsHvTCc=
+github.com/mark3labs/mcp-go v0.33.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
@@ -22,6 +22,7 @@ type Client struct {
requestID atomic.Int64
clientCapabilities mcp.ClientCapabilities
serverCapabilities mcp.ServerCapabilities
+ samplingHandler SamplingHandler
}
type ClientOption func(*Client)
@@ -33,6 +34,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
}
}
+// WithSamplingHandler sets the sampling handler for the client.
+// When set, the client will declare sampling capability during initialization.
+func WithSamplingHandler(handler SamplingHandler) ClientOption {
+ return func(c *Client) {
+ c.samplingHandler = handler
+ }
+}
+
+// WithSession assumes a MCP Session has already been initialized
+func WithSession() ClientOption {
+ return func(c *Client) {
+ c.initialized = true
+ }
+}
+
// NewClient creates a new MCP client with the given transport.
// Usage:
//
@@ -71,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error {
handler(notification)
}
})
+
+ // Set up request handler for bidirectional communication (e.g., sampling)
+ if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok {
+ bidirectional.SetRequestHandler(c.handleIncomingRequest)
+ }
+
return nil
}
@@ -127,6 +149,12 @@ func (c *Client) Initialize(
ctx context.Context,
request mcp.InitializeRequest,
) (*mcp.InitializeResult, error) {
+ // Merge client capabilities with sampling capability if handler is configured
+ capabilities := request.Params.Capabilities
+ if c.samplingHandler != nil {
+ capabilities.Sampling = &struct{}{}
+ }
+
// Ensure we send a params object with all required fields
params := struct {
ProtocolVersion string `json:"protocolVersion"`
@@ -135,7 +163,7 @@ func (c *Client) Initialize(
}{
ProtocolVersion: request.Params.ProtocolVersion,
ClientInfo: request.Params.ClientInfo,
- Capabilities: request.Params.Capabilities, // Will be empty struct if not set
+ Capabilities: capabilities,
}
response, err := c.sendRequest(ctx, "initialize", params)
@@ -398,6 +426,64 @@ func (c *Client) Complete(
return &result, nil
}
+// handleIncomingRequest processes incoming requests from the server.
+// This is the main entry point for server-to-client requests like sampling.
+func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
+ switch request.Method {
+ case string(mcp.MethodSamplingCreateMessage):
+ return c.handleSamplingRequestTransport(ctx, request)
+ default:
+ return nil, fmt.Errorf("unsupported request method: %s", request.Method)
+ }
+}
+
+// handleSamplingRequestTransport handles sampling requests at the transport level.
+func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
+ if c.samplingHandler == nil {
+ return nil, fmt.Errorf("no sampling handler configured")
+ }
+
+ // Parse the request parameters
+ var params mcp.CreateMessageParams
+ if request.Params != nil {
+ paramsBytes, err := json.Marshal(request.Params)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal params: %w", err)
+ }
+ if err := json.Unmarshal(paramsBytes, ¶ms); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal params: %w", err)
+ }
+ }
+
+ // Create the MCP request
+ mcpRequest := mcp.CreateMessageRequest{
+ Request: mcp.Request{
+ Method: string(mcp.MethodSamplingCreateMessage),
+ },
+ CreateMessageParams: params,
+ }
+
+ // Call the sampling handler
+ result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest)
+ if err != nil {
+ return nil, err
+ }
+
+ // Marshal the result
+ resultBytes, err := json.Marshal(result)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal result: %w", err)
+ }
+
+ // Create the transport response
+ response := &transport.JSONRPCResponse{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: request.ID,
+ Result: json.RawMessage(resultBytes),
+ }
+
+ return response, nil
+}
func listByPage[T any](
ctx context.Context,
client *Client,
@@ -432,3 +518,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
return c.clientCapabilities
}
+
+// GetSessionId returns the session ID of the transport.
+// If the transport does not support sessions, it returns an empty string.
+func (c *Client) GetSessionId() string {
+ if c.transport == nil {
+ return ""
+ }
+ return c.transport.GetSessionId()
+}
+
+// IsInitialized returns true if the client has been initialized.
+func (c *Client) IsInitialized() bool {
+ return c.initialized
+}
@@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP
if err != nil {
return nil, fmt.Errorf("failed to create SSE transport: %w", err)
}
- return NewClient(trans), nil
+ clientOptions := make([]ClientOption, 0)
+ sessionID := trans.GetSessionId()
+ if sessionID != "" {
+ clientOptions = append(clientOptions, WithSession())
+ }
+ return NewClient(trans, clientOptions...), nil
}
@@ -0,0 +1,20 @@
+package client
+
+import (
+ "context"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// SamplingHandler defines the interface for handling sampling requests from servers.
+// Clients can implement this interface to provide LLM sampling capabilities to servers.
+type SamplingHandler interface {
+ // CreateMessage handles a sampling request from the server and returns the generated message.
+ // The implementation should:
+ // 1. Validate the request parameters
+ // 2. Optionally prompt the user for approval (human-in-the-loop)
+ // 3. Select an appropriate model based on preferences
+ // 4. Generate the response using the selected model
+ // 5. Return the result with model information and stop reason
+ CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
+}
@@ -19,10 +19,26 @@ func NewStdioMCPClient(
env []string,
args ...string,
) (*Client, error) {
+ return NewStdioMCPClientWithOptions(command, env, args)
+}
+
+// NewStdioMCPClientWithOptions creates a new stdio-based MCP client that communicates with a subprocess.
+// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
+// Optional configuration functions can be provided to customize the transport before it starts,
+// such as setting a custom command function.
+//
+// NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport.
+// Don't call the Start method manually.
+// This is for backward compatibility.
+func NewStdioMCPClientWithOptions(
+ command string,
+ env []string,
+ args []string,
+ opts ...transport.StdioOption,
+) (*Client, error) {
+ stdioTransport := transport.NewStdioWithOptions(command, env, args, opts...)
- stdioTransport := transport.NewStdio(command, env, args...)
- err := stdioTransport.Start(context.Background())
- if err != nil {
+ if err := stdioTransport.Start(context.Background()); err != nil {
return nil, fmt.Errorf("failed to start stdio transport: %w", err)
}
@@ -68,3 +68,7 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc
func (*InProcessTransport) Close() error {
return nil
}
+
+func (c *InProcessTransport) GetSessionId() string {
+ return ""
+}
@@ -29,6 +29,22 @@ type Interface interface {
// Close the connection.
Close() error
+
+ // GetSessionId returns the session ID of the transport.
+ GetSessionId() string
+}
+
+// RequestHandler defines a function that handles incoming requests from the server.
+type RequestHandler func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error)
+
+// BidirectionalInterface extends Interface to support incoming requests from the server.
+// This is used for features like sampling where the server can send requests to the client.
+type BidirectionalInterface interface {
+ Interface
+
+ // SetRequestHandler sets the handler for incoming requests from the server.
+ // The handler should process the request and return a response.
+ SetRequestHandler(handler RequestHandler)
}
type JSONRPCRequest struct {
@@ -41,10 +57,10 @@ type JSONRPCRequest struct {
type JSONRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID mcp.RequestId `json:"id"`
- Result json.RawMessage `json:"result"`
+ Result json.RawMessage `json:"result,omitempty"`
Error *struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
- } `json:"error"`
+ } `json:"error,omitempty"`
}
@@ -428,6 +428,12 @@ func (c *SSE) Close() error {
return nil
}
+// GetSessionId returns the session ID of the transport.
+// Since SSE does not maintain a session ID, it returns an empty string.
+func (c *SSE) GetSessionId() string {
+ return ""
+}
+
// SendNotification sends a JSON-RPC notification to the server without expecting a response.
func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
if c.endpoint == nil {
@@ -23,6 +23,7 @@ type Stdio struct {
env []string
cmd *exec.Cmd
+ cmdFunc CommandFunc
stdin io.WriteCloser
stdout *bufio.Reader
stderr io.ReadCloser
@@ -31,6 +32,28 @@ type Stdio struct {
done chan struct{}
onNotification func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
+ onRequest RequestHandler
+ requestMu sync.RWMutex
+ ctx context.Context
+ ctxMu sync.RWMutex
+}
+
+// StdioOption defines a function that configures a Stdio transport instance.
+// Options can be used to customize the behavior of the transport before it starts,
+// such as setting a custom command function.
+type StdioOption func(*Stdio)
+
+// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess.
+// It can be used to apply sandboxing, custom environment control, working directories, etc.
+type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error)
+
+// WithCommandFunc sets a custom command factory function for the stdio transport.
+// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess,
+// allowing control over attributes like environment, working directory, and system-level sandboxing.
+func WithCommandFunc(f CommandFunc) StdioOption {
+ return func(s *Stdio) {
+ s.cmdFunc = f
+ }
}
// NewIO returns a new stdio-based transport using existing input, output, and
@@ -44,6 +67,7 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio
responses: make(map[string]chan *JSONRPCResponse),
done: make(chan struct{}),
+ ctx: context.Background(),
}
}
@@ -55,20 +79,43 @@ func NewStdio(
env []string,
args ...string,
) *Stdio {
+ return NewStdioWithOptions(command, env, args)
+}
- client := &Stdio{
+// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess.
+// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
+// Returns an error if the subprocess cannot be started or the pipes cannot be created.
+// Optional configuration functions can be provided to customize the transport before it starts,
+// such as setting a custom command factory.
+func NewStdioWithOptions(
+ command string,
+ env []string,
+ args []string,
+ opts ...StdioOption,
+) *Stdio {
+ s := &Stdio{
command: command,
args: args,
env: env,
responses: make(map[string]chan *JSONRPCResponse),
done: make(chan struct{}),
+ ctx: context.Background(),
+ }
+
+ for _, opt := range opts {
+ opt(s)
}
- return client
+ return s
}
func (c *Stdio) Start(ctx context.Context) error {
+ // Store the context for use in request handling
+ c.ctxMu.Lock()
+ c.ctx = ctx
+ c.ctxMu.Unlock()
+
if err := c.spawnCommand(ctx); err != nil {
return err
}
@@ -83,18 +130,25 @@ func (c *Stdio) Start(ctx context.Context) error {
return nil
}
-// spawnCommand spawns a new process running c.command.
+// spawnCommand spawns a new process running the configured command, args, and env.
+// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess;
+// otherwise, the default behavior uses exec.CommandContext with the merged environment.
+// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication.
func (c *Stdio) spawnCommand(ctx context.Context) error {
if c.command == "" {
return nil
}
- cmd := exec.CommandContext(ctx, c.command, c.args...)
-
- mergedEnv := os.Environ()
- mergedEnv = append(mergedEnv, c.env...)
+ var cmd *exec.Cmd
+ var err error
- cmd.Env = mergedEnv
+ // Standard behavior if no command func present.
+ if c.cmdFunc == nil {
+ cmd = exec.CommandContext(ctx, c.command, c.args...)
+ cmd.Env = append(os.Environ(), c.env...)
+ } else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil {
+ return err
+ }
stdin, err := cmd.StdinPipe()
if err != nil {
@@ -148,6 +202,12 @@ func (c *Stdio) Close() error {
return nil
}
+// GetSessionId returns the session ID of the transport.
+// Since stdio does not maintain a session ID, it returns an empty string.
+func (c *Stdio) GetSessionId() string {
+ return ""
+}
+
// SetNotificationHandler sets the handler function to be called when a notification is received.
// Only one handler can be set at a time; setting a new one replaces the previous handler.
func (c *Stdio) SetNotificationHandler(
@@ -158,6 +218,14 @@ func (c *Stdio) SetNotificationHandler(
c.onNotification = handler
}
+// SetRequestHandler sets the handler function to be called when a request is received from the server.
+// This enables bidirectional communication for features like sampling.
+func (c *Stdio) SetRequestHandler(handler RequestHandler) {
+ c.requestMu.Lock()
+ defer c.requestMu.Unlock()
+ c.onRequest = handler
+}
+
// readResponses continuously reads and processes responses from the server's stdout.
// It handles both responses to requests and notifications, routing them appropriately.
// Runs until the done channel is closed or an error occurs reading from stdout.
@@ -175,13 +243,18 @@ func (c *Stdio) readResponses() {
return
}
- var baseMessage JSONRPCResponse
+ // First try to parse as a generic message to check for ID field
+ var baseMessage struct {
+ JSONRPC string `json:"jsonrpc"`
+ ID *mcp.RequestId `json:"id,omitempty"`
+ Method string `json:"method,omitempty"`
+ }
if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
continue
}
- // Handle notification
- if baseMessage.ID.IsNil() {
+ // If it has a method but no ID, it's a notification
+ if baseMessage.Method != "" && baseMessage.ID == nil {
var notification mcp.JSONRPCNotification
if err := json.Unmarshal([]byte(line), ¬ification); err != nil {
continue
@@ -194,15 +267,30 @@ func (c *Stdio) readResponses() {
continue
}
+ // If it has a method and an ID, it's an incoming request
+ if baseMessage.Method != "" && baseMessage.ID != nil {
+ var request JSONRPCRequest
+ if err := json.Unmarshal([]byte(line), &request); err == nil {
+ c.handleIncomingRequest(request)
+ continue
+ }
+ }
+
+ // Otherwise, it's a response to our request
+ var response JSONRPCResponse
+ if err := json.Unmarshal([]byte(line), &response); err != nil {
+ continue
+ }
+
// Create string key for map lookup
- idKey := baseMessage.ID.String()
+ idKey := response.ID.String()
c.mu.RLock()
ch, exists := c.responses[idKey]
c.mu.RUnlock()
if exists {
- ch <- &baseMessage
+ ch <- &response
c.mu.Lock()
delete(c.responses, idKey)
c.mu.Unlock()
@@ -281,6 +369,96 @@ func (c *Stdio) SendNotification(
return nil
}
+// handleIncomingRequest processes incoming requests from the server.
+// It calls the registered request handler and sends the response back to the server.
+func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) {
+ c.requestMu.RLock()
+ handler := c.onRequest
+ c.requestMu.RUnlock()
+
+ if handler == nil {
+ // Send error response if no handler is configured
+ errorResponse := JSONRPCResponse{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: request.ID,
+ Error: &struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Code: mcp.METHOD_NOT_FOUND,
+ Message: "No request handler configured",
+ },
+ }
+ c.sendResponse(errorResponse)
+ return
+ }
+
+ // Handle the request in a goroutine to avoid blocking
+ go func() {
+ c.ctxMu.RLock()
+ ctx := c.ctx
+ c.ctxMu.RUnlock()
+
+ // Check if context is already cancelled before processing
+ select {
+ case <-ctx.Done():
+ errorResponse := JSONRPCResponse{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: request.ID,
+ Error: &struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Code: mcp.INTERNAL_ERROR,
+ Message: ctx.Err().Error(),
+ },
+ }
+ c.sendResponse(errorResponse)
+ return
+ default:
+ }
+
+ response, err := handler(ctx, request)
+
+ if err != nil {
+ errorResponse := JSONRPCResponse{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: request.ID,
+ Error: &struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Code: mcp.INTERNAL_ERROR,
+ Message: err.Error(),
+ },
+ }
+ c.sendResponse(errorResponse)
+ return
+ }
+
+ if response != nil {
+ c.sendResponse(*response)
+ }
+ }()
+}
+
+// sendResponse sends a response back to the server.
+func (c *Stdio) sendResponse(response JSONRPCResponse) {
+ responseBytes, err := json.Marshal(response)
+ if err != nil {
+ fmt.Printf("Error marshaling response: %v\n", err)
+ return
+ }
+ responseBytes = append(responseBytes, '\n')
+
+ if _, err := c.stdin.Write(responseBytes); err != nil {
+ fmt.Printf("Error writing response: %v\n", err)
+ }
+}
+
// Stderr returns a reader for the stderr output of the subprocess.
// This can be used to capture error messages or logs from the subprocess.
func (c *Stdio) Stderr() io.Reader {
@@ -17,10 +17,24 @@ import (
"time"
"github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/util"
)
type StreamableHTTPCOption func(*StreamableHTTP)
+// WithContinuousListening enables receiving server-to-client notifications when no request is in flight.
+// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification),
+// you should enable this option.
+//
+// It will establish a standalone long-live GET HTTP connection to the server.
+// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
+// NOTICE: Even enabled, the server may not support this feature.
+func WithContinuousListening() StreamableHTTPCOption {
+ return func(sc *StreamableHTTP) {
+ sc.getListeningEnabled = true
+ }
+}
+
// WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport.
func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
@@ -54,6 +68,19 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
}
}
+func WithLogger(logger util.Logger) StreamableHTTPCOption {
+ return func(sc *StreamableHTTP) {
+ sc.logger = logger
+ }
+}
+
+// WithSession creates a client with a pre-configured session
+func WithSession(sessionID string) StreamableHTTPCOption {
+ return func(sc *StreamableHTTP) {
+ sc.sessionID.Store(sessionID)
+ }
+}
+
// StreamableHTTP implements Streamable HTTP transport.
//
// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
@@ -64,19 +91,22 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
//
// The current implementation does not support the following features:
// - batching
-// - continuously listening for server notifications when no request is in flight
-// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
// - resuming stream
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
// - server -> client request
type StreamableHTTP struct {
- serverURL *url.URL
- httpClient *http.Client
- headers map[string]string
- headerFunc HTTPHeaderFunc
+ serverURL *url.URL
+ httpClient *http.Client
+ headers map[string]string
+ headerFunc HTTPHeaderFunc
+ logger util.Logger
+ getListeningEnabled bool
sessionID atomic.Value // string
+ initialized chan struct{}
+ initializedOnce sync.Once
+
notificationHandler func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
@@ -95,15 +125,19 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str
}
smc := &StreamableHTTP{
- serverURL: parsedURL,
- httpClient: &http.Client{},
- headers: make(map[string]string),
- closed: make(chan struct{}),
+ serverURL: parsedURL,
+ httpClient: &http.Client{},
+ headers: make(map[string]string),
+ closed: make(chan struct{}),
+ logger: util.DefaultLogger(),
+ initialized: make(chan struct{}),
}
smc.sessionID.Store("") // set initial value to simplify later usage
for _, opt := range options {
- opt(smc)
+ if opt != nil {
+ opt(smc)
+ }
}
// If OAuth is configured, set the base URL for metadata discovery
@@ -118,7 +152,20 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str
// Start initiates the HTTP connection to the server.
func (c *StreamableHTTP) Start(ctx context.Context) error {
- // For Streamable HTTP, we don't need to establish a persistent connection
+ // For Streamable HTTP, we don't need to establish a persistent connection by default
+ if c.getListeningEnabled {
+ go func() {
+ select {
+ case <-c.initialized:
+ ctx, cancel := c.contextAwareOfClientClose(ctx)
+ defer cancel()
+ c.listenForever(ctx)
+ case <-c.closed:
+ return
+ }
+ }()
+ }
+
return nil
}
@@ -142,13 +189,13 @@ func (c *StreamableHTTP) Close() error {
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
if err != nil {
- fmt.Printf("failed to create close request\n: %v", err)
+ c.logger.Errorf("failed to create close request: %v", err)
return
}
req.Header.Set(headerKeySessionID, sessionId)
res, err := c.httpClient.Do(req)
if err != nil {
- fmt.Printf("failed to send close request\n: %v", err)
+ c.logger.Errorf("failed to send close request: %v", err)
return
}
res.Body.Close()
@@ -185,77 +232,29 @@ func (c *StreamableHTTP) SendRequest(
request JSONRPCRequest,
) (*JSONRPCResponse, error) {
- // Create a combined context that could be canceled when the client is closed
- newCtx, cancel := context.WithCancel(ctx)
- defer cancel()
- go func() {
- select {
- case <-c.closed:
- cancel()
- case <-newCtx.Done():
- // The original context was canceled, no need to do anything
- }
- }()
- ctx = newCtx
-
// Marshal request
requestBody, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
- // Create HTTP request
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
-
- // Set headers
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json, text/event-stream")
- sessionID := c.sessionID.Load()
- if sessionID != "" {
- req.Header.Set(headerKeySessionID, sessionID.(string))
- }
- for k, v := range c.headers {
- req.Header.Set(k, v)
- }
-
- // Add OAuth authorization if configured
- if c.oauthHandler != nil {
- authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
- if err != nil {
- // If we get an authorization error, return a specific error that can be handled by the client
- if err.Error() == "no valid token available, authorization required" {
- return nil, &OAuthAuthorizationRequiredError{
- Handler: c.oauthHandler,
- }
- }
- return nil, fmt.Errorf("failed to get authorization header: %w", err)
- }
- req.Header.Set("Authorization", authHeader)
- }
-
- if c.headerFunc != nil {
- for k, v := range c.headerFunc(ctx) {
- req.Header.Set(k, v)
- }
- }
+ ctx, cancel := c.contextAwareOfClientClose(ctx)
+ defer cancel()
- // Send request
- resp, err := c.httpClient.Do(req)
+ resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
if err != nil {
- return nil, fmt.Errorf("failed to send request: %w", err)
+ if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
+ // If the request is initialize, should not return a SessionTerminated error
+ // It should be a genuine endpoint-routing issue.
+ // ( Fall through to return StatusCode checking. )
+ } else {
+ return nil, fmt.Errorf("failed to send request: %w", err)
+ }
}
defer resp.Body.Close()
// Check if we got an error response
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
- // handle session closed
- if resp.StatusCode == http.StatusNotFound {
- c.sessionID.CompareAndSwap(sessionID, "")
- return nil, fmt.Errorf("session terminated (404). need to re-initialize")
- }
// Handle OAuth unauthorized error
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
@@ -279,6 +278,10 @@ func (c *StreamableHTTP) SendRequest(
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
c.sessionID.Store(sessionID)
}
+
+ c.initializedOnce.Do(func() {
+ close(c.initialized)
+ })
}
// Handle different response types
@@ -300,16 +303,77 @@ func (c *StreamableHTTP) SendRequest(
case "text/event-stream":
// Server is using SSE for streaming responses
- return c.handleSSEResponse(ctx, resp.Body)
+ return c.handleSSEResponse(ctx, resp.Body, false)
default:
return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
}
}
+func (c *StreamableHTTP) sendHTTP(
+ ctx context.Context,
+ method string,
+ body io.Reader,
+ acceptType string,
+) (resp *http.Response, err error) {
+
+ // Create HTTP request
+ req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ // Set headers
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", acceptType)
+ sessionID := c.sessionID.Load().(string)
+ if sessionID != "" {
+ req.Header.Set(headerKeySessionID, sessionID)
+ }
+ for k, v := range c.headers {
+ req.Header.Set(k, v)
+ }
+
+ // Add OAuth authorization if configured
+ if c.oauthHandler != nil {
+ authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
+ if err != nil {
+ // If we get an authorization error, return a specific error that can be handled by the client
+ if err.Error() == "no valid token available, authorization required" {
+ return nil, &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+ return nil, fmt.Errorf("failed to get authorization header: %w", err)
+ }
+ req.Header.Set("Authorization", authHeader)
+ }
+
+ if c.headerFunc != nil {
+ for k, v := range c.headerFunc(ctx) {
+ req.Header.Set(k, v)
+ }
+ }
+
+ // Send request
+ resp, err = c.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to send request: %w", err)
+ }
+
+ // universal handling for session terminated
+ if resp.StatusCode == http.StatusNotFound {
+ c.sessionID.CompareAndSwap(sessionID, "")
+ return nil, ErrSessionTerminated
+ }
+
+ return resp, nil
+}
+
// handleSSEResponse processes an SSE stream for a specific request.
// It returns the final result for the request once received, or an error.
-func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
+// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening.
+func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) {
// Create a channel for this specific request
responseChan := make(chan *JSONRPCResponse, 1)
@@ -328,7 +392,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
var message JSONRPCResponse
if err := json.Unmarshal([]byte(data), &message); err != nil {
- fmt.Printf("failed to unmarshal message: %v\n", err)
+ c.logger.Errorf("failed to unmarshal message: %v", err)
return
}
@@ -336,7 +400,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
if message.ID.IsNil() {
var notification mcp.JSONRPCNotification
if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
- fmt.Printf("failed to unmarshal notification: %v\n", err)
+ c.logger.Errorf("failed to unmarshal notification: %v", err)
return
}
c.notifyMu.RLock()
@@ -347,7 +411,9 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
return
}
- responseChan <- &message
+ if !ignoreResponse {
+ responseChan <- &message
+ }
})
}()
@@ -393,7 +459,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
case <-ctx.Done():
return
default:
- fmt.Printf("SSE stream error: %v\n", err)
+ c.logger.Errorf("SSE stream error: %v", err)
return
}
}
@@ -432,44 +498,10 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
}
// Create HTTP request
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
-
- // Set headers
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json, text/event-stream")
- if sessionID := c.sessionID.Load(); sessionID != "" {
- req.Header.Set(headerKeySessionID, sessionID.(string))
- }
- for k, v := range c.headers {
- req.Header.Set(k, v)
- }
-
- // Add OAuth authorization if configured
- if c.oauthHandler != nil {
- authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
- if err != nil {
- // If we get an authorization error, return a specific error that can be handled by the client
- if errors.Is(err, ErrOAuthAuthorizationRequired) {
- return &OAuthAuthorizationRequiredError{
- Handler: c.oauthHandler,
- }
- }
- return fmt.Errorf("failed to get authorization header: %w", err)
- }
- req.Header.Set("Authorization", authHeader)
- }
-
- if c.headerFunc != nil {
- for k, v := range c.headerFunc(ctx) {
- req.Header.Set(k, v)
- }
- }
+ ctx, cancel := c.contextAwareOfClientClose(ctx)
+ defer cancel()
- // Send request
- resp, err := c.httpClient.Do(req)
+ resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
@@ -513,3 +545,84 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler {
func (c *StreamableHTTP) IsOAuthEnabled() bool {
return c.oauthHandler != nil
}
+
+func (c *StreamableHTTP) listenForever(ctx context.Context) {
+ c.logger.Infof("listening to server forever")
+ for {
+ err := c.createGETConnectionToServer(ctx)
+ if errors.Is(err, ErrGetMethodNotAllowed) {
+ // server does not support listening
+ c.logger.Errorf("server does not support listening")
+ return
+ }
+
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+
+ if err != nil {
+ c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
+ }
+ time.Sleep(retryInterval)
+ }
+}
+
+var (
+ ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize")
+ ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
+
+ retryInterval = 1 * time.Second // a variable is convenient for testing
+)
+
+func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error {
+
+ resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
+ if err != nil {
+ return fmt.Errorf("failed to send request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ // Check if we got an error response
+ if resp.StatusCode == http.StatusMethodNotAllowed {
+ return ErrGetMethodNotAllowed
+ }
+
+ if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
+ body, _ := io.ReadAll(resp.Body)
+ return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
+ }
+
+ // handle SSE response
+ contentType := resp.Header.Get("Content-Type")
+ if contentType != "text/event-stream" {
+ return fmt.Errorf("unexpected content type: %s", contentType)
+ }
+
+ // When ignoreResponse is true, the function will never return expect context is done.
+ // NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response
+ // messages. To be more compatible, we should handle this response, however, as the transport layer is message-based,
+ // currently, there is no convenient way to handle this response.
+ // So we ignore the response here. It's not a bug, but may be not compatible with other SDKs.
+ _, err = c.handleSSEResponse(ctx, resp.Body, true)
+ if err != nil {
+ return fmt.Errorf("failed to handle SSE response: %w", err)
+ }
+
+ return nil
+}
+
+func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) {
+ newCtx, cancel := context.WithCancel(ctx)
+ go func() {
+ select {
+ case <-c.closed:
+ cancel()
+ case <-newCtx.Done():
+ // The original context was canceled
+ cancel()
+ }
+ }()
+ return newCtx, cancel
+}
@@ -945,7 +945,20 @@ func PropertyNames(schema map[string]any) PropertyOption {
}
}
-// Items defines the schema for array items
+// Items defines the schema for array items.
+// Accepts any schema definition for maximum flexibility.
+//
+// Example:
+//
+// Items(map[string]any{
+// "type": "object",
+// "properties": map[string]any{
+// "name": map[string]any{"type": "string"},
+// "age": map[string]any{"type": "number"},
+// },
+// })
+//
+// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead.
func Items(schema any) PropertyOption {
return func(schemaMap map[string]any) {
schemaMap["items"] = schema
@@ -972,3 +985,94 @@ func UniqueItems(unique bool) PropertyOption {
schema["uniqueItems"] = unique
}
}
+
+// WithStringItems configures an array's items to be of type string.
+//
+// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern()
+// Note: Options like Required() are not valid for item schemas and will be ignored.
+//
+// Examples:
+//
+// mcp.WithArray("tags", mcp.WithStringItems())
+// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue")))
+// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50)))
+//
+// Limitations: Only supports simple string arrays. Use Items() for complex objects.
+func WithStringItems(opts ...PropertyOption) PropertyOption {
+ return func(schema map[string]any) {
+ itemSchema := map[string]any{
+ "type": "string",
+ }
+
+ for _, opt := range opts {
+ opt(itemSchema)
+ }
+
+ schema["items"] = itemSchema
+ }
+}
+
+// WithStringEnumItems configures an array's items to be of type string with a specified enum.
+// Example:
+//
+// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"}))
+//
+// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility.
+func WithStringEnumItems(values []string) PropertyOption {
+ return func(schema map[string]any) {
+ schema["items"] = map[string]any{
+ "type": "string",
+ "enum": values,
+ }
+ }
+}
+
+// WithNumberItems configures an array's items to be of type number.
+//
+// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf()
+// Note: Options like Required() are not valid for item schemas and will be ignored.
+//
+// Examples:
+//
+// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100)))
+// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0)))
+//
+// Limitations: Only supports simple number arrays. Use Items() for complex objects.
+func WithNumberItems(opts ...PropertyOption) PropertyOption {
+ return func(schema map[string]any) {
+ itemSchema := map[string]any{
+ "type": "number",
+ }
+
+ for _, opt := range opts {
+ opt(itemSchema)
+ }
+
+ schema["items"] = itemSchema
+ }
+}
+
+// WithBooleanItems configures an array's items to be of type boolean.
+//
+// Supported options: Description(), DefaultBool()
+// Note: Options like Required() are not valid for item schemas and will be ignored.
+//
+// Examples:
+//
+// mcp.WithArray("flags", mcp.WithBooleanItems())
+// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions")))
+//
+// Limitations: Only supports simple boolean arrays. Use Items() for complex objects.
+func WithBooleanItems(opts ...PropertyOption) PropertyOption {
+ return func(schema map[string]any) {
+ itemSchema := map[string]any{
+ "type": "boolean",
+ }
+
+ for _, opt := range opts {
+ opt(itemSchema)
+ }
+
+ schema["items"] = itemSchema
+ }
+}
@@ -763,6 +763,11 @@ const (
/* Sampling */
+const (
+ // MethodSamplingCreateMessage allows servers to request LLM completions from clients
+ MethodSamplingCreateMessage MCPMethod = "sampling/createMessage"
+)
+
// CreateMessageRequest is a request from the server to sample an LLM via the
// client. The client has full discretion over which model to select. The client
// should also inform the user before beginning sampling, to allow them to inspect
@@ -865,6 +870,22 @@ type AudioContent struct {
func (AudioContent) isContent() {}
+// ResourceLink represents a link to a resource that the client can access.
+type ResourceLink struct {
+ Annotated
+ Type string `json:"type"` // Must be "resource_link"
+ // The URI of the resource.
+ URI string `json:"uri"`
+ // The name of the resource.
+ Name string `json:"name"`
+ // The description of the resource.
+ Description string `json:"description"`
+ // The MIME type of the resource.
+ MIMEType string `json:"mimeType"`
+}
+
+func (ResourceLink) isContent() {}
+
// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result.
//
// It is up to the client how best to render embedded resources for the
@@ -222,6 +222,17 @@ func NewAudioContent(data, mimeType string) AudioContent {
}
}
+// Helper function to create a new ResourceLink
+func NewResourceLink(uri, name, description, mimeType string) ResourceLink {
+ return ResourceLink{
+ Type: "resource_link",
+ URI: uri,
+ Name: name,
+ Description: description,
+ MIMEType: mimeType,
+ }
+}
+
// Helper function to create a new EmbeddedResource
func NewEmbeddedResource(resource ResourceContents) EmbeddedResource {
return EmbeddedResource{
@@ -476,6 +487,16 @@ func ParseContent(contentMap map[string]any) (Content, error) {
}
return NewAudioContent(data, mimeType), nil
+ case "resource_link":
+ uri := ExtractString(contentMap, "uri")
+ name := ExtractString(contentMap, "name")
+ description := ExtractString(contentMap, "description")
+ mimeType := ExtractString(contentMap, "mimeType")
+ if uri == "" || name == "" {
+ return nil, fmt.Errorf("resource_link uri or name is missing")
+ }
+ return NewResourceLink(uri, name, description, mimeType), nil
+
case "resource":
resourceMap := ExtractMap(contentMap, "resource")
if resourceMap == nil {
@@ -0,0 +1,37 @@
+package server
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// EnableSampling enables sampling capabilities for the server.
+// This allows the server to send sampling requests to clients that support it.
+func (s *MCPServer) EnableSampling() {
+ s.capabilitiesMu.Lock()
+ defer s.capabilitiesMu.Unlock()
+}
+
+// RequestSampling sends a sampling request to the client.
+// The client must have declared sampling capability during initialization.
+func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
+ session := ClientSessionFromContext(ctx)
+ if session == nil {
+ return nil, fmt.Errorf("no active session")
+ }
+
+ // Check if the session supports sampling requests
+ if samplingSession, ok := session.(SessionWithSampling); ok {
+ return samplingSession.RequestSampling(ctx, request)
+ }
+
+ return nil, fmt.Errorf("session does not support sampling")
+}
+
+// SessionWithSampling extends ClientSession to support sampling requests.
+type SessionWithSampling interface {
+ ClientSession
+ RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
+}
@@ -9,6 +9,7 @@ import (
"log"
"os"
"os/signal"
+ "sync"
"sync/atomic"
"syscall"
@@ -51,10 +52,21 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
// stdioSession is a static client session, since stdio has only one client.
type stdioSession struct {
- notifications chan mcp.JSONRPCNotification
- initialized atomic.Bool
- loggingLevel atomic.Value
- clientInfo atomic.Value // stores session-specific client info
+ notifications chan mcp.JSONRPCNotification
+ initialized atomic.Bool
+ loggingLevel atomic.Value
+ clientInfo atomic.Value // stores session-specific client info
+ writer io.Writer // for sending requests to client
+ requestID atomic.Int64 // for generating unique request IDs
+ mu sync.RWMutex // protects writer
+ pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests
+ pendingMu sync.RWMutex // protects pendingRequests
+}
+
+// samplingResponse represents a response to a sampling request
+type samplingResponse struct {
+ result *mcp.CreateMessageResult
+ err error
}
func (s *stdioSession) SessionID() string {
@@ -100,14 +112,86 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
return level.(mcp.LoggingLevel)
}
+// RequestSampling sends a sampling request to the client and waits for the response.
+func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
+ s.mu.RLock()
+ writer := s.writer
+ s.mu.RUnlock()
+
+ if writer == nil {
+ return nil, fmt.Errorf("no writer available for sending requests")
+ }
+
+ // Generate a unique request ID
+ id := s.requestID.Add(1)
+
+ // Create a response channel for this request
+ responseChan := make(chan *samplingResponse, 1)
+ s.pendingMu.Lock()
+ s.pendingRequests[id] = responseChan
+ s.pendingMu.Unlock()
+
+ // Cleanup function to remove the pending request
+ cleanup := func() {
+ s.pendingMu.Lock()
+ delete(s.pendingRequests, id)
+ s.pendingMu.Unlock()
+ }
+ defer cleanup()
+
+ // Create the JSON-RPC request
+ jsonRPCRequest := struct {
+ JSONRPC string `json:"jsonrpc"`
+ ID int64 `json:"id"`
+ Method string `json:"method"`
+ Params mcp.CreateMessageParams `json:"params"`
+ }{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: id,
+ Method: string(mcp.MethodSamplingCreateMessage),
+ Params: request.CreateMessageParams,
+ }
+
+ // Marshal and send the request
+ requestBytes, err := json.Marshal(jsonRPCRequest)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal sampling request: %w", err)
+ }
+ requestBytes = append(requestBytes, '\n')
+
+ if _, err := writer.Write(requestBytes); err != nil {
+ return nil, fmt.Errorf("failed to write sampling request: %w", err)
+ }
+
+ // Wait for the response or context cancellation
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case response := <-responseChan:
+ if response.err != nil {
+ return nil, response.err
+ }
+ return response.result, nil
+ }
+}
+
+// SetWriter sets the writer for sending requests to the client.
+func (s *stdioSession) SetWriter(writer io.Writer) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.writer = writer
+}
+
var (
_ ClientSession = (*stdioSession)(nil)
_ SessionWithLogging = (*stdioSession)(nil)
_ SessionWithClientInfo = (*stdioSession)(nil)
+ _ SessionWithSampling = (*stdioSession)(nil)
)
var stdioSessionInstance = stdioSession{
- notifications: make(chan mcp.JSONRPCNotification, 100),
+ notifications: make(chan mcp.JSONRPCNotification, 100),
+ pendingRequests: make(map[int64]chan *samplingResponse),
}
// NewStdioServer creates a new stdio server wrapper around an MCPServer.
@@ -224,6 +308,9 @@ func (s *StdioServer) Listen(
defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
ctx = s.server.WithContext(ctx, &stdioSessionInstance)
+ // Set the writer for sending requests to the client
+ stdioSessionInstance.SetWriter(stdout)
+
// Add in any custom context.
if s.contextFunc != nil {
ctx = s.contextFunc(ctx)
@@ -256,7 +343,29 @@ func (s *StdioServer) processMessage(
return s.writeResponse(response, writer)
}
- // Handle the message using the wrapped server
+ // Check if this is a response to a sampling request
+ if s.handleSamplingResponse(rawMessage) {
+ return nil
+ }
+
+ // Check if this is a tool call that might need sampling (and thus should be processed concurrently)
+ var baseMessage struct {
+ Method string `json:"method"`
+ }
+ if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" {
+ // Process tool calls concurrently to avoid blocking on sampling requests
+ go func() {
+ response := s.server.HandleMessage(ctx, rawMessage)
+ if response != nil {
+ if err := s.writeResponse(response, writer); err != nil {
+ s.errLogger.Printf("Error writing tool response: %v", err)
+ }
+ }
+ }()
+ return nil
+ }
+
+ // Handle other messages synchronously
response := s.server.HandleMessage(ctx, rawMessage)
// Only write response if there is one (not for notifications)
@@ -269,6 +378,65 @@ func (s *StdioServer) processMessage(
return nil
}
+// handleSamplingResponse checks if the message is a response to a sampling request
+// and routes it to the appropriate pending request channel.
+func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool {
+ return stdioSessionInstance.handleSamplingResponse(rawMessage)
+}
+
+// handleSamplingResponse handles incoming sampling responses for this session
+func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool {
+ // Try to parse as a JSON-RPC response
+ var response struct {
+ JSONRPC string `json:"jsonrpc"`
+ ID json.Number `json:"id"`
+ Result json.RawMessage `json:"result,omitempty"`
+ Error *struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ } `json:"error,omitempty"`
+ }
+
+ if err := json.Unmarshal(rawMessage, &response); err != nil {
+ return false
+ }
+ // Parse the ID as int64
+ idInt64, err := response.ID.Int64()
+ if err != nil || (response.Result == nil && response.Error == nil) {
+ return false
+ }
+
+ // Look for a pending request with this ID
+ s.pendingMu.RLock()
+ responseChan, exists := s.pendingRequests[idInt64]
+ s.pendingMu.RUnlock()
+
+ if !exists {
+ return false
+ } // Parse and send the response
+ samplingResp := &samplingResponse{}
+
+ if response.Error != nil {
+ samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message)
+ } else {
+ var result mcp.CreateMessageResult
+ if err := json.Unmarshal(response.Result, &result); err != nil {
+ samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err)
+ } else {
+ samplingResp.result = &result
+ }
+ }
+
+ // Send the response (non-blocking)
+ select {
+ case responseChan <- samplingResp:
+ default:
+ // Channel is full or closed, ignore
+ }
+
+ return true
+}
+
// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
// Returns an error if marshaling or writing fails.
func (s *StdioServer) writeResponse(
@@ -40,7 +40,9 @@ func WithEndpointPath(endpointPath string) StreamableHTTPOption {
// to StatelessSessionIdManager.
func WithStateLess(stateLess bool) StreamableHTTPOption {
return func(s *StreamableHTTPServer) {
- s.sessionIdManager = &StatelessSessionIdManager{}
+ if stateLess {
+ s.sessionIdManager = &StatelessSessionIdManager{}
+ }
}
}
@@ -374,7 +376,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
- w.WriteHeader(http.StatusAccepted)
+ w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
@@ -403,7 +403,7 @@ github.com/kylelemons/godebug/pretty
# github.com/lucasb-eyer/go-colorful v1.2.0
## explicit; go 1.12
github.com/lucasb-eyer/go-colorful
-# github.com/mark3labs/mcp-go v0.32.0
+# github.com/mark3labs/mcp-go v0.33.0
## explicit; go 1.23
github.com/mark3labs/mcp-go/client
github.com/mark3labs/mcp-go/client/transport