1package server
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "sync"
12 "sync/atomic"
13 "time"
14
15 "github.com/google/uuid"
16 "github.com/mark3labs/mcp-go/mcp"
17 "github.com/mark3labs/mcp-go/util"
18)
19
20// StreamableHTTPOption defines a function type for configuring StreamableHTTPServer
21type StreamableHTTPOption func(*StreamableHTTPServer)
22
23// WithEndpointPath sets the endpoint path for the server.
24// The default is "/mcp".
25// It's only works for `Start` method. When used as a http.Handler, it has no effect.
26func WithEndpointPath(endpointPath string) StreamableHTTPOption {
27 return func(s *StreamableHTTPServer) {
28 // Normalize the endpoint path to ensure it starts with a slash and doesn't end with one
29 normalizedPath := "/" + strings.Trim(endpointPath, "/")
30 s.endpointPath = normalizedPath
31 }
32}
33
34// WithStateLess sets the server to stateless mode.
35// If true, the server will manage no session information. Every request will be treated
36// as a new session. No session id returned to the client.
37// The default is false.
38//
39// Notice: This is a convenience method. It's identical to set WithSessionIdManager option
40// to StatelessSessionIdManager.
41func WithStateLess(stateLess bool) StreamableHTTPOption {
42 return func(s *StreamableHTTPServer) {
43 if stateLess {
44 s.sessionIdManager = &StatelessSessionIdManager{}
45 }
46 }
47}
48
49// WithSessionIdManager sets a custom session id generator for the server.
50// By default, the server will use SimpleStatefulSessionIdGenerator, which generates
51// session ids with uuid, and it's insecure.
52// Notice: it will override the WithStateLess option.
53func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption {
54 return func(s *StreamableHTTPServer) {
55 s.sessionIdManager = manager
56 }
57}
58
59// WithHeartbeatInterval sets the heartbeat interval. Positive interval means the
60// server will send a heartbeat to the client through the GET connection, to keep
61// the connection alive from being closed by the network infrastructure (e.g.
62// gateways). If the client does not establish a GET connection, it has no
63// effect. The default is not to send heartbeats.
64func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption {
65 return func(s *StreamableHTTPServer) {
66 s.listenHeartbeatInterval = interval
67 }
68}
69
70// WithHTTPContextFunc sets a function that will be called to customise the context
71// to the server using the incoming request.
72// This can be used to inject context values from headers, for example.
73func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption {
74 return func(s *StreamableHTTPServer) {
75 s.contextFunc = fn
76 }
77}
78
79// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer.
80// NOTE: When providing a custom HTTP server, you must handle routing yourself
81// If routing is not set up, the server will start but won't handle any MCP requests.
82func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption {
83 return func(s *StreamableHTTPServer) {
84 s.httpServer = srv
85 }
86}
87
88// WithLogger sets the logger for the server
89func WithLogger(logger util.Logger) StreamableHTTPOption {
90 return func(s *StreamableHTTPServer) {
91 s.logger = logger
92 }
93}
94
95// StreamableHTTPServer implements a Streamable-http based MCP server.
96// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
97// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
98//
99// Usage:
100//
101// server := NewStreamableHTTPServer(mcpServer)
102// server.Start(":8080") // The final url for client is http://xxxx:8080/mcp by default
103//
104// or the server itself can be used as a http.Handler, which is convenient to
105// integrate with existing http servers, or advanced usage:
106//
107// handler := NewStreamableHTTPServer(mcpServer)
108// http.Handle("/streamable-http", handler)
109// http.ListenAndServe(":8080", nil)
110//
111// Notice:
112// Except for the GET handlers(listening), the POST handlers(request/notification) will
113// not trigger the session registration. So the methods like `SendNotificationToSpecificClient`
114// or `hooks.onRegisterSession` will not be triggered for POST messages.
115//
116// The current implementation does not support the following features from the specification:
117// - Batching of requests/notifications/responses in arrays.
118// - Stream Resumability
119type StreamableHTTPServer struct {
120 server *MCPServer
121 sessionTools *sessionToolsStore
122 sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64)
123
124 httpServer *http.Server
125 mu sync.RWMutex
126
127 endpointPath string
128 contextFunc HTTPContextFunc
129 sessionIdManager SessionIdManager
130 listenHeartbeatInterval time.Duration
131 logger util.Logger
132}
133
134// NewStreamableHTTPServer creates a new streamable-http server instance
135func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer {
136 s := &StreamableHTTPServer{
137 server: server,
138 sessionTools: newSessionToolsStore(),
139 endpointPath: "/mcp",
140 sessionIdManager: &InsecureStatefulSessionIdManager{},
141 logger: util.DefaultLogger(),
142 }
143
144 // Apply all options
145 for _, opt := range opts {
146 opt(s)
147 }
148 return s
149}
150
151// ServeHTTP implements the http.Handler interface.
152func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
153 switch r.Method {
154 case http.MethodPost:
155 s.handlePost(w, r)
156 case http.MethodGet:
157 s.handleGet(w, r)
158 case http.MethodDelete:
159 s.handleDelete(w, r)
160 default:
161 http.NotFound(w, r)
162 }
163}
164
165// Start begins serving the http server on the specified address and path
166// (endpointPath). like:
167//
168// s.Start(":8080")
169func (s *StreamableHTTPServer) Start(addr string) error {
170 s.mu.Lock()
171 if s.httpServer == nil {
172 mux := http.NewServeMux()
173 mux.Handle(s.endpointPath, s)
174 s.httpServer = &http.Server{
175 Addr: addr,
176 Handler: mux,
177 }
178 } else {
179 if s.httpServer.Addr == "" {
180 s.httpServer.Addr = addr
181 } else if s.httpServer.Addr != addr {
182 return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr)
183 }
184 }
185 srv := s.httpServer
186 s.mu.Unlock()
187
188 return srv.ListenAndServe()
189}
190
191// Shutdown gracefully stops the server, closing all active sessions
192// and shutting down the HTTP server.
193func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
194
195 // shutdown the server if needed (may use as a http.Handler)
196 s.mu.RLock()
197 srv := s.httpServer
198 s.mu.RUnlock()
199 if srv != nil {
200 return srv.Shutdown(ctx)
201 }
202 return nil
203}
204
205// --- internal methods ---
206
207const (
208 headerKeySessionID = "Mcp-Session-Id"
209)
210
211func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {
212 // post request carry request/notification message
213
214 // Check content type
215 contentType := r.Header.Get("Content-Type")
216 if contentType != "application/json" {
217 http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest)
218 return
219 }
220
221 // Check the request body is valid json, meanwhile, get the request Method
222 rawData, err := io.ReadAll(r.Body)
223 if err != nil {
224 s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err))
225 return
226 }
227 var baseMessage struct {
228 Method mcp.MCPMethod `json:"method"`
229 }
230 if err := json.Unmarshal(rawData, &baseMessage); err != nil {
231 s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json")
232 return
233 }
234 isInitializeRequest := baseMessage.Method == mcp.MethodInitialize
235
236 // Prepare the session for the mcp server
237 // The session is ephemeral. Its life is the same as the request. It's only created
238 // for interaction with the mcp server.
239 var sessionID string
240 if isInitializeRequest {
241 // generate a new one for initialize request
242 sessionID = s.sessionIdManager.Generate()
243 } else {
244 // Get session ID from header.
245 // Stateful servers need the client to carry the session ID.
246 sessionID = r.Header.Get(headerKeySessionID)
247 isTerminated, err := s.sessionIdManager.Validate(sessionID)
248 if err != nil {
249 http.Error(w, "Invalid session ID", http.StatusBadRequest)
250 return
251 }
252 if isTerminated {
253 http.Error(w, "Session terminated", http.StatusNotFound)
254 return
255 }
256 }
257
258 session := newStreamableHttpSession(sessionID, s.sessionTools)
259
260 // Set the client context before handling the message
261 ctx := s.server.WithContext(r.Context(), session)
262 if s.contextFunc != nil {
263 ctx = s.contextFunc(ctx, r)
264 }
265
266 // handle potential notifications
267 mu := sync.Mutex{}
268 upgradedHeader := false
269 done := make(chan struct{})
270
271 go func() {
272 for {
273 select {
274 case nt := <-session.notificationChannel:
275 func() {
276 mu.Lock()
277 defer mu.Unlock()
278 // if the done chan is closed, as the request is terminated, just return
279 select {
280 case <-done:
281 return
282 default:
283 }
284 defer func() {
285 flusher, ok := w.(http.Flusher)
286 if ok {
287 flusher.Flush()
288 }
289 }()
290
291 // if there's notifications, upgradedHeader to SSE response
292 if !upgradedHeader {
293 w.Header().Set("Content-Type", "text/event-stream")
294 w.Header().Set("Connection", "keep-alive")
295 w.Header().Set("Cache-Control", "no-cache")
296 w.WriteHeader(http.StatusAccepted)
297 upgradedHeader = true
298 }
299 err := writeSSEEvent(w, nt)
300 if err != nil {
301 s.logger.Errorf("Failed to write SSE event: %v", err)
302 return
303 }
304 }()
305 case <-done:
306 return
307 case <-ctx.Done():
308 return
309 }
310 }
311 }()
312
313 // Process message through MCPServer
314 response := s.server.HandleMessage(ctx, rawData)
315 if response == nil {
316 // For notifications, just send 202 Accepted with no body
317 w.WriteHeader(http.StatusAccepted)
318 return
319 }
320
321 // Write response
322 mu.Lock()
323 defer mu.Unlock()
324 // close the done chan before unlock
325 defer close(done)
326 if ctx.Err() != nil {
327 return
328 }
329 // If client-server communication already upgraded to SSE stream
330 if session.upgradeToSSE.Load() {
331 if !upgradedHeader {
332 w.Header().Set("Content-Type", "text/event-stream")
333 w.Header().Set("Connection", "keep-alive")
334 w.Header().Set("Cache-Control", "no-cache")
335 w.WriteHeader(http.StatusAccepted)
336 upgradedHeader = true
337 }
338 if err := writeSSEEvent(w, response); err != nil {
339 s.logger.Errorf("Failed to write final SSE response event: %v", err)
340 }
341 } else {
342 w.Header().Set("Content-Type", "application/json")
343 if isInitializeRequest && sessionID != "" {
344 // send the session ID back to the client
345 w.Header().Set(headerKeySessionID, sessionID)
346 }
347 w.WriteHeader(http.StatusOK)
348 err := json.NewEncoder(w).Encode(response)
349 if err != nil {
350 s.logger.Errorf("Failed to write response: %v", err)
351 }
352 }
353}
354
355func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {
356 // get request is for listening to notifications
357 // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
358
359 sessionID := r.Header.Get(headerKeySessionID)
360 // the specification didn't say we should validate the session id
361
362 if sessionID == "" {
363 // It's a stateless server,
364 // but the MCP server requires a unique ID for registering, so we use a random one
365 sessionID = uuid.New().String()
366 }
367
368 session := newStreamableHttpSession(sessionID, s.sessionTools)
369 if err := s.server.RegisterSession(r.Context(), session); err != nil {
370 http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest)
371 return
372 }
373 defer s.server.UnregisterSession(r.Context(), sessionID)
374
375 // Set the client context before handling the message
376 w.Header().Set("Content-Type", "text/event-stream")
377 w.Header().Set("Cache-Control", "no-cache")
378 w.Header().Set("Connection", "keep-alive")
379 w.WriteHeader(http.StatusOK)
380
381 flusher, ok := w.(http.Flusher)
382 if !ok {
383 http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
384 return
385 }
386 flusher.Flush()
387
388 // Start notification handler for this session
389 done := make(chan struct{})
390 defer close(done)
391 writeChan := make(chan any, 16)
392
393 go func() {
394 for {
395 select {
396 case nt := <-session.notificationChannel:
397 select {
398 case writeChan <- &nt:
399 case <-done:
400 return
401 }
402 case <-done:
403 return
404 }
405 }
406 }()
407
408 if s.listenHeartbeatInterval > 0 {
409 // heartbeat to keep the connection alive
410 go func() {
411 ticker := time.NewTicker(s.listenHeartbeatInterval)
412 defer ticker.Stop()
413 for {
414 select {
415 case <-ticker.C:
416 message := mcp.JSONRPCRequest{
417 JSONRPC: "2.0",
418 ID: mcp.NewRequestId(s.nextRequestID(sessionID)),
419 Request: mcp.Request{
420 Method: "ping",
421 },
422 }
423 select {
424 case writeChan <- message:
425 case <-done:
426 return
427 }
428 case <-done:
429 return
430 }
431 }
432 }()
433 }
434
435 // Keep the connection open until the client disconnects
436 //
437 // There's will a Available() check when handler ends, and it maybe race with Flush(),
438 // so we use a separate channel to send the data, inteading of flushing directly in other goroutine.
439 for {
440 select {
441 case data := <-writeChan:
442 if data == nil {
443 continue
444 }
445 if err := writeSSEEvent(w, data); err != nil {
446 s.logger.Errorf("Failed to write SSE event: %v", err)
447 return
448 }
449 flusher.Flush()
450 case <-r.Context().Done():
451 return
452 }
453 }
454}
455
456func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) {
457 // delete request terminate the session
458 sessionID := r.Header.Get(headerKeySessionID)
459 notAllowed, err := s.sessionIdManager.Terminate(sessionID)
460 if err != nil {
461 http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError)
462 return
463 }
464 if notAllowed {
465 http.Error(w, "Session termination not allowed", http.StatusMethodNotAllowed)
466 return
467 }
468
469 // remove the session relateddata from the sessionToolsStore
470 s.sessionTools.delete(sessionID)
471
472 // remove current session's requstID information
473 s.sessionRequestIDs.Delete(sessionID)
474
475 w.WriteHeader(http.StatusOK)
476}
477
478func writeSSEEvent(w io.Writer, data any) error {
479 jsonData, err := json.Marshal(data)
480 if err != nil {
481 return fmt.Errorf("failed to marshal data: %w", err)
482 }
483 _, err = fmt.Fprintf(w, "event: message\ndata: %s\n\n", jsonData)
484 if err != nil {
485 return fmt.Errorf("failed to write SSE event: %w", err)
486 }
487 return nil
488}
489
490// writeJSONRPCError writes a JSON-RPC error response with the given error details.
491func (s *StreamableHTTPServer) writeJSONRPCError(
492 w http.ResponseWriter,
493 id any,
494 code int,
495 message string,
496) {
497 response := createErrorResponse(id, code, message)
498 w.Header().Set("Content-Type", "application/json")
499 w.WriteHeader(http.StatusBadRequest)
500 err := json.NewEncoder(w).Encode(response)
501 if err != nil {
502 s.logger.Errorf("Failed to write JSONRPCError: %v", err)
503 }
504}
505
506// nextRequestID gets the next incrementing requestID for the current session
507func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 {
508 actual, _ := s.sessionRequestIDs.LoadOrStore(sessionID, new(atomic.Int64))
509 counter := actual.(*atomic.Int64)
510 return counter.Add(1)
511}
512
513// --- session ---
514
515type sessionToolsStore struct {
516 mu sync.RWMutex
517 tools map[string]map[string]ServerTool // sessionID -> toolName -> tool
518}
519
520func newSessionToolsStore() *sessionToolsStore {
521 return &sessionToolsStore{
522 tools: make(map[string]map[string]ServerTool),
523 }
524}
525
526func (s *sessionToolsStore) get(sessionID string) map[string]ServerTool {
527 s.mu.RLock()
528 defer s.mu.RUnlock()
529 return s.tools[sessionID]
530}
531
532func (s *sessionToolsStore) set(sessionID string, tools map[string]ServerTool) {
533 s.mu.Lock()
534 defer s.mu.Unlock()
535 s.tools[sessionID] = tools
536}
537
538func (s *sessionToolsStore) delete(sessionID string) {
539 s.mu.Lock()
540 defer s.mu.Unlock()
541 delete(s.tools, sessionID)
542}
543
544// streamableHttpSession is a session for streamable-http transport
545// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler.
546// When in GET handlers(listening), it's a real session, and will be registered in the MCP server.
547type streamableHttpSession struct {
548 sessionID string
549 notificationChannel chan mcp.JSONRPCNotification // server -> client notifications
550 tools *sessionToolsStore
551 upgradeToSSE atomic.Bool
552}
553
554func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession {
555 return &streamableHttpSession{
556 sessionID: sessionID,
557 notificationChannel: make(chan mcp.JSONRPCNotification, 100),
558 tools: toolStore,
559 }
560}
561
562func (s *streamableHttpSession) SessionID() string {
563 return s.sessionID
564}
565
566func (s *streamableHttpSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
567 return s.notificationChannel
568}
569
570func (s *streamableHttpSession) Initialize() {
571 // do nothing
572 // the session is ephemeral, no real initialized action needed
573}
574
575func (s *streamableHttpSession) Initialized() bool {
576 // the session is ephemeral, no real initialized action needed
577 return true
578}
579
580var _ ClientSession = (*streamableHttpSession)(nil)
581
582func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool {
583 return s.tools.get(s.sessionID)
584}
585
586func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) {
587 s.tools.set(s.sessionID, tools)
588}
589
590var _ SessionWithTools = (*streamableHttpSession)(nil)
591
592func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {
593 s.upgradeToSSE.Store(true)
594}
595
596var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
597
598// --- session id manager ---
599
600type SessionIdManager interface {
601 Generate() string
602 // Validate checks if a session ID is valid and not terminated.
603 // Returns isTerminated=true if the ID is valid but belongs to a terminated session.
604 // Returns err!=nil if the ID format is invalid or lookup failed.
605 Validate(sessionID string) (isTerminated bool, err error)
606 // Terminate marks a session ID as terminated.
607 // Returns isNotAllowed=true if the server policy prevents client termination.
608 // Returns err!=nil if the ID is invalid or termination failed.
609 Terminate(sessionID string) (isNotAllowed bool, err error)
610}
611
612// StatelessSessionIdManager does nothing, which means it has no session management, which is stateless.
613type StatelessSessionIdManager struct{}
614
615func (s *StatelessSessionIdManager) Generate() string {
616 return ""
617}
618func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
619 // In stateless mode, ignore session IDs completely - don't validate or reject them
620 return false, nil
621}
622func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
623 return false, nil
624}
625
626// InsecureStatefulSessionIdManager generate id with uuid
627// It won't validate the id indeed, so it could be fake.
628// For more secure session id, use a more complex generator, like a JWT.
629type InsecureStatefulSessionIdManager struct{}
630
631const idPrefix = "mcp-session-"
632
633func (s *InsecureStatefulSessionIdManager) Generate() string {
634 return idPrefix + uuid.New().String()
635}
636func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
637 // validate the session id is a valid uuid
638 if !strings.HasPrefix(sessionID, idPrefix) {
639 return false, fmt.Errorf("invalid session id: %s", sessionID)
640 }
641 if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil {
642 return false, fmt.Errorf("invalid session id: %s", sessionID)
643 }
644 return false, nil
645}
646func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
647 return false, nil
648}
649
650// NewTestStreamableHTTPServer creates a test server for testing purposes
651func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *httptest.Server {
652 sseServer := NewStreamableHTTPServer(server, opts...)
653 testServer := httptest.NewServer(sseServer)
654 return testServer
655}