1package transport
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "os"
10 "os/exec"
11 "sync"
12
13 "github.com/mark3labs/mcp-go/mcp"
14)
15
16// Stdio implements the transport layer of the MCP protocol using stdio communication.
17// It launches a subprocess and communicates with it via standard input/output streams
18// using JSON-RPC messages. The client handles message routing between requests and
19// responses, and supports asynchronous notifications.
20type Stdio struct {
21 command string
22 args []string
23 env []string
24
25 cmd *exec.Cmd
26 cmdFunc CommandFunc
27 stdin io.WriteCloser
28 stdout *bufio.Reader
29 stderr io.ReadCloser
30 responses map[string]chan *JSONRPCResponse
31 mu sync.RWMutex
32 done chan struct{}
33 onNotification func(mcp.JSONRPCNotification)
34 notifyMu sync.RWMutex
35 onRequest RequestHandler
36 requestMu sync.RWMutex
37 ctx context.Context
38 ctxMu sync.RWMutex
39}
40
41// StdioOption defines a function that configures a Stdio transport instance.
42// Options can be used to customize the behavior of the transport before it starts,
43// such as setting a custom command function.
44type StdioOption func(*Stdio)
45
46// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess.
47// It can be used to apply sandboxing, custom environment control, working directories, etc.
48type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error)
49
50// WithCommandFunc sets a custom command factory function for the stdio transport.
51// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess,
52// allowing control over attributes like environment, working directory, and system-level sandboxing.
53func WithCommandFunc(f CommandFunc) StdioOption {
54 return func(s *Stdio) {
55 s.cmdFunc = f
56 }
57}
58
59// NewIO returns a new stdio-based transport using existing input, output, and
60// logging streams instead of spawning a subprocess.
61// This is useful for testing and simulating client behavior.
62func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio {
63 return &Stdio{
64 stdin: output,
65 stdout: bufio.NewReader(input),
66 stderr: logging,
67
68 responses: make(map[string]chan *JSONRPCResponse),
69 done: make(chan struct{}),
70 ctx: context.Background(),
71 }
72}
73
74// NewStdio creates a new stdio transport to communicate with a subprocess.
75// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
76// Returns an error if the subprocess cannot be started or the pipes cannot be created.
77func NewStdio(
78 command string,
79 env []string,
80 args ...string,
81) *Stdio {
82 return NewStdioWithOptions(command, env, args)
83}
84
85// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess.
86// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
87// Returns an error if the subprocess cannot be started or the pipes cannot be created.
88// Optional configuration functions can be provided to customize the transport before it starts,
89// such as setting a custom command factory.
90func NewStdioWithOptions(
91 command string,
92 env []string,
93 args []string,
94 opts ...StdioOption,
95) *Stdio {
96 s := &Stdio{
97 command: command,
98 args: args,
99 env: env,
100
101 responses: make(map[string]chan *JSONRPCResponse),
102 done: make(chan struct{}),
103 ctx: context.Background(),
104 }
105
106 for _, opt := range opts {
107 opt(s)
108 }
109
110 return s
111}
112
113func (c *Stdio) Start(ctx context.Context) error {
114 // Store the context for use in request handling
115 c.ctxMu.Lock()
116 c.ctx = ctx
117 c.ctxMu.Unlock()
118
119 if err := c.spawnCommand(ctx); err != nil {
120 return err
121 }
122
123 ready := make(chan struct{})
124 go func() {
125 close(ready)
126 c.readResponses()
127 }()
128 <-ready
129
130 return nil
131}
132
133// spawnCommand spawns a new process running the configured command, args, and env.
134// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess;
135// otherwise, the default behavior uses exec.CommandContext with the merged environment.
136// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication.
137func (c *Stdio) spawnCommand(ctx context.Context) error {
138 if c.command == "" {
139 return nil
140 }
141
142 var cmd *exec.Cmd
143 var err error
144
145 // Standard behavior if no command func present.
146 if c.cmdFunc == nil {
147 cmd = exec.CommandContext(ctx, c.command, c.args...)
148 cmd.Env = append(os.Environ(), c.env...)
149 } else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil {
150 return err
151 }
152
153 stdin, err := cmd.StdinPipe()
154 if err != nil {
155 return fmt.Errorf("failed to create stdin pipe: %w", err)
156 }
157
158 stdout, err := cmd.StdoutPipe()
159 if err != nil {
160 return fmt.Errorf("failed to create stdout pipe: %w", err)
161 }
162
163 stderr, err := cmd.StderrPipe()
164 if err != nil {
165 return fmt.Errorf("failed to create stderr pipe: %w", err)
166 }
167
168 c.cmd = cmd
169 c.stdin = stdin
170 c.stderr = stderr
171 c.stdout = bufio.NewReader(stdout)
172
173 if err := cmd.Start(); err != nil {
174 return fmt.Errorf("failed to start command: %w", err)
175 }
176
177 return nil
178}
179
180// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
181// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
182func (c *Stdio) Close() error {
183 select {
184 case <-c.done:
185 return nil
186 default:
187 }
188 // cancel all in-flight request
189 close(c.done)
190
191 if err := c.stdin.Close(); err != nil {
192 return fmt.Errorf("failed to close stdin: %w", err)
193 }
194 if err := c.stderr.Close(); err != nil {
195 return fmt.Errorf("failed to close stderr: %w", err)
196 }
197
198 if c.cmd != nil {
199 return c.cmd.Wait()
200 }
201
202 return nil
203}
204
205// GetSessionId returns the session ID of the transport.
206// Since stdio does not maintain a session ID, it returns an empty string.
207func (c *Stdio) GetSessionId() string {
208 return ""
209}
210
211// SetNotificationHandler sets the handler function to be called when a notification is received.
212// Only one handler can be set at a time; setting a new one replaces the previous handler.
213func (c *Stdio) SetNotificationHandler(
214 handler func(notification mcp.JSONRPCNotification),
215) {
216 c.notifyMu.Lock()
217 defer c.notifyMu.Unlock()
218 c.onNotification = handler
219}
220
221// SetRequestHandler sets the handler function to be called when a request is received from the server.
222// This enables bidirectional communication for features like sampling.
223func (c *Stdio) SetRequestHandler(handler RequestHandler) {
224 c.requestMu.Lock()
225 defer c.requestMu.Unlock()
226 c.onRequest = handler
227}
228
229// readResponses continuously reads and processes responses from the server's stdout.
230// It handles both responses to requests and notifications, routing them appropriately.
231// Runs until the done channel is closed or an error occurs reading from stdout.
232func (c *Stdio) readResponses() {
233 for {
234 select {
235 case <-c.done:
236 return
237 default:
238 line, err := c.stdout.ReadString('\n')
239 if err != nil {
240 if err != io.EOF {
241 fmt.Printf("Error reading response: %v\n", err)
242 }
243 return
244 }
245
246 // First try to parse as a generic message to check for ID field
247 var baseMessage struct {
248 JSONRPC string `json:"jsonrpc"`
249 ID *mcp.RequestId `json:"id,omitempty"`
250 Method string `json:"method,omitempty"`
251 }
252 if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
253 continue
254 }
255
256 // If it has a method but no ID, it's a notification
257 if baseMessage.Method != "" && baseMessage.ID == nil {
258 var notification mcp.JSONRPCNotification
259 if err := json.Unmarshal([]byte(line), ¬ification); err != nil {
260 continue
261 }
262 c.notifyMu.RLock()
263 if c.onNotification != nil {
264 c.onNotification(notification)
265 }
266 c.notifyMu.RUnlock()
267 continue
268 }
269
270 // If it has a method and an ID, it's an incoming request
271 if baseMessage.Method != "" && baseMessage.ID != nil {
272 var request JSONRPCRequest
273 if err := json.Unmarshal([]byte(line), &request); err == nil {
274 c.handleIncomingRequest(request)
275 continue
276 }
277 }
278
279 // Otherwise, it's a response to our request
280 var response JSONRPCResponse
281 if err := json.Unmarshal([]byte(line), &response); err != nil {
282 continue
283 }
284
285 // Create string key for map lookup
286 idKey := response.ID.String()
287
288 c.mu.RLock()
289 ch, exists := c.responses[idKey]
290 c.mu.RUnlock()
291
292 if exists {
293 ch <- &response
294 c.mu.Lock()
295 delete(c.responses, idKey)
296 c.mu.Unlock()
297 }
298 }
299 }
300}
301
302// SendRequest sends a JSON-RPC request to the server and waits for a response.
303// It creates a unique request ID, sends the request over stdin, and waits for
304// the corresponding response or context cancellation.
305// Returns the raw JSON response message or an error if the request fails.
306func (c *Stdio) SendRequest(
307 ctx context.Context,
308 request JSONRPCRequest,
309) (*JSONRPCResponse, error) {
310 if c.stdin == nil {
311 return nil, fmt.Errorf("stdio client not started")
312 }
313
314 // Marshal request
315 requestBytes, err := json.Marshal(request)
316 if err != nil {
317 return nil, fmt.Errorf("failed to marshal request: %w", err)
318 }
319 requestBytes = append(requestBytes, '\n')
320
321 // Create string key for map lookup
322 idKey := request.ID.String()
323
324 // Register response channel
325 responseChan := make(chan *JSONRPCResponse, 1)
326 c.mu.Lock()
327 c.responses[idKey] = responseChan
328 c.mu.Unlock()
329 deleteResponseChan := func() {
330 c.mu.Lock()
331 delete(c.responses, idKey)
332 c.mu.Unlock()
333 }
334
335 // Send request
336 if _, err := c.stdin.Write(requestBytes); err != nil {
337 deleteResponseChan()
338 return nil, fmt.Errorf("failed to write request: %w", err)
339 }
340
341 select {
342 case <-ctx.Done():
343 deleteResponseChan()
344 return nil, ctx.Err()
345 case response := <-responseChan:
346 return response, nil
347 }
348}
349
350// SendNotification sends a json RPC Notification to the server.
351func (c *Stdio) SendNotification(
352 ctx context.Context,
353 notification mcp.JSONRPCNotification,
354) error {
355 if c.stdin == nil {
356 return fmt.Errorf("stdio client not started")
357 }
358
359 notificationBytes, err := json.Marshal(notification)
360 if err != nil {
361 return fmt.Errorf("failed to marshal notification: %w", err)
362 }
363 notificationBytes = append(notificationBytes, '\n')
364
365 if _, err := c.stdin.Write(notificationBytes); err != nil {
366 return fmt.Errorf("failed to write notification: %w", err)
367 }
368
369 return nil
370}
371
372// handleIncomingRequest processes incoming requests from the server.
373// It calls the registered request handler and sends the response back to the server.
374func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) {
375 c.requestMu.RLock()
376 handler := c.onRequest
377 c.requestMu.RUnlock()
378
379 if handler == nil {
380 // Send error response if no handler is configured
381 errorResponse := JSONRPCResponse{
382 JSONRPC: mcp.JSONRPC_VERSION,
383 ID: request.ID,
384 Error: &struct {
385 Code int `json:"code"`
386 Message string `json:"message"`
387 Data json.RawMessage `json:"data"`
388 }{
389 Code: mcp.METHOD_NOT_FOUND,
390 Message: "No request handler configured",
391 },
392 }
393 c.sendResponse(errorResponse)
394 return
395 }
396
397 // Handle the request in a goroutine to avoid blocking
398 go func() {
399 c.ctxMu.RLock()
400 ctx := c.ctx
401 c.ctxMu.RUnlock()
402
403 // Check if context is already cancelled before processing
404 select {
405 case <-ctx.Done():
406 errorResponse := JSONRPCResponse{
407 JSONRPC: mcp.JSONRPC_VERSION,
408 ID: request.ID,
409 Error: &struct {
410 Code int `json:"code"`
411 Message string `json:"message"`
412 Data json.RawMessage `json:"data"`
413 }{
414 Code: mcp.INTERNAL_ERROR,
415 Message: ctx.Err().Error(),
416 },
417 }
418 c.sendResponse(errorResponse)
419 return
420 default:
421 }
422
423 response, err := handler(ctx, request)
424
425 if err != nil {
426 errorResponse := JSONRPCResponse{
427 JSONRPC: mcp.JSONRPC_VERSION,
428 ID: request.ID,
429 Error: &struct {
430 Code int `json:"code"`
431 Message string `json:"message"`
432 Data json.RawMessage `json:"data"`
433 }{
434 Code: mcp.INTERNAL_ERROR,
435 Message: err.Error(),
436 },
437 }
438 c.sendResponse(errorResponse)
439 return
440 }
441
442 if response != nil {
443 c.sendResponse(*response)
444 }
445 }()
446}
447
448// sendResponse sends a response back to the server.
449func (c *Stdio) sendResponse(response JSONRPCResponse) {
450 responseBytes, err := json.Marshal(response)
451 if err != nil {
452 fmt.Printf("Error marshaling response: %v\n", err)
453 return
454 }
455 responseBytes = append(responseBytes, '\n')
456
457 if _, err := c.stdin.Write(responseBytes); err != nil {
458 fmt.Printf("Error writing response: %v\n", err)
459 }
460}
461
462// Stderr returns a reader for the stderr output of the subprocess.
463// This can be used to capture error messages or logs from the subprocess.
464func (c *Stdio) Stderr() io.Reader {
465 return c.stderr
466}