stdio.go

  1package transport
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"os"
 10	"os/exec"
 11	"sync"
 12
 13	"github.com/mark3labs/mcp-go/mcp"
 14)
 15
 16// Stdio implements the transport layer of the MCP protocol using stdio communication.
 17// It launches a subprocess and communicates with it via standard input/output streams
 18// using JSON-RPC messages. The client handles message routing between requests and
 19// responses, and supports asynchronous notifications.
 20type Stdio struct {
 21	command string
 22	args    []string
 23	env     []string
 24
 25	cmd            *exec.Cmd
 26	cmdFunc        CommandFunc
 27	stdin          io.WriteCloser
 28	stdout         *bufio.Reader
 29	stderr         io.ReadCloser
 30	responses      map[string]chan *JSONRPCResponse
 31	mu             sync.RWMutex
 32	done           chan struct{}
 33	onNotification func(mcp.JSONRPCNotification)
 34	notifyMu       sync.RWMutex
 35	onRequest      RequestHandler
 36	requestMu      sync.RWMutex
 37	ctx            context.Context
 38	ctxMu          sync.RWMutex
 39}
 40
 41// StdioOption defines a function that configures a Stdio transport instance.
 42// Options can be used to customize the behavior of the transport before it starts,
 43// such as setting a custom command function.
 44type StdioOption func(*Stdio)
 45
 46// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess.
 47// It can be used to apply sandboxing, custom environment control, working directories, etc.
 48type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error)
 49
 50// WithCommandFunc sets a custom command factory function for the stdio transport.
 51// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess,
 52// allowing control over attributes like environment, working directory, and system-level sandboxing.
 53func WithCommandFunc(f CommandFunc) StdioOption {
 54	return func(s *Stdio) {
 55		s.cmdFunc = f
 56	}
 57}
 58
 59// NewIO returns a new stdio-based transport using existing input, output, and
 60// logging streams instead of spawning a subprocess.
 61// This is useful for testing and simulating client behavior.
 62func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio {
 63	return &Stdio{
 64		stdin:  output,
 65		stdout: bufio.NewReader(input),
 66		stderr: logging,
 67
 68		responses: make(map[string]chan *JSONRPCResponse),
 69		done:      make(chan struct{}),
 70		ctx:       context.Background(),
 71	}
 72}
 73
 74// NewStdio creates a new stdio transport to communicate with a subprocess.
 75// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
 76// Returns an error if the subprocess cannot be started or the pipes cannot be created.
 77func NewStdio(
 78	command string,
 79	env []string,
 80	args ...string,
 81) *Stdio {
 82	return NewStdioWithOptions(command, env, args)
 83}
 84
 85// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess.
 86// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
 87// Returns an error if the subprocess cannot be started or the pipes cannot be created.
 88// Optional configuration functions can be provided to customize the transport before it starts,
 89// such as setting a custom command factory.
 90func NewStdioWithOptions(
 91	command string,
 92	env []string,
 93	args []string,
 94	opts ...StdioOption,
 95) *Stdio {
 96	s := &Stdio{
 97		command: command,
 98		args:    args,
 99		env:     env,
100
101		responses: make(map[string]chan *JSONRPCResponse),
102		done:      make(chan struct{}),
103		ctx:       context.Background(),
104	}
105
106	for _, opt := range opts {
107		opt(s)
108	}
109
110	return s
111}
112
113func (c *Stdio) Start(ctx context.Context) error {
114	// Store the context for use in request handling
115	c.ctxMu.Lock()
116	c.ctx = ctx
117	c.ctxMu.Unlock()
118
119	if err := c.spawnCommand(ctx); err != nil {
120		return err
121	}
122
123	ready := make(chan struct{})
124	go func() {
125		close(ready)
126		c.readResponses()
127	}()
128	<-ready
129
130	return nil
131}
132
133// spawnCommand spawns a new process running the configured command, args, and env.
134// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess;
135// otherwise, the default behavior uses exec.CommandContext with the merged environment.
136// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication.
137func (c *Stdio) spawnCommand(ctx context.Context) error {
138	if c.command == "" {
139		return nil
140	}
141
142	var cmd *exec.Cmd
143	var err error
144
145	// Standard behavior if no command func present.
146	if c.cmdFunc == nil {
147		cmd = exec.CommandContext(ctx, c.command, c.args...)
148		cmd.Env = append(os.Environ(), c.env...)
149	} else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil {
150		return err
151	}
152
153	stdin, err := cmd.StdinPipe()
154	if err != nil {
155		return fmt.Errorf("failed to create stdin pipe: %w", err)
156	}
157
158	stdout, err := cmd.StdoutPipe()
159	if err != nil {
160		return fmt.Errorf("failed to create stdout pipe: %w", err)
161	}
162
163	stderr, err := cmd.StderrPipe()
164	if err != nil {
165		return fmt.Errorf("failed to create stderr pipe: %w", err)
166	}
167
168	c.cmd = cmd
169	c.stdin = stdin
170	c.stderr = stderr
171	c.stdout = bufio.NewReader(stdout)
172
173	if err := cmd.Start(); err != nil {
174		return fmt.Errorf("failed to start command: %w", err)
175	}
176
177	return nil
178}
179
180// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
181// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
182func (c *Stdio) Close() error {
183	select {
184	case <-c.done:
185		return nil
186	default:
187	}
188	// cancel all in-flight request
189	close(c.done)
190
191	if err := c.stdin.Close(); err != nil {
192		return fmt.Errorf("failed to close stdin: %w", err)
193	}
194	if err := c.stderr.Close(); err != nil {
195		return fmt.Errorf("failed to close stderr: %w", err)
196	}
197
198	if c.cmd != nil {
199		return c.cmd.Wait()
200	}
201
202	return nil
203}
204
205// GetSessionId returns the session ID of the transport.
206// Since stdio does not maintain a session ID, it returns an empty string.
207func (c *Stdio) GetSessionId() string {
208	return ""
209}
210
211// SetNotificationHandler sets the handler function to be called when a notification is received.
212// Only one handler can be set at a time; setting a new one replaces the previous handler.
213func (c *Stdio) SetNotificationHandler(
214	handler func(notification mcp.JSONRPCNotification),
215) {
216	c.notifyMu.Lock()
217	defer c.notifyMu.Unlock()
218	c.onNotification = handler
219}
220
221// SetRequestHandler sets the handler function to be called when a request is received from the server.
222// This enables bidirectional communication for features like sampling.
223func (c *Stdio) SetRequestHandler(handler RequestHandler) {
224	c.requestMu.Lock()
225	defer c.requestMu.Unlock()
226	c.onRequest = handler
227}
228
229// readResponses continuously reads and processes responses from the server's stdout.
230// It handles both responses to requests and notifications, routing them appropriately.
231// Runs until the done channel is closed or an error occurs reading from stdout.
232func (c *Stdio) readResponses() {
233	for {
234		select {
235		case <-c.done:
236			return
237		default:
238			line, err := c.stdout.ReadString('\n')
239			if err != nil {
240				if err != io.EOF {
241					fmt.Printf("Error reading response: %v\n", err)
242				}
243				return
244			}
245
246			// First try to parse as a generic message to check for ID field
247			var baseMessage struct {
248				JSONRPC string         `json:"jsonrpc"`
249				ID      *mcp.RequestId `json:"id,omitempty"`
250				Method  string         `json:"method,omitempty"`
251			}
252			if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
253				continue
254			}
255
256			// If it has a method but no ID, it's a notification
257			if baseMessage.Method != "" && baseMessage.ID == nil {
258				var notification mcp.JSONRPCNotification
259				if err := json.Unmarshal([]byte(line), &notification); err != nil {
260					continue
261				}
262				c.notifyMu.RLock()
263				if c.onNotification != nil {
264					c.onNotification(notification)
265				}
266				c.notifyMu.RUnlock()
267				continue
268			}
269
270			// If it has a method and an ID, it's an incoming request
271			if baseMessage.Method != "" && baseMessage.ID != nil {
272				var request JSONRPCRequest
273				if err := json.Unmarshal([]byte(line), &request); err == nil {
274					c.handleIncomingRequest(request)
275					continue
276				}
277			}
278
279			// Otherwise, it's a response to our request
280			var response JSONRPCResponse
281			if err := json.Unmarshal([]byte(line), &response); err != nil {
282				continue
283			}
284
285			// Create string key for map lookup
286			idKey := response.ID.String()
287
288			c.mu.RLock()
289			ch, exists := c.responses[idKey]
290			c.mu.RUnlock()
291
292			if exists {
293				ch <- &response
294				c.mu.Lock()
295				delete(c.responses, idKey)
296				c.mu.Unlock()
297			}
298		}
299	}
300}
301
302// SendRequest sends a JSON-RPC request to the server and waits for a response.
303// It creates a unique request ID, sends the request over stdin, and waits for
304// the corresponding response or context cancellation.
305// Returns the raw JSON response message or an error if the request fails.
306func (c *Stdio) SendRequest(
307	ctx context.Context,
308	request JSONRPCRequest,
309) (*JSONRPCResponse, error) {
310	if c.stdin == nil {
311		return nil, fmt.Errorf("stdio client not started")
312	}
313
314	// Marshal request
315	requestBytes, err := json.Marshal(request)
316	if err != nil {
317		return nil, fmt.Errorf("failed to marshal request: %w", err)
318	}
319	requestBytes = append(requestBytes, '\n')
320
321	// Create string key for map lookup
322	idKey := request.ID.String()
323
324	// Register response channel
325	responseChan := make(chan *JSONRPCResponse, 1)
326	c.mu.Lock()
327	c.responses[idKey] = responseChan
328	c.mu.Unlock()
329	deleteResponseChan := func() {
330		c.mu.Lock()
331		delete(c.responses, idKey)
332		c.mu.Unlock()
333	}
334
335	// Send request
336	if _, err := c.stdin.Write(requestBytes); err != nil {
337		deleteResponseChan()
338		return nil, fmt.Errorf("failed to write request: %w", err)
339	}
340
341	select {
342	case <-ctx.Done():
343		deleteResponseChan()
344		return nil, ctx.Err()
345	case response := <-responseChan:
346		return response, nil
347	}
348}
349
350// SendNotification sends a json RPC Notification to the server.
351func (c *Stdio) SendNotification(
352	ctx context.Context,
353	notification mcp.JSONRPCNotification,
354) error {
355	if c.stdin == nil {
356		return fmt.Errorf("stdio client not started")
357	}
358
359	notificationBytes, err := json.Marshal(notification)
360	if err != nil {
361		return fmt.Errorf("failed to marshal notification: %w", err)
362	}
363	notificationBytes = append(notificationBytes, '\n')
364
365	if _, err := c.stdin.Write(notificationBytes); err != nil {
366		return fmt.Errorf("failed to write notification: %w", err)
367	}
368
369	return nil
370}
371
372// handleIncomingRequest processes incoming requests from the server.
373// It calls the registered request handler and sends the response back to the server.
374func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) {
375	c.requestMu.RLock()
376	handler := c.onRequest
377	c.requestMu.RUnlock()
378
379	if handler == nil {
380		// Send error response if no handler is configured
381		errorResponse := JSONRPCResponse{
382			JSONRPC: mcp.JSONRPC_VERSION,
383			ID:      request.ID,
384			Error: &struct {
385				Code    int             `json:"code"`
386				Message string          `json:"message"`
387				Data    json.RawMessage `json:"data"`
388			}{
389				Code:    mcp.METHOD_NOT_FOUND,
390				Message: "No request handler configured",
391			},
392		}
393		c.sendResponse(errorResponse)
394		return
395	}
396
397	// Handle the request in a goroutine to avoid blocking
398	go func() {
399		c.ctxMu.RLock()
400		ctx := c.ctx
401		c.ctxMu.RUnlock()
402
403		// Check if context is already cancelled before processing
404		select {
405		case <-ctx.Done():
406			errorResponse := JSONRPCResponse{
407				JSONRPC: mcp.JSONRPC_VERSION,
408				ID:      request.ID,
409				Error: &struct {
410					Code    int             `json:"code"`
411					Message string          `json:"message"`
412					Data    json.RawMessage `json:"data"`
413				}{
414					Code:    mcp.INTERNAL_ERROR,
415					Message: ctx.Err().Error(),
416				},
417			}
418			c.sendResponse(errorResponse)
419			return
420		default:
421		}
422
423		response, err := handler(ctx, request)
424
425		if err != nil {
426			errorResponse := JSONRPCResponse{
427				JSONRPC: mcp.JSONRPC_VERSION,
428				ID:      request.ID,
429				Error: &struct {
430					Code    int             `json:"code"`
431					Message string          `json:"message"`
432					Data    json.RawMessage `json:"data"`
433				}{
434					Code:    mcp.INTERNAL_ERROR,
435					Message: err.Error(),
436				},
437			}
438			c.sendResponse(errorResponse)
439			return
440		}
441
442		if response != nil {
443			c.sendResponse(*response)
444		}
445	}()
446}
447
448// sendResponse sends a response back to the server.
449func (c *Stdio) sendResponse(response JSONRPCResponse) {
450	responseBytes, err := json.Marshal(response)
451	if err != nil {
452		fmt.Printf("Error marshaling response: %v\n", err)
453		return
454	}
455	responseBytes = append(responseBytes, '\n')
456
457	if _, err := c.stdin.Write(responseBytes); err != nil {
458		fmt.Printf("Error writing response: %v\n", err)
459	}
460}
461
462// Stderr returns a reader for the stderr output of the subprocess.
463// This can be used to capture error messages or logs from the subprocess.
464func (c *Stdio) Stderr() io.Reader {
465	return c.stderr
466}