1package server
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"log"
 10	"os"
 11	"os/signal"
 12	"sync"
 13	"sync/atomic"
 14	"syscall"
 15
 16	"github.com/mark3labs/mcp-go/mcp"
 17)
 18
 19// StdioContextFunc is a function that takes an existing context and returns
 20// a potentially modified context.
 21// This can be used to inject context values from environment variables,
 22// for example.
 23type StdioContextFunc func(ctx context.Context) context.Context
 24
 25// StdioServer wraps a MCPServer and handles stdio communication.
 26// It provides a simple way to create command-line MCP servers that
 27// communicate via standard input/output streams using JSON-RPC messages.
 28type StdioServer struct {
 29	server      *MCPServer
 30	errLogger   *log.Logger
 31	contextFunc StdioContextFunc
 32}
 33
 34// StdioOption defines a function type for configuring StdioServer
 35type StdioOption func(*StdioServer)
 36
 37// WithErrorLogger sets the error logger for the server
 38func WithErrorLogger(logger *log.Logger) StdioOption {
 39	return func(s *StdioServer) {
 40		s.errLogger = logger
 41	}
 42}
 43
 44// WithStdioContextFunc sets a function that will be called to customise the context
 45// to the server. Note that the stdio server uses the same context for all requests,
 46// so this function will only be called once per server instance.
 47func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
 48	return func(s *StdioServer) {
 49		s.contextFunc = fn
 50	}
 51}
 52
 53// stdioSession is a static client session, since stdio has only one client.
 54type stdioSession struct {
 55	notifications   chan mcp.JSONRPCNotification
 56	initialized     atomic.Bool
 57	loggingLevel    atomic.Value
 58	clientInfo      atomic.Value                     // stores session-specific client info
 59	writer          io.Writer                        // for sending requests to client
 60	requestID       atomic.Int64                     // for generating unique request IDs
 61	mu              sync.RWMutex                     // protects writer
 62	pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests
 63	pendingMu       sync.RWMutex                     // protects pendingRequests
 64}
 65
 66// samplingResponse represents a response to a sampling request
 67type samplingResponse struct {
 68	result *mcp.CreateMessageResult
 69	err    error
 70}
 71
 72func (s *stdioSession) SessionID() string {
 73	return "stdio"
 74}
 75
 76func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
 77	return s.notifications
 78}
 79
 80func (s *stdioSession) Initialize() {
 81	// set default logging level
 82	s.loggingLevel.Store(mcp.LoggingLevelError)
 83	s.initialized.Store(true)
 84}
 85
 86func (s *stdioSession) Initialized() bool {
 87	return s.initialized.Load()
 88}
 89
 90func (s *stdioSession) GetClientInfo() mcp.Implementation {
 91	if value := s.clientInfo.Load(); value != nil {
 92		if clientInfo, ok := value.(mcp.Implementation); ok {
 93			return clientInfo
 94		}
 95	}
 96	return mcp.Implementation{}
 97}
 98
 99func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) {
100	s.clientInfo.Store(clientInfo)
101}
102
103func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) {
104	s.loggingLevel.Store(level)
105}
106
107func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
108	level := s.loggingLevel.Load()
109	if level == nil {
110		return mcp.LoggingLevelError
111	}
112	return level.(mcp.LoggingLevel)
113}
114
115// RequestSampling sends a sampling request to the client and waits for the response.
116func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
117	s.mu.RLock()
118	writer := s.writer
119	s.mu.RUnlock()
120
121	if writer == nil {
122		return nil, fmt.Errorf("no writer available for sending requests")
123	}
124
125	// Generate a unique request ID
126	id := s.requestID.Add(1)
127
128	// Create a response channel for this request
129	responseChan := make(chan *samplingResponse, 1)
130	s.pendingMu.Lock()
131	s.pendingRequests[id] = responseChan
132	s.pendingMu.Unlock()
133
134	// Cleanup function to remove the pending request
135	cleanup := func() {
136		s.pendingMu.Lock()
137		delete(s.pendingRequests, id)
138		s.pendingMu.Unlock()
139	}
140	defer cleanup()
141
142	// Create the JSON-RPC request
143	jsonRPCRequest := struct {
144		JSONRPC string                  `json:"jsonrpc"`
145		ID      int64                   `json:"id"`
146		Method  string                  `json:"method"`
147		Params  mcp.CreateMessageParams `json:"params"`
148	}{
149		JSONRPC: mcp.JSONRPC_VERSION,
150		ID:      id,
151		Method:  string(mcp.MethodSamplingCreateMessage),
152		Params:  request.CreateMessageParams,
153	}
154
155	// Marshal and send the request
156	requestBytes, err := json.Marshal(jsonRPCRequest)
157	if err != nil {
158		return nil, fmt.Errorf("failed to marshal sampling request: %w", err)
159	}
160	requestBytes = append(requestBytes, '\n')
161
162	if _, err := writer.Write(requestBytes); err != nil {
163		return nil, fmt.Errorf("failed to write sampling request: %w", err)
164	}
165
166	// Wait for the response or context cancellation
167	select {
168	case <-ctx.Done():
169		return nil, ctx.Err()
170	case response := <-responseChan:
171		if response.err != nil {
172			return nil, response.err
173		}
174		return response.result, nil
175	}
176}
177
178// SetWriter sets the writer for sending requests to the client.
179func (s *stdioSession) SetWriter(writer io.Writer) {
180	s.mu.Lock()
181	defer s.mu.Unlock()
182	s.writer = writer
183}
184
185var (
186	_ ClientSession         = (*stdioSession)(nil)
187	_ SessionWithLogging    = (*stdioSession)(nil)
188	_ SessionWithClientInfo = (*stdioSession)(nil)
189	_ SessionWithSampling   = (*stdioSession)(nil)
190)
191
192var stdioSessionInstance = stdioSession{
193	notifications:   make(chan mcp.JSONRPCNotification, 100),
194	pendingRequests: make(map[int64]chan *samplingResponse),
195}
196
197// NewStdioServer creates a new stdio server wrapper around an MCPServer.
198// It initializes the server with a default error logger that discards all output.
199func NewStdioServer(server *MCPServer) *StdioServer {
200	return &StdioServer{
201		server: server,
202		errLogger: log.New(
203			os.Stderr,
204			"",
205			log.LstdFlags,
206		), // Default to discarding logs
207	}
208}
209
210// SetErrorLogger configures where error messages from the StdioServer are logged.
211// The provided logger will receive all error messages generated during server operation.
212func (s *StdioServer) SetErrorLogger(logger *log.Logger) {
213	s.errLogger = logger
214}
215
216// SetContextFunc sets a function that will be called to customise the context
217// to the server. Note that the stdio server uses the same context for all requests,
218// so this function will only be called once per server instance.
219func (s *StdioServer) SetContextFunc(fn StdioContextFunc) {
220	s.contextFunc = fn
221}
222
223// handleNotifications continuously processes notifications from the session's notification channel
224// and writes them to the provided output. It runs until the context is cancelled.
225// Any errors encountered while writing notifications are logged but do not stop the handler.
226func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) {
227	for {
228		select {
229		case notification := <-stdioSessionInstance.notifications:
230			if err := s.writeResponse(notification, stdout); err != nil {
231				s.errLogger.Printf("Error writing notification: %v", err)
232			}
233		case <-ctx.Done():
234			return
235		}
236	}
237}
238
239// processInputStream continuously reads and processes messages from the input stream.
240// It handles EOF gracefully as a normal termination condition.
241// The function returns when either:
242// - The context is cancelled (returns context.Err())
243// - EOF is encountered (returns nil)
244// - An error occurs while reading or processing messages (returns the error)
245func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error {
246	for {
247		if err := ctx.Err(); err != nil {
248			return err
249		}
250
251		line, err := s.readNextLine(ctx, reader)
252		if err != nil {
253			if err == io.EOF {
254				return nil
255			}
256			s.errLogger.Printf("Error reading input: %v", err)
257			return err
258		}
259
260		if err := s.processMessage(ctx, line, stdout); err != nil {
261			if err == io.EOF {
262				return nil
263			}
264			s.errLogger.Printf("Error handling message: %v", err)
265			return err
266		}
267	}
268}
269
270// readNextLine reads a single line from the input reader in a context-aware manner.
271// It uses channels to make the read operation cancellable via context.
272// Returns the read line and any error encountered. If the context is cancelled,
273// returns an empty string and the context's error. EOF is returned when the input
274// stream is closed.
275func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) {
276	type result struct {
277		line string
278		err  error
279	}
280
281	resultCh := make(chan result, 1)
282
283	go func() {
284		line, err := reader.ReadString('\n')
285		resultCh <- result{line: line, err: err}
286	}()
287
288	select {
289	case <-ctx.Done():
290		return "", nil
291	case res := <-resultCh:
292		return res.line, res.err
293	}
294}
295
296// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
297// It runs until the context is cancelled or an error occurs.
298// Returns an error if there are issues with reading input or writing output.
299func (s *StdioServer) Listen(
300	ctx context.Context,
301	stdin io.Reader,
302	stdout io.Writer,
303) error {
304	// Set a static client context since stdio only has one client
305	if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil {
306		return fmt.Errorf("register session: %w", err)
307	}
308	defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
309	ctx = s.server.WithContext(ctx, &stdioSessionInstance)
310
311	// Set the writer for sending requests to the client
312	stdioSessionInstance.SetWriter(stdout)
313
314	// Add in any custom context.
315	if s.contextFunc != nil {
316		ctx = s.contextFunc(ctx)
317	}
318
319	reader := bufio.NewReader(stdin)
320
321	// Start notification handler
322	go s.handleNotifications(ctx, stdout)
323	return s.processInputStream(ctx, reader, stdout)
324}
325
326// processMessage handles a single JSON-RPC message and writes the response.
327// It parses the message, processes it through the wrapped MCPServer, and writes any response.
328// Returns an error if there are issues with message processing or response writing.
329func (s *StdioServer) processMessage(
330	ctx context.Context,
331	line string,
332	writer io.Writer,
333) error {
334	// If line is empty, likely due to ctx cancellation
335	if len(line) == 0 {
336		return nil
337	}
338
339	// Parse the message as raw JSON
340	var rawMessage json.RawMessage
341	if err := json.Unmarshal([]byte(line), &rawMessage); err != nil {
342		response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error")
343		return s.writeResponse(response, writer)
344	}
345
346	// Check if this is a response to a sampling request
347	if s.handleSamplingResponse(rawMessage) {
348		return nil
349	}
350
351	// Check if this is a tool call that might need sampling (and thus should be processed concurrently)
352	var baseMessage struct {
353		Method string `json:"method"`
354	}
355	if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" {
356		// Process tool calls concurrently to avoid blocking on sampling requests
357		go func() {
358			response := s.server.HandleMessage(ctx, rawMessage)
359			if response != nil {
360				if err := s.writeResponse(response, writer); err != nil {
361					s.errLogger.Printf("Error writing tool response: %v", err)
362				}
363			}
364		}()
365		return nil
366	}
367
368	// Handle other messages synchronously
369	response := s.server.HandleMessage(ctx, rawMessage)
370
371	// Only write response if there is one (not for notifications)
372	if response != nil {
373		if err := s.writeResponse(response, writer); err != nil {
374			return fmt.Errorf("failed to write response: %w", err)
375		}
376	}
377
378	return nil
379}
380
381// handleSamplingResponse checks if the message is a response to a sampling request
382// and routes it to the appropriate pending request channel.
383func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool {
384	return stdioSessionInstance.handleSamplingResponse(rawMessage)
385}
386
387// handleSamplingResponse handles incoming sampling responses for this session
388func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool {
389	// Try to parse as a JSON-RPC response
390	var response struct {
391		JSONRPC string          `json:"jsonrpc"`
392		ID      json.Number     `json:"id"`
393		Result  json.RawMessage `json:"result,omitempty"`
394		Error   *struct {
395			Code    int    `json:"code"`
396			Message string `json:"message"`
397		} `json:"error,omitempty"`
398	}
399
400	if err := json.Unmarshal(rawMessage, &response); err != nil {
401		return false
402	}
403	// Parse the ID as int64
404	idInt64, err := response.ID.Int64()
405	if err != nil || (response.Result == nil && response.Error == nil) {
406		return false
407	}
408
409	// Look for a pending request with this ID
410	s.pendingMu.RLock()
411	responseChan, exists := s.pendingRequests[idInt64]
412	s.pendingMu.RUnlock()
413
414	if !exists {
415		return false
416	} // Parse and send the response
417	samplingResp := &samplingResponse{}
418
419	if response.Error != nil {
420		samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message)
421	} else {
422		var result mcp.CreateMessageResult
423		if err := json.Unmarshal(response.Result, &result); err != nil {
424			samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err)
425		} else {
426			samplingResp.result = &result
427		}
428	}
429
430	// Send the response (non-blocking)
431	select {
432	case responseChan <- samplingResp:
433	default:
434		// Channel is full or closed, ignore
435	}
436
437	return true
438}
439
440// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
441// Returns an error if marshaling or writing fails.
442func (s *StdioServer) writeResponse(
443	response mcp.JSONRPCMessage,
444	writer io.Writer,
445) error {
446	responseBytes, err := json.Marshal(response)
447	if err != nil {
448		return err
449	}
450
451	// Write response followed by newline
452	if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil {
453		return err
454	}
455
456	return nil
457}
458
459// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout.
460// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT.
461// Returns an error if the server encounters any issues during operation.
462func ServeStdio(server *MCPServer, opts ...StdioOption) error {
463	s := NewStdioServer(server)
464
465	for _, opt := range opts {
466		opt(s)
467	}
468
469	ctx, cancel := context.WithCancel(context.Background())
470	defer cancel()
471
472	// Set up signal handling
473	sigChan := make(chan os.Signal, 1)
474	signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
475
476	go func() {
477		<-sigChan
478		cancel()
479	}()
480
481	return s.Listen(ctx, os.Stdin, os.Stdout)
482}