1package server
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"net/http/httptest"
 10	"strings"
 11	"sync"
 12	"sync/atomic"
 13	"time"
 14
 15	"github.com/google/uuid"
 16	"github.com/mark3labs/mcp-go/mcp"
 17	"github.com/mark3labs/mcp-go/util"
 18)
 19
 20// StreamableHTTPOption defines a function type for configuring StreamableHTTPServer
 21type StreamableHTTPOption func(*StreamableHTTPServer)
 22
 23// WithEndpointPath sets the endpoint path for the server.
 24// The default is "/mcp".
 25// It's only works for `Start` method. When used as a http.Handler, it has no effect.
 26func WithEndpointPath(endpointPath string) StreamableHTTPOption {
 27	return func(s *StreamableHTTPServer) {
 28		// Normalize the endpoint path to ensure it starts with a slash and doesn't end with one
 29		normalizedPath := "/" + strings.Trim(endpointPath, "/")
 30		s.endpointPath = normalizedPath
 31	}
 32}
 33
 34// WithStateLess sets the server to stateless mode.
 35// If true, the server will manage no session information. Every request will be treated
 36// as a new session. No session id returned to the client.
 37// The default is false.
 38//
 39// Notice: This is a convenience method. It's identical to set WithSessionIdManager option
 40// to StatelessSessionIdManager.
 41func WithStateLess(stateLess bool) StreamableHTTPOption {
 42	return func(s *StreamableHTTPServer) {
 43		if stateLess {
 44			s.sessionIdManager = &StatelessSessionIdManager{}
 45		}
 46	}
 47}
 48
 49// WithSessionIdManager sets a custom session id generator for the server.
 50// By default, the server will use SimpleStatefulSessionIdGenerator, which generates
 51// session ids with uuid, and it's insecure.
 52// Notice: it will override the WithStateLess option.
 53func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption {
 54	return func(s *StreamableHTTPServer) {
 55		s.sessionIdManager = manager
 56	}
 57}
 58
 59// WithHeartbeatInterval sets the heartbeat interval. Positive interval means the
 60// server will send a heartbeat to the client through the GET connection, to keep
 61// the connection alive from being closed by the network infrastructure (e.g.
 62// gateways). If the client does not establish a GET connection, it has no
 63// effect. The default is not to send heartbeats.
 64func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption {
 65	return func(s *StreamableHTTPServer) {
 66		s.listenHeartbeatInterval = interval
 67	}
 68}
 69
 70// WithHTTPContextFunc sets a function that will be called to customise the context
 71// to the server using the incoming request.
 72// This can be used to inject context values from headers, for example.
 73func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption {
 74	return func(s *StreamableHTTPServer) {
 75		s.contextFunc = fn
 76	}
 77}
 78
 79// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer.
 80// NOTE: When providing a custom HTTP server, you must handle routing yourself
 81// If routing is not set up, the server will start but won't handle any MCP requests.
 82func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption {
 83	return func(s *StreamableHTTPServer) {
 84		s.httpServer = srv
 85	}
 86}
 87
 88// WithLogger sets the logger for the server
 89func WithLogger(logger util.Logger) StreamableHTTPOption {
 90	return func(s *StreamableHTTPServer) {
 91		s.logger = logger
 92	}
 93}
 94
 95// StreamableHTTPServer implements a Streamable-http based MCP server.
 96// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
 97// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
 98//
 99// Usage:
100//
101//	server := NewStreamableHTTPServer(mcpServer)
102//	server.Start(":8080") // The final url for client is http://xxxx:8080/mcp by default
103//
104// or the server itself can be used as a http.Handler, which is convenient to
105// integrate with existing http servers, or advanced usage:
106//
107//	handler := NewStreamableHTTPServer(mcpServer)
108//	http.Handle("/streamable-http", handler)
109//	http.ListenAndServe(":8080", nil)
110//
111// Notice:
112// Except for the GET handlers(listening), the POST handlers(request/notification) will
113// not trigger the session registration. So the methods like `SendNotificationToSpecificClient`
114// or `hooks.onRegisterSession` will not be triggered for POST messages.
115//
116// The current implementation does not support the following features from the specification:
117//   - Batching of requests/notifications/responses in arrays.
118//   - Stream Resumability
119type StreamableHTTPServer struct {
120	server            *MCPServer
121	sessionTools      *sessionToolsStore
122	sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64)
123
124	httpServer *http.Server
125	mu         sync.RWMutex
126
127	endpointPath            string
128	contextFunc             HTTPContextFunc
129	sessionIdManager        SessionIdManager
130	listenHeartbeatInterval time.Duration
131	logger                  util.Logger
132}
133
134// NewStreamableHTTPServer creates a new streamable-http server instance
135func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer {
136	s := &StreamableHTTPServer{
137		server:           server,
138		sessionTools:     newSessionToolsStore(),
139		endpointPath:     "/mcp",
140		sessionIdManager: &InsecureStatefulSessionIdManager{},
141		logger:           util.DefaultLogger(),
142	}
143
144	// Apply all options
145	for _, opt := range opts {
146		opt(s)
147	}
148	return s
149}
150
151// ServeHTTP implements the http.Handler interface.
152func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
153	switch r.Method {
154	case http.MethodPost:
155		s.handlePost(w, r)
156	case http.MethodGet:
157		s.handleGet(w, r)
158	case http.MethodDelete:
159		s.handleDelete(w, r)
160	default:
161		http.NotFound(w, r)
162	}
163}
164
165// Start begins serving the http server on the specified address and path
166// (endpointPath). like:
167//
168//	s.Start(":8080")
169func (s *StreamableHTTPServer) Start(addr string) error {
170	s.mu.Lock()
171	if s.httpServer == nil {
172		mux := http.NewServeMux()
173		mux.Handle(s.endpointPath, s)
174		s.httpServer = &http.Server{
175			Addr:    addr,
176			Handler: mux,
177		}
178	} else {
179		if s.httpServer.Addr == "" {
180			s.httpServer.Addr = addr
181		} else if s.httpServer.Addr != addr {
182			return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr)
183		}
184	}
185	srv := s.httpServer
186	s.mu.Unlock()
187
188	return srv.ListenAndServe()
189}
190
191// Shutdown gracefully stops the server, closing all active sessions
192// and shutting down the HTTP server.
193func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
194
195	// shutdown the server if needed (may use as a http.Handler)
196	s.mu.RLock()
197	srv := s.httpServer
198	s.mu.RUnlock()
199	if srv != nil {
200		return srv.Shutdown(ctx)
201	}
202	return nil
203}
204
205// --- internal methods ---
206
207const (
208	headerKeySessionID = "Mcp-Session-Id"
209)
210
211func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {
212	// post request carry request/notification message
213
214	// Check content type
215	contentType := r.Header.Get("Content-Type")
216	if contentType != "application/json" {
217		http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest)
218		return
219	}
220
221	// Check the request body is valid json, meanwhile, get the request Method
222	rawData, err := io.ReadAll(r.Body)
223	if err != nil {
224		s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err))
225		return
226	}
227	var baseMessage struct {
228		Method mcp.MCPMethod `json:"method"`
229	}
230	if err := json.Unmarshal(rawData, &baseMessage); err != nil {
231		s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json")
232		return
233	}
234	isInitializeRequest := baseMessage.Method == mcp.MethodInitialize
235
236	// Prepare the session for the mcp server
237	// The session is ephemeral. Its life is the same as the request. It's only created
238	// for interaction with the mcp server.
239	var sessionID string
240	if isInitializeRequest {
241		// generate a new one for initialize request
242		sessionID = s.sessionIdManager.Generate()
243	} else {
244		// Get session ID from header.
245		// Stateful servers need the client to carry the session ID.
246		sessionID = r.Header.Get(headerKeySessionID)
247		isTerminated, err := s.sessionIdManager.Validate(sessionID)
248		if err != nil {
249			http.Error(w, "Invalid session ID", http.StatusBadRequest)
250			return
251		}
252		if isTerminated {
253			http.Error(w, "Session terminated", http.StatusNotFound)
254			return
255		}
256	}
257
258	session := newStreamableHttpSession(sessionID, s.sessionTools)
259
260	// Set the client context before handling the message
261	ctx := s.server.WithContext(r.Context(), session)
262	if s.contextFunc != nil {
263		ctx = s.contextFunc(ctx, r)
264	}
265
266	// handle potential notifications
267	mu := sync.Mutex{}
268	upgradedHeader := false
269	done := make(chan struct{})
270
271	go func() {
272		for {
273			select {
274			case nt := <-session.notificationChannel:
275				func() {
276					mu.Lock()
277					defer mu.Unlock()
278					// if the done chan is closed, as the request is terminated, just return
279					select {
280					case <-done:
281						return
282					default:
283					}
284					defer func() {
285						flusher, ok := w.(http.Flusher)
286						if ok {
287							flusher.Flush()
288						}
289					}()
290
291					// if there's notifications, upgradedHeader to SSE response
292					if !upgradedHeader {
293						w.Header().Set("Content-Type", "text/event-stream")
294						w.Header().Set("Connection", "keep-alive")
295						w.Header().Set("Cache-Control", "no-cache")
296						w.WriteHeader(http.StatusAccepted)
297						upgradedHeader = true
298					}
299					err := writeSSEEvent(w, nt)
300					if err != nil {
301						s.logger.Errorf("Failed to write SSE event: %v", err)
302						return
303					}
304				}()
305			case <-done:
306				return
307			case <-ctx.Done():
308				return
309			}
310		}
311	}()
312
313	// Process message through MCPServer
314	response := s.server.HandleMessage(ctx, rawData)
315	if response == nil {
316		// For notifications, just send 202 Accepted with no body
317		w.WriteHeader(http.StatusAccepted)
318		return
319	}
320
321	// Write response
322	mu.Lock()
323	defer mu.Unlock()
324	// close the done chan before unlock
325	defer close(done)
326	if ctx.Err() != nil {
327		return
328	}
329	// If client-server communication already upgraded to SSE stream
330	if session.upgradeToSSE.Load() {
331		if !upgradedHeader {
332			w.Header().Set("Content-Type", "text/event-stream")
333			w.Header().Set("Connection", "keep-alive")
334			w.Header().Set("Cache-Control", "no-cache")
335			w.WriteHeader(http.StatusAccepted)
336			upgradedHeader = true
337		}
338		if err := writeSSEEvent(w, response); err != nil {
339			s.logger.Errorf("Failed to write final SSE response event: %v", err)
340		}
341	} else {
342		w.Header().Set("Content-Type", "application/json")
343		if isInitializeRequest && sessionID != "" {
344			// send the session ID back to the client
345			w.Header().Set(headerKeySessionID, sessionID)
346		}
347		w.WriteHeader(http.StatusOK)
348		err := json.NewEncoder(w).Encode(response)
349		if err != nil {
350			s.logger.Errorf("Failed to write response: %v", err)
351		}
352	}
353}
354
355func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {
356	// get request is for listening to notifications
357	// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
358
359	sessionID := r.Header.Get(headerKeySessionID)
360	// the specification didn't say we should validate the session id
361
362	if sessionID == "" {
363		// It's a stateless server,
364		// but the MCP server requires a unique ID for registering, so we use a random one
365		sessionID = uuid.New().String()
366	}
367
368	session := newStreamableHttpSession(sessionID, s.sessionTools)
369	if err := s.server.RegisterSession(r.Context(), session); err != nil {
370		http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest)
371		return
372	}
373	defer s.server.UnregisterSession(r.Context(), sessionID)
374
375	// Set the client context before handling the message
376	w.Header().Set("Content-Type", "text/event-stream")
377	w.Header().Set("Cache-Control", "no-cache")
378	w.Header().Set("Connection", "keep-alive")
379	w.WriteHeader(http.StatusOK)
380
381	flusher, ok := w.(http.Flusher)
382	if !ok {
383		http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
384		return
385	}
386	flusher.Flush()
387
388	// Start notification handler for this session
389	done := make(chan struct{})
390	defer close(done)
391	writeChan := make(chan any, 16)
392
393	go func() {
394		for {
395			select {
396			case nt := <-session.notificationChannel:
397				select {
398				case writeChan <- &nt:
399				case <-done:
400					return
401				}
402			case <-done:
403				return
404			}
405		}
406	}()
407
408	if s.listenHeartbeatInterval > 0 {
409		// heartbeat to keep the connection alive
410		go func() {
411			ticker := time.NewTicker(s.listenHeartbeatInterval)
412			defer ticker.Stop()
413			for {
414				select {
415				case <-ticker.C:
416					message := mcp.JSONRPCRequest{
417						JSONRPC: "2.0",
418						ID:      mcp.NewRequestId(s.nextRequestID(sessionID)),
419						Request: mcp.Request{
420							Method: "ping",
421						},
422					}
423					select {
424					case writeChan <- message:
425					case <-done:
426						return
427					}
428				case <-done:
429					return
430				}
431			}
432		}()
433	}
434
435	// Keep the connection open until the client disconnects
436	//
437	// There's will a Available() check when handler ends, and it maybe race with Flush(),
438	// so we use a separate channel to send the data, inteading of flushing directly in other goroutine.
439	for {
440		select {
441		case data := <-writeChan:
442			if data == nil {
443				continue
444			}
445			if err := writeSSEEvent(w, data); err != nil {
446				s.logger.Errorf("Failed to write SSE event: %v", err)
447				return
448			}
449			flusher.Flush()
450		case <-r.Context().Done():
451			return
452		}
453	}
454}
455
456func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) {
457	// delete request terminate the session
458	sessionID := r.Header.Get(headerKeySessionID)
459	notAllowed, err := s.sessionIdManager.Terminate(sessionID)
460	if err != nil {
461		http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError)
462		return
463	}
464	if notAllowed {
465		http.Error(w, "Session termination not allowed", http.StatusMethodNotAllowed)
466		return
467	}
468
469	// remove the session relateddata from the sessionToolsStore
470	s.sessionTools.delete(sessionID)
471
472	// remove current session's requstID information
473	s.sessionRequestIDs.Delete(sessionID)
474
475	w.WriteHeader(http.StatusOK)
476}
477
478func writeSSEEvent(w io.Writer, data any) error {
479	jsonData, err := json.Marshal(data)
480	if err != nil {
481		return fmt.Errorf("failed to marshal data: %w", err)
482	}
483	_, err = fmt.Fprintf(w, "event: message\ndata: %s\n\n", jsonData)
484	if err != nil {
485		return fmt.Errorf("failed to write SSE event: %w", err)
486	}
487	return nil
488}
489
490// writeJSONRPCError writes a JSON-RPC error response with the given error details.
491func (s *StreamableHTTPServer) writeJSONRPCError(
492	w http.ResponseWriter,
493	id any,
494	code int,
495	message string,
496) {
497	response := createErrorResponse(id, code, message)
498	w.Header().Set("Content-Type", "application/json")
499	w.WriteHeader(http.StatusBadRequest)
500	err := json.NewEncoder(w).Encode(response)
501	if err != nil {
502		s.logger.Errorf("Failed to write JSONRPCError: %v", err)
503	}
504}
505
506// nextRequestID gets the next incrementing requestID for the current session
507func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 {
508	actual, _ := s.sessionRequestIDs.LoadOrStore(sessionID, new(atomic.Int64))
509	counter := actual.(*atomic.Int64)
510	return counter.Add(1)
511}
512
513// --- session ---
514
515type sessionToolsStore struct {
516	mu    sync.RWMutex
517	tools map[string]map[string]ServerTool // sessionID -> toolName -> tool
518}
519
520func newSessionToolsStore() *sessionToolsStore {
521	return &sessionToolsStore{
522		tools: make(map[string]map[string]ServerTool),
523	}
524}
525
526func (s *sessionToolsStore) get(sessionID string) map[string]ServerTool {
527	s.mu.RLock()
528	defer s.mu.RUnlock()
529	return s.tools[sessionID]
530}
531
532func (s *sessionToolsStore) set(sessionID string, tools map[string]ServerTool) {
533	s.mu.Lock()
534	defer s.mu.Unlock()
535	s.tools[sessionID] = tools
536}
537
538func (s *sessionToolsStore) delete(sessionID string) {
539	s.mu.Lock()
540	defer s.mu.Unlock()
541	delete(s.tools, sessionID)
542}
543
544// streamableHttpSession is a session for streamable-http transport
545// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler.
546// When in GET handlers(listening), it's a real session, and will be registered in the MCP server.
547type streamableHttpSession struct {
548	sessionID           string
549	notificationChannel chan mcp.JSONRPCNotification // server -> client notifications
550	tools               *sessionToolsStore
551	upgradeToSSE        atomic.Bool
552}
553
554func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession {
555	return &streamableHttpSession{
556		sessionID:           sessionID,
557		notificationChannel: make(chan mcp.JSONRPCNotification, 100),
558		tools:               toolStore,
559	}
560}
561
562func (s *streamableHttpSession) SessionID() string {
563	return s.sessionID
564}
565
566func (s *streamableHttpSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
567	return s.notificationChannel
568}
569
570func (s *streamableHttpSession) Initialize() {
571	// do nothing
572	// the session is ephemeral, no real initialized action needed
573}
574
575func (s *streamableHttpSession) Initialized() bool {
576	// the session is ephemeral, no real initialized action needed
577	return true
578}
579
580var _ ClientSession = (*streamableHttpSession)(nil)
581
582func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool {
583	return s.tools.get(s.sessionID)
584}
585
586func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) {
587	s.tools.set(s.sessionID, tools)
588}
589
590var _ SessionWithTools = (*streamableHttpSession)(nil)
591
592func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {
593	s.upgradeToSSE.Store(true)
594}
595
596var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
597
598// --- session id manager ---
599
600type SessionIdManager interface {
601	Generate() string
602	// Validate checks if a session ID is valid and not terminated.
603	// Returns isTerminated=true if the ID is valid but belongs to a terminated session.
604	// Returns err!=nil if the ID format is invalid or lookup failed.
605	Validate(sessionID string) (isTerminated bool, err error)
606	// Terminate marks a session ID as terminated.
607	// Returns isNotAllowed=true if the server policy prevents client termination.
608	// Returns err!=nil if the ID is invalid or termination failed.
609	Terminate(sessionID string) (isNotAllowed bool, err error)
610}
611
612// StatelessSessionIdManager does nothing, which means it has no session management, which is stateless.
613type StatelessSessionIdManager struct{}
614
615func (s *StatelessSessionIdManager) Generate() string {
616	return ""
617}
618func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
619	// In stateless mode, ignore session IDs completely - don't validate or reject them
620	return false, nil
621}
622func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
623	return false, nil
624}
625
626// InsecureStatefulSessionIdManager generate id with uuid
627// It won't validate the id indeed, so it could be fake.
628// For more secure session id, use a more complex generator, like a JWT.
629type InsecureStatefulSessionIdManager struct{}
630
631const idPrefix = "mcp-session-"
632
633func (s *InsecureStatefulSessionIdManager) Generate() string {
634	return idPrefix + uuid.New().String()
635}
636func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
637	// validate the session id is a valid uuid
638	if !strings.HasPrefix(sessionID, idPrefix) {
639		return false, fmt.Errorf("invalid session id: %s", sessionID)
640	}
641	if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil {
642		return false, fmt.Errorf("invalid session id: %s", sessionID)
643	}
644	return false, nil
645}
646func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
647	return false, nil
648}
649
650// NewTestStreamableHTTPServer creates a test server for testing purposes
651func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *httptest.Server {
652	sseServer := NewStreamableHTTPServer(server, opts...)
653	testServer := httptest.NewServer(sseServer)
654	return testServer
655}