stdio.go

  1package server
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"log"
 10	"os"
 11	"os/signal"
 12	"sync/atomic"
 13	"syscall"
 14
 15	"github.com/mark3labs/mcp-go/mcp"
 16)
 17
 18// StdioContextFunc is a function that takes an existing context and returns
 19// a potentially modified context.
 20// This can be used to inject context values from environment variables,
 21// for example.
 22type StdioContextFunc func(ctx context.Context) context.Context
 23
 24// StdioServer wraps a MCPServer and handles stdio communication.
 25// It provides a simple way to create command-line MCP servers that
 26// communicate via standard input/output streams using JSON-RPC messages.
 27type StdioServer struct {
 28	server      *MCPServer
 29	errLogger   *log.Logger
 30	contextFunc StdioContextFunc
 31}
 32
 33// StdioOption defines a function type for configuring StdioServer
 34type StdioOption func(*StdioServer)
 35
 36// WithErrorLogger sets the error logger for the server
 37func WithErrorLogger(logger *log.Logger) StdioOption {
 38	return func(s *StdioServer) {
 39		s.errLogger = logger
 40	}
 41}
 42
 43// WithStdioContextFunc sets a function that will be called to customise the context
 44// to the server. Note that the stdio server uses the same context for all requests,
 45// so this function will only be called once per server instance.
 46func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
 47	return func(s *StdioServer) {
 48		s.contextFunc = fn
 49	}
 50}
 51
 52// stdioSession is a static client session, since stdio has only one client.
 53type stdioSession struct {
 54	notifications chan mcp.JSONRPCNotification
 55	initialized   atomic.Bool
 56	loggingLevel  atomic.Value
 57	clientInfo    atomic.Value // stores session-specific client info
 58}
 59
 60func (s *stdioSession) SessionID() string {
 61	return "stdio"
 62}
 63
 64func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
 65	return s.notifications
 66}
 67
 68func (s *stdioSession) Initialize() {
 69	// set default logging level
 70	s.loggingLevel.Store(mcp.LoggingLevelError)
 71	s.initialized.Store(true)
 72}
 73
 74func (s *stdioSession) Initialized() bool {
 75	return s.initialized.Load()
 76}
 77
 78func (s *stdioSession) GetClientInfo() mcp.Implementation {
 79	if value := s.clientInfo.Load(); value != nil {
 80		if clientInfo, ok := value.(mcp.Implementation); ok {
 81			return clientInfo
 82		}
 83	}
 84	return mcp.Implementation{}
 85}
 86
 87func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) {
 88	s.clientInfo.Store(clientInfo)
 89}
 90
 91func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) {
 92	s.loggingLevel.Store(level)
 93}
 94
 95func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
 96	level := s.loggingLevel.Load()
 97	if level == nil {
 98		return mcp.LoggingLevelError
 99	}
100	return level.(mcp.LoggingLevel)
101}
102
103var (
104	_ ClientSession         = (*stdioSession)(nil)
105	_ SessionWithLogging    = (*stdioSession)(nil)
106	_ SessionWithClientInfo = (*stdioSession)(nil)
107)
108
109var stdioSessionInstance = stdioSession{
110	notifications: make(chan mcp.JSONRPCNotification, 100),
111}
112
113// NewStdioServer creates a new stdio server wrapper around an MCPServer.
114// It initializes the server with a default error logger that discards all output.
115func NewStdioServer(server *MCPServer) *StdioServer {
116	return &StdioServer{
117		server: server,
118		errLogger: log.New(
119			os.Stderr,
120			"",
121			log.LstdFlags,
122		), // Default to discarding logs
123	}
124}
125
126// SetErrorLogger configures where error messages from the StdioServer are logged.
127// The provided logger will receive all error messages generated during server operation.
128func (s *StdioServer) SetErrorLogger(logger *log.Logger) {
129	s.errLogger = logger
130}
131
132// SetContextFunc sets a function that will be called to customise the context
133// to the server. Note that the stdio server uses the same context for all requests,
134// so this function will only be called once per server instance.
135func (s *StdioServer) SetContextFunc(fn StdioContextFunc) {
136	s.contextFunc = fn
137}
138
139// handleNotifications continuously processes notifications from the session's notification channel
140// and writes them to the provided output. It runs until the context is cancelled.
141// Any errors encountered while writing notifications are logged but do not stop the handler.
142func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) {
143	for {
144		select {
145		case notification := <-stdioSessionInstance.notifications:
146			if err := s.writeResponse(notification, stdout); err != nil {
147				s.errLogger.Printf("Error writing notification: %v", err)
148			}
149		case <-ctx.Done():
150			return
151		}
152	}
153}
154
155// processInputStream continuously reads and processes messages from the input stream.
156// It handles EOF gracefully as a normal termination condition.
157// The function returns when either:
158// - The context is cancelled (returns context.Err())
159// - EOF is encountered (returns nil)
160// - An error occurs while reading or processing messages (returns the error)
161func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error {
162	for {
163		if err := ctx.Err(); err != nil {
164			return err
165		}
166
167		line, err := s.readNextLine(ctx, reader)
168		if err != nil {
169			if err == io.EOF {
170				return nil
171			}
172			s.errLogger.Printf("Error reading input: %v", err)
173			return err
174		}
175
176		if err := s.processMessage(ctx, line, stdout); err != nil {
177			if err == io.EOF {
178				return nil
179			}
180			s.errLogger.Printf("Error handling message: %v", err)
181			return err
182		}
183	}
184}
185
186// readNextLine reads a single line from the input reader in a context-aware manner.
187// It uses channels to make the read operation cancellable via context.
188// Returns the read line and any error encountered. If the context is cancelled,
189// returns an empty string and the context's error. EOF is returned when the input
190// stream is closed.
191func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) {
192	type result struct {
193		line string
194		err  error
195	}
196
197	resultCh := make(chan result, 1)
198
199	go func() {
200		line, err := reader.ReadString('\n')
201		resultCh <- result{line: line, err: err}
202	}()
203
204	select {
205	case <-ctx.Done():
206		return "", nil
207	case res := <-resultCh:
208		return res.line, res.err
209	}
210}
211
212// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
213// It runs until the context is cancelled or an error occurs.
214// Returns an error if there are issues with reading input or writing output.
215func (s *StdioServer) Listen(
216	ctx context.Context,
217	stdin io.Reader,
218	stdout io.Writer,
219) error {
220	// Set a static client context since stdio only has one client
221	if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil {
222		return fmt.Errorf("register session: %w", err)
223	}
224	defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
225	ctx = s.server.WithContext(ctx, &stdioSessionInstance)
226
227	// Add in any custom context.
228	if s.contextFunc != nil {
229		ctx = s.contextFunc(ctx)
230	}
231
232	reader := bufio.NewReader(stdin)
233
234	// Start notification handler
235	go s.handleNotifications(ctx, stdout)
236	return s.processInputStream(ctx, reader, stdout)
237}
238
239// processMessage handles a single JSON-RPC message and writes the response.
240// It parses the message, processes it through the wrapped MCPServer, and writes any response.
241// Returns an error if there are issues with message processing or response writing.
242func (s *StdioServer) processMessage(
243	ctx context.Context,
244	line string,
245	writer io.Writer,
246) error {
247	// If line is empty, likely due to ctx cancellation
248	if len(line) == 0 {
249		return nil
250	}
251
252	// Parse the message as raw JSON
253	var rawMessage json.RawMessage
254	if err := json.Unmarshal([]byte(line), &rawMessage); err != nil {
255		response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error")
256		return s.writeResponse(response, writer)
257	}
258
259	// Handle the message using the wrapped server
260	response := s.server.HandleMessage(ctx, rawMessage)
261
262	// Only write response if there is one (not for notifications)
263	if response != nil {
264		if err := s.writeResponse(response, writer); err != nil {
265			return fmt.Errorf("failed to write response: %w", err)
266		}
267	}
268
269	return nil
270}
271
272// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
273// Returns an error if marshaling or writing fails.
274func (s *StdioServer) writeResponse(
275	response mcp.JSONRPCMessage,
276	writer io.Writer,
277) error {
278	responseBytes, err := json.Marshal(response)
279	if err != nil {
280		return err
281	}
282
283	// Write response followed by newline
284	if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil {
285		return err
286	}
287
288	return nil
289}
290
291// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout.
292// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT.
293// Returns an error if the server encounters any issues during operation.
294func ServeStdio(server *MCPServer, opts ...StdioOption) error {
295	s := NewStdioServer(server)
296
297	for _, opt := range opts {
298		opt(s)
299	}
300
301	ctx, cancel := context.WithCancel(context.Background())
302	defer cancel()
303
304	// Set up signal handling
305	sigChan := make(chan os.Signal, 1)
306	signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
307
308	go func() {
309		<-sigChan
310		cancel()
311	}()
312
313	return s.Listen(ctx, os.Stdin, os.Stdout)
314}