1package transport
2
3import (
4 "bufio"
5 "bytes"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "mime"
12 "net/http"
13 "net/url"
14 "strings"
15 "sync"
16 "sync/atomic"
17 "time"
18
19 "github.com/mark3labs/mcp-go/mcp"
20)
21
22type StreamableHTTPCOption func(*StreamableHTTP)
23
24// WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport.
25func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption {
26 return func(sc *StreamableHTTP) {
27 sc.httpClient = client
28 }
29}
30
31func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
32 return func(sc *StreamableHTTP) {
33 sc.headers = headers
34 }
35}
36
37func WithHTTPHeaderFunc(headerFunc HTTPHeaderFunc) StreamableHTTPCOption {
38 return func(sc *StreamableHTTP) {
39 sc.headerFunc = headerFunc
40 }
41}
42
43// WithHTTPTimeout sets the timeout for a HTTP request and stream.
44func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
45 return func(sc *StreamableHTTP) {
46 sc.httpClient.Timeout = timeout
47 }
48}
49
50// WithHTTPOAuth enables OAuth authentication for the client.
51func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
52 return func(sc *StreamableHTTP) {
53 sc.oauthHandler = NewOAuthHandler(config)
54 }
55}
56
57// StreamableHTTP implements Streamable HTTP transport.
58//
59// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
60// The HTTP response body can either be a single JSON-RPC response,
61// or an upgraded SSE stream that concludes with a JSON-RPC response for the same request.
62//
63// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
64//
65// The current implementation does not support the following features:
66// - batching
67// - continuously listening for server notifications when no request is in flight
68// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
69// - resuming stream
70// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
71// - server -> client request
72type StreamableHTTP struct {
73 serverURL *url.URL
74 httpClient *http.Client
75 headers map[string]string
76 headerFunc HTTPHeaderFunc
77
78 sessionID atomic.Value // string
79
80 notificationHandler func(mcp.JSONRPCNotification)
81 notifyMu sync.RWMutex
82
83 closed chan struct{}
84
85 // OAuth support
86 oauthHandler *OAuthHandler
87}
88
89// NewStreamableHTTP creates a new Streamable HTTP transport with the given server URL.
90// Returns an error if the URL is invalid.
91func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) {
92 parsedURL, err := url.Parse(serverURL)
93 if err != nil {
94 return nil, fmt.Errorf("invalid URL: %w", err)
95 }
96
97 smc := &StreamableHTTP{
98 serverURL: parsedURL,
99 httpClient: &http.Client{},
100 headers: make(map[string]string),
101 closed: make(chan struct{}),
102 }
103 smc.sessionID.Store("") // set initial value to simplify later usage
104
105 for _, opt := range options {
106 opt(smc)
107 }
108
109 // If OAuth is configured, set the base URL for metadata discovery
110 if smc.oauthHandler != nil {
111 // Extract base URL from server URL for metadata discovery
112 baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
113 smc.oauthHandler.SetBaseURL(baseURL)
114 }
115
116 return smc, nil
117}
118
119// Start initiates the HTTP connection to the server.
120func (c *StreamableHTTP) Start(ctx context.Context) error {
121 // For Streamable HTTP, we don't need to establish a persistent connection
122 return nil
123}
124
125// Close closes the all the HTTP connections to the server.
126func (c *StreamableHTTP) Close() error {
127 select {
128 case <-c.closed:
129 return nil
130 default:
131 }
132 // Cancel all in-flight requests
133 close(c.closed)
134
135 sessionId := c.sessionID.Load().(string)
136 if sessionId != "" {
137 c.sessionID.Store("")
138
139 // notify server session closed
140 go func() {
141 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
142 defer cancel()
143 req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
144 if err != nil {
145 fmt.Printf("failed to create close request\n: %v", err)
146 return
147 }
148 req.Header.Set(headerKeySessionID, sessionId)
149 res, err := c.httpClient.Do(req)
150 if err != nil {
151 fmt.Printf("failed to send close request\n: %v", err)
152 return
153 }
154 res.Body.Close()
155 }()
156 }
157
158 return nil
159}
160
161const (
162 headerKeySessionID = "Mcp-Session-Id"
163)
164
165// ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required
166var ErrOAuthAuthorizationRequired = errors.New("no valid token available, authorization required")
167
168// OAuthAuthorizationRequiredError is returned when OAuth authorization is required
169type OAuthAuthorizationRequiredError struct {
170 Handler *OAuthHandler
171}
172
173func (e *OAuthAuthorizationRequiredError) Error() string {
174 return ErrOAuthAuthorizationRequired.Error()
175}
176
177func (e *OAuthAuthorizationRequiredError) Unwrap() error {
178 return ErrOAuthAuthorizationRequired
179}
180
181// SendRequest sends a JSON-RPC request to the server and waits for a response.
182// Returns the raw JSON response message or an error if the request fails.
183func (c *StreamableHTTP) SendRequest(
184 ctx context.Context,
185 request JSONRPCRequest,
186) (*JSONRPCResponse, error) {
187
188 // Create a combined context that could be canceled when the client is closed
189 newCtx, cancel := context.WithCancel(ctx)
190 defer cancel()
191 go func() {
192 select {
193 case <-c.closed:
194 cancel()
195 case <-newCtx.Done():
196 // The original context was canceled, no need to do anything
197 }
198 }()
199 ctx = newCtx
200
201 // Marshal request
202 requestBody, err := json.Marshal(request)
203 if err != nil {
204 return nil, fmt.Errorf("failed to marshal request: %w", err)
205 }
206
207 // Create HTTP request
208 req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
209 if err != nil {
210 return nil, fmt.Errorf("failed to create request: %w", err)
211 }
212
213 // Set headers
214 req.Header.Set("Content-Type", "application/json")
215 req.Header.Set("Accept", "application/json, text/event-stream")
216 sessionID := c.sessionID.Load()
217 if sessionID != "" {
218 req.Header.Set(headerKeySessionID, sessionID.(string))
219 }
220 for k, v := range c.headers {
221 req.Header.Set(k, v)
222 }
223
224 // Add OAuth authorization if configured
225 if c.oauthHandler != nil {
226 authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
227 if err != nil {
228 // If we get an authorization error, return a specific error that can be handled by the client
229 if err.Error() == "no valid token available, authorization required" {
230 return nil, &OAuthAuthorizationRequiredError{
231 Handler: c.oauthHandler,
232 }
233 }
234 return nil, fmt.Errorf("failed to get authorization header: %w", err)
235 }
236 req.Header.Set("Authorization", authHeader)
237 }
238
239 if c.headerFunc != nil {
240 for k, v := range c.headerFunc(ctx) {
241 req.Header.Set(k, v)
242 }
243 }
244
245 // Send request
246 resp, err := c.httpClient.Do(req)
247 if err != nil {
248 return nil, fmt.Errorf("failed to send request: %w", err)
249 }
250 defer resp.Body.Close()
251
252 // Check if we got an error response
253 if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
254 // handle session closed
255 if resp.StatusCode == http.StatusNotFound {
256 c.sessionID.CompareAndSwap(sessionID, "")
257 return nil, fmt.Errorf("session terminated (404). need to re-initialize")
258 }
259
260 // Handle OAuth unauthorized error
261 if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
262 return nil, &OAuthAuthorizationRequiredError{
263 Handler: c.oauthHandler,
264 }
265 }
266
267 // handle error response
268 var errResponse JSONRPCResponse
269 body, _ := io.ReadAll(resp.Body)
270 if err := json.Unmarshal(body, &errResponse); err == nil {
271 return &errResponse, nil
272 }
273 return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
274 }
275
276 if request.Method == string(mcp.MethodInitialize) {
277 // saved the received session ID in the response
278 // empty session ID is allowed
279 if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
280 c.sessionID.Store(sessionID)
281 }
282 }
283
284 // Handle different response types
285 mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
286 switch mediaType {
287 case "application/json":
288 // Single response
289 var response JSONRPCResponse
290 if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
291 return nil, fmt.Errorf("failed to decode response: %w", err)
292 }
293
294 // should not be a notification
295 if response.ID.IsNil() {
296 return nil, fmt.Errorf("response should contain RPC id: %v", response)
297 }
298
299 return &response, nil
300
301 case "text/event-stream":
302 // Server is using SSE for streaming responses
303 return c.handleSSEResponse(ctx, resp.Body)
304
305 default:
306 return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
307 }
308}
309
310// handleSSEResponse processes an SSE stream for a specific request.
311// It returns the final result for the request once received, or an error.
312func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
313
314 // Create a channel for this specific request
315 responseChan := make(chan *JSONRPCResponse, 1)
316
317 ctx, cancel := context.WithCancel(ctx)
318 defer cancel()
319
320 // Start a goroutine to process the SSE stream
321 go func() {
322 // only close responseChan after readingSSE()
323 defer close(responseChan)
324
325 c.readSSE(ctx, reader, func(event, data string) {
326
327 // (unsupported: batching)
328
329 var message JSONRPCResponse
330 if err := json.Unmarshal([]byte(data), &message); err != nil {
331 fmt.Printf("failed to unmarshal message: %v\n", err)
332 return
333 }
334
335 // Handle notification
336 if message.ID.IsNil() {
337 var notification mcp.JSONRPCNotification
338 if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
339 fmt.Printf("failed to unmarshal notification: %v\n", err)
340 return
341 }
342 c.notifyMu.RLock()
343 if c.notificationHandler != nil {
344 c.notificationHandler(notification)
345 }
346 c.notifyMu.RUnlock()
347 return
348 }
349
350 responseChan <- &message
351 })
352 }()
353
354 // Wait for the response or context cancellation
355 select {
356 case response := <-responseChan:
357 if response == nil {
358 return nil, fmt.Errorf("unexpected nil response")
359 }
360 return response, nil
361 case <-ctx.Done():
362 return nil, ctx.Err()
363 }
364}
365
366// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair.
367// It will end when the reader is closed (or the context is done).
368func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) {
369 defer reader.Close()
370
371 br := bufio.NewReader(reader)
372 var event, data string
373
374 for {
375 select {
376 case <-ctx.Done():
377 return
378 default:
379 line, err := br.ReadString('\n')
380 if err != nil {
381 if err == io.EOF {
382 // Process any pending event before exit
383 if data != "" {
384 // If no event type is specified, use empty string (default event type)
385 if event == "" {
386 event = "message"
387 }
388 handler(event, data)
389 }
390 return
391 }
392 select {
393 case <-ctx.Done():
394 return
395 default:
396 fmt.Printf("SSE stream error: %v\n", err)
397 return
398 }
399 }
400
401 // Remove only newline markers
402 line = strings.TrimRight(line, "\r\n")
403 if line == "" {
404 // Empty line means end of event
405 if data != "" {
406 // If no event type is specified, use empty string (default event type)
407 if event == "" {
408 event = "message"
409 }
410 handler(event, data)
411 event = ""
412 data = ""
413 }
414 continue
415 }
416
417 if strings.HasPrefix(line, "event:") {
418 event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
419 } else if strings.HasPrefix(line, "data:") {
420 data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
421 }
422 }
423 }
424}
425
426func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
427
428 // Marshal request
429 requestBody, err := json.Marshal(notification)
430 if err != nil {
431 return fmt.Errorf("failed to marshal notification: %w", err)
432 }
433
434 // Create HTTP request
435 req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
436 if err != nil {
437 return fmt.Errorf("failed to create request: %w", err)
438 }
439
440 // Set headers
441 req.Header.Set("Content-Type", "application/json")
442 req.Header.Set("Accept", "application/json, text/event-stream")
443 if sessionID := c.sessionID.Load(); sessionID != "" {
444 req.Header.Set(headerKeySessionID, sessionID.(string))
445 }
446 for k, v := range c.headers {
447 req.Header.Set(k, v)
448 }
449
450 // Add OAuth authorization if configured
451 if c.oauthHandler != nil {
452 authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
453 if err != nil {
454 // If we get an authorization error, return a specific error that can be handled by the client
455 if errors.Is(err, ErrOAuthAuthorizationRequired) {
456 return &OAuthAuthorizationRequiredError{
457 Handler: c.oauthHandler,
458 }
459 }
460 return fmt.Errorf("failed to get authorization header: %w", err)
461 }
462 req.Header.Set("Authorization", authHeader)
463 }
464
465 if c.headerFunc != nil {
466 for k, v := range c.headerFunc(ctx) {
467 req.Header.Set(k, v)
468 }
469 }
470
471 // Send request
472 resp, err := c.httpClient.Do(req)
473 if err != nil {
474 return fmt.Errorf("failed to send request: %w", err)
475 }
476 defer resp.Body.Close()
477
478 if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
479 // Handle OAuth unauthorized error
480 if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
481 return &OAuthAuthorizationRequiredError{
482 Handler: c.oauthHandler,
483 }
484 }
485
486 body, _ := io.ReadAll(resp.Body)
487 return fmt.Errorf(
488 "notification failed with status %d: %s",
489 resp.StatusCode,
490 body,
491 )
492 }
493
494 return nil
495}
496
497func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) {
498 c.notifyMu.Lock()
499 defer c.notifyMu.Unlock()
500 c.notificationHandler = handler
501}
502
503func (c *StreamableHTTP) GetSessionId() string {
504 return c.sessionID.Load().(string)
505}
506
507// GetOAuthHandler returns the OAuth handler if configured
508func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler {
509 return c.oauthHandler
510}
511
512// IsOAuthEnabled returns true if OAuth is enabled
513func (c *StreamableHTTP) IsOAuthEnabled() bool {
514 return c.oauthHandler != nil
515}