sse.go

  1package transport
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"net/http"
 12	"net/url"
 13	"strings"
 14	"sync"
 15	"sync/atomic"
 16	"time"
 17
 18	"github.com/mark3labs/mcp-go/mcp"
 19)
 20
 21// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE).
 22// It maintains a persistent HTTP connection to receive server-pushed events
 23// while sending requests over regular HTTP POST calls. The client handles
 24// automatic reconnection and message routing between requests and responses.
 25type SSE struct {
 26	baseURL        *url.URL
 27	endpoint       *url.URL
 28	httpClient     *http.Client
 29	responses      map[string]chan *JSONRPCResponse
 30	mu             sync.RWMutex
 31	onNotification func(mcp.JSONRPCNotification)
 32	notifyMu       sync.RWMutex
 33	endpointChan   chan struct{}
 34	headers        map[string]string
 35	headerFunc     HTTPHeaderFunc
 36
 37	started         atomic.Bool
 38	closed          atomic.Bool
 39	cancelSSEStream context.CancelFunc
 40
 41	// OAuth support
 42	oauthHandler *OAuthHandler
 43}
 44
 45type ClientOption func(*SSE)
 46
 47func WithHeaders(headers map[string]string) ClientOption {
 48	return func(sc *SSE) {
 49		sc.headers = headers
 50	}
 51}
 52
 53func WithHeaderFunc(headerFunc HTTPHeaderFunc) ClientOption {
 54	return func(sc *SSE) {
 55		sc.headerFunc = headerFunc
 56	}
 57}
 58
 59func WithHTTPClient(httpClient *http.Client) ClientOption {
 60	return func(sc *SSE) {
 61		sc.httpClient = httpClient
 62	}
 63}
 64
 65func WithOAuth(config OAuthConfig) ClientOption {
 66	return func(sc *SSE) {
 67		sc.oauthHandler = NewOAuthHandler(config)
 68	}
 69}
 70
 71// NewSSE creates a new SSE-based MCP client with the given base URL.
 72// Returns an error if the URL is invalid.
 73func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
 74	parsedURL, err := url.Parse(baseURL)
 75	if err != nil {
 76		return nil, fmt.Errorf("invalid URL: %w", err)
 77	}
 78
 79	smc := &SSE{
 80		baseURL:      parsedURL,
 81		httpClient:   &http.Client{},
 82		responses:    make(map[string]chan *JSONRPCResponse),
 83		endpointChan: make(chan struct{}),
 84		headers:      make(map[string]string),
 85	}
 86
 87	for _, opt := range options {
 88		opt(smc)
 89	}
 90
 91	// If OAuth is configured, set the base URL for metadata discovery
 92	if smc.oauthHandler != nil {
 93		// Extract base URL from server URL for metadata discovery
 94		baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
 95		smc.oauthHandler.SetBaseURL(baseURL)
 96	}
 97
 98	return smc, nil
 99}
100
101// Start initiates the SSE connection to the server and waits for the endpoint information.
102// Returns an error if the connection fails or times out waiting for the endpoint.
103func (c *SSE) Start(ctx context.Context) error {
104
105	if c.started.Load() {
106		return fmt.Errorf("has already started")
107	}
108
109	ctx, cancel := context.WithCancel(ctx)
110	c.cancelSSEStream = cancel
111
112	req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil)
113
114	if err != nil {
115		return fmt.Errorf("failed to create request: %w", err)
116	}
117
118	req.Header.Set("Accept", "text/event-stream")
119	req.Header.Set("Cache-Control", "no-cache")
120	req.Header.Set("Connection", "keep-alive")
121
122	// set custom http headers
123	for k, v := range c.headers {
124		req.Header.Set(k, v)
125	}
126	if c.headerFunc != nil {
127		for k, v := range c.headerFunc(ctx) {
128			req.Header.Set(k, v)
129		}
130	}
131
132	// Add OAuth authorization if configured
133	if c.oauthHandler != nil {
134		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
135		if err != nil {
136			// If we get an authorization error, return a specific error that can be handled by the client
137			if err.Error() == "no valid token available, authorization required" {
138				return &OAuthAuthorizationRequiredError{
139					Handler: c.oauthHandler,
140				}
141			}
142			return fmt.Errorf("failed to get authorization header: %w", err)
143		}
144		req.Header.Set("Authorization", authHeader)
145	}
146
147	resp, err := c.httpClient.Do(req)
148	if err != nil {
149		return fmt.Errorf("failed to connect to SSE stream: %w", err)
150	}
151
152	if resp.StatusCode != http.StatusOK {
153		resp.Body.Close()
154		// Handle OAuth unauthorized error
155		if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
156			return &OAuthAuthorizationRequiredError{
157				Handler: c.oauthHandler,
158			}
159		}
160		return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
161	}
162
163	go c.readSSE(resp.Body)
164
165	// Wait for the endpoint to be received
166	timeout := time.NewTimer(30 * time.Second)
167	defer timeout.Stop()
168	select {
169	case <-c.endpointChan:
170		// Endpoint received, proceed
171	case <-ctx.Done():
172		return fmt.Errorf("context cancelled while waiting for endpoint")
173	case <-timeout.C: // Add a timeout
174		cancel()
175		return fmt.Errorf("timeout waiting for endpoint")
176	}
177
178	c.started.Store(true)
179	return nil
180}
181
182// readSSE continuously reads the SSE stream and processes events.
183// It runs until the connection is closed or an error occurs.
184func (c *SSE) readSSE(reader io.ReadCloser) {
185	defer reader.Close()
186
187	br := bufio.NewReader(reader)
188	var event, data string
189
190	for {
191		// when close or start's ctx cancel, the reader will be closed
192		// and the for loop will break.
193		line, err := br.ReadString('\n')
194		if err != nil {
195			if err == io.EOF {
196				// Process any pending event before exit
197				if data != "" {
198					// If no event type is specified, use empty string (default event type)
199					if event == "" {
200						event = "message"
201					}
202					c.handleSSEEvent(event, data)
203				}
204				break
205			}
206			if !c.closed.Load() {
207				fmt.Printf("SSE stream error: %v\n", err)
208			}
209			return
210		}
211
212		// Remove only newline markers
213		line = strings.TrimRight(line, "\r\n")
214		if line == "" {
215			// Empty line means end of event
216			if data != "" {
217				// If no event type is specified, use empty string (default event type)
218				if event == "" {
219					event = "message"
220				}
221				c.handleSSEEvent(event, data)
222				event = ""
223				data = ""
224			}
225			continue
226		}
227
228		if strings.HasPrefix(line, "event:") {
229			event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
230		} else if strings.HasPrefix(line, "data:") {
231			data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
232		}
233	}
234}
235
236// handleSSEEvent processes SSE events based on their type.
237// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication.
238func (c *SSE) handleSSEEvent(event, data string) {
239	switch event {
240	case "endpoint":
241		endpoint, err := c.baseURL.Parse(data)
242		if err != nil {
243			fmt.Printf("Error parsing endpoint URL: %v\n", err)
244			return
245		}
246		if endpoint.Host != c.baseURL.Host {
247			fmt.Printf("Endpoint origin does not match connection origin\n")
248			return
249		}
250		c.endpoint = endpoint
251		close(c.endpointChan)
252
253	case "message":
254		var baseMessage JSONRPCResponse
255		if err := json.Unmarshal([]byte(data), &baseMessage); err != nil {
256			fmt.Printf("Error unmarshaling message: %v\n", err)
257			return
258		}
259
260		// Handle notification
261		if baseMessage.ID.IsNil() {
262			var notification mcp.JSONRPCNotification
263			if err := json.Unmarshal([]byte(data), &notification); err != nil {
264				return
265			}
266			c.notifyMu.RLock()
267			if c.onNotification != nil {
268				c.onNotification(notification)
269			}
270			c.notifyMu.RUnlock()
271			return
272		}
273
274		// Create string key for map lookup
275		idKey := baseMessage.ID.String()
276
277		c.mu.RLock()
278		ch, exists := c.responses[idKey]
279		c.mu.RUnlock()
280
281		if exists {
282			ch <- &baseMessage
283			c.mu.Lock()
284			delete(c.responses, idKey)
285			c.mu.Unlock()
286		}
287	}
288}
289
290func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) {
291	c.notifyMu.Lock()
292	defer c.notifyMu.Unlock()
293	c.onNotification = handler
294}
295
296// SendRequest sends a JSON-RPC request to the server and waits for a response.
297// Returns the raw JSON response message or an error if the request fails.
298func (c *SSE) SendRequest(
299	ctx context.Context,
300	request JSONRPCRequest,
301) (*JSONRPCResponse, error) {
302
303	if !c.started.Load() {
304		return nil, fmt.Errorf("transport not started yet")
305	}
306	if c.closed.Load() {
307		return nil, fmt.Errorf("transport has been closed")
308	}
309	if c.endpoint == nil {
310		return nil, fmt.Errorf("endpoint not received")
311	}
312
313	// Marshal request
314	requestBytes, err := json.Marshal(request)
315	if err != nil {
316		return nil, fmt.Errorf("failed to marshal request: %w", err)
317	}
318
319	// Create HTTP request
320	req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint.String(), bytes.NewReader(requestBytes))
321	if err != nil {
322		return nil, fmt.Errorf("failed to create request: %w", err)
323	}
324
325	// Set headers
326	req.Header.Set("Content-Type", "application/json")
327	for k, v := range c.headers {
328		req.Header.Set(k, v)
329	}
330
331	// Add OAuth authorization if configured
332	if c.oauthHandler != nil {
333		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
334		if err != nil {
335			// If we get an authorization error, return a specific error that can be handled by the client
336			if err.Error() == "no valid token available, authorization required" {
337				return nil, &OAuthAuthorizationRequiredError{
338					Handler: c.oauthHandler,
339				}
340			}
341			return nil, fmt.Errorf("failed to get authorization header: %w", err)
342		}
343		req.Header.Set("Authorization", authHeader)
344	}
345
346	if c.headerFunc != nil {
347		for k, v := range c.headerFunc(ctx) {
348			req.Header.Set(k, v)
349		}
350	}
351
352	// Create string key for map lookup
353	idKey := request.ID.String()
354
355	// Register response channel
356	responseChan := make(chan *JSONRPCResponse, 1)
357	c.mu.Lock()
358	c.responses[idKey] = responseChan
359	c.mu.Unlock()
360	deleteResponseChan := func() {
361		c.mu.Lock()
362		delete(c.responses, idKey)
363		c.mu.Unlock()
364	}
365
366	// Send request
367	resp, err := c.httpClient.Do(req)
368	if err != nil {
369		deleteResponseChan()
370		return nil, fmt.Errorf("failed to send request: %w", err)
371	}
372
373	// Drain any outstanding io
374	body, err := io.ReadAll(resp.Body)
375	resp.Body.Close()
376
377	if err != nil {
378		return nil, fmt.Errorf("failed to read response body: %w", err)
379	}
380
381	// Check if we got an error response
382	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
383		deleteResponseChan()
384
385		// Handle OAuth unauthorized error
386		if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
387			return nil, &OAuthAuthorizationRequiredError{
388				Handler: c.oauthHandler,
389			}
390		}
391
392		return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
393	}
394
395	select {
396	case <-ctx.Done():
397		deleteResponseChan()
398		return nil, ctx.Err()
399	case response, ok := <-responseChan:
400		if ok {
401			return response, nil
402		}
403		return nil, fmt.Errorf("connection has been closed")
404	}
405}
406
407// Close shuts down the SSE client connection and cleans up any pending responses.
408// Returns an error if the shutdown process fails.
409func (c *SSE) Close() error {
410	if !c.closed.CompareAndSwap(false, true) {
411		return nil // Already closed
412	}
413
414	if c.cancelSSEStream != nil {
415		// It could stop the sse stream body, to quit the readSSE loop immediately
416		// Also, it could quit start() immediately if not receiving the endpoint
417		c.cancelSSEStream()
418	}
419
420	// Clean up any pending responses
421	c.mu.Lock()
422	for _, ch := range c.responses {
423		close(ch)
424	}
425	c.responses = make(map[string]chan *JSONRPCResponse)
426	c.mu.Unlock()
427
428	return nil
429}
430
431// SendNotification sends a JSON-RPC notification to the server without expecting a response.
432func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
433	if c.endpoint == nil {
434		return fmt.Errorf("endpoint not received")
435	}
436
437	notificationBytes, err := json.Marshal(notification)
438	if err != nil {
439		return fmt.Errorf("failed to marshal notification: %w", err)
440	}
441
442	req, err := http.NewRequestWithContext(
443		ctx,
444		"POST",
445		c.endpoint.String(),
446		bytes.NewReader(notificationBytes),
447	)
448	if err != nil {
449		return fmt.Errorf("failed to create notification request: %w", err)
450	}
451
452	req.Header.Set("Content-Type", "application/json")
453	// Set custom HTTP headers
454	for k, v := range c.headers {
455		req.Header.Set(k, v)
456	}
457
458	// Add OAuth authorization if configured
459	if c.oauthHandler != nil {
460		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
461		if err != nil {
462			// If we get an authorization error, return a specific error that can be handled by the client
463			if errors.Is(err, ErrOAuthAuthorizationRequired) {
464				return &OAuthAuthorizationRequiredError{
465					Handler: c.oauthHandler,
466				}
467			}
468			return fmt.Errorf("failed to get authorization header: %w", err)
469		}
470		req.Header.Set("Authorization", authHeader)
471	}
472
473	if c.headerFunc != nil {
474		for k, v := range c.headerFunc(ctx) {
475			req.Header.Set(k, v)
476		}
477	}
478
479	resp, err := c.httpClient.Do(req)
480	if err != nil {
481		return fmt.Errorf("failed to send notification: %w", err)
482	}
483	defer resp.Body.Close()
484
485	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
486		// Handle OAuth unauthorized error
487		if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
488			return &OAuthAuthorizationRequiredError{
489				Handler: c.oauthHandler,
490			}
491		}
492
493		body, _ := io.ReadAll(resp.Body)
494		return fmt.Errorf(
495			"notification failed with status %d: %s",
496			resp.StatusCode,
497			body,
498		)
499	}
500
501	return nil
502}
503
504// GetEndpoint returns the current endpoint URL for the SSE connection.
505func (c *SSE) GetEndpoint() *url.URL {
506	return c.endpoint
507}
508
509// GetBaseURL returns the base URL set in the SSE constructor.
510func (c *SSE) GetBaseURL() *url.URL {
511	return c.baseURL
512}
513
514// GetOAuthHandler returns the OAuth handler if configured
515func (c *SSE) GetOAuthHandler() *OAuthHandler {
516	return c.oauthHandler
517}
518
519// IsOAuthEnabled returns true if OAuth is enabled
520func (c *SSE) IsOAuthEnabled() bool {
521	return c.oauthHandler != nil
522}