1package client
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "os"
11 "os/exec"
12 "sync"
13 "sync/atomic"
14
15 "github.com/mark3labs/mcp-go/mcp"
16)
17
18// StdioMCPClient implements the MCPClient interface using stdio communication.
19// It launches a subprocess and communicates with it via standard input/output streams
20// using JSON-RPC messages. The client handles message routing between requests and
21// responses, and supports asynchronous notifications.
22type StdioMCPClient struct {
23 cmd *exec.Cmd
24 stdin io.WriteCloser
25 stdout *bufio.Reader
26 requestID atomic.Int64
27 responses map[int64]chan RPCResponse
28 mu sync.RWMutex
29 done chan struct{}
30 initialized bool
31 notifications []func(mcp.JSONRPCNotification)
32 notifyMu sync.RWMutex
33 capabilities mcp.ServerCapabilities
34}
35
36// NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess.
37// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
38// Returns an error if the subprocess cannot be started or the pipes cannot be created.
39func NewStdioMCPClient(
40 command string,
41 env []string,
42 args ...string,
43) (*StdioMCPClient, error) {
44 cmd := exec.Command(command, args...)
45
46 mergedEnv := os.Environ()
47 mergedEnv = append(mergedEnv, env...)
48
49 cmd.Env = mergedEnv
50
51 stdin, err := cmd.StdinPipe()
52 if err != nil {
53 return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
54 }
55
56 stdout, err := cmd.StdoutPipe()
57 if err != nil {
58 return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
59 }
60
61 client := &StdioMCPClient{
62 cmd: cmd,
63 stdin: stdin,
64 stdout: bufio.NewReader(stdout),
65 responses: make(map[int64]chan RPCResponse),
66 done: make(chan struct{}),
67 }
68
69 if err := cmd.Start(); err != nil {
70 return nil, fmt.Errorf("failed to start command: %w", err)
71 }
72
73 // Start reading responses in a goroutine and wait for it to be ready
74 ready := make(chan struct{})
75 go func() {
76 close(ready)
77 client.readResponses()
78 }()
79 <-ready
80
81 return client, nil
82}
83
84// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
85// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
86func (c *StdioMCPClient) Close() error {
87 close(c.done)
88 if err := c.stdin.Close(); err != nil {
89 return fmt.Errorf("failed to close stdin: %w", err)
90 }
91 return c.cmd.Wait()
92}
93
94// OnNotification registers a handler function to be called when notifications are received.
95// Multiple handlers can be registered and will be called in the order they were added.
96func (c *StdioMCPClient) OnNotification(
97 handler func(notification mcp.JSONRPCNotification),
98) {
99 c.notifyMu.Lock()
100 defer c.notifyMu.Unlock()
101 c.notifications = append(c.notifications, handler)
102}
103
104// readResponses continuously reads and processes responses from the server's stdout.
105// It handles both responses to requests and notifications, routing them appropriately.
106// Runs until the done channel is closed or an error occurs reading from stdout.
107func (c *StdioMCPClient) readResponses() {
108 for {
109 select {
110 case <-c.done:
111 return
112 default:
113 line, err := c.stdout.ReadString('\n')
114 if err != nil {
115 if err != io.EOF {
116 fmt.Printf("Error reading response: %v\n", err)
117 }
118 return
119 }
120
121 var baseMessage struct {
122 JSONRPC string `json:"jsonrpc"`
123 ID *int64 `json:"id,omitempty"`
124 Method string `json:"method,omitempty"`
125 Result json.RawMessage `json:"result,omitempty"`
126 Error *struct {
127 Code int `json:"code"`
128 Message string `json:"message"`
129 } `json:"error,omitempty"`
130 }
131
132 if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
133 continue
134 }
135
136 // Handle notification
137 if baseMessage.ID == nil {
138 var notification mcp.JSONRPCNotification
139 if err := json.Unmarshal([]byte(line), ¬ification); err != nil {
140 continue
141 }
142 c.notifyMu.RLock()
143 for _, handler := range c.notifications {
144 handler(notification)
145 }
146 c.notifyMu.RUnlock()
147 continue
148 }
149
150 c.mu.RLock()
151 ch, ok := c.responses[*baseMessage.ID]
152 c.mu.RUnlock()
153
154 if ok {
155 if baseMessage.Error != nil {
156 ch <- RPCResponse{
157 Error: &baseMessage.Error.Message,
158 }
159 } else {
160 ch <- RPCResponse{
161 Response: &baseMessage.Result,
162 }
163 }
164 c.mu.Lock()
165 delete(c.responses, *baseMessage.ID)
166 c.mu.Unlock()
167 }
168 }
169 }
170}
171
172// sendRequest sends a JSON-RPC request to the server and waits for a response.
173// It creates a unique request ID, sends the request over stdin, and waits for
174// the corresponding response or context cancellation.
175// Returns the raw JSON response message or an error if the request fails.
176func (c *StdioMCPClient) sendRequest(
177 ctx context.Context,
178 method string,
179 params interface{},
180) (*json.RawMessage, error) {
181 if !c.initialized && method != "initialize" {
182 return nil, fmt.Errorf("client not initialized")
183 }
184
185 id := c.requestID.Add(1)
186
187 // Create the complete request structure
188 request := mcp.JSONRPCRequest{
189 JSONRPC: mcp.JSONRPC_VERSION,
190 ID: id,
191 Request: mcp.Request{
192 Method: method,
193 },
194 Params: params,
195 }
196
197 responseChan := make(chan RPCResponse, 1)
198 c.mu.Lock()
199 c.responses[id] = responseChan
200 c.mu.Unlock()
201
202 requestBytes, err := json.Marshal(request)
203 if err != nil {
204 return nil, fmt.Errorf("failed to marshal request: %w", err)
205 }
206 requestBytes = append(requestBytes, '\n')
207
208 if _, err := c.stdin.Write(requestBytes); err != nil {
209 return nil, fmt.Errorf("failed to write request: %w", err)
210 }
211
212 select {
213 case <-ctx.Done():
214 c.mu.Lock()
215 delete(c.responses, id)
216 c.mu.Unlock()
217 return nil, ctx.Err()
218 case response := <-responseChan:
219 if response.Error != nil {
220 return nil, errors.New(*response.Error)
221 }
222 return response.Response, nil
223 }
224}
225
226func (c *StdioMCPClient) Ping(ctx context.Context) error {
227 _, err := c.sendRequest(ctx, "ping", nil)
228 return err
229}
230
231func (c *StdioMCPClient) Initialize(
232 ctx context.Context,
233 request mcp.InitializeRequest,
234) (*mcp.InitializeResult, error) {
235 // This structure ensures Capabilities is always included in JSON
236 params := struct {
237 ProtocolVersion string `json:"protocolVersion"`
238 ClientInfo mcp.Implementation `json:"clientInfo"`
239 Capabilities mcp.ClientCapabilities `json:"capabilities"`
240 }{
241 ProtocolVersion: request.Params.ProtocolVersion,
242 ClientInfo: request.Params.ClientInfo,
243 Capabilities: request.Params.Capabilities, // Will be empty struct if not set
244 }
245
246 response, err := c.sendRequest(ctx, "initialize", params)
247 if err != nil {
248 return nil, err
249 }
250
251 var result mcp.InitializeResult
252 if err := json.Unmarshal(*response, &result); err != nil {
253 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
254 }
255
256 // Store capabilities
257 c.capabilities = result.Capabilities
258
259 // Send initialized notification
260 notification := mcp.JSONRPCNotification{
261 JSONRPC: mcp.JSONRPC_VERSION,
262 Notification: mcp.Notification{
263 Method: "notifications/initialized",
264 },
265 }
266
267 notificationBytes, err := json.Marshal(notification)
268 if err != nil {
269 return nil, fmt.Errorf(
270 "failed to marshal initialized notification: %w",
271 err,
272 )
273 }
274 notificationBytes = append(notificationBytes, '\n')
275
276 if _, err := c.stdin.Write(notificationBytes); err != nil {
277 return nil, fmt.Errorf(
278 "failed to send initialized notification: %w",
279 err,
280 )
281 }
282
283 c.initialized = true
284 return &result, nil
285}
286
287func (c *StdioMCPClient) ListResources(
288 ctx context.Context,
289 request mcp.ListResourcesRequest,
290) (*mcp.
291 ListResourcesResult, error) {
292 response, err := c.sendRequest(
293 ctx,
294 "resources/list",
295 request.Params,
296 )
297 if err != nil {
298 return nil, err
299 }
300
301 var result mcp.ListResourcesResult
302 if err := json.Unmarshal(*response, &result); err != nil {
303 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
304 }
305
306 return &result, nil
307}
308
309func (c *StdioMCPClient) ListResourceTemplates(
310 ctx context.Context,
311 request mcp.ListResourceTemplatesRequest,
312) (*mcp.
313 ListResourceTemplatesResult, error) {
314 response, err := c.sendRequest(
315 ctx,
316 "resources/templates/list",
317 request.Params,
318 )
319 if err != nil {
320 return nil, err
321 }
322
323 var result mcp.ListResourceTemplatesResult
324 if err := json.Unmarshal(*response, &result); err != nil {
325 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
326 }
327
328 return &result, nil
329}
330
331func (c *StdioMCPClient) ReadResource(
332 ctx context.Context,
333 request mcp.ReadResourceRequest,
334) (*mcp.ReadResourceResult,
335 error) {
336 response, err := c.sendRequest(ctx, "resources/read", request.Params)
337 if err != nil {
338 return nil, err
339 }
340
341 return mcp.ParseReadResourceResult(response)
342}
343
344func (c *StdioMCPClient) Subscribe(
345 ctx context.Context,
346 request mcp.SubscribeRequest,
347) error {
348 _, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
349 return err
350}
351
352func (c *StdioMCPClient) Unsubscribe(
353 ctx context.Context,
354 request mcp.UnsubscribeRequest,
355) error {
356 _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
357 return err
358}
359
360func (c *StdioMCPClient) ListPrompts(
361 ctx context.Context,
362 request mcp.ListPromptsRequest,
363) (*mcp.ListPromptsResult, error) {
364 response, err := c.sendRequest(ctx, "prompts/list", request.Params)
365 if err != nil {
366 return nil, err
367 }
368
369 var result mcp.ListPromptsResult
370 if err := json.Unmarshal(*response, &result); err != nil {
371 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
372 }
373
374 return &result, nil
375}
376
377func (c *StdioMCPClient) GetPrompt(
378 ctx context.Context,
379 request mcp.GetPromptRequest,
380) (*mcp.GetPromptResult, error) {
381 response, err := c.sendRequest(ctx, "prompts/get", request.Params)
382 if err != nil {
383 return nil, err
384 }
385
386 return mcp.ParseGetPromptResult(response)
387}
388
389func (c *StdioMCPClient) ListTools(
390 ctx context.Context,
391 request mcp.ListToolsRequest,
392) (*mcp.ListToolsResult, error) {
393 response, err := c.sendRequest(ctx, "tools/list", request.Params)
394 if err != nil {
395 return nil, err
396 }
397
398 var result mcp.ListToolsResult
399 if err := json.Unmarshal(*response, &result); err != nil {
400 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
401 }
402
403 return &result, nil
404}
405
406func (c *StdioMCPClient) CallTool(
407 ctx context.Context,
408 request mcp.CallToolRequest,
409) (*mcp.CallToolResult, error) {
410 response, err := c.sendRequest(ctx, "tools/call", request.Params)
411 if err != nil {
412 return nil, err
413 }
414
415 return mcp.ParseCallToolResult(response)
416}
417
418func (c *StdioMCPClient) SetLevel(
419 ctx context.Context,
420 request mcp.SetLevelRequest,
421) error {
422 _, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
423 return err
424}
425
426func (c *StdioMCPClient) Complete(
427 ctx context.Context,
428 request mcp.CompleteRequest,
429) (*mcp.CompleteResult, error) {
430 response, err := c.sendRequest(ctx, "completion/complete", request.Params)
431 if err != nil {
432 return nil, err
433 }
434
435 var result mcp.CompleteResult
436 if err := json.Unmarshal(*response, &result); err != nil {
437 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
438 }
439
440 return &result, nil
441}