1package transport
2
3import (
4 "bufio"
5 "bytes"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "mime"
12 "net/http"
13 "net/url"
14 "strings"
15 "sync"
16 "sync/atomic"
17 "time"
18
19 "github.com/mark3labs/mcp-go/mcp"
20 "github.com/mark3labs/mcp-go/util"
21)
22
23type StreamableHTTPCOption func(*StreamableHTTP)
24
25// WithContinuousListening enables receiving server-to-client notifications when no request is in flight.
26// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification),
27// you should enable this option.
28//
29// It will establish a standalone long-live GET HTTP connection to the server.
30// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
31// NOTICE: Even enabled, the server may not support this feature.
32func WithContinuousListening() StreamableHTTPCOption {
33 return func(sc *StreamableHTTP) {
34 sc.getListeningEnabled = true
35 }
36}
37
38// WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport.
39func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption {
40 return func(sc *StreamableHTTP) {
41 sc.httpClient = client
42 }
43}
44
45func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
46 return func(sc *StreamableHTTP) {
47 sc.headers = headers
48 }
49}
50
51func WithHTTPHeaderFunc(headerFunc HTTPHeaderFunc) StreamableHTTPCOption {
52 return func(sc *StreamableHTTP) {
53 sc.headerFunc = headerFunc
54 }
55}
56
57// WithHTTPTimeout sets the timeout for a HTTP request and stream.
58func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
59 return func(sc *StreamableHTTP) {
60 sc.httpClient.Timeout = timeout
61 }
62}
63
64// WithHTTPOAuth enables OAuth authentication for the client.
65func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
66 return func(sc *StreamableHTTP) {
67 sc.oauthHandler = NewOAuthHandler(config)
68 }
69}
70
71func WithLogger(logger util.Logger) StreamableHTTPCOption {
72 return func(sc *StreamableHTTP) {
73 sc.logger = logger
74 }
75}
76
77// WithSession creates a client with a pre-configured session
78func WithSession(sessionID string) StreamableHTTPCOption {
79 return func(sc *StreamableHTTP) {
80 sc.sessionID.Store(sessionID)
81 }
82}
83
84// StreamableHTTP implements Streamable HTTP transport.
85//
86// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
87// The HTTP response body can either be a single JSON-RPC response,
88// or an upgraded SSE stream that concludes with a JSON-RPC response for the same request.
89//
90// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
91//
92// The current implementation does not support the following features:
93// - batching
94// - resuming stream
95// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
96// - server -> client request
97type StreamableHTTP struct {
98 serverURL *url.URL
99 httpClient *http.Client
100 headers map[string]string
101 headerFunc HTTPHeaderFunc
102 logger util.Logger
103 getListeningEnabled bool
104
105 sessionID atomic.Value // string
106
107 initialized chan struct{}
108 initializedOnce sync.Once
109
110 notificationHandler func(mcp.JSONRPCNotification)
111 notifyMu sync.RWMutex
112
113 closed chan struct{}
114
115 // OAuth support
116 oauthHandler *OAuthHandler
117}
118
119// NewStreamableHTTP creates a new Streamable HTTP transport with the given server URL.
120// Returns an error if the URL is invalid.
121func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) {
122 parsedURL, err := url.Parse(serverURL)
123 if err != nil {
124 return nil, fmt.Errorf("invalid URL: %w", err)
125 }
126
127 smc := &StreamableHTTP{
128 serverURL: parsedURL,
129 httpClient: &http.Client{},
130 headers: make(map[string]string),
131 closed: make(chan struct{}),
132 logger: util.DefaultLogger(),
133 initialized: make(chan struct{}),
134 }
135 smc.sessionID.Store("") // set initial value to simplify later usage
136
137 for _, opt := range options {
138 if opt != nil {
139 opt(smc)
140 }
141 }
142
143 // If OAuth is configured, set the base URL for metadata discovery
144 if smc.oauthHandler != nil {
145 // Extract base URL from server URL for metadata discovery
146 baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
147 smc.oauthHandler.SetBaseURL(baseURL)
148 }
149
150 return smc, nil
151}
152
153// Start initiates the HTTP connection to the server.
154func (c *StreamableHTTP) Start(ctx context.Context) error {
155 // For Streamable HTTP, we don't need to establish a persistent connection by default
156 if c.getListeningEnabled {
157 go func() {
158 select {
159 case <-c.initialized:
160 ctx, cancel := c.contextAwareOfClientClose(ctx)
161 defer cancel()
162 c.listenForever(ctx)
163 case <-c.closed:
164 return
165 }
166 }()
167 }
168
169 return nil
170}
171
172// Close closes the all the HTTP connections to the server.
173func (c *StreamableHTTP) Close() error {
174 select {
175 case <-c.closed:
176 return nil
177 default:
178 }
179 // Cancel all in-flight requests
180 close(c.closed)
181
182 sessionId := c.sessionID.Load().(string)
183 if sessionId != "" {
184 c.sessionID.Store("")
185
186 // notify server session closed
187 go func() {
188 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
189 defer cancel()
190 req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
191 if err != nil {
192 c.logger.Errorf("failed to create close request: %v", err)
193 return
194 }
195 req.Header.Set(headerKeySessionID, sessionId)
196 res, err := c.httpClient.Do(req)
197 if err != nil {
198 c.logger.Errorf("failed to send close request: %v", err)
199 return
200 }
201 res.Body.Close()
202 }()
203 }
204
205 return nil
206}
207
208const (
209 headerKeySessionID = "Mcp-Session-Id"
210)
211
212// ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required
213var ErrOAuthAuthorizationRequired = errors.New("no valid token available, authorization required")
214
215// OAuthAuthorizationRequiredError is returned when OAuth authorization is required
216type OAuthAuthorizationRequiredError struct {
217 Handler *OAuthHandler
218}
219
220func (e *OAuthAuthorizationRequiredError) Error() string {
221 return ErrOAuthAuthorizationRequired.Error()
222}
223
224func (e *OAuthAuthorizationRequiredError) Unwrap() error {
225 return ErrOAuthAuthorizationRequired
226}
227
228// SendRequest sends a JSON-RPC request to the server and waits for a response.
229// Returns the raw JSON response message or an error if the request fails.
230func (c *StreamableHTTP) SendRequest(
231 ctx context.Context,
232 request JSONRPCRequest,
233) (*JSONRPCResponse, error) {
234
235 // Marshal request
236 requestBody, err := json.Marshal(request)
237 if err != nil {
238 return nil, fmt.Errorf("failed to marshal request: %w", err)
239 }
240
241 ctx, cancel := c.contextAwareOfClientClose(ctx)
242 defer cancel()
243
244 resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
245 if err != nil {
246 if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
247 // If the request is initialize, should not return a SessionTerminated error
248 // It should be a genuine endpoint-routing issue.
249 // ( Fall through to return StatusCode checking. )
250 } else {
251 return nil, fmt.Errorf("failed to send request: %w", err)
252 }
253 }
254 defer resp.Body.Close()
255
256 // Check if we got an error response
257 if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
258
259 // Handle OAuth unauthorized error
260 if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
261 return nil, &OAuthAuthorizationRequiredError{
262 Handler: c.oauthHandler,
263 }
264 }
265
266 // handle error response
267 var errResponse JSONRPCResponse
268 body, _ := io.ReadAll(resp.Body)
269 if err := json.Unmarshal(body, &errResponse); err == nil {
270 return &errResponse, nil
271 }
272 return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
273 }
274
275 if request.Method == string(mcp.MethodInitialize) {
276 // saved the received session ID in the response
277 // empty session ID is allowed
278 if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
279 c.sessionID.Store(sessionID)
280 }
281
282 c.initializedOnce.Do(func() {
283 close(c.initialized)
284 })
285 }
286
287 // Handle different response types
288 mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
289 switch mediaType {
290 case "application/json":
291 // Single response
292 var response JSONRPCResponse
293 if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
294 return nil, fmt.Errorf("failed to decode response: %w", err)
295 }
296
297 // should not be a notification
298 if response.ID.IsNil() {
299 return nil, fmt.Errorf("response should contain RPC id: %v", response)
300 }
301
302 return &response, nil
303
304 case "text/event-stream":
305 // Server is using SSE for streaming responses
306 return c.handleSSEResponse(ctx, resp.Body, false)
307
308 default:
309 return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
310 }
311}
312
313func (c *StreamableHTTP) sendHTTP(
314 ctx context.Context,
315 method string,
316 body io.Reader,
317 acceptType string,
318) (resp *http.Response, err error) {
319
320 // Create HTTP request
321 req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
322 if err != nil {
323 return nil, fmt.Errorf("failed to create request: %w", err)
324 }
325
326 // Set headers
327 req.Header.Set("Content-Type", "application/json")
328 req.Header.Set("Accept", acceptType)
329 sessionID := c.sessionID.Load().(string)
330 if sessionID != "" {
331 req.Header.Set(headerKeySessionID, sessionID)
332 }
333 for k, v := range c.headers {
334 req.Header.Set(k, v)
335 }
336
337 // Add OAuth authorization if configured
338 if c.oauthHandler != nil {
339 authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
340 if err != nil {
341 // If we get an authorization error, return a specific error that can be handled by the client
342 if err.Error() == "no valid token available, authorization required" {
343 return nil, &OAuthAuthorizationRequiredError{
344 Handler: c.oauthHandler,
345 }
346 }
347 return nil, fmt.Errorf("failed to get authorization header: %w", err)
348 }
349 req.Header.Set("Authorization", authHeader)
350 }
351
352 if c.headerFunc != nil {
353 for k, v := range c.headerFunc(ctx) {
354 req.Header.Set(k, v)
355 }
356 }
357
358 // Send request
359 resp, err = c.httpClient.Do(req)
360 if err != nil {
361 return nil, fmt.Errorf("failed to send request: %w", err)
362 }
363
364 // universal handling for session terminated
365 if resp.StatusCode == http.StatusNotFound {
366 c.sessionID.CompareAndSwap(sessionID, "")
367 return nil, ErrSessionTerminated
368 }
369
370 return resp, nil
371}
372
373// handleSSEResponse processes an SSE stream for a specific request.
374// It returns the final result for the request once received, or an error.
375// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening.
376func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) {
377
378 // Create a channel for this specific request
379 responseChan := make(chan *JSONRPCResponse, 1)
380
381 ctx, cancel := context.WithCancel(ctx)
382 defer cancel()
383
384 // Start a goroutine to process the SSE stream
385 go func() {
386 // only close responseChan after readingSSE()
387 defer close(responseChan)
388
389 c.readSSE(ctx, reader, func(event, data string) {
390
391 // (unsupported: batching)
392
393 var message JSONRPCResponse
394 if err := json.Unmarshal([]byte(data), &message); err != nil {
395 c.logger.Errorf("failed to unmarshal message: %v", err)
396 return
397 }
398
399 // Handle notification
400 if message.ID.IsNil() {
401 var notification mcp.JSONRPCNotification
402 if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
403 c.logger.Errorf("failed to unmarshal notification: %v", err)
404 return
405 }
406 c.notifyMu.RLock()
407 if c.notificationHandler != nil {
408 c.notificationHandler(notification)
409 }
410 c.notifyMu.RUnlock()
411 return
412 }
413
414 if !ignoreResponse {
415 responseChan <- &message
416 }
417 })
418 }()
419
420 // Wait for the response or context cancellation
421 select {
422 case response := <-responseChan:
423 if response == nil {
424 return nil, fmt.Errorf("unexpected nil response")
425 }
426 return response, nil
427 case <-ctx.Done():
428 return nil, ctx.Err()
429 }
430}
431
432// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair.
433// It will end when the reader is closed (or the context is done).
434func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) {
435 defer reader.Close()
436
437 br := bufio.NewReader(reader)
438 var event, data string
439
440 for {
441 select {
442 case <-ctx.Done():
443 return
444 default:
445 line, err := br.ReadString('\n')
446 if err != nil {
447 if err == io.EOF {
448 // Process any pending event before exit
449 if data != "" {
450 // If no event type is specified, use empty string (default event type)
451 if event == "" {
452 event = "message"
453 }
454 handler(event, data)
455 }
456 return
457 }
458 select {
459 case <-ctx.Done():
460 return
461 default:
462 c.logger.Errorf("SSE stream error: %v", err)
463 return
464 }
465 }
466
467 // Remove only newline markers
468 line = strings.TrimRight(line, "\r\n")
469 if line == "" {
470 // Empty line means end of event
471 if data != "" {
472 // If no event type is specified, use empty string (default event type)
473 if event == "" {
474 event = "message"
475 }
476 handler(event, data)
477 event = ""
478 data = ""
479 }
480 continue
481 }
482
483 if strings.HasPrefix(line, "event:") {
484 event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
485 } else if strings.HasPrefix(line, "data:") {
486 data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
487 }
488 }
489 }
490}
491
492func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
493
494 // Marshal request
495 requestBody, err := json.Marshal(notification)
496 if err != nil {
497 return fmt.Errorf("failed to marshal notification: %w", err)
498 }
499
500 // Create HTTP request
501 ctx, cancel := c.contextAwareOfClientClose(ctx)
502 defer cancel()
503
504 resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
505 if err != nil {
506 return fmt.Errorf("failed to send request: %w", err)
507 }
508 defer resp.Body.Close()
509
510 if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
511 // Handle OAuth unauthorized error
512 if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
513 return &OAuthAuthorizationRequiredError{
514 Handler: c.oauthHandler,
515 }
516 }
517
518 body, _ := io.ReadAll(resp.Body)
519 return fmt.Errorf(
520 "notification failed with status %d: %s",
521 resp.StatusCode,
522 body,
523 )
524 }
525
526 return nil
527}
528
529func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) {
530 c.notifyMu.Lock()
531 defer c.notifyMu.Unlock()
532 c.notificationHandler = handler
533}
534
535func (c *StreamableHTTP) GetSessionId() string {
536 return c.sessionID.Load().(string)
537}
538
539// GetOAuthHandler returns the OAuth handler if configured
540func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler {
541 return c.oauthHandler
542}
543
544// IsOAuthEnabled returns true if OAuth is enabled
545func (c *StreamableHTTP) IsOAuthEnabled() bool {
546 return c.oauthHandler != nil
547}
548
549func (c *StreamableHTTP) listenForever(ctx context.Context) {
550 c.logger.Infof("listening to server forever")
551 for {
552 err := c.createGETConnectionToServer(ctx)
553 if errors.Is(err, ErrGetMethodNotAllowed) {
554 // server does not support listening
555 c.logger.Errorf("server does not support listening")
556 return
557 }
558
559 select {
560 case <-ctx.Done():
561 return
562 default:
563 }
564
565 if err != nil {
566 c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
567 }
568 time.Sleep(retryInterval)
569 }
570}
571
572var (
573 ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize")
574 ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
575
576 retryInterval = 1 * time.Second // a variable is convenient for testing
577)
578
579func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error {
580
581 resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
582 if err != nil {
583 return fmt.Errorf("failed to send request: %w", err)
584 }
585 defer resp.Body.Close()
586
587 // Check if we got an error response
588 if resp.StatusCode == http.StatusMethodNotAllowed {
589 return ErrGetMethodNotAllowed
590 }
591
592 if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
593 body, _ := io.ReadAll(resp.Body)
594 return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
595 }
596
597 // handle SSE response
598 contentType := resp.Header.Get("Content-Type")
599 if contentType != "text/event-stream" {
600 return fmt.Errorf("unexpected content type: %s", contentType)
601 }
602
603 // When ignoreResponse is true, the function will never return expect context is done.
604 // NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response
605 // messages. To be more compatible, we should handle this response, however, as the transport layer is message-based,
606 // currently, there is no convenient way to handle this response.
607 // So we ignore the response here. It's not a bug, but may be not compatible with other SDKs.
608 _, err = c.handleSSEResponse(ctx, resp.Body, true)
609 if err != nil {
610 return fmt.Errorf("failed to handle SSE response: %w", err)
611 }
612
613 return nil
614}
615
616func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) {
617 newCtx, cancel := context.WithCancel(ctx)
618 go func() {
619 select {
620 case <-c.closed:
621 cancel()
622 case <-newCtx.Done():
623 // The original context was canceled
624 cancel()
625 }
626 }()
627 return newCtx, cancel
628}