1package server
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log"
8 "net/http"
9 "net/http/httptest"
10 "net/url"
11 "path"
12 "strings"
13 "sync"
14 "sync/atomic"
15 "time"
16
17 "github.com/google/uuid"
18
19 "github.com/mark3labs/mcp-go/mcp"
20)
21
22// sseSession represents an active SSE connection.
23type sseSession struct {
24 done chan struct{}
25 eventQueue chan string // Channel for queuing events
26 sessionID string
27 requestID atomic.Int64
28 notificationChannel chan mcp.JSONRPCNotification
29 initialized atomic.Bool
30 loggingLevel atomic.Value
31 tools sync.Map // stores session-specific tools
32 clientInfo atomic.Value // stores session-specific client info
33}
34
35// SSEContextFunc is a function that takes an existing context and the current
36// request and returns a potentially modified context based on the request
37// content. This can be used to inject context values from headers, for example.
38type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context
39
40// DynamicBasePathFunc allows the user to provide a function to generate the
41// base path for a given request and sessionID. This is useful for cases where
42// the base path is not known at the time of SSE server creation, such as when
43// using a reverse proxy or when the base path is dynamically generated. The
44// function should return the base path (e.g., "/mcp/tenant123").
45type DynamicBasePathFunc func(r *http.Request, sessionID string) string
46
47func (s *sseSession) SessionID() string {
48 return s.sessionID
49}
50
51func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
52 return s.notificationChannel
53}
54
55func (s *sseSession) Initialize() {
56 // set default logging level
57 s.loggingLevel.Store(mcp.LoggingLevelError)
58 s.initialized.Store(true)
59}
60
61func (s *sseSession) Initialized() bool {
62 return s.initialized.Load()
63}
64
65func (s *sseSession) SetLogLevel(level mcp.LoggingLevel) {
66 s.loggingLevel.Store(level)
67}
68
69func (s *sseSession) GetLogLevel() mcp.LoggingLevel {
70 level := s.loggingLevel.Load()
71 if level == nil {
72 return mcp.LoggingLevelError
73 }
74 return level.(mcp.LoggingLevel)
75}
76
77func (s *sseSession) GetSessionTools() map[string]ServerTool {
78 tools := make(map[string]ServerTool)
79 s.tools.Range(func(key, value any) bool {
80 if tool, ok := value.(ServerTool); ok {
81 tools[key.(string)] = tool
82 }
83 return true
84 })
85 return tools
86}
87
88func (s *sseSession) SetSessionTools(tools map[string]ServerTool) {
89 // Clear existing tools
90 s.tools.Clear()
91
92 // Set new tools
93 for name, tool := range tools {
94 s.tools.Store(name, tool)
95 }
96}
97
98func (s *sseSession) GetClientInfo() mcp.Implementation {
99 if value := s.clientInfo.Load(); value != nil {
100 if clientInfo, ok := value.(mcp.Implementation); ok {
101 return clientInfo
102 }
103 }
104 return mcp.Implementation{}
105}
106
107func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) {
108 s.clientInfo.Store(clientInfo)
109}
110
111var (
112 _ ClientSession = (*sseSession)(nil)
113 _ SessionWithTools = (*sseSession)(nil)
114 _ SessionWithLogging = (*sseSession)(nil)
115 _ SessionWithClientInfo = (*sseSession)(nil)
116)
117
118// SSEServer implements a Server-Sent Events (SSE) based MCP server.
119// It provides real-time communication capabilities over HTTP using the SSE protocol.
120type SSEServer struct {
121 server *MCPServer
122 baseURL string
123 basePath string
124 appendQueryToMessageEndpoint bool
125 useFullURLForMessageEndpoint bool
126 messageEndpoint string
127 sseEndpoint string
128 sessions sync.Map
129 srv *http.Server
130 contextFunc SSEContextFunc
131 dynamicBasePathFunc DynamicBasePathFunc
132
133 keepAlive bool
134 keepAliveInterval time.Duration
135
136 mu sync.RWMutex
137}
138
139// SSEOption defines a function type for configuring SSEServer
140type SSEOption func(*SSEServer)
141
142// WithBaseURL sets the base URL for the SSE server
143func WithBaseURL(baseURL string) SSEOption {
144 return func(s *SSEServer) {
145 if baseURL != "" {
146 u, err := url.Parse(baseURL)
147 if err != nil {
148 return
149 }
150 if u.Scheme != "http" && u.Scheme != "https" {
151 return
152 }
153 // Check if the host is empty or only contains a port
154 if u.Host == "" || strings.HasPrefix(u.Host, ":") {
155 return
156 }
157 if len(u.Query()) > 0 {
158 return
159 }
160 }
161 s.baseURL = strings.TrimSuffix(baseURL, "/")
162 }
163}
164
165// WithStaticBasePath adds a new option for setting a static base path
166func WithStaticBasePath(basePath string) SSEOption {
167 return func(s *SSEServer) {
168 s.basePath = normalizeURLPath(basePath)
169 }
170}
171
172// WithBasePath adds a new option for setting a static base path.
173//
174// Deprecated: Use WithStaticBasePath instead. This will be removed in a future version.
175//
176//go:deprecated
177func WithBasePath(basePath string) SSEOption {
178 return WithStaticBasePath(basePath)
179}
180
181// WithDynamicBasePath accepts a function for generating the base path. This is
182// useful for cases where the base path is not known at the time of SSE server
183// creation, such as when using a reverse proxy or when the server is mounted
184// at a dynamic path.
185func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption {
186 return func(s *SSEServer) {
187 if fn != nil {
188 s.dynamicBasePathFunc = func(r *http.Request, sid string) string {
189 bp := fn(r, sid)
190 return normalizeURLPath(bp)
191 }
192 }
193 }
194}
195
196// WithMessageEndpoint sets the message endpoint path
197func WithMessageEndpoint(endpoint string) SSEOption {
198 return func(s *SSEServer) {
199 s.messageEndpoint = endpoint
200 }
201}
202
203// WithAppendQueryToMessageEndpoint configures the SSE server to append the original request's
204// query parameters to the message endpoint URL that is sent to clients during the SSE connection
205// initialization. This is useful when you need to preserve query parameters from the initial
206// SSE connection request and carry them over to subsequent message requests, maintaining
207// context or authentication details across the communication channel.
208func WithAppendQueryToMessageEndpoint() SSEOption {
209 return func(s *SSEServer) {
210 s.appendQueryToMessageEndpoint = true
211 }
212}
213
214// WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL)
215// or just the path portion for the message endpoint. Set to false when clients will concatenate
216// the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message".
217func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption {
218 return func(s *SSEServer) {
219 s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint
220 }
221}
222
223// WithSSEEndpoint sets the SSE endpoint path
224func WithSSEEndpoint(endpoint string) SSEOption {
225 return func(s *SSEServer) {
226 s.sseEndpoint = endpoint
227 }
228}
229
230// WithHTTPServer sets the HTTP server instance.
231// NOTE: When providing a custom HTTP server, you must handle routing yourself
232// If routing is not set up, the server will start but won't handle any MCP requests.
233func WithHTTPServer(srv *http.Server) SSEOption {
234 return func(s *SSEServer) {
235 s.srv = srv
236 }
237}
238
239func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption {
240 return func(s *SSEServer) {
241 s.keepAlive = true
242 s.keepAliveInterval = keepAliveInterval
243 }
244}
245
246func WithKeepAlive(keepAlive bool) SSEOption {
247 return func(s *SSEServer) {
248 s.keepAlive = keepAlive
249 }
250}
251
252// WithSSEContextFunc sets a function that will be called to customise the context
253// to the server using the incoming request.
254func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
255 return func(s *SSEServer) {
256 s.contextFunc = fn
257 }
258}
259
260// NewSSEServer creates a new SSE server instance with the given MCP server and options.
261func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
262 s := &SSEServer{
263 server: server,
264 sseEndpoint: "/sse",
265 messageEndpoint: "/message",
266 useFullURLForMessageEndpoint: true,
267 keepAlive: false,
268 keepAliveInterval: 10 * time.Second,
269 }
270
271 // Apply all options
272 for _, opt := range opts {
273 opt(s)
274 }
275
276 return s
277}
278
279// NewTestServer creates a test server for testing purposes
280func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server {
281 sseServer := NewSSEServer(server, opts...)
282
283 testServer := httptest.NewServer(sseServer)
284 sseServer.baseURL = testServer.URL
285 return testServer
286}
287
288// Start begins serving SSE connections on the specified address.
289// It sets up HTTP handlers for SSE and message endpoints.
290func (s *SSEServer) Start(addr string) error {
291 s.mu.Lock()
292 if s.srv == nil {
293 s.srv = &http.Server{
294 Addr: addr,
295 Handler: s,
296 }
297 } else {
298 if s.srv.Addr == "" {
299 s.srv.Addr = addr
300 } else if s.srv.Addr != addr {
301 return fmt.Errorf("conflicting listen address: WithHTTPServer(%q) vs Start(%q)", s.srv.Addr, addr)
302 }
303 }
304 srv := s.srv
305 s.mu.Unlock()
306
307 return srv.ListenAndServe()
308}
309
310// Shutdown gracefully stops the SSE server, closing all active sessions
311// and shutting down the HTTP server.
312func (s *SSEServer) Shutdown(ctx context.Context) error {
313 s.mu.RLock()
314 srv := s.srv
315 s.mu.RUnlock()
316
317 if srv != nil {
318 s.sessions.Range(func(key, value any) bool {
319 if session, ok := value.(*sseSession); ok {
320 close(session.done)
321 }
322 s.sessions.Delete(key)
323 return true
324 })
325
326 return srv.Shutdown(ctx)
327 }
328 return nil
329}
330
331// handleSSE handles incoming SSE connection requests.
332// It sets up appropriate headers and creates a new session for the client.
333func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
334 if r.Method != http.MethodGet {
335 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
336 return
337 }
338
339 w.Header().Set("Content-Type", "text/event-stream")
340 w.Header().Set("Cache-Control", "no-cache")
341 w.Header().Set("Connection", "keep-alive")
342 w.Header().Set("Access-Control-Allow-Origin", "*")
343
344 flusher, ok := w.(http.Flusher)
345 if !ok {
346 http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
347 return
348 }
349
350 sessionID := uuid.New().String()
351 session := &sseSession{
352 done: make(chan struct{}),
353 eventQueue: make(chan string, 100), // Buffer for events
354 sessionID: sessionID,
355 notificationChannel: make(chan mcp.JSONRPCNotification, 100),
356 }
357
358 s.sessions.Store(sessionID, session)
359 defer s.sessions.Delete(sessionID)
360
361 if err := s.server.RegisterSession(r.Context(), session); err != nil {
362 http.Error(
363 w,
364 fmt.Sprintf("Session registration failed: %v", err),
365 http.StatusInternalServerError,
366 )
367 return
368 }
369 defer s.server.UnregisterSession(r.Context(), sessionID)
370
371 // Start notification handler for this session
372 go func() {
373 for {
374 select {
375 case notification := <-session.notificationChannel:
376 eventData, err := json.Marshal(notification)
377 if err == nil {
378 select {
379 case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
380 // Event queued successfully
381 case <-session.done:
382 return
383 }
384 }
385 case <-session.done:
386 return
387 case <-r.Context().Done():
388 return
389 }
390 }
391 }()
392
393 // Start keep alive : ping
394 if s.keepAlive {
395 go func() {
396 ticker := time.NewTicker(s.keepAliveInterval)
397 defer ticker.Stop()
398 for {
399 select {
400 case <-ticker.C:
401 message := mcp.JSONRPCRequest{
402 JSONRPC: "2.0",
403 ID: mcp.NewRequestId(session.requestID.Add(1)),
404 Request: mcp.Request{
405 Method: "ping",
406 },
407 }
408 messageBytes, _ := json.Marshal(message)
409 pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes)
410 select {
411 case session.eventQueue <- pingMsg:
412 // Message sent successfully
413 case <-session.done:
414 return
415 }
416 case <-session.done:
417 return
418 case <-r.Context().Done():
419 return
420 }
421 }
422 }()
423 }
424
425 // Send the initial endpoint event
426 endpoint := s.GetMessageEndpointForClient(r, sessionID)
427 if s.appendQueryToMessageEndpoint && len(r.URL.RawQuery) > 0 {
428 endpoint += "&" + r.URL.RawQuery
429 }
430 fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", endpoint)
431 flusher.Flush()
432
433 // Main event loop - this runs in the HTTP handler goroutine
434 for {
435 select {
436 case event := <-session.eventQueue:
437 // Write the event to the response
438 fmt.Fprint(w, event)
439 flusher.Flush()
440 case <-r.Context().Done():
441 close(session.done)
442 return
443 case <-session.done:
444 return
445 }
446 }
447}
448
449// GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID
450// for the given request. This is the canonical way to compute the message endpoint for a client.
451// It handles both dynamic and static path modes, and honors the WithUseFullURLForMessageEndpoint flag.
452func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID string) string {
453 basePath := s.basePath
454 if s.dynamicBasePathFunc != nil {
455 basePath = s.dynamicBasePathFunc(r, sessionID)
456 }
457
458 endpointPath := normalizeURLPath(basePath, s.messageEndpoint)
459 if s.useFullURLForMessageEndpoint && s.baseURL != "" {
460 endpointPath = s.baseURL + endpointPath
461 }
462
463 return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID)
464}
465
466// handleMessage processes incoming JSON-RPC messages from clients and sends responses
467// back through the SSE connection and 202 code to HTTP response.
468func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
469 if r.Method != http.MethodPost {
470 s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed")
471 return
472 }
473
474 sessionID := r.URL.Query().Get("sessionId")
475 if sessionID == "" {
476 s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId")
477 return
478 }
479 sessionI, ok := s.sessions.Load(sessionID)
480 if !ok {
481 s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID")
482 return
483 }
484 session := sessionI.(*sseSession)
485
486 // Set the client context before handling the message
487 ctx := s.server.WithContext(r.Context(), session)
488 if s.contextFunc != nil {
489 ctx = s.contextFunc(ctx, r)
490 }
491
492 // Parse message as raw JSON
493 var rawMessage json.RawMessage
494 if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil {
495 s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error")
496 return
497 }
498
499 // Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled.
500 // this is required because the http ctx will be canceled when the client disconnects
501 detachedCtx := context.WithoutCancel(ctx)
502
503 // quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE
504 w.WriteHeader(http.StatusAccepted)
505
506 // Create a new context for handling the message that will be canceled when the message handling is done
507 messageCtx, cancel := context.WithCancel(detachedCtx)
508
509 go func(ctx context.Context) {
510 defer cancel()
511 // Use the context that will be canceled when session is done
512 // Process message through MCPServer
513 response := s.server.HandleMessage(ctx, rawMessage)
514 // Only send response if there is one (not for notifications)
515 if response != nil {
516 var message string
517 if eventData, err := json.Marshal(response); err != nil {
518 // If there is an error marshalling the response, send a generic error response
519 log.Printf("failed to marshal response: %v", err)
520 message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n"
521 } else {
522 message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData)
523 }
524
525 // Queue the event for sending via SSE
526 select {
527 case session.eventQueue <- message:
528 // Event queued successfully
529 case <-session.done:
530 // Session is closed, don't try to queue
531 default:
532 // Queue is full, log this situation
533 log.Printf("Event queue full for session %s", sessionID)
534 }
535 }
536 }(messageCtx)
537}
538
539// writeJSONRPCError writes a JSON-RPC error response with the given error details.
540func (s *SSEServer) writeJSONRPCError(
541 w http.ResponseWriter,
542 id any,
543 code int,
544 message string,
545) {
546 response := createErrorResponse(id, code, message)
547 w.Header().Set("Content-Type", "application/json")
548 w.WriteHeader(http.StatusBadRequest)
549 if err := json.NewEncoder(w).Encode(response); err != nil {
550 http.Error(
551 w,
552 fmt.Sprintf("Failed to encode response: %v", err),
553 http.StatusInternalServerError,
554 )
555 return
556 }
557}
558
559// SendEventToSession sends an event to a specific SSE session identified by sessionID.
560// Returns an error if the session is not found or closed.
561func (s *SSEServer) SendEventToSession(
562 sessionID string,
563 event any,
564) error {
565 sessionI, ok := s.sessions.Load(sessionID)
566 if !ok {
567 return fmt.Errorf("session not found: %s", sessionID)
568 }
569 session := sessionI.(*sseSession)
570
571 eventData, err := json.Marshal(event)
572 if err != nil {
573 return err
574 }
575
576 // Queue the event for sending via SSE
577 select {
578 case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
579 return nil
580 case <-session.done:
581 return fmt.Errorf("session closed")
582 default:
583 return fmt.Errorf("event queue full")
584 }
585}
586
587func (s *SSEServer) GetUrlPath(input string) (string, error) {
588 parse, err := url.Parse(input)
589 if err != nil {
590 return "", fmt.Errorf("failed to parse URL %s: %w", input, err)
591 }
592 return parse.Path, nil
593}
594
595func (s *SSEServer) CompleteSseEndpoint() (string, error) {
596 if s.dynamicBasePathFunc != nil {
597 return "", &ErrDynamicPathConfig{Method: "CompleteSseEndpoint"}
598 }
599
600 path := normalizeURLPath(s.basePath, s.sseEndpoint)
601 return s.baseURL + path, nil
602}
603
604func (s *SSEServer) CompleteSsePath() string {
605 path, err := s.CompleteSseEndpoint()
606 if err != nil {
607 return normalizeURLPath(s.basePath, s.sseEndpoint)
608 }
609 urlPath, err := s.GetUrlPath(path)
610 if err != nil {
611 return normalizeURLPath(s.basePath, s.sseEndpoint)
612 }
613 return urlPath
614}
615
616func (s *SSEServer) CompleteMessageEndpoint() (string, error) {
617 if s.dynamicBasePathFunc != nil {
618 return "", &ErrDynamicPathConfig{Method: "CompleteMessageEndpoint"}
619 }
620 path := normalizeURLPath(s.basePath, s.messageEndpoint)
621 return s.baseURL + path, nil
622}
623
624func (s *SSEServer) CompleteMessagePath() string {
625 path, err := s.CompleteMessageEndpoint()
626 if err != nil {
627 return normalizeURLPath(s.basePath, s.messageEndpoint)
628 }
629 urlPath, err := s.GetUrlPath(path)
630 if err != nil {
631 return normalizeURLPath(s.basePath, s.messageEndpoint)
632 }
633 return urlPath
634}
635
636// SSEHandler returns an http.Handler for the SSE endpoint.
637//
638// This method allows you to mount the SSE handler at any arbitrary path
639// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is
640// intended for advanced scenarios where you want to control the routing or
641// support dynamic segments.
642//
643// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios,
644// you must use the WithDynamicBasePath option to ensure the correct base path
645// is communicated to clients.
646//
647// Example usage:
648//
649// // Advanced/dynamic:
650// sseServer := NewSSEServer(mcpServer,
651// WithDynamicBasePath(func(r *http.Request, sessionID string) string {
652// tenant := r.PathValue("tenant")
653// return "/mcp/" + tenant
654// }),
655// WithBaseURL("http://localhost:8080")
656// )
657// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler())
658// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler())
659//
660// For non-dynamic cases, use ServeHTTP method instead.
661func (s *SSEServer) SSEHandler() http.Handler {
662 return http.HandlerFunc(s.handleSSE)
663}
664
665// MessageHandler returns an http.Handler for the message endpoint.
666//
667// This method allows you to mount the message handler at any arbitrary path
668// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is
669// intended for advanced scenarios where you want to control the routing or
670// support dynamic segments.
671//
672// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios,
673// you must use the WithDynamicBasePath option to ensure the correct base path
674// is communicated to clients.
675//
676// Example usage:
677//
678// // Advanced/dynamic:
679// sseServer := NewSSEServer(mcpServer,
680// WithDynamicBasePath(func(r *http.Request, sessionID string) string {
681// tenant := r.PathValue("tenant")
682// return "/mcp/" + tenant
683// }),
684// WithBaseURL("http://localhost:8080")
685// )
686// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler())
687// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler())
688//
689// For non-dynamic cases, use ServeHTTP method instead.
690func (s *SSEServer) MessageHandler() http.Handler {
691 return http.HandlerFunc(s.handleMessage)
692}
693
694// ServeHTTP implements the http.Handler interface.
695func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
696 if s.dynamicBasePathFunc != nil {
697 http.Error(
698 w,
699 (&ErrDynamicPathConfig{Method: "ServeHTTP"}).Error(),
700 http.StatusInternalServerError,
701 )
702 return
703 }
704 path := r.URL.Path
705 // Use exact path matching rather than Contains
706 ssePath := s.CompleteSsePath()
707 if ssePath != "" && path == ssePath {
708 s.handleSSE(w, r)
709 return
710 }
711 messagePath := s.CompleteMessagePath()
712 if messagePath != "" && path == messagePath {
713 s.handleMessage(w, r)
714 return
715 }
716
717 http.NotFound(w, r)
718}
719
720// normalizeURLPath joins path elements like path.Join but ensures the
721// result always starts with a leading slash and never ends with a slash
722func normalizeURLPath(elem ...string) string {
723 joined := path.Join(elem...)
724
725 // Ensure leading slash
726 if !strings.HasPrefix(joined, "/") {
727 joined = "/" + joined
728 }
729
730 // Remove trailing slash if not just "/"
731 if len(joined) > 1 && strings.HasSuffix(joined, "/") {
732 joined = joined[:len(joined)-1]
733 }
734
735 return joined
736}