stdio.go

  1package client
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"os"
 11	"os/exec"
 12	"sync"
 13	"sync/atomic"
 14
 15	"github.com/mark3labs/mcp-go/mcp"
 16)
 17
 18// StdioMCPClient implements the MCPClient interface using stdio communication.
 19// It launches a subprocess and communicates with it via standard input/output streams
 20// using JSON-RPC messages. The client handles message routing between requests and
 21// responses, and supports asynchronous notifications.
 22type StdioMCPClient struct {
 23	cmd           *exec.Cmd
 24	stdin         io.WriteCloser
 25	stdout        *bufio.Reader
 26	requestID     atomic.Int64
 27	responses     map[int64]chan RPCResponse
 28	mu            sync.RWMutex
 29	done          chan struct{}
 30	initialized   bool
 31	notifications []func(mcp.JSONRPCNotification)
 32	notifyMu      sync.RWMutex
 33	capabilities  mcp.ServerCapabilities
 34}
 35
 36// NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess.
 37// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
 38// Returns an error if the subprocess cannot be started or the pipes cannot be created.
 39func NewStdioMCPClient(
 40	command string,
 41	env []string,
 42	args ...string,
 43) (*StdioMCPClient, error) {
 44	cmd := exec.Command(command, args...)
 45
 46	mergedEnv := os.Environ()
 47	mergedEnv = append(mergedEnv, env...)
 48
 49	cmd.Env = mergedEnv
 50
 51	stdin, err := cmd.StdinPipe()
 52	if err != nil {
 53		return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
 54	}
 55
 56	stdout, err := cmd.StdoutPipe()
 57	if err != nil {
 58		return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
 59	}
 60
 61	client := &StdioMCPClient{
 62		cmd:       cmd,
 63		stdin:     stdin,
 64		stdout:    bufio.NewReader(stdout),
 65		responses: make(map[int64]chan RPCResponse),
 66		done:      make(chan struct{}),
 67	}
 68
 69	if err := cmd.Start(); err != nil {
 70		return nil, fmt.Errorf("failed to start command: %w", err)
 71	}
 72
 73	// Start reading responses in a goroutine and wait for it to be ready
 74	ready := make(chan struct{})
 75	go func() {
 76		close(ready)
 77		client.readResponses()
 78	}()
 79	<-ready
 80
 81	return client, nil
 82}
 83
 84// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
 85// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
 86func (c *StdioMCPClient) Close() error {
 87	close(c.done)
 88	if err := c.stdin.Close(); err != nil {
 89		return fmt.Errorf("failed to close stdin: %w", err)
 90	}
 91	return c.cmd.Wait()
 92}
 93
 94// OnNotification registers a handler function to be called when notifications are received.
 95// Multiple handlers can be registered and will be called in the order they were added.
 96func (c *StdioMCPClient) OnNotification(
 97	handler func(notification mcp.JSONRPCNotification),
 98) {
 99	c.notifyMu.Lock()
100	defer c.notifyMu.Unlock()
101	c.notifications = append(c.notifications, handler)
102}
103
104// readResponses continuously reads and processes responses from the server's stdout.
105// It handles both responses to requests and notifications, routing them appropriately.
106// Runs until the done channel is closed or an error occurs reading from stdout.
107func (c *StdioMCPClient) readResponses() {
108	for {
109		select {
110		case <-c.done:
111			return
112		default:
113			line, err := c.stdout.ReadString('\n')
114			if err != nil {
115				if err != io.EOF {
116					fmt.Printf("Error reading response: %v\n", err)
117				}
118				return
119			}
120
121			var baseMessage struct {
122				JSONRPC string          `json:"jsonrpc"`
123				ID      *int64          `json:"id,omitempty"`
124				Method  string          `json:"method,omitempty"`
125				Result  json.RawMessage `json:"result,omitempty"`
126				Error   *struct {
127					Code    int    `json:"code"`
128					Message string `json:"message"`
129				} `json:"error,omitempty"`
130			}
131
132			if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
133				continue
134			}
135
136			// Handle notification
137			if baseMessage.ID == nil {
138				var notification mcp.JSONRPCNotification
139				if err := json.Unmarshal([]byte(line), &notification); err != nil {
140					continue
141				}
142				c.notifyMu.RLock()
143				for _, handler := range c.notifications {
144					handler(notification)
145				}
146				c.notifyMu.RUnlock()
147				continue
148			}
149
150			c.mu.RLock()
151			ch, ok := c.responses[*baseMessage.ID]
152			c.mu.RUnlock()
153
154			if ok {
155				if baseMessage.Error != nil {
156					ch <- RPCResponse{
157						Error: &baseMessage.Error.Message,
158					}
159				} else {
160					ch <- RPCResponse{
161						Response: &baseMessage.Result,
162					}
163				}
164				c.mu.Lock()
165				delete(c.responses, *baseMessage.ID)
166				c.mu.Unlock()
167			}
168		}
169	}
170}
171
172// sendRequest sends a JSON-RPC request to the server and waits for a response.
173// It creates a unique request ID, sends the request over stdin, and waits for
174// the corresponding response or context cancellation.
175// Returns the raw JSON response message or an error if the request fails.
176func (c *StdioMCPClient) sendRequest(
177	ctx context.Context,
178	method string,
179	params interface{},
180) (*json.RawMessage, error) {
181	if !c.initialized && method != "initialize" {
182		return nil, fmt.Errorf("client not initialized")
183	}
184
185	id := c.requestID.Add(1)
186
187	// Create the complete request structure
188	request := mcp.JSONRPCRequest{
189		JSONRPC: mcp.JSONRPC_VERSION,
190		ID:      id,
191		Request: mcp.Request{
192			Method: method,
193		},
194		Params: params,
195	}
196
197	responseChan := make(chan RPCResponse, 1)
198	c.mu.Lock()
199	c.responses[id] = responseChan
200	c.mu.Unlock()
201
202	requestBytes, err := json.Marshal(request)
203	if err != nil {
204		return nil, fmt.Errorf("failed to marshal request: %w", err)
205	}
206	requestBytes = append(requestBytes, '\n')
207
208	if _, err := c.stdin.Write(requestBytes); err != nil {
209		return nil, fmt.Errorf("failed to write request: %w", err)
210	}
211
212	select {
213	case <-ctx.Done():
214		c.mu.Lock()
215		delete(c.responses, id)
216		c.mu.Unlock()
217		return nil, ctx.Err()
218	case response := <-responseChan:
219		if response.Error != nil {
220			return nil, errors.New(*response.Error)
221		}
222		return response.Response, nil
223	}
224}
225
226func (c *StdioMCPClient) Ping(ctx context.Context) error {
227	_, err := c.sendRequest(ctx, "ping", nil)
228	return err
229}
230
231func (c *StdioMCPClient) Initialize(
232	ctx context.Context,
233	request mcp.InitializeRequest,
234) (*mcp.InitializeResult, error) {
235	// This structure ensures Capabilities is always included in JSON
236	params := struct {
237		ProtocolVersion string                 `json:"protocolVersion"`
238		ClientInfo      mcp.Implementation     `json:"clientInfo"`
239		Capabilities    mcp.ClientCapabilities `json:"capabilities"`
240	}{
241		ProtocolVersion: request.Params.ProtocolVersion,
242		ClientInfo:      request.Params.ClientInfo,
243		Capabilities:    request.Params.Capabilities, // Will be empty struct if not set
244	}
245
246	response, err := c.sendRequest(ctx, "initialize", params)
247	if err != nil {
248		return nil, err
249	}
250
251	var result mcp.InitializeResult
252	if err := json.Unmarshal(*response, &result); err != nil {
253		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
254	}
255
256	// Store capabilities
257	c.capabilities = result.Capabilities
258
259	// Send initialized notification
260	notification := mcp.JSONRPCNotification{
261		JSONRPC: mcp.JSONRPC_VERSION,
262		Notification: mcp.Notification{
263			Method: "notifications/initialized",
264		},
265	}
266
267	notificationBytes, err := json.Marshal(notification)
268	if err != nil {
269		return nil, fmt.Errorf(
270			"failed to marshal initialized notification: %w",
271			err,
272		)
273	}
274	notificationBytes = append(notificationBytes, '\n')
275
276	if _, err := c.stdin.Write(notificationBytes); err != nil {
277		return nil, fmt.Errorf(
278			"failed to send initialized notification: %w",
279			err,
280		)
281	}
282
283	c.initialized = true
284	return &result, nil
285}
286
287func (c *StdioMCPClient) ListResources(
288	ctx context.Context,
289	request mcp.ListResourcesRequest,
290) (*mcp.
291	ListResourcesResult, error) {
292	response, err := c.sendRequest(
293		ctx,
294		"resources/list",
295		request.Params,
296	)
297	if err != nil {
298		return nil, err
299	}
300
301	var result mcp.ListResourcesResult
302	if err := json.Unmarshal(*response, &result); err != nil {
303		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
304	}
305
306	return &result, nil
307}
308
309func (c *StdioMCPClient) ListResourceTemplates(
310	ctx context.Context,
311	request mcp.ListResourceTemplatesRequest,
312) (*mcp.
313	ListResourceTemplatesResult, error) {
314	response, err := c.sendRequest(
315		ctx,
316		"resources/templates/list",
317		request.Params,
318	)
319	if err != nil {
320		return nil, err
321	}
322
323	var result mcp.ListResourceTemplatesResult
324	if err := json.Unmarshal(*response, &result); err != nil {
325		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
326	}
327
328	return &result, nil
329}
330
331func (c *StdioMCPClient) ReadResource(
332	ctx context.Context,
333	request mcp.ReadResourceRequest,
334) (*mcp.ReadResourceResult,
335	error) {
336	response, err := c.sendRequest(ctx, "resources/read", request.Params)
337	if err != nil {
338		return nil, err
339	}
340
341	return mcp.ParseReadResourceResult(response)
342}
343
344func (c *StdioMCPClient) Subscribe(
345	ctx context.Context,
346	request mcp.SubscribeRequest,
347) error {
348	_, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
349	return err
350}
351
352func (c *StdioMCPClient) Unsubscribe(
353	ctx context.Context,
354	request mcp.UnsubscribeRequest,
355) error {
356	_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
357	return err
358}
359
360func (c *StdioMCPClient) ListPrompts(
361	ctx context.Context,
362	request mcp.ListPromptsRequest,
363) (*mcp.ListPromptsResult, error) {
364	response, err := c.sendRequest(ctx, "prompts/list", request.Params)
365	if err != nil {
366		return nil, err
367	}
368
369	var result mcp.ListPromptsResult
370	if err := json.Unmarshal(*response, &result); err != nil {
371		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
372	}
373
374	return &result, nil
375}
376
377func (c *StdioMCPClient) GetPrompt(
378	ctx context.Context,
379	request mcp.GetPromptRequest,
380) (*mcp.GetPromptResult, error) {
381	response, err := c.sendRequest(ctx, "prompts/get", request.Params)
382	if err != nil {
383		return nil, err
384	}
385
386	return mcp.ParseGetPromptResult(response)
387}
388
389func (c *StdioMCPClient) ListTools(
390	ctx context.Context,
391	request mcp.ListToolsRequest,
392) (*mcp.ListToolsResult, error) {
393	response, err := c.sendRequest(ctx, "tools/list", request.Params)
394	if err != nil {
395		return nil, err
396	}
397
398	var result mcp.ListToolsResult
399	if err := json.Unmarshal(*response, &result); err != nil {
400		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
401	}
402
403	return &result, nil
404}
405
406func (c *StdioMCPClient) CallTool(
407	ctx context.Context,
408	request mcp.CallToolRequest,
409) (*mcp.CallToolResult, error) {
410	response, err := c.sendRequest(ctx, "tools/call", request.Params)
411	if err != nil {
412		return nil, err
413	}
414
415	return mcp.ParseCallToolResult(response)
416}
417
418func (c *StdioMCPClient) SetLevel(
419	ctx context.Context,
420	request mcp.SetLevelRequest,
421) error {
422	_, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
423	return err
424}
425
426func (c *StdioMCPClient) Complete(
427	ctx context.Context,
428	request mcp.CompleteRequest,
429) (*mcp.CompleteResult, error) {
430	response, err := c.sendRequest(ctx, "completion/complete", request.Params)
431	if err != nil {
432		return nil, err
433	}
434
435	var result mcp.CompleteResult
436	if err := json.Unmarshal(*response, &result); err != nil {
437		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
438	}
439
440	return &result, nil
441}