1package transport
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "os"
10 "os/exec"
11 "sync"
12
13 "github.com/mark3labs/mcp-go/mcp"
14)
15
16// Stdio implements the transport layer of the MCP protocol using stdio communication.
17// It launches a subprocess and communicates with it via standard input/output streams
18// using JSON-RPC messages. The client handles message routing between requests and
19// responses, and supports asynchronous notifications.
20type Stdio struct {
21 command string
22 args []string
23 env []string
24
25 cmd *exec.Cmd
26 stdin io.WriteCloser
27 stdout *bufio.Reader
28 stderr io.ReadCloser
29 responses map[string]chan *JSONRPCResponse
30 mu sync.RWMutex
31 done chan struct{}
32 onNotification func(mcp.JSONRPCNotification)
33 notifyMu sync.RWMutex
34}
35
36// NewIO returns a new stdio-based transport using existing input, output, and
37// logging streams instead of spawning a subprocess.
38// This is useful for testing and simulating client behavior.
39func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio {
40 return &Stdio{
41 stdin: output,
42 stdout: bufio.NewReader(input),
43 stderr: logging,
44
45 responses: make(map[string]chan *JSONRPCResponse),
46 done: make(chan struct{}),
47 }
48}
49
50// NewStdio creates a new stdio transport to communicate with a subprocess.
51// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
52// Returns an error if the subprocess cannot be started or the pipes cannot be created.
53func NewStdio(
54 command string,
55 env []string,
56 args ...string,
57) *Stdio {
58
59 client := &Stdio{
60 command: command,
61 args: args,
62 env: env,
63
64 responses: make(map[string]chan *JSONRPCResponse),
65 done: make(chan struct{}),
66 }
67
68 return client
69}
70
71func (c *Stdio) Start(ctx context.Context) error {
72 if err := c.spawnCommand(ctx); err != nil {
73 return err
74 }
75
76 ready := make(chan struct{})
77 go func() {
78 close(ready)
79 c.readResponses()
80 }()
81 <-ready
82
83 return nil
84}
85
86// spawnCommand spawns a new process running c.command.
87func (c *Stdio) spawnCommand(ctx context.Context) error {
88 if c.command == "" {
89 return nil
90 }
91
92 cmd := exec.CommandContext(ctx, c.command, c.args...)
93
94 mergedEnv := os.Environ()
95 mergedEnv = append(mergedEnv, c.env...)
96
97 cmd.Env = mergedEnv
98
99 stdin, err := cmd.StdinPipe()
100 if err != nil {
101 return fmt.Errorf("failed to create stdin pipe: %w", err)
102 }
103
104 stdout, err := cmd.StdoutPipe()
105 if err != nil {
106 return fmt.Errorf("failed to create stdout pipe: %w", err)
107 }
108
109 stderr, err := cmd.StderrPipe()
110 if err != nil {
111 return fmt.Errorf("failed to create stderr pipe: %w", err)
112 }
113
114 c.cmd = cmd
115 c.stdin = stdin
116 c.stderr = stderr
117 c.stdout = bufio.NewReader(stdout)
118
119 if err := cmd.Start(); err != nil {
120 return fmt.Errorf("failed to start command: %w", err)
121 }
122
123 return nil
124}
125
126// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
127// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
128func (c *Stdio) Close() error {
129 select {
130 case <-c.done:
131 return nil
132 default:
133 }
134 // cancel all in-flight request
135 close(c.done)
136
137 if err := c.stdin.Close(); err != nil {
138 return fmt.Errorf("failed to close stdin: %w", err)
139 }
140 if err := c.stderr.Close(); err != nil {
141 return fmt.Errorf("failed to close stderr: %w", err)
142 }
143
144 if c.cmd != nil {
145 return c.cmd.Wait()
146 }
147
148 return nil
149}
150
151// SetNotificationHandler sets the handler function to be called when a notification is received.
152// Only one handler can be set at a time; setting a new one replaces the previous handler.
153func (c *Stdio) SetNotificationHandler(
154 handler func(notification mcp.JSONRPCNotification),
155) {
156 c.notifyMu.Lock()
157 defer c.notifyMu.Unlock()
158 c.onNotification = handler
159}
160
161// readResponses continuously reads and processes responses from the server's stdout.
162// It handles both responses to requests and notifications, routing them appropriately.
163// Runs until the done channel is closed or an error occurs reading from stdout.
164func (c *Stdio) readResponses() {
165 for {
166 select {
167 case <-c.done:
168 return
169 default:
170 line, err := c.stdout.ReadString('\n')
171 if err != nil {
172 if err != io.EOF {
173 fmt.Printf("Error reading response: %v\n", err)
174 }
175 return
176 }
177
178 var baseMessage JSONRPCResponse
179 if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
180 continue
181 }
182
183 // Handle notification
184 if baseMessage.ID.IsNil() {
185 var notification mcp.JSONRPCNotification
186 if err := json.Unmarshal([]byte(line), ¬ification); err != nil {
187 continue
188 }
189 c.notifyMu.RLock()
190 if c.onNotification != nil {
191 c.onNotification(notification)
192 }
193 c.notifyMu.RUnlock()
194 continue
195 }
196
197 // Create string key for map lookup
198 idKey := baseMessage.ID.String()
199
200 c.mu.RLock()
201 ch, exists := c.responses[idKey]
202 c.mu.RUnlock()
203
204 if exists {
205 ch <- &baseMessage
206 c.mu.Lock()
207 delete(c.responses, idKey)
208 c.mu.Unlock()
209 }
210 }
211 }
212}
213
214// SendRequest sends a JSON-RPC request to the server and waits for a response.
215// It creates a unique request ID, sends the request over stdin, and waits for
216// the corresponding response or context cancellation.
217// Returns the raw JSON response message or an error if the request fails.
218func (c *Stdio) SendRequest(
219 ctx context.Context,
220 request JSONRPCRequest,
221) (*JSONRPCResponse, error) {
222 if c.stdin == nil {
223 return nil, fmt.Errorf("stdio client not started")
224 }
225
226 // Marshal request
227 requestBytes, err := json.Marshal(request)
228 if err != nil {
229 return nil, fmt.Errorf("failed to marshal request: %w", err)
230 }
231 requestBytes = append(requestBytes, '\n')
232
233 // Create string key for map lookup
234 idKey := request.ID.String()
235
236 // Register response channel
237 responseChan := make(chan *JSONRPCResponse, 1)
238 c.mu.Lock()
239 c.responses[idKey] = responseChan
240 c.mu.Unlock()
241 deleteResponseChan := func() {
242 c.mu.Lock()
243 delete(c.responses, idKey)
244 c.mu.Unlock()
245 }
246
247 // Send request
248 if _, err := c.stdin.Write(requestBytes); err != nil {
249 deleteResponseChan()
250 return nil, fmt.Errorf("failed to write request: %w", err)
251 }
252
253 select {
254 case <-ctx.Done():
255 deleteResponseChan()
256 return nil, ctx.Err()
257 case response := <-responseChan:
258 return response, nil
259 }
260}
261
262// SendNotification sends a json RPC Notification to the server.
263func (c *Stdio) SendNotification(
264 ctx context.Context,
265 notification mcp.JSONRPCNotification,
266) error {
267 if c.stdin == nil {
268 return fmt.Errorf("stdio client not started")
269 }
270
271 notificationBytes, err := json.Marshal(notification)
272 if err != nil {
273 return fmt.Errorf("failed to marshal notification: %w", err)
274 }
275 notificationBytes = append(notificationBytes, '\n')
276
277 if _, err := c.stdin.Write(notificationBytes); err != nil {
278 return fmt.Errorf("failed to write notification: %w", err)
279 }
280
281 return nil
282}
283
284// Stderr returns a reader for the stderr output of the subprocess.
285// This can be used to capture error messages or logs from the subprocess.
286func (c *Stdio) Stderr() io.Reader {
287 return c.stderr
288}