1package server
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "log"
10 "os"
11 "os/signal"
12 "sync/atomic"
13 "syscall"
14
15 "github.com/mark3labs/mcp-go/mcp"
16)
17
18// StdioContextFunc is a function that takes an existing context and returns
19// a potentially modified context.
20// This can be used to inject context values from environment variables,
21// for example.
22type StdioContextFunc func(ctx context.Context) context.Context
23
24// StdioServer wraps a MCPServer and handles stdio communication.
25// It provides a simple way to create command-line MCP servers that
26// communicate via standard input/output streams using JSON-RPC messages.
27type StdioServer struct {
28 server *MCPServer
29 errLogger *log.Logger
30 contextFunc StdioContextFunc
31}
32
33// StdioOption defines a function type for configuring StdioServer
34type StdioOption func(*StdioServer)
35
36// WithErrorLogger sets the error logger for the server
37func WithErrorLogger(logger *log.Logger) StdioOption {
38 return func(s *StdioServer) {
39 s.errLogger = logger
40 }
41}
42
43// WithStdioContextFunc sets a function that will be called to customise the context
44// to the server. Note that the stdio server uses the same context for all requests,
45// so this function will only be called once per server instance.
46func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
47 return func(s *StdioServer) {
48 s.contextFunc = fn
49 }
50}
51
52// stdioSession is a static client session, since stdio has only one client.
53type stdioSession struct {
54 notifications chan mcp.JSONRPCNotification
55 initialized atomic.Bool
56 loggingLevel atomic.Value
57 clientInfo atomic.Value // stores session-specific client info
58}
59
60func (s *stdioSession) SessionID() string {
61 return "stdio"
62}
63
64func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
65 return s.notifications
66}
67
68func (s *stdioSession) Initialize() {
69 // set default logging level
70 s.loggingLevel.Store(mcp.LoggingLevelError)
71 s.initialized.Store(true)
72}
73
74func (s *stdioSession) Initialized() bool {
75 return s.initialized.Load()
76}
77
78func (s *stdioSession) GetClientInfo() mcp.Implementation {
79 if value := s.clientInfo.Load(); value != nil {
80 if clientInfo, ok := value.(mcp.Implementation); ok {
81 return clientInfo
82 }
83 }
84 return mcp.Implementation{}
85}
86
87func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) {
88 s.clientInfo.Store(clientInfo)
89}
90
91func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) {
92 s.loggingLevel.Store(level)
93}
94
95func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
96 level := s.loggingLevel.Load()
97 if level == nil {
98 return mcp.LoggingLevelError
99 }
100 return level.(mcp.LoggingLevel)
101}
102
103var (
104 _ ClientSession = (*stdioSession)(nil)
105 _ SessionWithLogging = (*stdioSession)(nil)
106 _ SessionWithClientInfo = (*stdioSession)(nil)
107)
108
109var stdioSessionInstance = stdioSession{
110 notifications: make(chan mcp.JSONRPCNotification, 100),
111}
112
113// NewStdioServer creates a new stdio server wrapper around an MCPServer.
114// It initializes the server with a default error logger that discards all output.
115func NewStdioServer(server *MCPServer) *StdioServer {
116 return &StdioServer{
117 server: server,
118 errLogger: log.New(
119 os.Stderr,
120 "",
121 log.LstdFlags,
122 ), // Default to discarding logs
123 }
124}
125
126// SetErrorLogger configures where error messages from the StdioServer are logged.
127// The provided logger will receive all error messages generated during server operation.
128func (s *StdioServer) SetErrorLogger(logger *log.Logger) {
129 s.errLogger = logger
130}
131
132// SetContextFunc sets a function that will be called to customise the context
133// to the server. Note that the stdio server uses the same context for all requests,
134// so this function will only be called once per server instance.
135func (s *StdioServer) SetContextFunc(fn StdioContextFunc) {
136 s.contextFunc = fn
137}
138
139// handleNotifications continuously processes notifications from the session's notification channel
140// and writes them to the provided output. It runs until the context is cancelled.
141// Any errors encountered while writing notifications are logged but do not stop the handler.
142func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) {
143 for {
144 select {
145 case notification := <-stdioSessionInstance.notifications:
146 if err := s.writeResponse(notification, stdout); err != nil {
147 s.errLogger.Printf("Error writing notification: %v", err)
148 }
149 case <-ctx.Done():
150 return
151 }
152 }
153}
154
155// processInputStream continuously reads and processes messages from the input stream.
156// It handles EOF gracefully as a normal termination condition.
157// The function returns when either:
158// - The context is cancelled (returns context.Err())
159// - EOF is encountered (returns nil)
160// - An error occurs while reading or processing messages (returns the error)
161func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error {
162 for {
163 if err := ctx.Err(); err != nil {
164 return err
165 }
166
167 line, err := s.readNextLine(ctx, reader)
168 if err != nil {
169 if err == io.EOF {
170 return nil
171 }
172 s.errLogger.Printf("Error reading input: %v", err)
173 return err
174 }
175
176 if err := s.processMessage(ctx, line, stdout); err != nil {
177 if err == io.EOF {
178 return nil
179 }
180 s.errLogger.Printf("Error handling message: %v", err)
181 return err
182 }
183 }
184}
185
186// readNextLine reads a single line from the input reader in a context-aware manner.
187// It uses channels to make the read operation cancellable via context.
188// Returns the read line and any error encountered. If the context is cancelled,
189// returns an empty string and the context's error. EOF is returned when the input
190// stream is closed.
191func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) {
192 type result struct {
193 line string
194 err error
195 }
196
197 resultCh := make(chan result, 1)
198
199 go func() {
200 line, err := reader.ReadString('\n')
201 resultCh <- result{line: line, err: err}
202 }()
203
204 select {
205 case <-ctx.Done():
206 return "", nil
207 case res := <-resultCh:
208 return res.line, res.err
209 }
210}
211
212// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
213// It runs until the context is cancelled or an error occurs.
214// Returns an error if there are issues with reading input or writing output.
215func (s *StdioServer) Listen(
216 ctx context.Context,
217 stdin io.Reader,
218 stdout io.Writer,
219) error {
220 // Set a static client context since stdio only has one client
221 if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil {
222 return fmt.Errorf("register session: %w", err)
223 }
224 defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
225 ctx = s.server.WithContext(ctx, &stdioSessionInstance)
226
227 // Add in any custom context.
228 if s.contextFunc != nil {
229 ctx = s.contextFunc(ctx)
230 }
231
232 reader := bufio.NewReader(stdin)
233
234 // Start notification handler
235 go s.handleNotifications(ctx, stdout)
236 return s.processInputStream(ctx, reader, stdout)
237}
238
239// processMessage handles a single JSON-RPC message and writes the response.
240// It parses the message, processes it through the wrapped MCPServer, and writes any response.
241// Returns an error if there are issues with message processing or response writing.
242func (s *StdioServer) processMessage(
243 ctx context.Context,
244 line string,
245 writer io.Writer,
246) error {
247 // If line is empty, likely due to ctx cancellation
248 if len(line) == 0 {
249 return nil
250 }
251
252 // Parse the message as raw JSON
253 var rawMessage json.RawMessage
254 if err := json.Unmarshal([]byte(line), &rawMessage); err != nil {
255 response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error")
256 return s.writeResponse(response, writer)
257 }
258
259 // Handle the message using the wrapped server
260 response := s.server.HandleMessage(ctx, rawMessage)
261
262 // Only write response if there is one (not for notifications)
263 if response != nil {
264 if err := s.writeResponse(response, writer); err != nil {
265 return fmt.Errorf("failed to write response: %w", err)
266 }
267 }
268
269 return nil
270}
271
272// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
273// Returns an error if marshaling or writing fails.
274func (s *StdioServer) writeResponse(
275 response mcp.JSONRPCMessage,
276 writer io.Writer,
277) error {
278 responseBytes, err := json.Marshal(response)
279 if err != nil {
280 return err
281 }
282
283 // Write response followed by newline
284 if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil {
285 return err
286 }
287
288 return nil
289}
290
291// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout.
292// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT.
293// Returns an error if the server encounters any issues during operation.
294func ServeStdio(server *MCPServer, opts ...StdioOption) error {
295 s := NewStdioServer(server)
296
297 for _, opt := range opts {
298 opt(s)
299 }
300
301 ctx, cancel := context.WithCancel(context.Background())
302 defer cancel()
303
304 // Set up signal handling
305 sigChan := make(chan os.Signal, 1)
306 signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
307
308 go func() {
309 <-sigChan
310 cancel()
311 }()
312
313 return s.Listen(ctx, os.Stdin, os.Stdout)
314}