1package transport
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"os"
 10	"os/exec"
 11	"sync"
 12
 13	"github.com/mark3labs/mcp-go/mcp"
 14)
 15
 16// Stdio implements the transport layer of the MCP protocol using stdio communication.
 17// It launches a subprocess and communicates with it via standard input/output streams
 18// using JSON-RPC messages. The client handles message routing between requests and
 19// responses, and supports asynchronous notifications.
 20type Stdio struct {
 21	command string
 22	args    []string
 23	env     []string
 24
 25	cmd            *exec.Cmd
 26	stdin          io.WriteCloser
 27	stdout         *bufio.Reader
 28	stderr         io.ReadCloser
 29	responses      map[string]chan *JSONRPCResponse
 30	mu             sync.RWMutex
 31	done           chan struct{}
 32	onNotification func(mcp.JSONRPCNotification)
 33	notifyMu       sync.RWMutex
 34}
 35
 36// NewIO returns a new stdio-based transport using existing input, output, and
 37// logging streams instead of spawning a subprocess.
 38// This is useful for testing and simulating client behavior.
 39func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio {
 40	return &Stdio{
 41		stdin:  output,
 42		stdout: bufio.NewReader(input),
 43		stderr: logging,
 44
 45		responses: make(map[string]chan *JSONRPCResponse),
 46		done:      make(chan struct{}),
 47	}
 48}
 49
 50// NewStdio creates a new stdio transport to communicate with a subprocess.
 51// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
 52// Returns an error if the subprocess cannot be started or the pipes cannot be created.
 53func NewStdio(
 54	command string,
 55	env []string,
 56	args ...string,
 57) *Stdio {
 58
 59	client := &Stdio{
 60		command: command,
 61		args:    args,
 62		env:     env,
 63
 64		responses: make(map[string]chan *JSONRPCResponse),
 65		done:      make(chan struct{}),
 66	}
 67
 68	return client
 69}
 70
 71func (c *Stdio) Start(ctx context.Context) error {
 72	if err := c.spawnCommand(ctx); err != nil {
 73		return err
 74	}
 75
 76	ready := make(chan struct{})
 77	go func() {
 78		close(ready)
 79		c.readResponses()
 80	}()
 81	<-ready
 82
 83	return nil
 84}
 85
 86// spawnCommand spawns a new process running c.command.
 87func (c *Stdio) spawnCommand(ctx context.Context) error {
 88	if c.command == "" {
 89		return nil
 90	}
 91
 92	cmd := exec.CommandContext(ctx, c.command, c.args...)
 93
 94	mergedEnv := os.Environ()
 95	mergedEnv = append(mergedEnv, c.env...)
 96
 97	cmd.Env = mergedEnv
 98
 99	stdin, err := cmd.StdinPipe()
100	if err != nil {
101		return fmt.Errorf("failed to create stdin pipe: %w", err)
102	}
103
104	stdout, err := cmd.StdoutPipe()
105	if err != nil {
106		return fmt.Errorf("failed to create stdout pipe: %w", err)
107	}
108
109	stderr, err := cmd.StderrPipe()
110	if err != nil {
111		return fmt.Errorf("failed to create stderr pipe: %w", err)
112	}
113
114	c.cmd = cmd
115	c.stdin = stdin
116	c.stderr = stderr
117	c.stdout = bufio.NewReader(stdout)
118
119	if err := cmd.Start(); err != nil {
120		return fmt.Errorf("failed to start command: %w", err)
121	}
122
123	return nil
124}
125
126// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
127// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
128func (c *Stdio) Close() error {
129	select {
130	case <-c.done:
131		return nil
132	default:
133	}
134	// cancel all in-flight request
135	close(c.done)
136
137	if err := c.stdin.Close(); err != nil {
138		return fmt.Errorf("failed to close stdin: %w", err)
139	}
140	if err := c.stderr.Close(); err != nil {
141		return fmt.Errorf("failed to close stderr: %w", err)
142	}
143
144	if c.cmd != nil {
145		return c.cmd.Wait()
146	}
147
148	return nil
149}
150
151// SetNotificationHandler sets the handler function to be called when a notification is received.
152// Only one handler can be set at a time; setting a new one replaces the previous handler.
153func (c *Stdio) SetNotificationHandler(
154	handler func(notification mcp.JSONRPCNotification),
155) {
156	c.notifyMu.Lock()
157	defer c.notifyMu.Unlock()
158	c.onNotification = handler
159}
160
161// readResponses continuously reads and processes responses from the server's stdout.
162// It handles both responses to requests and notifications, routing them appropriately.
163// Runs until the done channel is closed or an error occurs reading from stdout.
164func (c *Stdio) readResponses() {
165	for {
166		select {
167		case <-c.done:
168			return
169		default:
170			line, err := c.stdout.ReadString('\n')
171			if err != nil {
172				if err != io.EOF {
173					fmt.Printf("Error reading response: %v\n", err)
174				}
175				return
176			}
177
178			var baseMessage JSONRPCResponse
179			if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
180				continue
181			}
182
183			// Handle notification
184			if baseMessage.ID.IsNil() {
185				var notification mcp.JSONRPCNotification
186				if err := json.Unmarshal([]byte(line), ¬ification); err != nil {
187					continue
188				}
189				c.notifyMu.RLock()
190				if c.onNotification != nil {
191					c.onNotification(notification)
192				}
193				c.notifyMu.RUnlock()
194				continue
195			}
196
197			// Create string key for map lookup
198			idKey := baseMessage.ID.String()
199
200			c.mu.RLock()
201			ch, exists := c.responses[idKey]
202			c.mu.RUnlock()
203
204			if exists {
205				ch <- &baseMessage
206				c.mu.Lock()
207				delete(c.responses, idKey)
208				c.mu.Unlock()
209			}
210		}
211	}
212}
213
214// SendRequest sends a JSON-RPC request to the server and waits for a response.
215// It creates a unique request ID, sends the request over stdin, and waits for
216// the corresponding response or context cancellation.
217// Returns the raw JSON response message or an error if the request fails.
218func (c *Stdio) SendRequest(
219	ctx context.Context,
220	request JSONRPCRequest,
221) (*JSONRPCResponse, error) {
222	if c.stdin == nil {
223		return nil, fmt.Errorf("stdio client not started")
224	}
225
226	// Marshal request
227	requestBytes, err := json.Marshal(request)
228	if err != nil {
229		return nil, fmt.Errorf("failed to marshal request: %w", err)
230	}
231	requestBytes = append(requestBytes, '\n')
232
233	// Create string key for map lookup
234	idKey := request.ID.String()
235
236	// Register response channel
237	responseChan := make(chan *JSONRPCResponse, 1)
238	c.mu.Lock()
239	c.responses[idKey] = responseChan
240	c.mu.Unlock()
241	deleteResponseChan := func() {
242		c.mu.Lock()
243		delete(c.responses, idKey)
244		c.mu.Unlock()
245	}
246
247	// Send request
248	if _, err := c.stdin.Write(requestBytes); err != nil {
249		deleteResponseChan()
250		return nil, fmt.Errorf("failed to write request: %w", err)
251	}
252
253	select {
254	case <-ctx.Done():
255		deleteResponseChan()
256		return nil, ctx.Err()
257	case response := <-responseChan:
258		return response, nil
259	}
260}
261
262// SendNotification sends a json RPC Notification to the server.
263func (c *Stdio) SendNotification(
264	ctx context.Context,
265	notification mcp.JSONRPCNotification,
266) error {
267	if c.stdin == nil {
268		return fmt.Errorf("stdio client not started")
269	}
270
271	notificationBytes, err := json.Marshal(notification)
272	if err != nil {
273		return fmt.Errorf("failed to marshal notification: %w", err)
274	}
275	notificationBytes = append(notificationBytes, '\n')
276
277	if _, err := c.stdin.Write(notificationBytes); err != nil {
278		return fmt.Errorf("failed to write notification: %w", err)
279	}
280
281	return nil
282}
283
284// Stderr returns a reader for the stderr output of the subprocess.
285// This can be used to capture error messages or logs from the subprocess.
286func (c *Stdio) Stderr() io.Reader {
287	return c.stderr
288}