1package transport
2
3import (
4 "bufio"
5 "bytes"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "net/http"
12 "net/url"
13 "strings"
14 "sync"
15 "sync/atomic"
16 "time"
17
18 "github.com/mark3labs/mcp-go/mcp"
19)
20
21// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE).
22// It maintains a persistent HTTP connection to receive server-pushed events
23// while sending requests over regular HTTP POST calls. The client handles
24// automatic reconnection and message routing between requests and responses.
25type SSE struct {
26 baseURL *url.URL
27 endpoint *url.URL
28 httpClient *http.Client
29 responses map[string]chan *JSONRPCResponse
30 mu sync.RWMutex
31 onNotification func(mcp.JSONRPCNotification)
32 notifyMu sync.RWMutex
33 endpointChan chan struct{}
34 headers map[string]string
35 headerFunc HTTPHeaderFunc
36
37 started atomic.Bool
38 closed atomic.Bool
39 cancelSSEStream context.CancelFunc
40
41 // OAuth support
42 oauthHandler *OAuthHandler
43}
44
45type ClientOption func(*SSE)
46
47func WithHeaders(headers map[string]string) ClientOption {
48 return func(sc *SSE) {
49 sc.headers = headers
50 }
51}
52
53func WithHeaderFunc(headerFunc HTTPHeaderFunc) ClientOption {
54 return func(sc *SSE) {
55 sc.headerFunc = headerFunc
56 }
57}
58
59func WithHTTPClient(httpClient *http.Client) ClientOption {
60 return func(sc *SSE) {
61 sc.httpClient = httpClient
62 }
63}
64
65func WithOAuth(config OAuthConfig) ClientOption {
66 return func(sc *SSE) {
67 sc.oauthHandler = NewOAuthHandler(config)
68 }
69}
70
71// NewSSE creates a new SSE-based MCP client with the given base URL.
72// Returns an error if the URL is invalid.
73func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
74 parsedURL, err := url.Parse(baseURL)
75 if err != nil {
76 return nil, fmt.Errorf("invalid URL: %w", err)
77 }
78
79 smc := &SSE{
80 baseURL: parsedURL,
81 httpClient: &http.Client{},
82 responses: make(map[string]chan *JSONRPCResponse),
83 endpointChan: make(chan struct{}),
84 headers: make(map[string]string),
85 }
86
87 for _, opt := range options {
88 opt(smc)
89 }
90
91 // If OAuth is configured, set the base URL for metadata discovery
92 if smc.oauthHandler != nil {
93 // Extract base URL from server URL for metadata discovery
94 baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
95 smc.oauthHandler.SetBaseURL(baseURL)
96 }
97
98 return smc, nil
99}
100
101// Start initiates the SSE connection to the server and waits for the endpoint information.
102// Returns an error if the connection fails or times out waiting for the endpoint.
103func (c *SSE) Start(ctx context.Context) error {
104
105 if c.started.Load() {
106 return fmt.Errorf("has already started")
107 }
108
109 ctx, cancel := context.WithCancel(ctx)
110 c.cancelSSEStream = cancel
111
112 req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil)
113
114 if err != nil {
115 return fmt.Errorf("failed to create request: %w", err)
116 }
117
118 req.Header.Set("Accept", "text/event-stream")
119 req.Header.Set("Cache-Control", "no-cache")
120 req.Header.Set("Connection", "keep-alive")
121
122 // set custom http headers
123 for k, v := range c.headers {
124 req.Header.Set(k, v)
125 }
126 if c.headerFunc != nil {
127 for k, v := range c.headerFunc(ctx) {
128 req.Header.Set(k, v)
129 }
130 }
131
132 // Add OAuth authorization if configured
133 if c.oauthHandler != nil {
134 authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
135 if err != nil {
136 // If we get an authorization error, return a specific error that can be handled by the client
137 if err.Error() == "no valid token available, authorization required" {
138 return &OAuthAuthorizationRequiredError{
139 Handler: c.oauthHandler,
140 }
141 }
142 return fmt.Errorf("failed to get authorization header: %w", err)
143 }
144 req.Header.Set("Authorization", authHeader)
145 }
146
147 resp, err := c.httpClient.Do(req)
148 if err != nil {
149 return fmt.Errorf("failed to connect to SSE stream: %w", err)
150 }
151
152 if resp.StatusCode != http.StatusOK {
153 resp.Body.Close()
154 // Handle OAuth unauthorized error
155 if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
156 return &OAuthAuthorizationRequiredError{
157 Handler: c.oauthHandler,
158 }
159 }
160 return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
161 }
162
163 go c.readSSE(resp.Body)
164
165 // Wait for the endpoint to be received
166 timeout := time.NewTimer(30 * time.Second)
167 defer timeout.Stop()
168 select {
169 case <-c.endpointChan:
170 // Endpoint received, proceed
171 case <-ctx.Done():
172 return fmt.Errorf("context cancelled while waiting for endpoint")
173 case <-timeout.C: // Add a timeout
174 cancel()
175 return fmt.Errorf("timeout waiting for endpoint")
176 }
177
178 c.started.Store(true)
179 return nil
180}
181
182// readSSE continuously reads the SSE stream and processes events.
183// It runs until the connection is closed or an error occurs.
184func (c *SSE) readSSE(reader io.ReadCloser) {
185 defer reader.Close()
186
187 br := bufio.NewReader(reader)
188 var event, data string
189
190 for {
191 // when close or start's ctx cancel, the reader will be closed
192 // and the for loop will break.
193 line, err := br.ReadString('\n')
194 if err != nil {
195 if err == io.EOF {
196 // Process any pending event before exit
197 if data != "" {
198 // If no event type is specified, use empty string (default event type)
199 if event == "" {
200 event = "message"
201 }
202 c.handleSSEEvent(event, data)
203 }
204 break
205 }
206 if !c.closed.Load() {
207 fmt.Printf("SSE stream error: %v\n", err)
208 }
209 return
210 }
211
212 // Remove only newline markers
213 line = strings.TrimRight(line, "\r\n")
214 if line == "" {
215 // Empty line means end of event
216 if data != "" {
217 // If no event type is specified, use empty string (default event type)
218 if event == "" {
219 event = "message"
220 }
221 c.handleSSEEvent(event, data)
222 event = ""
223 data = ""
224 }
225 continue
226 }
227
228 if strings.HasPrefix(line, "event:") {
229 event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
230 } else if strings.HasPrefix(line, "data:") {
231 data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
232 }
233 }
234}
235
236// handleSSEEvent processes SSE events based on their type.
237// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication.
238func (c *SSE) handleSSEEvent(event, data string) {
239 switch event {
240 case "endpoint":
241 endpoint, err := c.baseURL.Parse(data)
242 if err != nil {
243 fmt.Printf("Error parsing endpoint URL: %v\n", err)
244 return
245 }
246 if endpoint.Host != c.baseURL.Host {
247 fmt.Printf("Endpoint origin does not match connection origin\n")
248 return
249 }
250 c.endpoint = endpoint
251 close(c.endpointChan)
252
253 case "message":
254 var baseMessage JSONRPCResponse
255 if err := json.Unmarshal([]byte(data), &baseMessage); err != nil {
256 fmt.Printf("Error unmarshaling message: %v\n", err)
257 return
258 }
259
260 // Handle notification
261 if baseMessage.ID.IsNil() {
262 var notification mcp.JSONRPCNotification
263 if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
264 return
265 }
266 c.notifyMu.RLock()
267 if c.onNotification != nil {
268 c.onNotification(notification)
269 }
270 c.notifyMu.RUnlock()
271 return
272 }
273
274 // Create string key for map lookup
275 idKey := baseMessage.ID.String()
276
277 c.mu.RLock()
278 ch, exists := c.responses[idKey]
279 c.mu.RUnlock()
280
281 if exists {
282 ch <- &baseMessage
283 c.mu.Lock()
284 delete(c.responses, idKey)
285 c.mu.Unlock()
286 }
287 }
288}
289
290func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) {
291 c.notifyMu.Lock()
292 defer c.notifyMu.Unlock()
293 c.onNotification = handler
294}
295
296// SendRequest sends a JSON-RPC request to the server and waits for a response.
297// Returns the raw JSON response message or an error if the request fails.
298func (c *SSE) SendRequest(
299 ctx context.Context,
300 request JSONRPCRequest,
301) (*JSONRPCResponse, error) {
302
303 if !c.started.Load() {
304 return nil, fmt.Errorf("transport not started yet")
305 }
306 if c.closed.Load() {
307 return nil, fmt.Errorf("transport has been closed")
308 }
309 if c.endpoint == nil {
310 return nil, fmt.Errorf("endpoint not received")
311 }
312
313 // Marshal request
314 requestBytes, err := json.Marshal(request)
315 if err != nil {
316 return nil, fmt.Errorf("failed to marshal request: %w", err)
317 }
318
319 // Create HTTP request
320 req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint.String(), bytes.NewReader(requestBytes))
321 if err != nil {
322 return nil, fmt.Errorf("failed to create request: %w", err)
323 }
324
325 // Set headers
326 req.Header.Set("Content-Type", "application/json")
327 for k, v := range c.headers {
328 req.Header.Set(k, v)
329 }
330
331 // Add OAuth authorization if configured
332 if c.oauthHandler != nil {
333 authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
334 if err != nil {
335 // If we get an authorization error, return a specific error that can be handled by the client
336 if err.Error() == "no valid token available, authorization required" {
337 return nil, &OAuthAuthorizationRequiredError{
338 Handler: c.oauthHandler,
339 }
340 }
341 return nil, fmt.Errorf("failed to get authorization header: %w", err)
342 }
343 req.Header.Set("Authorization", authHeader)
344 }
345
346 if c.headerFunc != nil {
347 for k, v := range c.headerFunc(ctx) {
348 req.Header.Set(k, v)
349 }
350 }
351
352 // Create string key for map lookup
353 idKey := request.ID.String()
354
355 // Register response channel
356 responseChan := make(chan *JSONRPCResponse, 1)
357 c.mu.Lock()
358 c.responses[idKey] = responseChan
359 c.mu.Unlock()
360 deleteResponseChan := func() {
361 c.mu.Lock()
362 delete(c.responses, idKey)
363 c.mu.Unlock()
364 }
365
366 // Send request
367 resp, err := c.httpClient.Do(req)
368 if err != nil {
369 deleteResponseChan()
370 return nil, fmt.Errorf("failed to send request: %w", err)
371 }
372
373 // Drain any outstanding io
374 body, err := io.ReadAll(resp.Body)
375 resp.Body.Close()
376
377 if err != nil {
378 return nil, fmt.Errorf("failed to read response body: %w", err)
379 }
380
381 // Check if we got an error response
382 if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
383 deleteResponseChan()
384
385 // Handle OAuth unauthorized error
386 if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
387 return nil, &OAuthAuthorizationRequiredError{
388 Handler: c.oauthHandler,
389 }
390 }
391
392 return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
393 }
394
395 select {
396 case <-ctx.Done():
397 deleteResponseChan()
398 return nil, ctx.Err()
399 case response, ok := <-responseChan:
400 if ok {
401 return response, nil
402 }
403 return nil, fmt.Errorf("connection has been closed")
404 }
405}
406
407// Close shuts down the SSE client connection and cleans up any pending responses.
408// Returns an error if the shutdown process fails.
409func (c *SSE) Close() error {
410 if !c.closed.CompareAndSwap(false, true) {
411 return nil // Already closed
412 }
413
414 if c.cancelSSEStream != nil {
415 // It could stop the sse stream body, to quit the readSSE loop immediately
416 // Also, it could quit start() immediately if not receiving the endpoint
417 c.cancelSSEStream()
418 }
419
420 // Clean up any pending responses
421 c.mu.Lock()
422 for _, ch := range c.responses {
423 close(ch)
424 }
425 c.responses = make(map[string]chan *JSONRPCResponse)
426 c.mu.Unlock()
427
428 return nil
429}
430
431// GetSessionId returns the session ID of the transport.
432// Since SSE does not maintain a session ID, it returns an empty string.
433func (c *SSE) GetSessionId() string {
434 return ""
435}
436
437// SendNotification sends a JSON-RPC notification to the server without expecting a response.
438func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
439 if c.endpoint == nil {
440 return fmt.Errorf("endpoint not received")
441 }
442
443 notificationBytes, err := json.Marshal(notification)
444 if err != nil {
445 return fmt.Errorf("failed to marshal notification: %w", err)
446 }
447
448 req, err := http.NewRequestWithContext(
449 ctx,
450 "POST",
451 c.endpoint.String(),
452 bytes.NewReader(notificationBytes),
453 )
454 if err != nil {
455 return fmt.Errorf("failed to create notification request: %w", err)
456 }
457
458 req.Header.Set("Content-Type", "application/json")
459 // Set custom HTTP headers
460 for k, v := range c.headers {
461 req.Header.Set(k, v)
462 }
463
464 // Add OAuth authorization if configured
465 if c.oauthHandler != nil {
466 authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
467 if err != nil {
468 // If we get an authorization error, return a specific error that can be handled by the client
469 if errors.Is(err, ErrOAuthAuthorizationRequired) {
470 return &OAuthAuthorizationRequiredError{
471 Handler: c.oauthHandler,
472 }
473 }
474 return fmt.Errorf("failed to get authorization header: %w", err)
475 }
476 req.Header.Set("Authorization", authHeader)
477 }
478
479 if c.headerFunc != nil {
480 for k, v := range c.headerFunc(ctx) {
481 req.Header.Set(k, v)
482 }
483 }
484
485 resp, err := c.httpClient.Do(req)
486 if err != nil {
487 return fmt.Errorf("failed to send notification: %w", err)
488 }
489 defer resp.Body.Close()
490
491 if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
492 // Handle OAuth unauthorized error
493 if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
494 return &OAuthAuthorizationRequiredError{
495 Handler: c.oauthHandler,
496 }
497 }
498
499 body, _ := io.ReadAll(resp.Body)
500 return fmt.Errorf(
501 "notification failed with status %d: %s",
502 resp.StatusCode,
503 body,
504 )
505 }
506
507 return nil
508}
509
510// GetEndpoint returns the current endpoint URL for the SSE connection.
511func (c *SSE) GetEndpoint() *url.URL {
512 return c.endpoint
513}
514
515// GetBaseURL returns the base URL set in the SSE constructor.
516func (c *SSE) GetBaseURL() *url.URL {
517 return c.baseURL
518}
519
520// GetOAuthHandler returns the OAuth handler if configured
521func (c *SSE) GetOAuthHandler() *OAuthHandler {
522 return c.oauthHandler
523}
524
525// IsOAuthEnabled returns true if OAuth is enabled
526func (c *SSE) IsOAuthEnabled() bool {
527 return c.oauthHandler != nil
528}