sse.go

  1package transport
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"net/http"
 12	"net/url"
 13	"strings"
 14	"sync"
 15	"sync/atomic"
 16	"time"
 17
 18	"github.com/mark3labs/mcp-go/mcp"
 19)
 20
 21// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE).
 22// It maintains a persistent HTTP connection to receive server-pushed events
 23// while sending requests over regular HTTP POST calls. The client handles
 24// automatic reconnection and message routing between requests and responses.
 25type SSE struct {
 26	baseURL        *url.URL
 27	endpoint       *url.URL
 28	httpClient     *http.Client
 29	responses      map[string]chan *JSONRPCResponse
 30	mu             sync.RWMutex
 31	onNotification func(mcp.JSONRPCNotification)
 32	notifyMu       sync.RWMutex
 33	endpointChan   chan struct{}
 34	headers        map[string]string
 35	headerFunc     HTTPHeaderFunc
 36
 37	started         atomic.Bool
 38	closed          atomic.Bool
 39	cancelSSEStream context.CancelFunc
 40
 41	// OAuth support
 42	oauthHandler *OAuthHandler
 43}
 44
 45type ClientOption func(*SSE)
 46
 47func WithHeaders(headers map[string]string) ClientOption {
 48	return func(sc *SSE) {
 49		sc.headers = headers
 50	}
 51}
 52
 53func WithHeaderFunc(headerFunc HTTPHeaderFunc) ClientOption {
 54	return func(sc *SSE) {
 55		sc.headerFunc = headerFunc
 56	}
 57}
 58
 59func WithHTTPClient(httpClient *http.Client) ClientOption {
 60	return func(sc *SSE) {
 61		sc.httpClient = httpClient
 62	}
 63}
 64
 65func WithOAuth(config OAuthConfig) ClientOption {
 66	return func(sc *SSE) {
 67		sc.oauthHandler = NewOAuthHandler(config)
 68	}
 69}
 70
 71// NewSSE creates a new SSE-based MCP client with the given base URL.
 72// Returns an error if the URL is invalid.
 73func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
 74	parsedURL, err := url.Parse(baseURL)
 75	if err != nil {
 76		return nil, fmt.Errorf("invalid URL: %w", err)
 77	}
 78
 79	smc := &SSE{
 80		baseURL:      parsedURL,
 81		httpClient:   &http.Client{},
 82		responses:    make(map[string]chan *JSONRPCResponse),
 83		endpointChan: make(chan struct{}),
 84		headers:      make(map[string]string),
 85	}
 86
 87	for _, opt := range options {
 88		opt(smc)
 89	}
 90
 91	// If OAuth is configured, set the base URL for metadata discovery
 92	if smc.oauthHandler != nil {
 93		// Extract base URL from server URL for metadata discovery
 94		baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
 95		smc.oauthHandler.SetBaseURL(baseURL)
 96	}
 97
 98	return smc, nil
 99}
100
101// Start initiates the SSE connection to the server and waits for the endpoint information.
102// Returns an error if the connection fails or times out waiting for the endpoint.
103func (c *SSE) Start(ctx context.Context) error {
104
105	if c.started.Load() {
106		return fmt.Errorf("has already started")
107	}
108
109	ctx, cancel := context.WithCancel(ctx)
110	c.cancelSSEStream = cancel
111
112	req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil)
113
114	if err != nil {
115		return fmt.Errorf("failed to create request: %w", err)
116	}
117
118	req.Header.Set("Accept", "text/event-stream")
119	req.Header.Set("Cache-Control", "no-cache")
120	req.Header.Set("Connection", "keep-alive")
121
122	// set custom http headers
123	for k, v := range c.headers {
124		req.Header.Set(k, v)
125	}
126	if c.headerFunc != nil {
127		for k, v := range c.headerFunc(ctx) {
128			req.Header.Set(k, v)
129		}
130	}
131
132	// Add OAuth authorization if configured
133	if c.oauthHandler != nil {
134		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
135		if err != nil {
136			// If we get an authorization error, return a specific error that can be handled by the client
137			if err.Error() == "no valid token available, authorization required" {
138				return &OAuthAuthorizationRequiredError{
139					Handler: c.oauthHandler,
140				}
141			}
142			return fmt.Errorf("failed to get authorization header: %w", err)
143		}
144		req.Header.Set("Authorization", authHeader)
145	}
146
147	resp, err := c.httpClient.Do(req)
148	if err != nil {
149		return fmt.Errorf("failed to connect to SSE stream: %w", err)
150	}
151
152	if resp.StatusCode != http.StatusOK {
153		resp.Body.Close()
154		// Handle OAuth unauthorized error
155		if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
156			return &OAuthAuthorizationRequiredError{
157				Handler: c.oauthHandler,
158			}
159		}
160		return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
161	}
162
163	go c.readSSE(resp.Body)
164
165	// Wait for the endpoint to be received
166	timeout := time.NewTimer(30 * time.Second)
167	defer timeout.Stop()
168	select {
169	case <-c.endpointChan:
170		// Endpoint received, proceed
171	case <-ctx.Done():
172		return fmt.Errorf("context cancelled while waiting for endpoint")
173	case <-timeout.C: // Add a timeout
174		cancel()
175		return fmt.Errorf("timeout waiting for endpoint")
176	}
177
178	c.started.Store(true)
179	return nil
180}
181
182// readSSE continuously reads the SSE stream and processes events.
183// It runs until the connection is closed or an error occurs.
184func (c *SSE) readSSE(reader io.ReadCloser) {
185	defer reader.Close()
186
187	br := bufio.NewReader(reader)
188	var event, data string
189
190	for {
191		// when close or start's ctx cancel, the reader will be closed
192		// and the for loop will break.
193		line, err := br.ReadString('\n')
194		if err != nil {
195			if err == io.EOF {
196				// Process any pending event before exit
197				if data != "" {
198					// If no event type is specified, use empty string (default event type)
199					if event == "" {
200						event = "message"
201					}
202					c.handleSSEEvent(event, data)
203				}
204				break
205			}
206			if !c.closed.Load() {
207				fmt.Printf("SSE stream error: %v\n", err)
208			}
209			return
210		}
211
212		// Remove only newline markers
213		line = strings.TrimRight(line, "\r\n")
214		if line == "" {
215			// Empty line means end of event
216			if data != "" {
217				// If no event type is specified, use empty string (default event type)
218				if event == "" {
219					event = "message"
220				}
221				c.handleSSEEvent(event, data)
222				event = ""
223				data = ""
224			}
225			continue
226		}
227
228		if strings.HasPrefix(line, "event:") {
229			event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
230		} else if strings.HasPrefix(line, "data:") {
231			data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
232		}
233	}
234}
235
236// handleSSEEvent processes SSE events based on their type.
237// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication.
238func (c *SSE) handleSSEEvent(event, data string) {
239	switch event {
240	case "endpoint":
241		endpoint, err := c.baseURL.Parse(data)
242		if err != nil {
243			fmt.Printf("Error parsing endpoint URL: %v\n", err)
244			return
245		}
246		if endpoint.Host != c.baseURL.Host {
247			fmt.Printf("Endpoint origin does not match connection origin\n")
248			return
249		}
250		c.endpoint = endpoint
251		close(c.endpointChan)
252
253	case "message":
254		var baseMessage JSONRPCResponse
255		if err := json.Unmarshal([]byte(data), &baseMessage); err != nil {
256			fmt.Printf("Error unmarshaling message: %v\n", err)
257			return
258		}
259
260		// Handle notification
261		if baseMessage.ID.IsNil() {
262			var notification mcp.JSONRPCNotification
263			if err := json.Unmarshal([]byte(data), &notification); err != nil {
264				return
265			}
266			c.notifyMu.RLock()
267			if c.onNotification != nil {
268				c.onNotification(notification)
269			}
270			c.notifyMu.RUnlock()
271			return
272		}
273
274		// Create string key for map lookup
275		idKey := baseMessage.ID.String()
276
277		c.mu.RLock()
278		ch, exists := c.responses[idKey]
279		c.mu.RUnlock()
280
281		if exists {
282			ch <- &baseMessage
283			c.mu.Lock()
284			delete(c.responses, idKey)
285			c.mu.Unlock()
286		}
287	}
288}
289
290func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) {
291	c.notifyMu.Lock()
292	defer c.notifyMu.Unlock()
293	c.onNotification = handler
294}
295
296// SendRequest sends a JSON-RPC request to the server and waits for a response.
297// Returns the raw JSON response message or an error if the request fails.
298func (c *SSE) SendRequest(
299	ctx context.Context,
300	request JSONRPCRequest,
301) (*JSONRPCResponse, error) {
302
303	if !c.started.Load() {
304		return nil, fmt.Errorf("transport not started yet")
305	}
306	if c.closed.Load() {
307		return nil, fmt.Errorf("transport has been closed")
308	}
309	if c.endpoint == nil {
310		return nil, fmt.Errorf("endpoint not received")
311	}
312
313	// Marshal request
314	requestBytes, err := json.Marshal(request)
315	if err != nil {
316		return nil, fmt.Errorf("failed to marshal request: %w", err)
317	}
318
319	// Create HTTP request
320	req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint.String(), bytes.NewReader(requestBytes))
321	if err != nil {
322		return nil, fmt.Errorf("failed to create request: %w", err)
323	}
324
325	// Set headers
326	req.Header.Set("Content-Type", "application/json")
327	for k, v := range c.headers {
328		req.Header.Set(k, v)
329	}
330
331	// Add OAuth authorization if configured
332	if c.oauthHandler != nil {
333		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
334		if err != nil {
335			// If we get an authorization error, return a specific error that can be handled by the client
336			if err.Error() == "no valid token available, authorization required" {
337				return nil, &OAuthAuthorizationRequiredError{
338					Handler: c.oauthHandler,
339				}
340			}
341			return nil, fmt.Errorf("failed to get authorization header: %w", err)
342		}
343		req.Header.Set("Authorization", authHeader)
344	}
345
346	if c.headerFunc != nil {
347		for k, v := range c.headerFunc(ctx) {
348			req.Header.Set(k, v)
349		}
350	}
351
352	// Create string key for map lookup
353	idKey := request.ID.String()
354
355	// Register response channel
356	responseChan := make(chan *JSONRPCResponse, 1)
357	c.mu.Lock()
358	c.responses[idKey] = responseChan
359	c.mu.Unlock()
360	deleteResponseChan := func() {
361		c.mu.Lock()
362		delete(c.responses, idKey)
363		c.mu.Unlock()
364	}
365
366	// Send request
367	resp, err := c.httpClient.Do(req)
368	if err != nil {
369		deleteResponseChan()
370		return nil, fmt.Errorf("failed to send request: %w", err)
371	}
372
373	// Drain any outstanding io
374	body, err := io.ReadAll(resp.Body)
375	resp.Body.Close()
376
377	if err != nil {
378		return nil, fmt.Errorf("failed to read response body: %w", err)
379	}
380
381	// Check if we got an error response
382	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
383		deleteResponseChan()
384
385		// Handle OAuth unauthorized error
386		if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
387			return nil, &OAuthAuthorizationRequiredError{
388				Handler: c.oauthHandler,
389			}
390		}
391
392		return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
393	}
394
395	select {
396	case <-ctx.Done():
397		deleteResponseChan()
398		return nil, ctx.Err()
399	case response, ok := <-responseChan:
400		if ok {
401			return response, nil
402		}
403		return nil, fmt.Errorf("connection has been closed")
404	}
405}
406
407// Close shuts down the SSE client connection and cleans up any pending responses.
408// Returns an error if the shutdown process fails.
409func (c *SSE) Close() error {
410	if !c.closed.CompareAndSwap(false, true) {
411		return nil // Already closed
412	}
413
414	if c.cancelSSEStream != nil {
415		// It could stop the sse stream body, to quit the readSSE loop immediately
416		// Also, it could quit start() immediately if not receiving the endpoint
417		c.cancelSSEStream()
418	}
419
420	// Clean up any pending responses
421	c.mu.Lock()
422	for _, ch := range c.responses {
423		close(ch)
424	}
425	c.responses = make(map[string]chan *JSONRPCResponse)
426	c.mu.Unlock()
427
428	return nil
429}
430
431// GetSessionId returns the session ID of the transport.
432// Since SSE does not maintain a session ID, it returns an empty string.
433func (c *SSE) GetSessionId() string {
434	return ""
435}
436
437// SendNotification sends a JSON-RPC notification to the server without expecting a response.
438func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
439	if c.endpoint == nil {
440		return fmt.Errorf("endpoint not received")
441	}
442
443	notificationBytes, err := json.Marshal(notification)
444	if err != nil {
445		return fmt.Errorf("failed to marshal notification: %w", err)
446	}
447
448	req, err := http.NewRequestWithContext(
449		ctx,
450		"POST",
451		c.endpoint.String(),
452		bytes.NewReader(notificationBytes),
453	)
454	if err != nil {
455		return fmt.Errorf("failed to create notification request: %w", err)
456	}
457
458	req.Header.Set("Content-Type", "application/json")
459	// Set custom HTTP headers
460	for k, v := range c.headers {
461		req.Header.Set(k, v)
462	}
463
464	// Add OAuth authorization if configured
465	if c.oauthHandler != nil {
466		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
467		if err != nil {
468			// If we get an authorization error, return a specific error that can be handled by the client
469			if errors.Is(err, ErrOAuthAuthorizationRequired) {
470				return &OAuthAuthorizationRequiredError{
471					Handler: c.oauthHandler,
472				}
473			}
474			return fmt.Errorf("failed to get authorization header: %w", err)
475		}
476		req.Header.Set("Authorization", authHeader)
477	}
478
479	if c.headerFunc != nil {
480		for k, v := range c.headerFunc(ctx) {
481			req.Header.Set(k, v)
482		}
483	}
484
485	resp, err := c.httpClient.Do(req)
486	if err != nil {
487		return fmt.Errorf("failed to send notification: %w", err)
488	}
489	defer resp.Body.Close()
490
491	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
492		// Handle OAuth unauthorized error
493		if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
494			return &OAuthAuthorizationRequiredError{
495				Handler: c.oauthHandler,
496			}
497		}
498
499		body, _ := io.ReadAll(resp.Body)
500		return fmt.Errorf(
501			"notification failed with status %d: %s",
502			resp.StatusCode,
503			body,
504		)
505	}
506
507	return nil
508}
509
510// GetEndpoint returns the current endpoint URL for the SSE connection.
511func (c *SSE) GetEndpoint() *url.URL {
512	return c.endpoint
513}
514
515// GetBaseURL returns the base URL set in the SSE constructor.
516func (c *SSE) GetBaseURL() *url.URL {
517	return c.baseURL
518}
519
520// GetOAuthHandler returns the OAuth handler if configured
521func (c *SSE) GetOAuthHandler() *OAuthHandler {
522	return c.oauthHandler
523}
524
525// IsOAuthEnabled returns true if OAuth is enabled
526func (c *SSE) IsOAuthEnabled() bool {
527	return c.oauthHandler != nil
528}