sse.go

  1package client
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"net/http"
 12	"net/url"
 13	"strings"
 14	"sync"
 15	"sync/atomic"
 16	"time"
 17
 18	"github.com/mark3labs/mcp-go/mcp"
 19)
 20
 21// SSEMCPClient implements the MCPClient interface using Server-Sent Events (SSE).
 22// It maintains a persistent HTTP connection to receive server-pushed events
 23// while sending requests over regular HTTP POST calls. The client handles
 24// automatic reconnection and message routing between requests and responses.
 25type SSEMCPClient struct {
 26	baseURL        *url.URL
 27	endpoint       *url.URL
 28	httpClient     *http.Client
 29	requestID      atomic.Int64
 30	responses      map[int64]chan RPCResponse
 31	mu             sync.RWMutex
 32	done           chan struct{}
 33	initialized    bool
 34	notifications  []func(mcp.JSONRPCNotification)
 35	notifyMu       sync.RWMutex
 36	endpointChan   chan struct{}
 37	capabilities   mcp.ServerCapabilities
 38	headers        map[string]string
 39	sseReadTimeout time.Duration
 40}
 41
 42type ClientOption func(*SSEMCPClient)
 43
 44func WithHeaders(headers map[string]string) ClientOption {
 45	return func(sc *SSEMCPClient) {
 46		sc.headers = headers
 47	}
 48}
 49
 50func WithSSEReadTimeout(timeout time.Duration) ClientOption {
 51	return func(sc *SSEMCPClient) {
 52		sc.sseReadTimeout = timeout
 53	}
 54}
 55
 56// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL.
 57// Returns an error if the URL is invalid.
 58func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) {
 59	parsedURL, err := url.Parse(baseURL)
 60	if err != nil {
 61		return nil, fmt.Errorf("invalid URL: %w", err)
 62	}
 63
 64	smc := &SSEMCPClient{
 65		baseURL:        parsedURL,
 66		httpClient:     &http.Client{},
 67		responses:      make(map[int64]chan RPCResponse),
 68		done:           make(chan struct{}),
 69		endpointChan:   make(chan struct{}),
 70		sseReadTimeout: 30 * time.Second,
 71		headers:        make(map[string]string),
 72	}
 73
 74	for _, opt := range options {
 75		opt(smc)
 76	}
 77
 78	return smc, nil
 79}
 80
 81// Start initiates the SSE connection to the server and waits for the endpoint information.
 82// Returns an error if the connection fails or times out waiting for the endpoint.
 83func (c *SSEMCPClient) Start(ctx context.Context) error {
 84
 85	req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil)
 86
 87	if err != nil {
 88
 89		return fmt.Errorf("failed to create request: %w", err)
 90
 91	}
 92
 93	req.Header.Set("Accept", "text/event-stream")
 94	req.Header.Set("Cache-Control", "no-cache")
 95	req.Header.Set("Connection", "keep-alive")
 96
 97	resp, err := c.httpClient.Do(req)
 98	if err != nil {
 99		return fmt.Errorf("failed to connect to SSE stream: %w", err)
100	}
101
102	if resp.StatusCode != http.StatusOK {
103		resp.Body.Close()
104		return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
105	}
106
107	go c.readSSE(resp.Body)
108
109	// Wait for the endpoint to be received
110
111	select {
112	case <-c.endpointChan:
113		// Endpoint received, proceed
114	case <-ctx.Done():
115		return fmt.Errorf("context cancelled while waiting for endpoint")
116	case <-time.After(30 * time.Second): // Add a timeout
117		return fmt.Errorf("timeout waiting for endpoint")
118	}
119
120	return nil
121}
122
123// readSSE continuously reads the SSE stream and processes events.
124// It runs until the connection is closed or an error occurs.
125func (c *SSEMCPClient) readSSE(reader io.ReadCloser) {
126	defer reader.Close()
127
128	br := bufio.NewReader(reader)
129	var event, data string
130
131	ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout)
132	defer cancel()
133
134	for {
135		select {
136		case <-ctx.Done():
137			return
138		default:
139			line, err := br.ReadString('\n')
140			if err != nil {
141				if err == io.EOF {
142					// Process any pending event before exit
143					if event != "" && data != "" {
144						c.handleSSEEvent(event, data)
145					}
146					break
147				}
148				select {
149				case <-c.done:
150					return
151				default:
152					fmt.Printf("SSE stream error: %v\n", err)
153					return
154				}
155			}
156
157			// Remove only newline markers
158			line = strings.TrimRight(line, "\r\n")
159			if line == "" {
160				// Empty line means end of event
161				if event != "" && data != "" {
162					c.handleSSEEvent(event, data)
163					event = ""
164					data = ""
165				}
166				continue
167			}
168
169			if strings.HasPrefix(line, "event:") {
170				event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
171			} else if strings.HasPrefix(line, "data:") {
172				data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
173			}
174		}
175	}
176}
177
178// handleSSEEvent processes SSE events based on their type.
179// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication.
180func (c *SSEMCPClient) handleSSEEvent(event, data string) {
181	switch event {
182	case "endpoint":
183		endpoint, err := c.baseURL.Parse(data)
184		if err != nil {
185			fmt.Printf("Error parsing endpoint URL: %v\n", err)
186			return
187		}
188		if endpoint.Host != c.baseURL.Host {
189			fmt.Printf("Endpoint origin does not match connection origin\n")
190			return
191		}
192		c.endpoint = endpoint
193		close(c.endpointChan)
194
195	case "message":
196		var baseMessage struct {
197			JSONRPC string          `json:"jsonrpc"`
198			ID      *int64          `json:"id,omitempty"`
199			Method  string          `json:"method,omitempty"`
200			Result  json.RawMessage `json:"result,omitempty"`
201			Error   *struct {
202				Code    int    `json:"code"`
203				Message string `json:"message"`
204			} `json:"error,omitempty"`
205		}
206
207		if err := json.Unmarshal([]byte(data), &baseMessage); err != nil {
208			fmt.Printf("Error unmarshaling message: %v\n", err)
209			return
210		}
211
212		// Handle notification
213		if baseMessage.ID == nil {
214			var notification mcp.JSONRPCNotification
215			if err := json.Unmarshal([]byte(data), &notification); err != nil {
216				return
217			}
218			c.notifyMu.RLock()
219			for _, handler := range c.notifications {
220				handler(notification)
221			}
222			c.notifyMu.RUnlock()
223			return
224		}
225
226		c.mu.RLock()
227		ch, ok := c.responses[*baseMessage.ID]
228		c.mu.RUnlock()
229
230		if ok {
231			if baseMessage.Error != nil {
232				ch <- RPCResponse{
233					Error: &baseMessage.Error.Message,
234				}
235			} else {
236				ch <- RPCResponse{
237					Response: &baseMessage.Result,
238				}
239			}
240			c.mu.Lock()
241			delete(c.responses, *baseMessage.ID)
242			c.mu.Unlock()
243		}
244	}
245}
246
247// OnNotification registers a handler function to be called when notifications are received.
248// Multiple handlers can be registered and will be called in the order they were added.
249func (c *SSEMCPClient) OnNotification(
250	handler func(notification mcp.JSONRPCNotification),
251) {
252	c.notifyMu.Lock()
253	defer c.notifyMu.Unlock()
254	c.notifications = append(c.notifications, handler)
255}
256
257// sendRequest sends a JSON-RPC request to the server and waits for a response.
258// Returns the raw JSON response message or an error if the request fails.
259func (c *SSEMCPClient) sendRequest(
260	ctx context.Context,
261	method string,
262	params interface{},
263) (*json.RawMessage, error) {
264	if !c.initialized && method != "initialize" {
265		return nil, fmt.Errorf("client not initialized")
266	}
267
268	if c.endpoint == nil {
269		return nil, fmt.Errorf("endpoint not received")
270	}
271
272	id := c.requestID.Add(1)
273
274	request := mcp.JSONRPCRequest{
275		JSONRPC: mcp.JSONRPC_VERSION,
276		ID:      id,
277		Request: mcp.Request{
278			Method: method,
279		},
280		Params: params,
281	}
282
283	requestBytes, err := json.Marshal(request)
284	if err != nil {
285		return nil, fmt.Errorf("failed to marshal request: %w", err)
286	}
287
288	responseChan := make(chan RPCResponse, 1)
289	c.mu.Lock()
290	c.responses[id] = responseChan
291	c.mu.Unlock()
292
293	req, err := http.NewRequestWithContext(
294		ctx,
295		"POST",
296		c.endpoint.String(),
297		bytes.NewReader(requestBytes),
298	)
299	if err != nil {
300		return nil, fmt.Errorf("failed to create request: %w", err)
301	}
302
303	req.Header.Set("Content-Type", "application/json")
304	// set custom http headers
305	for k, v := range c.headers {
306		req.Header.Set(k, v)
307	}
308
309	resp, err := c.httpClient.Do(req)
310	if err != nil {
311		return nil, fmt.Errorf("failed to send request: %w", err)
312	}
313	defer resp.Body.Close()
314
315	if resp.StatusCode != http.StatusOK &&
316		resp.StatusCode != http.StatusAccepted {
317		body, _ := io.ReadAll(resp.Body)
318		return nil, fmt.Errorf(
319			"request failed with status %d: %s",
320			resp.StatusCode,
321			body,
322		)
323	}
324
325	select {
326	case <-ctx.Done():
327		c.mu.Lock()
328		delete(c.responses, id)
329		c.mu.Unlock()
330		return nil, ctx.Err()
331	case response := <-responseChan:
332		if response.Error != nil {
333			return nil, errors.New(*response.Error)
334		}
335		return response.Response, nil
336	}
337}
338
339func (c *SSEMCPClient) Initialize(
340	ctx context.Context,
341	request mcp.InitializeRequest,
342) (*mcp.InitializeResult, error) {
343	// Ensure we send a params object with all required fields
344	params := struct {
345		ProtocolVersion string                 `json:"protocolVersion"`
346		ClientInfo      mcp.Implementation     `json:"clientInfo"`
347		Capabilities    mcp.ClientCapabilities `json:"capabilities"`
348	}{
349		ProtocolVersion: request.Params.ProtocolVersion,
350		ClientInfo:      request.Params.ClientInfo,
351		Capabilities:    request.Params.Capabilities, // Will be empty struct if not set
352	}
353
354	response, err := c.sendRequest(ctx, "initialize", params)
355	if err != nil {
356		return nil, err
357	}
358
359	var result mcp.InitializeResult
360	if err := json.Unmarshal(*response, &result); err != nil {
361		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
362	}
363
364	// Store capabilities
365	c.capabilities = result.Capabilities
366
367	// Send initialized notification
368	notification := mcp.JSONRPCNotification{
369		JSONRPC: mcp.JSONRPC_VERSION,
370		Notification: mcp.Notification{
371			Method: "notifications/initialized",
372		},
373	}
374
375	notificationBytes, err := json.Marshal(notification)
376	if err != nil {
377		return nil, fmt.Errorf(
378			"failed to marshal initialized notification: %w",
379			err,
380		)
381	}
382
383	req, err := http.NewRequestWithContext(
384		ctx,
385		"POST",
386		c.endpoint.String(),
387		bytes.NewReader(notificationBytes),
388	)
389	if err != nil {
390		return nil, fmt.Errorf("failed to create notification request: %w", err)
391	}
392
393	req.Header.Set("Content-Type", "application/json")
394
395	resp, err := c.httpClient.Do(req)
396	if err != nil {
397		return nil, fmt.Errorf(
398			"failed to send initialized notification: %w",
399			err,
400		)
401	}
402	resp.Body.Close()
403
404	c.initialized = true
405	return &result, nil
406}
407
408func (c *SSEMCPClient) Ping(ctx context.Context) error {
409	_, err := c.sendRequest(ctx, "ping", nil)
410	return err
411}
412
413func (c *SSEMCPClient) ListResources(
414	ctx context.Context,
415	request mcp.ListResourcesRequest,
416) (*mcp.ListResourcesResult, error) {
417	response, err := c.sendRequest(ctx, "resources/list", request.Params)
418	if err != nil {
419		return nil, err
420	}
421
422	var result mcp.ListResourcesResult
423	if err := json.Unmarshal(*response, &result); err != nil {
424		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
425	}
426
427	return &result, nil
428}
429
430func (c *SSEMCPClient) ListResourceTemplates(
431	ctx context.Context,
432	request mcp.ListResourceTemplatesRequest,
433) (*mcp.ListResourceTemplatesResult, error) {
434	response, err := c.sendRequest(
435		ctx,
436		"resources/templates/list",
437		request.Params,
438	)
439	if err != nil {
440		return nil, err
441	}
442
443	var result mcp.ListResourceTemplatesResult
444	if err := json.Unmarshal(*response, &result); err != nil {
445		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
446	}
447
448	return &result, nil
449}
450
451func (c *SSEMCPClient) ReadResource(
452	ctx context.Context,
453	request mcp.ReadResourceRequest,
454) (*mcp.ReadResourceResult, error) {
455	response, err := c.sendRequest(ctx, "resources/read", request.Params)
456	if err != nil {
457		return nil, err
458	}
459
460	return mcp.ParseReadResourceResult(response)
461}
462
463func (c *SSEMCPClient) Subscribe(
464	ctx context.Context,
465	request mcp.SubscribeRequest,
466) error {
467	_, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
468	return err
469}
470
471func (c *SSEMCPClient) Unsubscribe(
472	ctx context.Context,
473	request mcp.UnsubscribeRequest,
474) error {
475	_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
476	return err
477}
478
479func (c *SSEMCPClient) ListPrompts(
480	ctx context.Context,
481	request mcp.ListPromptsRequest,
482) (*mcp.ListPromptsResult, error) {
483	response, err := c.sendRequest(ctx, "prompts/list", request.Params)
484	if err != nil {
485		return nil, err
486	}
487
488	var result mcp.ListPromptsResult
489	if err := json.Unmarshal(*response, &result); err != nil {
490		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
491	}
492
493	return &result, nil
494}
495
496func (c *SSEMCPClient) GetPrompt(
497	ctx context.Context,
498	request mcp.GetPromptRequest,
499) (*mcp.GetPromptResult, error) {
500	response, err := c.sendRequest(ctx, "prompts/get", request.Params)
501	if err != nil {
502		return nil, err
503	}
504
505	return mcp.ParseGetPromptResult(response)
506}
507
508func (c *SSEMCPClient) ListTools(
509	ctx context.Context,
510	request mcp.ListToolsRequest,
511) (*mcp.ListToolsResult, error) {
512	response, err := c.sendRequest(ctx, "tools/list", request.Params)
513	if err != nil {
514		return nil, err
515	}
516
517	var result mcp.ListToolsResult
518	if err := json.Unmarshal(*response, &result); err != nil {
519		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
520	}
521
522	return &result, nil
523}
524
525func (c *SSEMCPClient) CallTool(
526	ctx context.Context,
527	request mcp.CallToolRequest,
528) (*mcp.CallToolResult, error) {
529	response, err := c.sendRequest(ctx, "tools/call", request.Params)
530	if err != nil {
531		return nil, err
532	}
533
534	return mcp.ParseCallToolResult(response)
535}
536
537func (c *SSEMCPClient) SetLevel(
538	ctx context.Context,
539	request mcp.SetLevelRequest,
540) error {
541	_, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
542	return err
543}
544
545func (c *SSEMCPClient) Complete(
546	ctx context.Context,
547	request mcp.CompleteRequest,
548) (*mcp.CompleteResult, error) {
549	response, err := c.sendRequest(ctx, "completion/complete", request.Params)
550	if err != nil {
551		return nil, err
552	}
553
554	var result mcp.CompleteResult
555	if err := json.Unmarshal(*response, &result); err != nil {
556		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
557	}
558
559	return &result, nil
560}
561
562// Helper methods
563
564// GetEndpoint returns the current endpoint URL for the SSE connection.
565func (c *SSEMCPClient) GetEndpoint() *url.URL {
566	return c.endpoint
567}
568
569// Close shuts down the SSE client connection and cleans up any pending responses.
570// Returns an error if the shutdown process fails.
571func (c *SSEMCPClient) Close() error {
572	select {
573	case <-c.done:
574		return nil // Already closed
575	default:
576		close(c.done)
577	}
578
579	// Clean up any pending responses
580	c.mu.Lock()
581	for _, ch := range c.responses {
582		close(ch)
583	}
584	c.responses = make(map[int64]chan RPCResponse)
585	c.mu.Unlock()
586
587	return nil
588}