1package server
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "log"
10 "os"
11 "os/signal"
12 "sync"
13 "sync/atomic"
14 "syscall"
15
16 "github.com/mark3labs/mcp-go/mcp"
17)
18
19// StdioContextFunc is a function that takes an existing context and returns
20// a potentially modified context.
21// This can be used to inject context values from environment variables,
22// for example.
23type StdioContextFunc func(ctx context.Context) context.Context
24
25// StdioServer wraps a MCPServer and handles stdio communication.
26// It provides a simple way to create command-line MCP servers that
27// communicate via standard input/output streams using JSON-RPC messages.
28type StdioServer struct {
29 server *MCPServer
30 errLogger *log.Logger
31 contextFunc StdioContextFunc
32}
33
34// StdioOption defines a function type for configuring StdioServer
35type StdioOption func(*StdioServer)
36
37// WithErrorLogger sets the error logger for the server
38func WithErrorLogger(logger *log.Logger) StdioOption {
39 return func(s *StdioServer) {
40 s.errLogger = logger
41 }
42}
43
44// WithStdioContextFunc sets a function that will be called to customise the context
45// to the server. Note that the stdio server uses the same context for all requests,
46// so this function will only be called once per server instance.
47func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
48 return func(s *StdioServer) {
49 s.contextFunc = fn
50 }
51}
52
53// stdioSession is a static client session, since stdio has only one client.
54type stdioSession struct {
55 notifications chan mcp.JSONRPCNotification
56 initialized atomic.Bool
57 loggingLevel atomic.Value
58 clientInfo atomic.Value // stores session-specific client info
59 writer io.Writer // for sending requests to client
60 requestID atomic.Int64 // for generating unique request IDs
61 mu sync.RWMutex // protects writer
62 pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests
63 pendingMu sync.RWMutex // protects pendingRequests
64}
65
66// samplingResponse represents a response to a sampling request
67type samplingResponse struct {
68 result *mcp.CreateMessageResult
69 err error
70}
71
72func (s *stdioSession) SessionID() string {
73 return "stdio"
74}
75
76func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
77 return s.notifications
78}
79
80func (s *stdioSession) Initialize() {
81 // set default logging level
82 s.loggingLevel.Store(mcp.LoggingLevelError)
83 s.initialized.Store(true)
84}
85
86func (s *stdioSession) Initialized() bool {
87 return s.initialized.Load()
88}
89
90func (s *stdioSession) GetClientInfo() mcp.Implementation {
91 if value := s.clientInfo.Load(); value != nil {
92 if clientInfo, ok := value.(mcp.Implementation); ok {
93 return clientInfo
94 }
95 }
96 return mcp.Implementation{}
97}
98
99func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) {
100 s.clientInfo.Store(clientInfo)
101}
102
103func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) {
104 s.loggingLevel.Store(level)
105}
106
107func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
108 level := s.loggingLevel.Load()
109 if level == nil {
110 return mcp.LoggingLevelError
111 }
112 return level.(mcp.LoggingLevel)
113}
114
115// RequestSampling sends a sampling request to the client and waits for the response.
116func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
117 s.mu.RLock()
118 writer := s.writer
119 s.mu.RUnlock()
120
121 if writer == nil {
122 return nil, fmt.Errorf("no writer available for sending requests")
123 }
124
125 // Generate a unique request ID
126 id := s.requestID.Add(1)
127
128 // Create a response channel for this request
129 responseChan := make(chan *samplingResponse, 1)
130 s.pendingMu.Lock()
131 s.pendingRequests[id] = responseChan
132 s.pendingMu.Unlock()
133
134 // Cleanup function to remove the pending request
135 cleanup := func() {
136 s.pendingMu.Lock()
137 delete(s.pendingRequests, id)
138 s.pendingMu.Unlock()
139 }
140 defer cleanup()
141
142 // Create the JSON-RPC request
143 jsonRPCRequest := struct {
144 JSONRPC string `json:"jsonrpc"`
145 ID int64 `json:"id"`
146 Method string `json:"method"`
147 Params mcp.CreateMessageParams `json:"params"`
148 }{
149 JSONRPC: mcp.JSONRPC_VERSION,
150 ID: id,
151 Method: string(mcp.MethodSamplingCreateMessage),
152 Params: request.CreateMessageParams,
153 }
154
155 // Marshal and send the request
156 requestBytes, err := json.Marshal(jsonRPCRequest)
157 if err != nil {
158 return nil, fmt.Errorf("failed to marshal sampling request: %w", err)
159 }
160 requestBytes = append(requestBytes, '\n')
161
162 if _, err := writer.Write(requestBytes); err != nil {
163 return nil, fmt.Errorf("failed to write sampling request: %w", err)
164 }
165
166 // Wait for the response or context cancellation
167 select {
168 case <-ctx.Done():
169 return nil, ctx.Err()
170 case response := <-responseChan:
171 if response.err != nil {
172 return nil, response.err
173 }
174 return response.result, nil
175 }
176}
177
178// SetWriter sets the writer for sending requests to the client.
179func (s *stdioSession) SetWriter(writer io.Writer) {
180 s.mu.Lock()
181 defer s.mu.Unlock()
182 s.writer = writer
183}
184
185var (
186 _ ClientSession = (*stdioSession)(nil)
187 _ SessionWithLogging = (*stdioSession)(nil)
188 _ SessionWithClientInfo = (*stdioSession)(nil)
189 _ SessionWithSampling = (*stdioSession)(nil)
190)
191
192var stdioSessionInstance = stdioSession{
193 notifications: make(chan mcp.JSONRPCNotification, 100),
194 pendingRequests: make(map[int64]chan *samplingResponse),
195}
196
197// NewStdioServer creates a new stdio server wrapper around an MCPServer.
198// It initializes the server with a default error logger that discards all output.
199func NewStdioServer(server *MCPServer) *StdioServer {
200 return &StdioServer{
201 server: server,
202 errLogger: log.New(
203 os.Stderr,
204 "",
205 log.LstdFlags,
206 ), // Default to discarding logs
207 }
208}
209
210// SetErrorLogger configures where error messages from the StdioServer are logged.
211// The provided logger will receive all error messages generated during server operation.
212func (s *StdioServer) SetErrorLogger(logger *log.Logger) {
213 s.errLogger = logger
214}
215
216// SetContextFunc sets a function that will be called to customise the context
217// to the server. Note that the stdio server uses the same context for all requests,
218// so this function will only be called once per server instance.
219func (s *StdioServer) SetContextFunc(fn StdioContextFunc) {
220 s.contextFunc = fn
221}
222
223// handleNotifications continuously processes notifications from the session's notification channel
224// and writes them to the provided output. It runs until the context is cancelled.
225// Any errors encountered while writing notifications are logged but do not stop the handler.
226func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) {
227 for {
228 select {
229 case notification := <-stdioSessionInstance.notifications:
230 if err := s.writeResponse(notification, stdout); err != nil {
231 s.errLogger.Printf("Error writing notification: %v", err)
232 }
233 case <-ctx.Done():
234 return
235 }
236 }
237}
238
239// processInputStream continuously reads and processes messages from the input stream.
240// It handles EOF gracefully as a normal termination condition.
241// The function returns when either:
242// - The context is cancelled (returns context.Err())
243// - EOF is encountered (returns nil)
244// - An error occurs while reading or processing messages (returns the error)
245func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error {
246 for {
247 if err := ctx.Err(); err != nil {
248 return err
249 }
250
251 line, err := s.readNextLine(ctx, reader)
252 if err != nil {
253 if err == io.EOF {
254 return nil
255 }
256 s.errLogger.Printf("Error reading input: %v", err)
257 return err
258 }
259
260 if err := s.processMessage(ctx, line, stdout); err != nil {
261 if err == io.EOF {
262 return nil
263 }
264 s.errLogger.Printf("Error handling message: %v", err)
265 return err
266 }
267 }
268}
269
270// readNextLine reads a single line from the input reader in a context-aware manner.
271// It uses channels to make the read operation cancellable via context.
272// Returns the read line and any error encountered. If the context is cancelled,
273// returns an empty string and the context's error. EOF is returned when the input
274// stream is closed.
275func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) {
276 type result struct {
277 line string
278 err error
279 }
280
281 resultCh := make(chan result, 1)
282
283 go func() {
284 line, err := reader.ReadString('\n')
285 resultCh <- result{line: line, err: err}
286 }()
287
288 select {
289 case <-ctx.Done():
290 return "", nil
291 case res := <-resultCh:
292 return res.line, res.err
293 }
294}
295
296// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
297// It runs until the context is cancelled or an error occurs.
298// Returns an error if there are issues with reading input or writing output.
299func (s *StdioServer) Listen(
300 ctx context.Context,
301 stdin io.Reader,
302 stdout io.Writer,
303) error {
304 // Set a static client context since stdio only has one client
305 if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil {
306 return fmt.Errorf("register session: %w", err)
307 }
308 defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
309 ctx = s.server.WithContext(ctx, &stdioSessionInstance)
310
311 // Set the writer for sending requests to the client
312 stdioSessionInstance.SetWriter(stdout)
313
314 // Add in any custom context.
315 if s.contextFunc != nil {
316 ctx = s.contextFunc(ctx)
317 }
318
319 reader := bufio.NewReader(stdin)
320
321 // Start notification handler
322 go s.handleNotifications(ctx, stdout)
323 return s.processInputStream(ctx, reader, stdout)
324}
325
326// processMessage handles a single JSON-RPC message and writes the response.
327// It parses the message, processes it through the wrapped MCPServer, and writes any response.
328// Returns an error if there are issues with message processing or response writing.
329func (s *StdioServer) processMessage(
330 ctx context.Context,
331 line string,
332 writer io.Writer,
333) error {
334 // If line is empty, likely due to ctx cancellation
335 if len(line) == 0 {
336 return nil
337 }
338
339 // Parse the message as raw JSON
340 var rawMessage json.RawMessage
341 if err := json.Unmarshal([]byte(line), &rawMessage); err != nil {
342 response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error")
343 return s.writeResponse(response, writer)
344 }
345
346 // Check if this is a response to a sampling request
347 if s.handleSamplingResponse(rawMessage) {
348 return nil
349 }
350
351 // Check if this is a tool call that might need sampling (and thus should be processed concurrently)
352 var baseMessage struct {
353 Method string `json:"method"`
354 }
355 if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" {
356 // Process tool calls concurrently to avoid blocking on sampling requests
357 go func() {
358 response := s.server.HandleMessage(ctx, rawMessage)
359 if response != nil {
360 if err := s.writeResponse(response, writer); err != nil {
361 s.errLogger.Printf("Error writing tool response: %v", err)
362 }
363 }
364 }()
365 return nil
366 }
367
368 // Handle other messages synchronously
369 response := s.server.HandleMessage(ctx, rawMessage)
370
371 // Only write response if there is one (not for notifications)
372 if response != nil {
373 if err := s.writeResponse(response, writer); err != nil {
374 return fmt.Errorf("failed to write response: %w", err)
375 }
376 }
377
378 return nil
379}
380
381// handleSamplingResponse checks if the message is a response to a sampling request
382// and routes it to the appropriate pending request channel.
383func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool {
384 return stdioSessionInstance.handleSamplingResponse(rawMessage)
385}
386
387// handleSamplingResponse handles incoming sampling responses for this session
388func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool {
389 // Try to parse as a JSON-RPC response
390 var response struct {
391 JSONRPC string `json:"jsonrpc"`
392 ID json.Number `json:"id"`
393 Result json.RawMessage `json:"result,omitempty"`
394 Error *struct {
395 Code int `json:"code"`
396 Message string `json:"message"`
397 } `json:"error,omitempty"`
398 }
399
400 if err := json.Unmarshal(rawMessage, &response); err != nil {
401 return false
402 }
403 // Parse the ID as int64
404 idInt64, err := response.ID.Int64()
405 if err != nil || (response.Result == nil && response.Error == nil) {
406 return false
407 }
408
409 // Look for a pending request with this ID
410 s.pendingMu.RLock()
411 responseChan, exists := s.pendingRequests[idInt64]
412 s.pendingMu.RUnlock()
413
414 if !exists {
415 return false
416 } // Parse and send the response
417 samplingResp := &samplingResponse{}
418
419 if response.Error != nil {
420 samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message)
421 } else {
422 var result mcp.CreateMessageResult
423 if err := json.Unmarshal(response.Result, &result); err != nil {
424 samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err)
425 } else {
426 samplingResp.result = &result
427 }
428 }
429
430 // Send the response (non-blocking)
431 select {
432 case responseChan <- samplingResp:
433 default:
434 // Channel is full or closed, ignore
435 }
436
437 return true
438}
439
440// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
441// Returns an error if marshaling or writing fails.
442func (s *StdioServer) writeResponse(
443 response mcp.JSONRPCMessage,
444 writer io.Writer,
445) error {
446 responseBytes, err := json.Marshal(response)
447 if err != nil {
448 return err
449 }
450
451 // Write response followed by newline
452 if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil {
453 return err
454 }
455
456 return nil
457}
458
459// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout.
460// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT.
461// Returns an error if the server encounters any issues during operation.
462func ServeStdio(server *MCPServer, opts ...StdioOption) error {
463 s := NewStdioServer(server)
464
465 for _, opt := range opts {
466 opt(s)
467 }
468
469 ctx, cancel := context.WithCancel(context.Background())
470 defer cancel()
471
472 // Set up signal handling
473 sigChan := make(chan os.Signal, 1)
474 signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
475
476 go func() {
477 <-sigChan
478 cancel()
479 }()
480
481 return s.Listen(ctx, os.Stdin, os.Stdout)
482}