1package server
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"net/http/httptest"
 10	"strings"
 11	"sync"
 12	"sync/atomic"
 13	"time"
 14
 15	"github.com/google/uuid"
 16	"github.com/mark3labs/mcp-go/mcp"
 17	"github.com/mark3labs/mcp-go/util"
 18)
 19
 20// StreamableHTTPOption defines a function type for configuring StreamableHTTPServer
 21type StreamableHTTPOption func(*StreamableHTTPServer)
 22
 23// WithEndpointPath sets the endpoint path for the server.
 24// The default is "/mcp".
 25// It's only works for `Start` method. When used as a http.Handler, it has no effect.
 26func WithEndpointPath(endpointPath string) StreamableHTTPOption {
 27	return func(s *StreamableHTTPServer) {
 28		// Normalize the endpoint path to ensure it starts with a slash and doesn't end with one
 29		normalizedPath := "/" + strings.Trim(endpointPath, "/")
 30		s.endpointPath = normalizedPath
 31	}
 32}
 33
 34// WithStateLess sets the server to stateless mode.
 35// If true, the server will manage no session information. Every request will be treated
 36// as a new session. No session id returned to the client.
 37// The default is false.
 38//
 39// Notice: This is a convenience method. It's identical to set WithSessionIdManager option
 40// to StatelessSessionIdManager.
 41func WithStateLess(stateLess bool) StreamableHTTPOption {
 42	return func(s *StreamableHTTPServer) {
 43		s.sessionIdManager = &StatelessSessionIdManager{}
 44	}
 45}
 46
 47// WithSessionIdManager sets a custom session id generator for the server.
 48// By default, the server will use SimpleStatefulSessionIdGenerator, which generates
 49// session ids with uuid, and it's insecure.
 50// Notice: it will override the WithStateLess option.
 51func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption {
 52	return func(s *StreamableHTTPServer) {
 53		s.sessionIdManager = manager
 54	}
 55}
 56
 57// WithHeartbeatInterval sets the heartbeat interval. Positive interval means the
 58// server will send a heartbeat to the client through the GET connection, to keep
 59// the connection alive from being closed by the network infrastructure (e.g.
 60// gateways). If the client does not establish a GET connection, it has no
 61// effect. The default is not to send heartbeats.
 62func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption {
 63	return func(s *StreamableHTTPServer) {
 64		s.listenHeartbeatInterval = interval
 65	}
 66}
 67
 68// WithHTTPContextFunc sets a function that will be called to customise the context
 69// to the server using the incoming request.
 70// This can be used to inject context values from headers, for example.
 71func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption {
 72	return func(s *StreamableHTTPServer) {
 73		s.contextFunc = fn
 74	}
 75}
 76
 77// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer.
 78// NOTE: When providing a custom HTTP server, you must handle routing yourself
 79// If routing is not set up, the server will start but won't handle any MCP requests.
 80func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption {
 81	return func(s *StreamableHTTPServer) {
 82		s.httpServer = srv
 83	}
 84}
 85
 86// WithLogger sets the logger for the server
 87func WithLogger(logger util.Logger) StreamableHTTPOption {
 88	return func(s *StreamableHTTPServer) {
 89		s.logger = logger
 90	}
 91}
 92
 93// StreamableHTTPServer implements a Streamable-http based MCP server.
 94// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
 95// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
 96//
 97// Usage:
 98//
 99//	server := NewStreamableHTTPServer(mcpServer)
100//	server.Start(":8080") // The final url for client is http://xxxx:8080/mcp by default
101//
102// or the server itself can be used as a http.Handler, which is convenient to
103// integrate with existing http servers, or advanced usage:
104//
105//	handler := NewStreamableHTTPServer(mcpServer)
106//	http.Handle("/streamable-http", handler)
107//	http.ListenAndServe(":8080", nil)
108//
109// Notice:
110// Except for the GET handlers(listening), the POST handlers(request/notification) will
111// not trigger the session registration. So the methods like `SendNotificationToSpecificClient`
112// or `hooks.onRegisterSession` will not be triggered for POST messages.
113//
114// The current implementation does not support the following features from the specification:
115//   - Batching of requests/notifications/responses in arrays.
116//   - Stream Resumability
117type StreamableHTTPServer struct {
118	server            *MCPServer
119	sessionTools      *sessionToolsStore
120	sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64)
121
122	httpServer *http.Server
123	mu         sync.RWMutex
124
125	endpointPath            string
126	contextFunc             HTTPContextFunc
127	sessionIdManager        SessionIdManager
128	listenHeartbeatInterval time.Duration
129	logger                  util.Logger
130}
131
132// NewStreamableHTTPServer creates a new streamable-http server instance
133func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer {
134	s := &StreamableHTTPServer{
135		server:           server,
136		sessionTools:     newSessionToolsStore(),
137		endpointPath:     "/mcp",
138		sessionIdManager: &InsecureStatefulSessionIdManager{},
139		logger:           util.DefaultLogger(),
140	}
141
142	// Apply all options
143	for _, opt := range opts {
144		opt(s)
145	}
146	return s
147}
148
149// ServeHTTP implements the http.Handler interface.
150func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
151	switch r.Method {
152	case http.MethodPost:
153		s.handlePost(w, r)
154	case http.MethodGet:
155		s.handleGet(w, r)
156	case http.MethodDelete:
157		s.handleDelete(w, r)
158	default:
159		http.NotFound(w, r)
160	}
161}
162
163// Start begins serving the http server on the specified address and path
164// (endpointPath). like:
165//
166//	s.Start(":8080")
167func (s *StreamableHTTPServer) Start(addr string) error {
168	s.mu.Lock()
169	if s.httpServer == nil {
170		mux := http.NewServeMux()
171		mux.Handle(s.endpointPath, s)
172		s.httpServer = &http.Server{
173			Addr:    addr,
174			Handler: mux,
175		}
176	} else {
177		if s.httpServer.Addr == "" {
178			s.httpServer.Addr = addr
179		} else if s.httpServer.Addr != addr {
180			return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr)
181		}
182	}
183	srv := s.httpServer
184	s.mu.Unlock()
185
186	return srv.ListenAndServe()
187}
188
189// Shutdown gracefully stops the server, closing all active sessions
190// and shutting down the HTTP server.
191func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
192
193	// shutdown the server if needed (may use as a http.Handler)
194	s.mu.RLock()
195	srv := s.httpServer
196	s.mu.RUnlock()
197	if srv != nil {
198		return srv.Shutdown(ctx)
199	}
200	return nil
201}
202
203// --- internal methods ---
204
205const (
206	headerKeySessionID = "Mcp-Session-Id"
207)
208
209func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {
210	// post request carry request/notification message
211
212	// Check content type
213	contentType := r.Header.Get("Content-Type")
214	if contentType != "application/json" {
215		http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest)
216		return
217	}
218
219	// Check the request body is valid json, meanwhile, get the request Method
220	rawData, err := io.ReadAll(r.Body)
221	if err != nil {
222		s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err))
223		return
224	}
225	var baseMessage struct {
226		Method mcp.MCPMethod `json:"method"`
227	}
228	if err := json.Unmarshal(rawData, &baseMessage); err != nil {
229		s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json")
230		return
231	}
232	isInitializeRequest := baseMessage.Method == mcp.MethodInitialize
233
234	// Prepare the session for the mcp server
235	// The session is ephemeral. Its life is the same as the request. It's only created
236	// for interaction with the mcp server.
237	var sessionID string
238	if isInitializeRequest {
239		// generate a new one for initialize request
240		sessionID = s.sessionIdManager.Generate()
241	} else {
242		// Get session ID from header.
243		// Stateful servers need the client to carry the session ID.
244		sessionID = r.Header.Get(headerKeySessionID)
245		isTerminated, err := s.sessionIdManager.Validate(sessionID)
246		if err != nil {
247			http.Error(w, "Invalid session ID", http.StatusBadRequest)
248			return
249		}
250		if isTerminated {
251			http.Error(w, "Session terminated", http.StatusNotFound)
252			return
253		}
254	}
255
256	session := newStreamableHttpSession(sessionID, s.sessionTools)
257
258	// Set the client context before handling the message
259	ctx := s.server.WithContext(r.Context(), session)
260	if s.contextFunc != nil {
261		ctx = s.contextFunc(ctx, r)
262	}
263
264	// handle potential notifications
265	mu := sync.Mutex{}
266	upgradedHeader := false
267	done := make(chan struct{})
268
269	go func() {
270		for {
271			select {
272			case nt := <-session.notificationChannel:
273				func() {
274					mu.Lock()
275					defer mu.Unlock()
276					// if the done chan is closed, as the request is terminated, just return
277					select {
278					case <-done:
279						return
280					default:
281					}
282					defer func() {
283						flusher, ok := w.(http.Flusher)
284						if ok {
285							flusher.Flush()
286						}
287					}()
288
289					// if there's notifications, upgradedHeader to SSE response
290					if !upgradedHeader {
291						w.Header().Set("Content-Type", "text/event-stream")
292						w.Header().Set("Connection", "keep-alive")
293						w.Header().Set("Cache-Control", "no-cache")
294						w.WriteHeader(http.StatusAccepted)
295						upgradedHeader = true
296					}
297					err := writeSSEEvent(w, nt)
298					if err != nil {
299						s.logger.Errorf("Failed to write SSE event: %v", err)
300						return
301					}
302				}()
303			case <-done:
304				return
305			case <-ctx.Done():
306				return
307			}
308		}
309	}()
310
311	// Process message through MCPServer
312	response := s.server.HandleMessage(ctx, rawData)
313	if response == nil {
314		// For notifications, just send 202 Accepted with no body
315		w.WriteHeader(http.StatusAccepted)
316		return
317	}
318
319	// Write response
320	mu.Lock()
321	defer mu.Unlock()
322	// close the done chan before unlock
323	defer close(done)
324	if ctx.Err() != nil {
325		return
326	}
327	// If client-server communication already upgraded to SSE stream
328	if session.upgradeToSSE.Load() {
329		if !upgradedHeader {
330			w.Header().Set("Content-Type", "text/event-stream")
331			w.Header().Set("Connection", "keep-alive")
332			w.Header().Set("Cache-Control", "no-cache")
333			w.WriteHeader(http.StatusAccepted)
334			upgradedHeader = true
335		}
336		if err := writeSSEEvent(w, response); err != nil {
337			s.logger.Errorf("Failed to write final SSE response event: %v", err)
338		}
339	} else {
340		w.Header().Set("Content-Type", "application/json")
341		if isInitializeRequest && sessionID != "" {
342			// send the session ID back to the client
343			w.Header().Set(headerKeySessionID, sessionID)
344		}
345		w.WriteHeader(http.StatusOK)
346		err := json.NewEncoder(w).Encode(response)
347		if err != nil {
348			s.logger.Errorf("Failed to write response: %v", err)
349		}
350	}
351}
352
353func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {
354	// get request is for listening to notifications
355	// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
356
357	sessionID := r.Header.Get(headerKeySessionID)
358	// the specification didn't say we should validate the session id
359
360	if sessionID == "" {
361		// It's a stateless server,
362		// but the MCP server requires a unique ID for registering, so we use a random one
363		sessionID = uuid.New().String()
364	}
365
366	session := newStreamableHttpSession(sessionID, s.sessionTools)
367	if err := s.server.RegisterSession(r.Context(), session); err != nil {
368		http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest)
369		return
370	}
371	defer s.server.UnregisterSession(r.Context(), sessionID)
372
373	// Set the client context before handling the message
374	w.Header().Set("Content-Type", "text/event-stream")
375	w.Header().Set("Cache-Control", "no-cache")
376	w.Header().Set("Connection", "keep-alive")
377	w.WriteHeader(http.StatusAccepted)
378
379	flusher, ok := w.(http.Flusher)
380	if !ok {
381		http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
382		return
383	}
384	flusher.Flush()
385
386	// Start notification handler for this session
387	done := make(chan struct{})
388	defer close(done)
389	writeChan := make(chan any, 16)
390
391	go func() {
392		for {
393			select {
394			case nt := <-session.notificationChannel:
395				select {
396				case writeChan <- &nt:
397				case <-done:
398					return
399				}
400			case <-done:
401				return
402			}
403		}
404	}()
405
406	if s.listenHeartbeatInterval > 0 {
407		// heartbeat to keep the connection alive
408		go func() {
409			ticker := time.NewTicker(s.listenHeartbeatInterval)
410			defer ticker.Stop()
411			for {
412				select {
413				case <-ticker.C:
414					message := mcp.JSONRPCRequest{
415						JSONRPC: "2.0",
416						ID:      mcp.NewRequestId(s.nextRequestID(sessionID)),
417						Request: mcp.Request{
418							Method: "ping",
419						},
420					}
421					select {
422					case writeChan <- message:
423					case <-done:
424						return
425					}
426				case <-done:
427					return
428				}
429			}
430		}()
431	}
432
433	// Keep the connection open until the client disconnects
434	//
435	// There's will a Available() check when handler ends, and it maybe race with Flush(),
436	// so we use a separate channel to send the data, inteading of flushing directly in other goroutine.
437	for {
438		select {
439		case data := <-writeChan:
440			if data == nil {
441				continue
442			}
443			if err := writeSSEEvent(w, data); err != nil {
444				s.logger.Errorf("Failed to write SSE event: %v", err)
445				return
446			}
447			flusher.Flush()
448		case <-r.Context().Done():
449			return
450		}
451	}
452}
453
454func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) {
455	// delete request terminate the session
456	sessionID := r.Header.Get(headerKeySessionID)
457	notAllowed, err := s.sessionIdManager.Terminate(sessionID)
458	if err != nil {
459		http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError)
460		return
461	}
462	if notAllowed {
463		http.Error(w, "Session termination not allowed", http.StatusMethodNotAllowed)
464		return
465	}
466
467	// remove the session relateddata from the sessionToolsStore
468	s.sessionTools.delete(sessionID)
469
470	// remove current session's requstID information
471	s.sessionRequestIDs.Delete(sessionID)
472
473	w.WriteHeader(http.StatusOK)
474}
475
476func writeSSEEvent(w io.Writer, data any) error {
477	jsonData, err := json.Marshal(data)
478	if err != nil {
479		return fmt.Errorf("failed to marshal data: %w", err)
480	}
481	_, err = fmt.Fprintf(w, "event: message\ndata: %s\n\n", jsonData)
482	if err != nil {
483		return fmt.Errorf("failed to write SSE event: %w", err)
484	}
485	return nil
486}
487
488// writeJSONRPCError writes a JSON-RPC error response with the given error details.
489func (s *StreamableHTTPServer) writeJSONRPCError(
490	w http.ResponseWriter,
491	id any,
492	code int,
493	message string,
494) {
495	response := createErrorResponse(id, code, message)
496	w.Header().Set("Content-Type", "application/json")
497	w.WriteHeader(http.StatusBadRequest)
498	err := json.NewEncoder(w).Encode(response)
499	if err != nil {
500		s.logger.Errorf("Failed to write JSONRPCError: %v", err)
501	}
502}
503
504// nextRequestID gets the next incrementing requestID for the current session
505func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 {
506	actual, _ := s.sessionRequestIDs.LoadOrStore(sessionID, new(atomic.Int64))
507	counter := actual.(*atomic.Int64)
508	return counter.Add(1)
509}
510
511// --- session ---
512
513type sessionToolsStore struct {
514	mu    sync.RWMutex
515	tools map[string]map[string]ServerTool // sessionID -> toolName -> tool
516}
517
518func newSessionToolsStore() *sessionToolsStore {
519	return &sessionToolsStore{
520		tools: make(map[string]map[string]ServerTool),
521	}
522}
523
524func (s *sessionToolsStore) get(sessionID string) map[string]ServerTool {
525	s.mu.RLock()
526	defer s.mu.RUnlock()
527	return s.tools[sessionID]
528}
529
530func (s *sessionToolsStore) set(sessionID string, tools map[string]ServerTool) {
531	s.mu.Lock()
532	defer s.mu.Unlock()
533	s.tools[sessionID] = tools
534}
535
536func (s *sessionToolsStore) delete(sessionID string) {
537	s.mu.Lock()
538	defer s.mu.Unlock()
539	delete(s.tools, sessionID)
540}
541
542// streamableHttpSession is a session for streamable-http transport
543// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler.
544// When in GET handlers(listening), it's a real session, and will be registered in the MCP server.
545type streamableHttpSession struct {
546	sessionID           string
547	notificationChannel chan mcp.JSONRPCNotification // server -> client notifications
548	tools               *sessionToolsStore
549	upgradeToSSE        atomic.Bool
550}
551
552func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession {
553	return &streamableHttpSession{
554		sessionID:           sessionID,
555		notificationChannel: make(chan mcp.JSONRPCNotification, 100),
556		tools:               toolStore,
557	}
558}
559
560func (s *streamableHttpSession) SessionID() string {
561	return s.sessionID
562}
563
564func (s *streamableHttpSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
565	return s.notificationChannel
566}
567
568func (s *streamableHttpSession) Initialize() {
569	// do nothing
570	// the session is ephemeral, no real initialized action needed
571}
572
573func (s *streamableHttpSession) Initialized() bool {
574	// the session is ephemeral, no real initialized action needed
575	return true
576}
577
578var _ ClientSession = (*streamableHttpSession)(nil)
579
580func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool {
581	return s.tools.get(s.sessionID)
582}
583
584func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) {
585	s.tools.set(s.sessionID, tools)
586}
587
588var _ SessionWithTools = (*streamableHttpSession)(nil)
589
590func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {
591	s.upgradeToSSE.Store(true)
592}
593
594var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
595
596// --- session id manager ---
597
598type SessionIdManager interface {
599	Generate() string
600	// Validate checks if a session ID is valid and not terminated.
601	// Returns isTerminated=true if the ID is valid but belongs to a terminated session.
602	// Returns err!=nil if the ID format is invalid or lookup failed.
603	Validate(sessionID string) (isTerminated bool, err error)
604	// Terminate marks a session ID as terminated.
605	// Returns isNotAllowed=true if the server policy prevents client termination.
606	// Returns err!=nil if the ID is invalid or termination failed.
607	Terminate(sessionID string) (isNotAllowed bool, err error)
608}
609
610// StatelessSessionIdManager does nothing, which means it has no session management, which is stateless.
611type StatelessSessionIdManager struct{}
612
613func (s *StatelessSessionIdManager) Generate() string {
614	return ""
615}
616func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
617	// In stateless mode, ignore session IDs completely - don't validate or reject them
618	return false, nil
619}
620func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
621	return false, nil
622}
623
624// InsecureStatefulSessionIdManager generate id with uuid
625// It won't validate the id indeed, so it could be fake.
626// For more secure session id, use a more complex generator, like a JWT.
627type InsecureStatefulSessionIdManager struct{}
628
629const idPrefix = "mcp-session-"
630
631func (s *InsecureStatefulSessionIdManager) Generate() string {
632	return idPrefix + uuid.New().String()
633}
634func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
635	// validate the session id is a valid uuid
636	if !strings.HasPrefix(sessionID, idPrefix) {
637		return false, fmt.Errorf("invalid session id: %s", sessionID)
638	}
639	if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil {
640		return false, fmt.Errorf("invalid session id: %s", sessionID)
641	}
642	return false, nil
643}
644func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
645	return false, nil
646}
647
648// NewTestStreamableHTTPServer creates a test server for testing purposes
649func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *httptest.Server {
650	sseServer := NewStreamableHTTPServer(server, opts...)
651	testServer := httptest.NewServer(sseServer)
652	return testServer
653}