1package client
2
3import (
4 "bufio"
5 "bytes"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "net/http"
12 "net/url"
13 "strings"
14 "sync"
15 "sync/atomic"
16 "time"
17
18 "github.com/mark3labs/mcp-go/mcp"
19)
20
21// SSEMCPClient implements the MCPClient interface using Server-Sent Events (SSE).
22// It maintains a persistent HTTP connection to receive server-pushed events
23// while sending requests over regular HTTP POST calls. The client handles
24// automatic reconnection and message routing between requests and responses.
25type SSEMCPClient struct {
26 baseURL *url.URL
27 endpoint *url.URL
28 httpClient *http.Client
29 requestID atomic.Int64
30 responses map[int64]chan RPCResponse
31 mu sync.RWMutex
32 done chan struct{}
33 initialized bool
34 notifications []func(mcp.JSONRPCNotification)
35 notifyMu sync.RWMutex
36 endpointChan chan struct{}
37 capabilities mcp.ServerCapabilities
38 headers map[string]string
39 sseReadTimeout time.Duration
40}
41
42type ClientOption func(*SSEMCPClient)
43
44func WithHeaders(headers map[string]string) ClientOption {
45 return func(sc *SSEMCPClient) {
46 sc.headers = headers
47 }
48}
49
50func WithSSEReadTimeout(timeout time.Duration) ClientOption {
51 return func(sc *SSEMCPClient) {
52 sc.sseReadTimeout = timeout
53 }
54}
55
56// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL.
57// Returns an error if the URL is invalid.
58func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) {
59 parsedURL, err := url.Parse(baseURL)
60 if err != nil {
61 return nil, fmt.Errorf("invalid URL: %w", err)
62 }
63
64 smc := &SSEMCPClient{
65 baseURL: parsedURL,
66 httpClient: &http.Client{},
67 responses: make(map[int64]chan RPCResponse),
68 done: make(chan struct{}),
69 endpointChan: make(chan struct{}),
70 sseReadTimeout: 30 * time.Second,
71 headers: make(map[string]string),
72 }
73
74 for _, opt := range options {
75 opt(smc)
76 }
77
78 return smc, nil
79}
80
81// Start initiates the SSE connection to the server and waits for the endpoint information.
82// Returns an error if the connection fails or times out waiting for the endpoint.
83func (c *SSEMCPClient) Start(ctx context.Context) error {
84
85 req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil)
86
87 if err != nil {
88
89 return fmt.Errorf("failed to create request: %w", err)
90
91 }
92
93 req.Header.Set("Accept", "text/event-stream")
94 req.Header.Set("Cache-Control", "no-cache")
95 req.Header.Set("Connection", "keep-alive")
96
97 resp, err := c.httpClient.Do(req)
98 if err != nil {
99 return fmt.Errorf("failed to connect to SSE stream: %w", err)
100 }
101
102 if resp.StatusCode != http.StatusOK {
103 resp.Body.Close()
104 return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
105 }
106
107 go c.readSSE(resp.Body)
108
109 // Wait for the endpoint to be received
110
111 select {
112 case <-c.endpointChan:
113 // Endpoint received, proceed
114 case <-ctx.Done():
115 return fmt.Errorf("context cancelled while waiting for endpoint")
116 case <-time.After(30 * time.Second): // Add a timeout
117 return fmt.Errorf("timeout waiting for endpoint")
118 }
119
120 return nil
121}
122
123// readSSE continuously reads the SSE stream and processes events.
124// It runs until the connection is closed or an error occurs.
125func (c *SSEMCPClient) readSSE(reader io.ReadCloser) {
126 defer reader.Close()
127
128 br := bufio.NewReader(reader)
129 var event, data string
130
131 ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout)
132 defer cancel()
133
134 for {
135 select {
136 case <-ctx.Done():
137 return
138 default:
139 line, err := br.ReadString('\n')
140 if err != nil {
141 if err == io.EOF {
142 // Process any pending event before exit
143 if event != "" && data != "" {
144 c.handleSSEEvent(event, data)
145 }
146 break
147 }
148 select {
149 case <-c.done:
150 return
151 default:
152 fmt.Printf("SSE stream error: %v\n", err)
153 return
154 }
155 }
156
157 // Remove only newline markers
158 line = strings.TrimRight(line, "\r\n")
159 if line == "" {
160 // Empty line means end of event
161 if event != "" && data != "" {
162 c.handleSSEEvent(event, data)
163 event = ""
164 data = ""
165 }
166 continue
167 }
168
169 if strings.HasPrefix(line, "event:") {
170 event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
171 } else if strings.HasPrefix(line, "data:") {
172 data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
173 }
174 }
175 }
176}
177
178// handleSSEEvent processes SSE events based on their type.
179// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication.
180func (c *SSEMCPClient) handleSSEEvent(event, data string) {
181 switch event {
182 case "endpoint":
183 endpoint, err := c.baseURL.Parse(data)
184 if err != nil {
185 fmt.Printf("Error parsing endpoint URL: %v\n", err)
186 return
187 }
188 if endpoint.Host != c.baseURL.Host {
189 fmt.Printf("Endpoint origin does not match connection origin\n")
190 return
191 }
192 c.endpoint = endpoint
193 close(c.endpointChan)
194
195 case "message":
196 var baseMessage struct {
197 JSONRPC string `json:"jsonrpc"`
198 ID *int64 `json:"id,omitempty"`
199 Method string `json:"method,omitempty"`
200 Result json.RawMessage `json:"result,omitempty"`
201 Error *struct {
202 Code int `json:"code"`
203 Message string `json:"message"`
204 } `json:"error,omitempty"`
205 }
206
207 if err := json.Unmarshal([]byte(data), &baseMessage); err != nil {
208 fmt.Printf("Error unmarshaling message: %v\n", err)
209 return
210 }
211
212 // Handle notification
213 if baseMessage.ID == nil {
214 var notification mcp.JSONRPCNotification
215 if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
216 return
217 }
218 c.notifyMu.RLock()
219 for _, handler := range c.notifications {
220 handler(notification)
221 }
222 c.notifyMu.RUnlock()
223 return
224 }
225
226 c.mu.RLock()
227 ch, ok := c.responses[*baseMessage.ID]
228 c.mu.RUnlock()
229
230 if ok {
231 if baseMessage.Error != nil {
232 ch <- RPCResponse{
233 Error: &baseMessage.Error.Message,
234 }
235 } else {
236 ch <- RPCResponse{
237 Response: &baseMessage.Result,
238 }
239 }
240 c.mu.Lock()
241 delete(c.responses, *baseMessage.ID)
242 c.mu.Unlock()
243 }
244 }
245}
246
247// OnNotification registers a handler function to be called when notifications are received.
248// Multiple handlers can be registered and will be called in the order they were added.
249func (c *SSEMCPClient) OnNotification(
250 handler func(notification mcp.JSONRPCNotification),
251) {
252 c.notifyMu.Lock()
253 defer c.notifyMu.Unlock()
254 c.notifications = append(c.notifications, handler)
255}
256
257// sendRequest sends a JSON-RPC request to the server and waits for a response.
258// Returns the raw JSON response message or an error if the request fails.
259func (c *SSEMCPClient) sendRequest(
260 ctx context.Context,
261 method string,
262 params interface{},
263) (*json.RawMessage, error) {
264 if !c.initialized && method != "initialize" {
265 return nil, fmt.Errorf("client not initialized")
266 }
267
268 if c.endpoint == nil {
269 return nil, fmt.Errorf("endpoint not received")
270 }
271
272 id := c.requestID.Add(1)
273
274 request := mcp.JSONRPCRequest{
275 JSONRPC: mcp.JSONRPC_VERSION,
276 ID: id,
277 Request: mcp.Request{
278 Method: method,
279 },
280 Params: params,
281 }
282
283 requestBytes, err := json.Marshal(request)
284 if err != nil {
285 return nil, fmt.Errorf("failed to marshal request: %w", err)
286 }
287
288 responseChan := make(chan RPCResponse, 1)
289 c.mu.Lock()
290 c.responses[id] = responseChan
291 c.mu.Unlock()
292
293 req, err := http.NewRequestWithContext(
294 ctx,
295 "POST",
296 c.endpoint.String(),
297 bytes.NewReader(requestBytes),
298 )
299 if err != nil {
300 return nil, fmt.Errorf("failed to create request: %w", err)
301 }
302
303 req.Header.Set("Content-Type", "application/json")
304 // set custom http headers
305 for k, v := range c.headers {
306 req.Header.Set(k, v)
307 }
308
309 resp, err := c.httpClient.Do(req)
310 if err != nil {
311 return nil, fmt.Errorf("failed to send request: %w", err)
312 }
313 defer resp.Body.Close()
314
315 if resp.StatusCode != http.StatusOK &&
316 resp.StatusCode != http.StatusAccepted {
317 body, _ := io.ReadAll(resp.Body)
318 return nil, fmt.Errorf(
319 "request failed with status %d: %s",
320 resp.StatusCode,
321 body,
322 )
323 }
324
325 select {
326 case <-ctx.Done():
327 c.mu.Lock()
328 delete(c.responses, id)
329 c.mu.Unlock()
330 return nil, ctx.Err()
331 case response := <-responseChan:
332 if response.Error != nil {
333 return nil, errors.New(*response.Error)
334 }
335 return response.Response, nil
336 }
337}
338
339func (c *SSEMCPClient) Initialize(
340 ctx context.Context,
341 request mcp.InitializeRequest,
342) (*mcp.InitializeResult, error) {
343 // Ensure we send a params object with all required fields
344 params := struct {
345 ProtocolVersion string `json:"protocolVersion"`
346 ClientInfo mcp.Implementation `json:"clientInfo"`
347 Capabilities mcp.ClientCapabilities `json:"capabilities"`
348 }{
349 ProtocolVersion: request.Params.ProtocolVersion,
350 ClientInfo: request.Params.ClientInfo,
351 Capabilities: request.Params.Capabilities, // Will be empty struct if not set
352 }
353
354 response, err := c.sendRequest(ctx, "initialize", params)
355 if err != nil {
356 return nil, err
357 }
358
359 var result mcp.InitializeResult
360 if err := json.Unmarshal(*response, &result); err != nil {
361 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
362 }
363
364 // Store capabilities
365 c.capabilities = result.Capabilities
366
367 // Send initialized notification
368 notification := mcp.JSONRPCNotification{
369 JSONRPC: mcp.JSONRPC_VERSION,
370 Notification: mcp.Notification{
371 Method: "notifications/initialized",
372 },
373 }
374
375 notificationBytes, err := json.Marshal(notification)
376 if err != nil {
377 return nil, fmt.Errorf(
378 "failed to marshal initialized notification: %w",
379 err,
380 )
381 }
382
383 req, err := http.NewRequestWithContext(
384 ctx,
385 "POST",
386 c.endpoint.String(),
387 bytes.NewReader(notificationBytes),
388 )
389 if err != nil {
390 return nil, fmt.Errorf("failed to create notification request: %w", err)
391 }
392
393 req.Header.Set("Content-Type", "application/json")
394
395 resp, err := c.httpClient.Do(req)
396 if err != nil {
397 return nil, fmt.Errorf(
398 "failed to send initialized notification: %w",
399 err,
400 )
401 }
402 resp.Body.Close()
403
404 c.initialized = true
405 return &result, nil
406}
407
408func (c *SSEMCPClient) Ping(ctx context.Context) error {
409 _, err := c.sendRequest(ctx, "ping", nil)
410 return err
411}
412
413func (c *SSEMCPClient) ListResources(
414 ctx context.Context,
415 request mcp.ListResourcesRequest,
416) (*mcp.ListResourcesResult, error) {
417 response, err := c.sendRequest(ctx, "resources/list", request.Params)
418 if err != nil {
419 return nil, err
420 }
421
422 var result mcp.ListResourcesResult
423 if err := json.Unmarshal(*response, &result); err != nil {
424 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
425 }
426
427 return &result, nil
428}
429
430func (c *SSEMCPClient) ListResourceTemplates(
431 ctx context.Context,
432 request mcp.ListResourceTemplatesRequest,
433) (*mcp.ListResourceTemplatesResult, error) {
434 response, err := c.sendRequest(
435 ctx,
436 "resources/templates/list",
437 request.Params,
438 )
439 if err != nil {
440 return nil, err
441 }
442
443 var result mcp.ListResourceTemplatesResult
444 if err := json.Unmarshal(*response, &result); err != nil {
445 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
446 }
447
448 return &result, nil
449}
450
451func (c *SSEMCPClient) ReadResource(
452 ctx context.Context,
453 request mcp.ReadResourceRequest,
454) (*mcp.ReadResourceResult, error) {
455 response, err := c.sendRequest(ctx, "resources/read", request.Params)
456 if err != nil {
457 return nil, err
458 }
459
460 return mcp.ParseReadResourceResult(response)
461}
462
463func (c *SSEMCPClient) Subscribe(
464 ctx context.Context,
465 request mcp.SubscribeRequest,
466) error {
467 _, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
468 return err
469}
470
471func (c *SSEMCPClient) Unsubscribe(
472 ctx context.Context,
473 request mcp.UnsubscribeRequest,
474) error {
475 _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
476 return err
477}
478
479func (c *SSEMCPClient) ListPrompts(
480 ctx context.Context,
481 request mcp.ListPromptsRequest,
482) (*mcp.ListPromptsResult, error) {
483 response, err := c.sendRequest(ctx, "prompts/list", request.Params)
484 if err != nil {
485 return nil, err
486 }
487
488 var result mcp.ListPromptsResult
489 if err := json.Unmarshal(*response, &result); err != nil {
490 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
491 }
492
493 return &result, nil
494}
495
496func (c *SSEMCPClient) GetPrompt(
497 ctx context.Context,
498 request mcp.GetPromptRequest,
499) (*mcp.GetPromptResult, error) {
500 response, err := c.sendRequest(ctx, "prompts/get", request.Params)
501 if err != nil {
502 return nil, err
503 }
504
505 return mcp.ParseGetPromptResult(response)
506}
507
508func (c *SSEMCPClient) ListTools(
509 ctx context.Context,
510 request mcp.ListToolsRequest,
511) (*mcp.ListToolsResult, error) {
512 response, err := c.sendRequest(ctx, "tools/list", request.Params)
513 if err != nil {
514 return nil, err
515 }
516
517 var result mcp.ListToolsResult
518 if err := json.Unmarshal(*response, &result); err != nil {
519 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
520 }
521
522 return &result, nil
523}
524
525func (c *SSEMCPClient) CallTool(
526 ctx context.Context,
527 request mcp.CallToolRequest,
528) (*mcp.CallToolResult, error) {
529 response, err := c.sendRequest(ctx, "tools/call", request.Params)
530 if err != nil {
531 return nil, err
532 }
533
534 return mcp.ParseCallToolResult(response)
535}
536
537func (c *SSEMCPClient) SetLevel(
538 ctx context.Context,
539 request mcp.SetLevelRequest,
540) error {
541 _, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
542 return err
543}
544
545func (c *SSEMCPClient) Complete(
546 ctx context.Context,
547 request mcp.CompleteRequest,
548) (*mcp.CompleteResult, error) {
549 response, err := c.sendRequest(ctx, "completion/complete", request.Params)
550 if err != nil {
551 return nil, err
552 }
553
554 var result mcp.CompleteResult
555 if err := json.Unmarshal(*response, &result); err != nil {
556 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
557 }
558
559 return &result, nil
560}
561
562// Helper methods
563
564// GetEndpoint returns the current endpoint URL for the SSE connection.
565func (c *SSEMCPClient) GetEndpoint() *url.URL {
566 return c.endpoint
567}
568
569// Close shuts down the SSE client connection and cleans up any pending responses.
570// Returns an error if the shutdown process fails.
571func (c *SSEMCPClient) Close() error {
572 select {
573 case <-c.done:
574 return nil // Already closed
575 default:
576 close(c.done)
577 }
578
579 // Clean up any pending responses
580 c.mu.Lock()
581 for _, ch := range c.responses {
582 close(ch)
583 }
584 c.responses = make(map[int64]chan RPCResponse)
585 c.mu.Unlock()
586
587 return nil
588}