streamable_http.go

  1package transport
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"mime"
 12	"net/http"
 13	"net/url"
 14	"strings"
 15	"sync"
 16	"sync/atomic"
 17	"time"
 18
 19	"github.com/mark3labs/mcp-go/mcp"
 20)
 21
 22type StreamableHTTPCOption func(*StreamableHTTP)
 23
 24// WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport.
 25func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption {
 26	return func(sc *StreamableHTTP) {
 27		sc.httpClient = client
 28	}
 29}
 30
 31func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
 32	return func(sc *StreamableHTTP) {
 33		sc.headers = headers
 34	}
 35}
 36
 37func WithHTTPHeaderFunc(headerFunc HTTPHeaderFunc) StreamableHTTPCOption {
 38	return func(sc *StreamableHTTP) {
 39		sc.headerFunc = headerFunc
 40	}
 41}
 42
 43// WithHTTPTimeout sets the timeout for a HTTP request and stream.
 44func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
 45	return func(sc *StreamableHTTP) {
 46		sc.httpClient.Timeout = timeout
 47	}
 48}
 49
 50// WithHTTPOAuth enables OAuth authentication for the client.
 51func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
 52	return func(sc *StreamableHTTP) {
 53		sc.oauthHandler = NewOAuthHandler(config)
 54	}
 55}
 56
 57// StreamableHTTP implements Streamable HTTP transport.
 58//
 59// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
 60// The HTTP response body can either be a single JSON-RPC response,
 61// or an upgraded SSE stream that concludes with a JSON-RPC response for the same request.
 62//
 63// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
 64//
 65// The current implementation does not support the following features:
 66//   - batching
 67//   - continuously listening for server notifications when no request is in flight
 68//     (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
 69//   - resuming stream
 70//     (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
 71//   - server -> client request
 72type StreamableHTTP struct {
 73	serverURL  *url.URL
 74	httpClient *http.Client
 75	headers    map[string]string
 76	headerFunc HTTPHeaderFunc
 77
 78	sessionID atomic.Value // string
 79
 80	notificationHandler func(mcp.JSONRPCNotification)
 81	notifyMu            sync.RWMutex
 82
 83	closed chan struct{}
 84
 85	// OAuth support
 86	oauthHandler *OAuthHandler
 87}
 88
 89// NewStreamableHTTP creates a new Streamable HTTP transport with the given server URL.
 90// Returns an error if the URL is invalid.
 91func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) {
 92	parsedURL, err := url.Parse(serverURL)
 93	if err != nil {
 94		return nil, fmt.Errorf("invalid URL: %w", err)
 95	}
 96
 97	smc := &StreamableHTTP{
 98		serverURL:  parsedURL,
 99		httpClient: &http.Client{},
100		headers:    make(map[string]string),
101		closed:     make(chan struct{}),
102	}
103	smc.sessionID.Store("") // set initial value to simplify later usage
104
105	for _, opt := range options {
106		opt(smc)
107	}
108
109	// If OAuth is configured, set the base URL for metadata discovery
110	if smc.oauthHandler != nil {
111		// Extract base URL from server URL for metadata discovery
112		baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
113		smc.oauthHandler.SetBaseURL(baseURL)
114	}
115
116	return smc, nil
117}
118
119// Start initiates the HTTP connection to the server.
120func (c *StreamableHTTP) Start(ctx context.Context) error {
121	// For Streamable HTTP, we don't need to establish a persistent connection
122	return nil
123}
124
125// Close closes the all the HTTP connections to the server.
126func (c *StreamableHTTP) Close() error {
127	select {
128	case <-c.closed:
129		return nil
130	default:
131	}
132	// Cancel all in-flight requests
133	close(c.closed)
134
135	sessionId := c.sessionID.Load().(string)
136	if sessionId != "" {
137		c.sessionID.Store("")
138
139		// notify server session closed
140		go func() {
141			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
142			defer cancel()
143			req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
144			if err != nil {
145				fmt.Printf("failed to create close request\n: %v", err)
146				return
147			}
148			req.Header.Set(headerKeySessionID, sessionId)
149			res, err := c.httpClient.Do(req)
150			if err != nil {
151				fmt.Printf("failed to send close request\n: %v", err)
152				return
153			}
154			res.Body.Close()
155		}()
156	}
157
158	return nil
159}
160
161const (
162	headerKeySessionID = "Mcp-Session-Id"
163)
164
165// ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required
166var ErrOAuthAuthorizationRequired = errors.New("no valid token available, authorization required")
167
168// OAuthAuthorizationRequiredError is returned when OAuth authorization is required
169type OAuthAuthorizationRequiredError struct {
170	Handler *OAuthHandler
171}
172
173func (e *OAuthAuthorizationRequiredError) Error() string {
174	return ErrOAuthAuthorizationRequired.Error()
175}
176
177func (e *OAuthAuthorizationRequiredError) Unwrap() error {
178	return ErrOAuthAuthorizationRequired
179}
180
181// SendRequest sends a JSON-RPC request to the server and waits for a response.
182// Returns the raw JSON response message or an error if the request fails.
183func (c *StreamableHTTP) SendRequest(
184	ctx context.Context,
185	request JSONRPCRequest,
186) (*JSONRPCResponse, error) {
187
188	// Create a combined context that could be canceled when the client is closed
189	newCtx, cancel := context.WithCancel(ctx)
190	defer cancel()
191	go func() {
192		select {
193		case <-c.closed:
194			cancel()
195		case <-newCtx.Done():
196			// The original context was canceled, no need to do anything
197		}
198	}()
199	ctx = newCtx
200
201	// Marshal request
202	requestBody, err := json.Marshal(request)
203	if err != nil {
204		return nil, fmt.Errorf("failed to marshal request: %w", err)
205	}
206
207	// Create HTTP request
208	req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
209	if err != nil {
210		return nil, fmt.Errorf("failed to create request: %w", err)
211	}
212
213	// Set headers
214	req.Header.Set("Content-Type", "application/json")
215	req.Header.Set("Accept", "application/json, text/event-stream")
216	sessionID := c.sessionID.Load()
217	if sessionID != "" {
218		req.Header.Set(headerKeySessionID, sessionID.(string))
219	}
220	for k, v := range c.headers {
221		req.Header.Set(k, v)
222	}
223
224	// Add OAuth authorization if configured
225	if c.oauthHandler != nil {
226		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
227		if err != nil {
228			// If we get an authorization error, return a specific error that can be handled by the client
229			if err.Error() == "no valid token available, authorization required" {
230				return nil, &OAuthAuthorizationRequiredError{
231					Handler: c.oauthHandler,
232				}
233			}
234			return nil, fmt.Errorf("failed to get authorization header: %w", err)
235		}
236		req.Header.Set("Authorization", authHeader)
237	}
238
239	if c.headerFunc != nil {
240		for k, v := range c.headerFunc(ctx) {
241			req.Header.Set(k, v)
242		}
243	}
244
245	// Send request
246	resp, err := c.httpClient.Do(req)
247	if err != nil {
248		return nil, fmt.Errorf("failed to send request: %w", err)
249	}
250	defer resp.Body.Close()
251
252	// Check if we got an error response
253	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
254		// handle session closed
255		if resp.StatusCode == http.StatusNotFound {
256			c.sessionID.CompareAndSwap(sessionID, "")
257			return nil, fmt.Errorf("session terminated (404). need to re-initialize")
258		}
259
260		// Handle OAuth unauthorized error
261		if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
262			return nil, &OAuthAuthorizationRequiredError{
263				Handler: c.oauthHandler,
264			}
265		}
266
267		// handle error response
268		var errResponse JSONRPCResponse
269		body, _ := io.ReadAll(resp.Body)
270		if err := json.Unmarshal(body, &errResponse); err == nil {
271			return &errResponse, nil
272		}
273		return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
274	}
275
276	if request.Method == string(mcp.MethodInitialize) {
277		// saved the received session ID in the response
278		// empty session ID is allowed
279		if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
280			c.sessionID.Store(sessionID)
281		}
282	}
283
284	// Handle different response types
285	mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
286	switch mediaType {
287	case "application/json":
288		// Single response
289		var response JSONRPCResponse
290		if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
291			return nil, fmt.Errorf("failed to decode response: %w", err)
292		}
293
294		// should not be a notification
295		if response.ID.IsNil() {
296			return nil, fmt.Errorf("response should contain RPC id: %v", response)
297		}
298
299		return &response, nil
300
301	case "text/event-stream":
302		// Server is using SSE for streaming responses
303		return c.handleSSEResponse(ctx, resp.Body)
304
305	default:
306		return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
307	}
308}
309
310// handleSSEResponse processes an SSE stream for a specific request.
311// It returns the final result for the request once received, or an error.
312func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
313
314	// Create a channel for this specific request
315	responseChan := make(chan *JSONRPCResponse, 1)
316
317	ctx, cancel := context.WithCancel(ctx)
318	defer cancel()
319
320	// Start a goroutine to process the SSE stream
321	go func() {
322		// only close responseChan after readingSSE()
323		defer close(responseChan)
324
325		c.readSSE(ctx, reader, func(event, data string) {
326
327			// (unsupported: batching)
328
329			var message JSONRPCResponse
330			if err := json.Unmarshal([]byte(data), &message); err != nil {
331				fmt.Printf("failed to unmarshal message: %v\n", err)
332				return
333			}
334
335			// Handle notification
336			if message.ID.IsNil() {
337				var notification mcp.JSONRPCNotification
338				if err := json.Unmarshal([]byte(data), &notification); err != nil {
339					fmt.Printf("failed to unmarshal notification: %v\n", err)
340					return
341				}
342				c.notifyMu.RLock()
343				if c.notificationHandler != nil {
344					c.notificationHandler(notification)
345				}
346				c.notifyMu.RUnlock()
347				return
348			}
349
350			responseChan <- &message
351		})
352	}()
353
354	// Wait for the response or context cancellation
355	select {
356	case response := <-responseChan:
357		if response == nil {
358			return nil, fmt.Errorf("unexpected nil response")
359		}
360		return response, nil
361	case <-ctx.Done():
362		return nil, ctx.Err()
363	}
364}
365
366// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair.
367// It will end when the reader is closed (or the context is done).
368func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) {
369	defer reader.Close()
370
371	br := bufio.NewReader(reader)
372	var event, data string
373
374	for {
375		select {
376		case <-ctx.Done():
377			return
378		default:
379			line, err := br.ReadString('\n')
380			if err != nil {
381				if err == io.EOF {
382					// Process any pending event before exit
383					if data != "" {
384						// If no event type is specified, use empty string (default event type)
385						if event == "" {
386							event = "message"
387						}
388						handler(event, data)
389					}
390					return
391				}
392				select {
393				case <-ctx.Done():
394					return
395				default:
396					fmt.Printf("SSE stream error: %v\n", err)
397					return
398				}
399			}
400
401			// Remove only newline markers
402			line = strings.TrimRight(line, "\r\n")
403			if line == "" {
404				// Empty line means end of event
405				if data != "" {
406					// If no event type is specified, use empty string (default event type)
407					if event == "" {
408						event = "message"
409					}
410					handler(event, data)
411					event = ""
412					data = ""
413				}
414				continue
415			}
416
417			if strings.HasPrefix(line, "event:") {
418				event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
419			} else if strings.HasPrefix(line, "data:") {
420				data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
421			}
422		}
423	}
424}
425
426func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
427
428	// Marshal request
429	requestBody, err := json.Marshal(notification)
430	if err != nil {
431		return fmt.Errorf("failed to marshal notification: %w", err)
432	}
433
434	// Create HTTP request
435	req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
436	if err != nil {
437		return fmt.Errorf("failed to create request: %w", err)
438	}
439
440	// Set headers
441	req.Header.Set("Content-Type", "application/json")
442	req.Header.Set("Accept", "application/json, text/event-stream")
443	if sessionID := c.sessionID.Load(); sessionID != "" {
444		req.Header.Set(headerKeySessionID, sessionID.(string))
445	}
446	for k, v := range c.headers {
447		req.Header.Set(k, v)
448	}
449
450	// Add OAuth authorization if configured
451	if c.oauthHandler != nil {
452		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
453		if err != nil {
454			// If we get an authorization error, return a specific error that can be handled by the client
455			if errors.Is(err, ErrOAuthAuthorizationRequired) {
456				return &OAuthAuthorizationRequiredError{
457					Handler: c.oauthHandler,
458				}
459			}
460			return fmt.Errorf("failed to get authorization header: %w", err)
461		}
462		req.Header.Set("Authorization", authHeader)
463	}
464
465	if c.headerFunc != nil {
466		for k, v := range c.headerFunc(ctx) {
467			req.Header.Set(k, v)
468		}
469	}
470
471	// Send request
472	resp, err := c.httpClient.Do(req)
473	if err != nil {
474		return fmt.Errorf("failed to send request: %w", err)
475	}
476	defer resp.Body.Close()
477
478	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
479		// Handle OAuth unauthorized error
480		if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
481			return &OAuthAuthorizationRequiredError{
482				Handler: c.oauthHandler,
483			}
484		}
485
486		body, _ := io.ReadAll(resp.Body)
487		return fmt.Errorf(
488			"notification failed with status %d: %s",
489			resp.StatusCode,
490			body,
491		)
492	}
493
494	return nil
495}
496
497func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) {
498	c.notifyMu.Lock()
499	defer c.notifyMu.Unlock()
500	c.notificationHandler = handler
501}
502
503func (c *StreamableHTTP) GetSessionId() string {
504	return c.sessionID.Load().(string)
505}
506
507// GetOAuthHandler returns the OAuth handler if configured
508func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler {
509	return c.oauthHandler
510}
511
512// IsOAuthEnabled returns true if OAuth is enabled
513func (c *StreamableHTTP) IsOAuthEnabled() bool {
514	return c.oauthHandler != nil
515}