1package transport
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"mime"
 12	"net/http"
 13	"net/url"
 14	"strings"
 15	"sync"
 16	"sync/atomic"
 17	"time"
 18
 19	"github.com/mark3labs/mcp-go/mcp"
 20	"github.com/mark3labs/mcp-go/util"
 21)
 22
 23type StreamableHTTPCOption func(*StreamableHTTP)
 24
 25// WithContinuousListening enables receiving server-to-client notifications when no request is in flight.
 26// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification),
 27// you should enable this option.
 28//
 29// It will establish a standalone long-live GET HTTP connection to the server.
 30// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
 31// NOTICE: Even enabled, the server may not support this feature.
 32func WithContinuousListening() StreamableHTTPCOption {
 33	return func(sc *StreamableHTTP) {
 34		sc.getListeningEnabled = true
 35	}
 36}
 37
 38// WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport.
 39func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption {
 40	return func(sc *StreamableHTTP) {
 41		sc.httpClient = client
 42	}
 43}
 44
 45func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
 46	return func(sc *StreamableHTTP) {
 47		sc.headers = headers
 48	}
 49}
 50
 51func WithHTTPHeaderFunc(headerFunc HTTPHeaderFunc) StreamableHTTPCOption {
 52	return func(sc *StreamableHTTP) {
 53		sc.headerFunc = headerFunc
 54	}
 55}
 56
 57// WithHTTPTimeout sets the timeout for a HTTP request and stream.
 58func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
 59	return func(sc *StreamableHTTP) {
 60		sc.httpClient.Timeout = timeout
 61	}
 62}
 63
 64// WithHTTPOAuth enables OAuth authentication for the client.
 65func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
 66	return func(sc *StreamableHTTP) {
 67		sc.oauthHandler = NewOAuthHandler(config)
 68	}
 69}
 70
 71func WithLogger(logger util.Logger) StreamableHTTPCOption {
 72	return func(sc *StreamableHTTP) {
 73		sc.logger = logger
 74	}
 75}
 76
 77// WithSession creates a client with a pre-configured session
 78func WithSession(sessionID string) StreamableHTTPCOption {
 79	return func(sc *StreamableHTTP) {
 80		sc.sessionID.Store(sessionID)
 81	}
 82}
 83
 84// StreamableHTTP implements Streamable HTTP transport.
 85//
 86// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
 87// The HTTP response body can either be a single JSON-RPC response,
 88// or an upgraded SSE stream that concludes with a JSON-RPC response for the same request.
 89//
 90// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
 91//
 92// The current implementation does not support the following features:
 93//   - batching
 94//   - resuming stream
 95//     (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
 96//   - server -> client request
 97type StreamableHTTP struct {
 98	serverURL           *url.URL
 99	httpClient          *http.Client
100	headers             map[string]string
101	headerFunc          HTTPHeaderFunc
102	logger              util.Logger
103	getListeningEnabled bool
104
105	sessionID atomic.Value // string
106
107	initialized     chan struct{}
108	initializedOnce sync.Once
109
110	notificationHandler func(mcp.JSONRPCNotification)
111	notifyMu            sync.RWMutex
112
113	closed chan struct{}
114
115	// OAuth support
116	oauthHandler *OAuthHandler
117}
118
119// NewStreamableHTTP creates a new Streamable HTTP transport with the given server URL.
120// Returns an error if the URL is invalid.
121func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) {
122	parsedURL, err := url.Parse(serverURL)
123	if err != nil {
124		return nil, fmt.Errorf("invalid URL: %w", err)
125	}
126
127	smc := &StreamableHTTP{
128		serverURL:   parsedURL,
129		httpClient:  &http.Client{},
130		headers:     make(map[string]string),
131		closed:      make(chan struct{}),
132		logger:      util.DefaultLogger(),
133		initialized: make(chan struct{}),
134	}
135	smc.sessionID.Store("") // set initial value to simplify later usage
136
137	for _, opt := range options {
138		if opt != nil {
139			opt(smc)
140		}
141	}
142
143	// If OAuth is configured, set the base URL for metadata discovery
144	if smc.oauthHandler != nil {
145		// Extract base URL from server URL for metadata discovery
146		baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
147		smc.oauthHandler.SetBaseURL(baseURL)
148	}
149
150	return smc, nil
151}
152
153// Start initiates the HTTP connection to the server.
154func (c *StreamableHTTP) Start(ctx context.Context) error {
155	// For Streamable HTTP, we don't need to establish a persistent connection by default
156	if c.getListeningEnabled {
157		go func() {
158			select {
159			case <-c.initialized:
160				ctx, cancel := c.contextAwareOfClientClose(ctx)
161				defer cancel()
162				c.listenForever(ctx)
163			case <-c.closed:
164				return
165			}
166		}()
167	}
168
169	return nil
170}
171
172// Close closes the all the HTTP connections to the server.
173func (c *StreamableHTTP) Close() error {
174	select {
175	case <-c.closed:
176		return nil
177	default:
178	}
179	// Cancel all in-flight requests
180	close(c.closed)
181
182	sessionId := c.sessionID.Load().(string)
183	if sessionId != "" {
184		c.sessionID.Store("")
185
186		// notify server session closed
187		go func() {
188			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
189			defer cancel()
190			req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
191			if err != nil {
192				c.logger.Errorf("failed to create close request: %v", err)
193				return
194			}
195			req.Header.Set(headerKeySessionID, sessionId)
196			res, err := c.httpClient.Do(req)
197			if err != nil {
198				c.logger.Errorf("failed to send close request: %v", err)
199				return
200			}
201			res.Body.Close()
202		}()
203	}
204
205	return nil
206}
207
208const (
209	headerKeySessionID = "Mcp-Session-Id"
210)
211
212// ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required
213var ErrOAuthAuthorizationRequired = errors.New("no valid token available, authorization required")
214
215// OAuthAuthorizationRequiredError is returned when OAuth authorization is required
216type OAuthAuthorizationRequiredError struct {
217	Handler *OAuthHandler
218}
219
220func (e *OAuthAuthorizationRequiredError) Error() string {
221	return ErrOAuthAuthorizationRequired.Error()
222}
223
224func (e *OAuthAuthorizationRequiredError) Unwrap() error {
225	return ErrOAuthAuthorizationRequired
226}
227
228// SendRequest sends a JSON-RPC request to the server and waits for a response.
229// Returns the raw JSON response message or an error if the request fails.
230func (c *StreamableHTTP) SendRequest(
231	ctx context.Context,
232	request JSONRPCRequest,
233) (*JSONRPCResponse, error) {
234
235	// Marshal request
236	requestBody, err := json.Marshal(request)
237	if err != nil {
238		return nil, fmt.Errorf("failed to marshal request: %w", err)
239	}
240
241	ctx, cancel := c.contextAwareOfClientClose(ctx)
242	defer cancel()
243
244	resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
245	if err != nil {
246		if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
247			// If the request is initialize, should not return a SessionTerminated error
248			// It should be a genuine endpoint-routing issue.
249			// ( Fall through to return StatusCode checking. )
250		} else {
251			return nil, fmt.Errorf("failed to send request: %w", err)
252		}
253	}
254	defer resp.Body.Close()
255
256	// Check if we got an error response
257	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
258
259		// Handle OAuth unauthorized error
260		if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
261			return nil, &OAuthAuthorizationRequiredError{
262				Handler: c.oauthHandler,
263			}
264		}
265
266		// handle error response
267		var errResponse JSONRPCResponse
268		body, _ := io.ReadAll(resp.Body)
269		if err := json.Unmarshal(body, &errResponse); err == nil {
270			return &errResponse, nil
271		}
272		return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
273	}
274
275	if request.Method == string(mcp.MethodInitialize) {
276		// saved the received session ID in the response
277		// empty session ID is allowed
278		if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
279			c.sessionID.Store(sessionID)
280		}
281
282		c.initializedOnce.Do(func() {
283			close(c.initialized)
284		})
285	}
286
287	// Handle different response types
288	mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
289	switch mediaType {
290	case "application/json":
291		// Single response
292		var response JSONRPCResponse
293		if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
294			return nil, fmt.Errorf("failed to decode response: %w", err)
295		}
296
297		// should not be a notification
298		if response.ID.IsNil() {
299			return nil, fmt.Errorf("response should contain RPC id: %v", response)
300		}
301
302		return &response, nil
303
304	case "text/event-stream":
305		// Server is using SSE for streaming responses
306		return c.handleSSEResponse(ctx, resp.Body, false)
307
308	default:
309		return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
310	}
311}
312
313func (c *StreamableHTTP) sendHTTP(
314	ctx context.Context,
315	method string,
316	body io.Reader,
317	acceptType string,
318) (resp *http.Response, err error) {
319
320	// Create HTTP request
321	req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
322	if err != nil {
323		return nil, fmt.Errorf("failed to create request: %w", err)
324	}
325
326	// Set headers
327	req.Header.Set("Content-Type", "application/json")
328	req.Header.Set("Accept", acceptType)
329	sessionID := c.sessionID.Load().(string)
330	if sessionID != "" {
331		req.Header.Set(headerKeySessionID, sessionID)
332	}
333	for k, v := range c.headers {
334		req.Header.Set(k, v)
335	}
336
337	// Add OAuth authorization if configured
338	if c.oauthHandler != nil {
339		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
340		if err != nil {
341			// If we get an authorization error, return a specific error that can be handled by the client
342			if err.Error() == "no valid token available, authorization required" {
343				return nil, &OAuthAuthorizationRequiredError{
344					Handler: c.oauthHandler,
345				}
346			}
347			return nil, fmt.Errorf("failed to get authorization header: %w", err)
348		}
349		req.Header.Set("Authorization", authHeader)
350	}
351
352	if c.headerFunc != nil {
353		for k, v := range c.headerFunc(ctx) {
354			req.Header.Set(k, v)
355		}
356	}
357
358	// Send request
359	resp, err = c.httpClient.Do(req)
360	if err != nil {
361		return nil, fmt.Errorf("failed to send request: %w", err)
362	}
363
364	// universal handling for session terminated
365	if resp.StatusCode == http.StatusNotFound {
366		c.sessionID.CompareAndSwap(sessionID, "")
367		return nil, ErrSessionTerminated
368	}
369
370	return resp, nil
371}
372
373// handleSSEResponse processes an SSE stream for a specific request.
374// It returns the final result for the request once received, or an error.
375// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening.
376func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) {
377
378	// Create a channel for this specific request
379	responseChan := make(chan *JSONRPCResponse, 1)
380
381	ctx, cancel := context.WithCancel(ctx)
382	defer cancel()
383
384	// Start a goroutine to process the SSE stream
385	go func() {
386		// only close responseChan after readingSSE()
387		defer close(responseChan)
388
389		c.readSSE(ctx, reader, func(event, data string) {
390
391			// (unsupported: batching)
392
393			var message JSONRPCResponse
394			if err := json.Unmarshal([]byte(data), &message); err != nil {
395				c.logger.Errorf("failed to unmarshal message: %v", err)
396				return
397			}
398
399			// Handle notification
400			if message.ID.IsNil() {
401				var notification mcp.JSONRPCNotification
402				if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
403					c.logger.Errorf("failed to unmarshal notification: %v", err)
404					return
405				}
406				c.notifyMu.RLock()
407				if c.notificationHandler != nil {
408					c.notificationHandler(notification)
409				}
410				c.notifyMu.RUnlock()
411				return
412			}
413
414			if !ignoreResponse {
415				responseChan <- &message
416			}
417		})
418	}()
419
420	// Wait for the response or context cancellation
421	select {
422	case response := <-responseChan:
423		if response == nil {
424			return nil, fmt.Errorf("unexpected nil response")
425		}
426		return response, nil
427	case <-ctx.Done():
428		return nil, ctx.Err()
429	}
430}
431
432// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair.
433// It will end when the reader is closed (or the context is done).
434func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) {
435	defer reader.Close()
436
437	br := bufio.NewReader(reader)
438	var event, data string
439
440	for {
441		select {
442		case <-ctx.Done():
443			return
444		default:
445			line, err := br.ReadString('\n')
446			if err != nil {
447				if err == io.EOF {
448					// Process any pending event before exit
449					if data != "" {
450						// If no event type is specified, use empty string (default event type)
451						if event == "" {
452							event = "message"
453						}
454						handler(event, data)
455					}
456					return
457				}
458				select {
459				case <-ctx.Done():
460					return
461				default:
462					c.logger.Errorf("SSE stream error: %v", err)
463					return
464				}
465			}
466
467			// Remove only newline markers
468			line = strings.TrimRight(line, "\r\n")
469			if line == "" {
470				// Empty line means end of event
471				if data != "" {
472					// If no event type is specified, use empty string (default event type)
473					if event == "" {
474						event = "message"
475					}
476					handler(event, data)
477					event = ""
478					data = ""
479				}
480				continue
481			}
482
483			if strings.HasPrefix(line, "event:") {
484				event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
485			} else if strings.HasPrefix(line, "data:") {
486				data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
487			}
488		}
489	}
490}
491
492func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
493
494	// Marshal request
495	requestBody, err := json.Marshal(notification)
496	if err != nil {
497		return fmt.Errorf("failed to marshal notification: %w", err)
498	}
499
500	// Create HTTP request
501	ctx, cancel := c.contextAwareOfClientClose(ctx)
502	defer cancel()
503
504	resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
505	if err != nil {
506		return fmt.Errorf("failed to send request: %w", err)
507	}
508	defer resp.Body.Close()
509
510	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
511		// Handle OAuth unauthorized error
512		if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
513			return &OAuthAuthorizationRequiredError{
514				Handler: c.oauthHandler,
515			}
516		}
517
518		body, _ := io.ReadAll(resp.Body)
519		return fmt.Errorf(
520			"notification failed with status %d: %s",
521			resp.StatusCode,
522			body,
523		)
524	}
525
526	return nil
527}
528
529func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) {
530	c.notifyMu.Lock()
531	defer c.notifyMu.Unlock()
532	c.notificationHandler = handler
533}
534
535func (c *StreamableHTTP) GetSessionId() string {
536	return c.sessionID.Load().(string)
537}
538
539// GetOAuthHandler returns the OAuth handler if configured
540func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler {
541	return c.oauthHandler
542}
543
544// IsOAuthEnabled returns true if OAuth is enabled
545func (c *StreamableHTTP) IsOAuthEnabled() bool {
546	return c.oauthHandler != nil
547}
548
549func (c *StreamableHTTP) listenForever(ctx context.Context) {
550	c.logger.Infof("listening to server forever")
551	for {
552		err := c.createGETConnectionToServer(ctx)
553		if errors.Is(err, ErrGetMethodNotAllowed) {
554			// server does not support listening
555			c.logger.Errorf("server does not support listening")
556			return
557		}
558
559		select {
560		case <-ctx.Done():
561			return
562		default:
563		}
564
565		if err != nil {
566			c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
567		}
568		time.Sleep(retryInterval)
569	}
570}
571
572var (
573	ErrSessionTerminated   = fmt.Errorf("session terminated (404). need to re-initialize")
574	ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
575
576	retryInterval = 1 * time.Second // a variable is convenient for testing
577)
578
579func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error {
580
581	resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
582	if err != nil {
583		return fmt.Errorf("failed to send request: %w", err)
584	}
585	defer resp.Body.Close()
586
587	// Check if we got an error response
588	if resp.StatusCode == http.StatusMethodNotAllowed {
589		return ErrGetMethodNotAllowed
590	}
591
592	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
593		body, _ := io.ReadAll(resp.Body)
594		return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
595	}
596
597	// handle SSE response
598	contentType := resp.Header.Get("Content-Type")
599	if contentType != "text/event-stream" {
600		return fmt.Errorf("unexpected content type: %s", contentType)
601	}
602
603	// When ignoreResponse is true, the function will never return expect context is done.
604	// NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response
605	// messages. To be more compatible, we should handle this response, however, as the transport layer is message-based,
606	// currently, there is no convenient way to handle this response.
607	// So we ignore the response here. It's not a bug, but may be not compatible with other SDKs.
608	_, err = c.handleSSEResponse(ctx, resp.Body, true)
609	if err != nil {
610		return fmt.Errorf("failed to handle SSE response: %w", err)
611	}
612
613	return nil
614}
615
616func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) {
617	newCtx, cancel := context.WithCancel(ctx)
618	go func() {
619		select {
620		case <-c.closed:
621			cancel()
622		case <-newCtx.Done():
623			// The original context was canceled
624			cancel()
625		}
626	}()
627	return newCtx, cancel
628}