1package server
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "sync"
12 "sync/atomic"
13 "time"
14
15 "github.com/google/uuid"
16 "github.com/mark3labs/mcp-go/mcp"
17 "github.com/mark3labs/mcp-go/util"
18)
19
20// StreamableHTTPOption defines a function type for configuring StreamableHTTPServer
21type StreamableHTTPOption func(*StreamableHTTPServer)
22
23// WithEndpointPath sets the endpoint path for the server.
24// The default is "/mcp".
25// It's only works for `Start` method. When used as a http.Handler, it has no effect.
26func WithEndpointPath(endpointPath string) StreamableHTTPOption {
27 return func(s *StreamableHTTPServer) {
28 // Normalize the endpoint path to ensure it starts with a slash and doesn't end with one
29 normalizedPath := "/" + strings.Trim(endpointPath, "/")
30 s.endpointPath = normalizedPath
31 }
32}
33
34// WithStateLess sets the server to stateless mode.
35// If true, the server will manage no session information. Every request will be treated
36// as a new session. No session id returned to the client.
37// The default is false.
38//
39// Notice: This is a convenience method. It's identical to set WithSessionIdManager option
40// to StatelessSessionIdManager.
41func WithStateLess(stateLess bool) StreamableHTTPOption {
42 return func(s *StreamableHTTPServer) {
43 s.sessionIdManager = &StatelessSessionIdManager{}
44 }
45}
46
47// WithSessionIdManager sets a custom session id generator for the server.
48// By default, the server will use SimpleStatefulSessionIdGenerator, which generates
49// session ids with uuid, and it's insecure.
50// Notice: it will override the WithStateLess option.
51func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption {
52 return func(s *StreamableHTTPServer) {
53 s.sessionIdManager = manager
54 }
55}
56
57// WithHeartbeatInterval sets the heartbeat interval. Positive interval means the
58// server will send a heartbeat to the client through the GET connection, to keep
59// the connection alive from being closed by the network infrastructure (e.g.
60// gateways). If the client does not establish a GET connection, it has no
61// effect. The default is not to send heartbeats.
62func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption {
63 return func(s *StreamableHTTPServer) {
64 s.listenHeartbeatInterval = interval
65 }
66}
67
68// WithHTTPContextFunc sets a function that will be called to customise the context
69// to the server using the incoming request.
70// This can be used to inject context values from headers, for example.
71func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption {
72 return func(s *StreamableHTTPServer) {
73 s.contextFunc = fn
74 }
75}
76
77// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer.
78// NOTE: When providing a custom HTTP server, you must handle routing yourself
79// If routing is not set up, the server will start but won't handle any MCP requests.
80func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption {
81 return func(s *StreamableHTTPServer) {
82 s.httpServer = srv
83 }
84}
85
86// WithLogger sets the logger for the server
87func WithLogger(logger util.Logger) StreamableHTTPOption {
88 return func(s *StreamableHTTPServer) {
89 s.logger = logger
90 }
91}
92
93// StreamableHTTPServer implements a Streamable-http based MCP server.
94// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
95// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
96//
97// Usage:
98//
99// server := NewStreamableHTTPServer(mcpServer)
100// server.Start(":8080") // The final url for client is http://xxxx:8080/mcp by default
101//
102// or the server itself can be used as a http.Handler, which is convenient to
103// integrate with existing http servers, or advanced usage:
104//
105// handler := NewStreamableHTTPServer(mcpServer)
106// http.Handle("/streamable-http", handler)
107// http.ListenAndServe(":8080", nil)
108//
109// Notice:
110// Except for the GET handlers(listening), the POST handlers(request/notification) will
111// not trigger the session registration. So the methods like `SendNotificationToSpecificClient`
112// or `hooks.onRegisterSession` will not be triggered for POST messages.
113//
114// The current implementation does not support the following features from the specification:
115// - Batching of requests/notifications/responses in arrays.
116// - Stream Resumability
117type StreamableHTTPServer struct {
118 server *MCPServer
119 sessionTools *sessionToolsStore
120 sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64)
121
122 httpServer *http.Server
123 mu sync.RWMutex
124
125 endpointPath string
126 contextFunc HTTPContextFunc
127 sessionIdManager SessionIdManager
128 listenHeartbeatInterval time.Duration
129 logger util.Logger
130}
131
132// NewStreamableHTTPServer creates a new streamable-http server instance
133func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer {
134 s := &StreamableHTTPServer{
135 server: server,
136 sessionTools: newSessionToolsStore(),
137 endpointPath: "/mcp",
138 sessionIdManager: &InsecureStatefulSessionIdManager{},
139 logger: util.DefaultLogger(),
140 }
141
142 // Apply all options
143 for _, opt := range opts {
144 opt(s)
145 }
146 return s
147}
148
149// ServeHTTP implements the http.Handler interface.
150func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
151 switch r.Method {
152 case http.MethodPost:
153 s.handlePost(w, r)
154 case http.MethodGet:
155 s.handleGet(w, r)
156 case http.MethodDelete:
157 s.handleDelete(w, r)
158 default:
159 http.NotFound(w, r)
160 }
161}
162
163// Start begins serving the http server on the specified address and path
164// (endpointPath). like:
165//
166// s.Start(":8080")
167func (s *StreamableHTTPServer) Start(addr string) error {
168 s.mu.Lock()
169 if s.httpServer == nil {
170 mux := http.NewServeMux()
171 mux.Handle(s.endpointPath, s)
172 s.httpServer = &http.Server{
173 Addr: addr,
174 Handler: mux,
175 }
176 } else {
177 if s.httpServer.Addr == "" {
178 s.httpServer.Addr = addr
179 } else if s.httpServer.Addr != addr {
180 return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr)
181 }
182 }
183 srv := s.httpServer
184 s.mu.Unlock()
185
186 return srv.ListenAndServe()
187}
188
189// Shutdown gracefully stops the server, closing all active sessions
190// and shutting down the HTTP server.
191func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
192
193 // shutdown the server if needed (may use as a http.Handler)
194 s.mu.RLock()
195 srv := s.httpServer
196 s.mu.RUnlock()
197 if srv != nil {
198 return srv.Shutdown(ctx)
199 }
200 return nil
201}
202
203// --- internal methods ---
204
205const (
206 headerKeySessionID = "Mcp-Session-Id"
207)
208
209func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {
210 // post request carry request/notification message
211
212 // Check content type
213 contentType := r.Header.Get("Content-Type")
214 if contentType != "application/json" {
215 http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest)
216 return
217 }
218
219 // Check the request body is valid json, meanwhile, get the request Method
220 rawData, err := io.ReadAll(r.Body)
221 if err != nil {
222 s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err))
223 return
224 }
225 var baseMessage struct {
226 Method mcp.MCPMethod `json:"method"`
227 }
228 if err := json.Unmarshal(rawData, &baseMessage); err != nil {
229 s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json")
230 return
231 }
232 isInitializeRequest := baseMessage.Method == mcp.MethodInitialize
233
234 // Prepare the session for the mcp server
235 // The session is ephemeral. Its life is the same as the request. It's only created
236 // for interaction with the mcp server.
237 var sessionID string
238 if isInitializeRequest {
239 // generate a new one for initialize request
240 sessionID = s.sessionIdManager.Generate()
241 } else {
242 // Get session ID from header.
243 // Stateful servers need the client to carry the session ID.
244 sessionID = r.Header.Get(headerKeySessionID)
245 isTerminated, err := s.sessionIdManager.Validate(sessionID)
246 if err != nil {
247 http.Error(w, "Invalid session ID", http.StatusBadRequest)
248 return
249 }
250 if isTerminated {
251 http.Error(w, "Session terminated", http.StatusNotFound)
252 return
253 }
254 }
255
256 session := newStreamableHttpSession(sessionID, s.sessionTools)
257
258 // Set the client context before handling the message
259 ctx := s.server.WithContext(r.Context(), session)
260 if s.contextFunc != nil {
261 ctx = s.contextFunc(ctx, r)
262 }
263
264 // handle potential notifications
265 mu := sync.Mutex{}
266 upgradedHeader := false
267 done := make(chan struct{})
268
269 go func() {
270 for {
271 select {
272 case nt := <-session.notificationChannel:
273 func() {
274 mu.Lock()
275 defer mu.Unlock()
276 // if the done chan is closed, as the request is terminated, just return
277 select {
278 case <-done:
279 return
280 default:
281 }
282 defer func() {
283 flusher, ok := w.(http.Flusher)
284 if ok {
285 flusher.Flush()
286 }
287 }()
288
289 // if there's notifications, upgradedHeader to SSE response
290 if !upgradedHeader {
291 w.Header().Set("Content-Type", "text/event-stream")
292 w.Header().Set("Connection", "keep-alive")
293 w.Header().Set("Cache-Control", "no-cache")
294 w.WriteHeader(http.StatusAccepted)
295 upgradedHeader = true
296 }
297 err := writeSSEEvent(w, nt)
298 if err != nil {
299 s.logger.Errorf("Failed to write SSE event: %v", err)
300 return
301 }
302 }()
303 case <-done:
304 return
305 case <-ctx.Done():
306 return
307 }
308 }
309 }()
310
311 // Process message through MCPServer
312 response := s.server.HandleMessage(ctx, rawData)
313 if response == nil {
314 // For notifications, just send 202 Accepted with no body
315 w.WriteHeader(http.StatusAccepted)
316 return
317 }
318
319 // Write response
320 mu.Lock()
321 defer mu.Unlock()
322 // close the done chan before unlock
323 defer close(done)
324 if ctx.Err() != nil {
325 return
326 }
327 // If client-server communication already upgraded to SSE stream
328 if session.upgradeToSSE.Load() {
329 if !upgradedHeader {
330 w.Header().Set("Content-Type", "text/event-stream")
331 w.Header().Set("Connection", "keep-alive")
332 w.Header().Set("Cache-Control", "no-cache")
333 w.WriteHeader(http.StatusAccepted)
334 upgradedHeader = true
335 }
336 if err := writeSSEEvent(w, response); err != nil {
337 s.logger.Errorf("Failed to write final SSE response event: %v", err)
338 }
339 } else {
340 w.Header().Set("Content-Type", "application/json")
341 if isInitializeRequest && sessionID != "" {
342 // send the session ID back to the client
343 w.Header().Set(headerKeySessionID, sessionID)
344 }
345 w.WriteHeader(http.StatusOK)
346 err := json.NewEncoder(w).Encode(response)
347 if err != nil {
348 s.logger.Errorf("Failed to write response: %v", err)
349 }
350 }
351}
352
353func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {
354 // get request is for listening to notifications
355 // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
356
357 sessionID := r.Header.Get(headerKeySessionID)
358 // the specification didn't say we should validate the session id
359
360 if sessionID == "" {
361 // It's a stateless server,
362 // but the MCP server requires a unique ID for registering, so we use a random one
363 sessionID = uuid.New().String()
364 }
365
366 session := newStreamableHttpSession(sessionID, s.sessionTools)
367 if err := s.server.RegisterSession(r.Context(), session); err != nil {
368 http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest)
369 return
370 }
371 defer s.server.UnregisterSession(r.Context(), sessionID)
372
373 // Set the client context before handling the message
374 w.Header().Set("Content-Type", "text/event-stream")
375 w.Header().Set("Cache-Control", "no-cache")
376 w.Header().Set("Connection", "keep-alive")
377 w.WriteHeader(http.StatusAccepted)
378
379 flusher, ok := w.(http.Flusher)
380 if !ok {
381 http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
382 return
383 }
384 flusher.Flush()
385
386 // Start notification handler for this session
387 done := make(chan struct{})
388 defer close(done)
389 writeChan := make(chan any, 16)
390
391 go func() {
392 for {
393 select {
394 case nt := <-session.notificationChannel:
395 select {
396 case writeChan <- &nt:
397 case <-done:
398 return
399 }
400 case <-done:
401 return
402 }
403 }
404 }()
405
406 if s.listenHeartbeatInterval > 0 {
407 // heartbeat to keep the connection alive
408 go func() {
409 ticker := time.NewTicker(s.listenHeartbeatInterval)
410 defer ticker.Stop()
411 for {
412 select {
413 case <-ticker.C:
414 message := mcp.JSONRPCRequest{
415 JSONRPC: "2.0",
416 ID: mcp.NewRequestId(s.nextRequestID(sessionID)),
417 Request: mcp.Request{
418 Method: "ping",
419 },
420 }
421 select {
422 case writeChan <- message:
423 case <-done:
424 return
425 }
426 case <-done:
427 return
428 }
429 }
430 }()
431 }
432
433 // Keep the connection open until the client disconnects
434 //
435 // There's will a Available() check when handler ends, and it maybe race with Flush(),
436 // so we use a separate channel to send the data, inteading of flushing directly in other goroutine.
437 for {
438 select {
439 case data := <-writeChan:
440 if data == nil {
441 continue
442 }
443 if err := writeSSEEvent(w, data); err != nil {
444 s.logger.Errorf("Failed to write SSE event: %v", err)
445 return
446 }
447 flusher.Flush()
448 case <-r.Context().Done():
449 return
450 }
451 }
452}
453
454func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) {
455 // delete request terminate the session
456 sessionID := r.Header.Get(headerKeySessionID)
457 notAllowed, err := s.sessionIdManager.Terminate(sessionID)
458 if err != nil {
459 http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError)
460 return
461 }
462 if notAllowed {
463 http.Error(w, "Session termination not allowed", http.StatusMethodNotAllowed)
464 return
465 }
466
467 // remove the session relateddata from the sessionToolsStore
468 s.sessionTools.delete(sessionID)
469
470 // remove current session's requstID information
471 s.sessionRequestIDs.Delete(sessionID)
472
473 w.WriteHeader(http.StatusOK)
474}
475
476func writeSSEEvent(w io.Writer, data any) error {
477 jsonData, err := json.Marshal(data)
478 if err != nil {
479 return fmt.Errorf("failed to marshal data: %w", err)
480 }
481 _, err = fmt.Fprintf(w, "event: message\ndata: %s\n\n", jsonData)
482 if err != nil {
483 return fmt.Errorf("failed to write SSE event: %w", err)
484 }
485 return nil
486}
487
488// writeJSONRPCError writes a JSON-RPC error response with the given error details.
489func (s *StreamableHTTPServer) writeJSONRPCError(
490 w http.ResponseWriter,
491 id any,
492 code int,
493 message string,
494) {
495 response := createErrorResponse(id, code, message)
496 w.Header().Set("Content-Type", "application/json")
497 w.WriteHeader(http.StatusBadRequest)
498 err := json.NewEncoder(w).Encode(response)
499 if err != nil {
500 s.logger.Errorf("Failed to write JSONRPCError: %v", err)
501 }
502}
503
504// nextRequestID gets the next incrementing requestID for the current session
505func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 {
506 actual, _ := s.sessionRequestIDs.LoadOrStore(sessionID, new(atomic.Int64))
507 counter := actual.(*atomic.Int64)
508 return counter.Add(1)
509}
510
511// --- session ---
512
513type sessionToolsStore struct {
514 mu sync.RWMutex
515 tools map[string]map[string]ServerTool // sessionID -> toolName -> tool
516}
517
518func newSessionToolsStore() *sessionToolsStore {
519 return &sessionToolsStore{
520 tools: make(map[string]map[string]ServerTool),
521 }
522}
523
524func (s *sessionToolsStore) get(sessionID string) map[string]ServerTool {
525 s.mu.RLock()
526 defer s.mu.RUnlock()
527 return s.tools[sessionID]
528}
529
530func (s *sessionToolsStore) set(sessionID string, tools map[string]ServerTool) {
531 s.mu.Lock()
532 defer s.mu.Unlock()
533 s.tools[sessionID] = tools
534}
535
536func (s *sessionToolsStore) delete(sessionID string) {
537 s.mu.Lock()
538 defer s.mu.Unlock()
539 delete(s.tools, sessionID)
540}
541
542// streamableHttpSession is a session for streamable-http transport
543// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler.
544// When in GET handlers(listening), it's a real session, and will be registered in the MCP server.
545type streamableHttpSession struct {
546 sessionID string
547 notificationChannel chan mcp.JSONRPCNotification // server -> client notifications
548 tools *sessionToolsStore
549 upgradeToSSE atomic.Bool
550}
551
552func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession {
553 return &streamableHttpSession{
554 sessionID: sessionID,
555 notificationChannel: make(chan mcp.JSONRPCNotification, 100),
556 tools: toolStore,
557 }
558}
559
560func (s *streamableHttpSession) SessionID() string {
561 return s.sessionID
562}
563
564func (s *streamableHttpSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
565 return s.notificationChannel
566}
567
568func (s *streamableHttpSession) Initialize() {
569 // do nothing
570 // the session is ephemeral, no real initialized action needed
571}
572
573func (s *streamableHttpSession) Initialized() bool {
574 // the session is ephemeral, no real initialized action needed
575 return true
576}
577
578var _ ClientSession = (*streamableHttpSession)(nil)
579
580func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool {
581 return s.tools.get(s.sessionID)
582}
583
584func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) {
585 s.tools.set(s.sessionID, tools)
586}
587
588var _ SessionWithTools = (*streamableHttpSession)(nil)
589
590func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {
591 s.upgradeToSSE.Store(true)
592}
593
594var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
595
596// --- session id manager ---
597
598type SessionIdManager interface {
599 Generate() string
600 // Validate checks if a session ID is valid and not terminated.
601 // Returns isTerminated=true if the ID is valid but belongs to a terminated session.
602 // Returns err!=nil if the ID format is invalid or lookup failed.
603 Validate(sessionID string) (isTerminated bool, err error)
604 // Terminate marks a session ID as terminated.
605 // Returns isNotAllowed=true if the server policy prevents client termination.
606 // Returns err!=nil if the ID is invalid or termination failed.
607 Terminate(sessionID string) (isNotAllowed bool, err error)
608}
609
610// StatelessSessionIdManager does nothing, which means it has no session management, which is stateless.
611type StatelessSessionIdManager struct{}
612
613func (s *StatelessSessionIdManager) Generate() string {
614 return ""
615}
616func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
617 // In stateless mode, ignore session IDs completely - don't validate or reject them
618 return false, nil
619}
620func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
621 return false, nil
622}
623
624// InsecureStatefulSessionIdManager generate id with uuid
625// It won't validate the id indeed, so it could be fake.
626// For more secure session id, use a more complex generator, like a JWT.
627type InsecureStatefulSessionIdManager struct{}
628
629const idPrefix = "mcp-session-"
630
631func (s *InsecureStatefulSessionIdManager) Generate() string {
632 return idPrefix + uuid.New().String()
633}
634func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
635 // validate the session id is a valid uuid
636 if !strings.HasPrefix(sessionID, idPrefix) {
637 return false, fmt.Errorf("invalid session id: %s", sessionID)
638 }
639 if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil {
640 return false, fmt.Errorf("invalid session id: %s", sessionID)
641 }
642 return false, nil
643}
644func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
645 return false, nil
646}
647
648// NewTestStreamableHTTPServer creates a test server for testing purposes
649func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *httptest.Server {
650 sseServer := NewStreamableHTTPServer(server, opts...)
651 testServer := httptest.NewServer(sseServer)
652 return testServer
653}