1package client
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "sync"
9 "sync/atomic"
10
11 "github.com/mark3labs/mcp-go/client/transport"
12 "github.com/mark3labs/mcp-go/mcp"
13)
14
15// Client implements the MCP client.
16type Client struct {
17 transport transport.Interface
18
19 initialized bool
20 notifications []func(mcp.JSONRPCNotification)
21 notifyMu sync.RWMutex
22 requestID atomic.Int64
23 clientCapabilities mcp.ClientCapabilities
24 serverCapabilities mcp.ServerCapabilities
25 samplingHandler SamplingHandler
26}
27
28type ClientOption func(*Client)
29
30// WithClientCapabilities sets the client capabilities for the client.
31func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
32 return func(c *Client) {
33 c.clientCapabilities = capabilities
34 }
35}
36
37// WithSamplingHandler sets the sampling handler for the client.
38// When set, the client will declare sampling capability during initialization.
39func WithSamplingHandler(handler SamplingHandler) ClientOption {
40 return func(c *Client) {
41 c.samplingHandler = handler
42 }
43}
44
45// WithSession assumes a MCP Session has already been initialized
46func WithSession() ClientOption {
47 return func(c *Client) {
48 c.initialized = true
49 }
50}
51
52// NewClient creates a new MCP client with the given transport.
53// Usage:
54//
55// stdio := transport.NewStdio("./mcp_server", nil, "xxx")
56// client, err := NewClient(stdio)
57// if err != nil {
58// log.Fatalf("Failed to create client: %v", err)
59// }
60func NewClient(transport transport.Interface, options ...ClientOption) *Client {
61 client := &Client{
62 transport: transport,
63 }
64
65 for _, opt := range options {
66 opt(client)
67 }
68
69 return client
70}
71
72// Start initiates the connection to the server.
73// Must be called before using the client.
74func (c *Client) Start(ctx context.Context) error {
75 if c.transport == nil {
76 return fmt.Errorf("transport is nil")
77 }
78 err := c.transport.Start(ctx)
79 if err != nil {
80 return err
81 }
82
83 c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) {
84 c.notifyMu.RLock()
85 defer c.notifyMu.RUnlock()
86 for _, handler := range c.notifications {
87 handler(notification)
88 }
89 })
90
91 // Set up request handler for bidirectional communication (e.g., sampling)
92 if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok {
93 bidirectional.SetRequestHandler(c.handleIncomingRequest)
94 }
95
96 return nil
97}
98
99// Close shuts down the client and closes the transport.
100func (c *Client) Close() error {
101 return c.transport.Close()
102}
103
104// OnNotification registers a handler function to be called when notifications are received.
105// Multiple handlers can be registered and will be called in the order they were added.
106func (c *Client) OnNotification(
107 handler func(notification mcp.JSONRPCNotification),
108) {
109 c.notifyMu.Lock()
110 defer c.notifyMu.Unlock()
111 c.notifications = append(c.notifications, handler)
112}
113
114// sendRequest sends a JSON-RPC request to the server and waits for a response.
115// Returns the raw JSON response message or an error if the request fails.
116func (c *Client) sendRequest(
117 ctx context.Context,
118 method string,
119 params any,
120) (*json.RawMessage, error) {
121 if !c.initialized && method != "initialize" {
122 return nil, fmt.Errorf("client not initialized")
123 }
124
125 id := c.requestID.Add(1)
126
127 request := transport.JSONRPCRequest{
128 JSONRPC: mcp.JSONRPC_VERSION,
129 ID: mcp.NewRequestId(id),
130 Method: method,
131 Params: params,
132 }
133
134 response, err := c.transport.SendRequest(ctx, request)
135 if err != nil {
136 return nil, fmt.Errorf("transport error: %w", err)
137 }
138
139 if response.Error != nil {
140 return nil, errors.New(response.Error.Message)
141 }
142
143 return &response.Result, nil
144}
145
146// Initialize negotiates with the server.
147// Must be called after Start, and before any request methods.
148func (c *Client) Initialize(
149 ctx context.Context,
150 request mcp.InitializeRequest,
151) (*mcp.InitializeResult, error) {
152 // Merge client capabilities with sampling capability if handler is configured
153 capabilities := request.Params.Capabilities
154 if c.samplingHandler != nil {
155 capabilities.Sampling = &struct{}{}
156 }
157
158 // Ensure we send a params object with all required fields
159 params := struct {
160 ProtocolVersion string `json:"protocolVersion"`
161 ClientInfo mcp.Implementation `json:"clientInfo"`
162 Capabilities mcp.ClientCapabilities `json:"capabilities"`
163 }{
164 ProtocolVersion: request.Params.ProtocolVersion,
165 ClientInfo: request.Params.ClientInfo,
166 Capabilities: capabilities,
167 }
168
169 response, err := c.sendRequest(ctx, "initialize", params)
170 if err != nil {
171 return nil, err
172 }
173
174 var result mcp.InitializeResult
175 if err := json.Unmarshal(*response, &result); err != nil {
176 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
177 }
178
179 // Store serverCapabilities
180 c.serverCapabilities = result.Capabilities
181
182 // Send initialized notification
183 notification := mcp.JSONRPCNotification{
184 JSONRPC: mcp.JSONRPC_VERSION,
185 Notification: mcp.Notification{
186 Method: "notifications/initialized",
187 },
188 }
189
190 err = c.transport.SendNotification(ctx, notification)
191 if err != nil {
192 return nil, fmt.Errorf(
193 "failed to send initialized notification: %w",
194 err,
195 )
196 }
197
198 c.initialized = true
199 return &result, nil
200}
201
202func (c *Client) Ping(ctx context.Context) error {
203 _, err := c.sendRequest(ctx, "ping", nil)
204 return err
205}
206
207// ListResourcesByPage manually list resources by page.
208func (c *Client) ListResourcesByPage(
209 ctx context.Context,
210 request mcp.ListResourcesRequest,
211) (*mcp.ListResourcesResult, error) {
212 result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list")
213 if err != nil {
214 return nil, err
215 }
216 return result, nil
217}
218
219func (c *Client) ListResources(
220 ctx context.Context,
221 request mcp.ListResourcesRequest,
222) (*mcp.ListResourcesResult, error) {
223 result, err := c.ListResourcesByPage(ctx, request)
224 if err != nil {
225 return nil, err
226 }
227 for result.NextCursor != "" {
228 select {
229 case <-ctx.Done():
230 return nil, ctx.Err()
231 default:
232 request.Params.Cursor = result.NextCursor
233 newPageRes, err := c.ListResourcesByPage(ctx, request)
234 if err != nil {
235 return nil, err
236 }
237 result.Resources = append(result.Resources, newPageRes.Resources...)
238 result.NextCursor = newPageRes.NextCursor
239 }
240 }
241 return result, nil
242}
243
244func (c *Client) ListResourceTemplatesByPage(
245 ctx context.Context,
246 request mcp.ListResourceTemplatesRequest,
247) (*mcp.ListResourceTemplatesResult, error) {
248 result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list")
249 if err != nil {
250 return nil, err
251 }
252 return result, nil
253}
254
255func (c *Client) ListResourceTemplates(
256 ctx context.Context,
257 request mcp.ListResourceTemplatesRequest,
258) (*mcp.ListResourceTemplatesResult, error) {
259 result, err := c.ListResourceTemplatesByPage(ctx, request)
260 if err != nil {
261 return nil, err
262 }
263 for result.NextCursor != "" {
264 select {
265 case <-ctx.Done():
266 return nil, ctx.Err()
267 default:
268 request.Params.Cursor = result.NextCursor
269 newPageRes, err := c.ListResourceTemplatesByPage(ctx, request)
270 if err != nil {
271 return nil, err
272 }
273 result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...)
274 result.NextCursor = newPageRes.NextCursor
275 }
276 }
277 return result, nil
278}
279
280func (c *Client) ReadResource(
281 ctx context.Context,
282 request mcp.ReadResourceRequest,
283) (*mcp.ReadResourceResult, error) {
284 response, err := c.sendRequest(ctx, "resources/read", request.Params)
285 if err != nil {
286 return nil, err
287 }
288
289 return mcp.ParseReadResourceResult(response)
290}
291
292func (c *Client) Subscribe(
293 ctx context.Context,
294 request mcp.SubscribeRequest,
295) error {
296 _, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
297 return err
298}
299
300func (c *Client) Unsubscribe(
301 ctx context.Context,
302 request mcp.UnsubscribeRequest,
303) error {
304 _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
305 return err
306}
307
308func (c *Client) ListPromptsByPage(
309 ctx context.Context,
310 request mcp.ListPromptsRequest,
311) (*mcp.ListPromptsResult, error) {
312 result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list")
313 if err != nil {
314 return nil, err
315 }
316 return result, nil
317}
318
319func (c *Client) ListPrompts(
320 ctx context.Context,
321 request mcp.ListPromptsRequest,
322) (*mcp.ListPromptsResult, error) {
323 result, err := c.ListPromptsByPage(ctx, request)
324 if err != nil {
325 return nil, err
326 }
327 for result.NextCursor != "" {
328 select {
329 case <-ctx.Done():
330 return nil, ctx.Err()
331 default:
332 request.Params.Cursor = result.NextCursor
333 newPageRes, err := c.ListPromptsByPage(ctx, request)
334 if err != nil {
335 return nil, err
336 }
337 result.Prompts = append(result.Prompts, newPageRes.Prompts...)
338 result.NextCursor = newPageRes.NextCursor
339 }
340 }
341 return result, nil
342}
343
344func (c *Client) GetPrompt(
345 ctx context.Context,
346 request mcp.GetPromptRequest,
347) (*mcp.GetPromptResult, error) {
348 response, err := c.sendRequest(ctx, "prompts/get", request.Params)
349 if err != nil {
350 return nil, err
351 }
352
353 return mcp.ParseGetPromptResult(response)
354}
355
356func (c *Client) ListToolsByPage(
357 ctx context.Context,
358 request mcp.ListToolsRequest,
359) (*mcp.ListToolsResult, error) {
360 result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list")
361 if err != nil {
362 return nil, err
363 }
364 return result, nil
365}
366
367func (c *Client) ListTools(
368 ctx context.Context,
369 request mcp.ListToolsRequest,
370) (*mcp.ListToolsResult, error) {
371 result, err := c.ListToolsByPage(ctx, request)
372 if err != nil {
373 return nil, err
374 }
375 for result.NextCursor != "" {
376 select {
377 case <-ctx.Done():
378 return nil, ctx.Err()
379 default:
380 request.Params.Cursor = result.NextCursor
381 newPageRes, err := c.ListToolsByPage(ctx, request)
382 if err != nil {
383 return nil, err
384 }
385 result.Tools = append(result.Tools, newPageRes.Tools...)
386 result.NextCursor = newPageRes.NextCursor
387 }
388 }
389 return result, nil
390}
391
392func (c *Client) CallTool(
393 ctx context.Context,
394 request mcp.CallToolRequest,
395) (*mcp.CallToolResult, error) {
396 response, err := c.sendRequest(ctx, "tools/call", request.Params)
397 if err != nil {
398 return nil, err
399 }
400
401 return mcp.ParseCallToolResult(response)
402}
403
404func (c *Client) SetLevel(
405 ctx context.Context,
406 request mcp.SetLevelRequest,
407) error {
408 _, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
409 return err
410}
411
412func (c *Client) Complete(
413 ctx context.Context,
414 request mcp.CompleteRequest,
415) (*mcp.CompleteResult, error) {
416 response, err := c.sendRequest(ctx, "completion/complete", request.Params)
417 if err != nil {
418 return nil, err
419 }
420
421 var result mcp.CompleteResult
422 if err := json.Unmarshal(*response, &result); err != nil {
423 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
424 }
425
426 return &result, nil
427}
428
429// handleIncomingRequest processes incoming requests from the server.
430// This is the main entry point for server-to-client requests like sampling.
431func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
432 switch request.Method {
433 case string(mcp.MethodSamplingCreateMessage):
434 return c.handleSamplingRequestTransport(ctx, request)
435 default:
436 return nil, fmt.Errorf("unsupported request method: %s", request.Method)
437 }
438}
439
440// handleSamplingRequestTransport handles sampling requests at the transport level.
441func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
442 if c.samplingHandler == nil {
443 return nil, fmt.Errorf("no sampling handler configured")
444 }
445
446 // Parse the request parameters
447 var params mcp.CreateMessageParams
448 if request.Params != nil {
449 paramsBytes, err := json.Marshal(request.Params)
450 if err != nil {
451 return nil, fmt.Errorf("failed to marshal params: %w", err)
452 }
453 if err := json.Unmarshal(paramsBytes, ¶ms); err != nil {
454 return nil, fmt.Errorf("failed to unmarshal params: %w", err)
455 }
456 }
457
458 // Create the MCP request
459 mcpRequest := mcp.CreateMessageRequest{
460 Request: mcp.Request{
461 Method: string(mcp.MethodSamplingCreateMessage),
462 },
463 CreateMessageParams: params,
464 }
465
466 // Call the sampling handler
467 result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest)
468 if err != nil {
469 return nil, err
470 }
471
472 // Marshal the result
473 resultBytes, err := json.Marshal(result)
474 if err != nil {
475 return nil, fmt.Errorf("failed to marshal result: %w", err)
476 }
477
478 // Create the transport response
479 response := &transport.JSONRPCResponse{
480 JSONRPC: mcp.JSONRPC_VERSION,
481 ID: request.ID,
482 Result: json.RawMessage(resultBytes),
483 }
484
485 return response, nil
486}
487func listByPage[T any](
488 ctx context.Context,
489 client *Client,
490 request mcp.PaginatedRequest,
491 method string,
492) (*T, error) {
493 response, err := client.sendRequest(ctx, method, request.Params)
494 if err != nil {
495 return nil, err
496 }
497 var result T
498 if err := json.Unmarshal(*response, &result); err != nil {
499 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
500 }
501 return &result, nil
502}
503
504// Helper methods
505
506// GetTransport gives access to the underlying transport layer.
507// Cast it to the specific transport type and obtain the other helper methods.
508func (c *Client) GetTransport() transport.Interface {
509 return c.transport
510}
511
512// GetServerCapabilities returns the server capabilities.
513func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
514 return c.serverCapabilities
515}
516
517// GetClientCapabilities returns the client capabilities.
518func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
519 return c.clientCapabilities
520}
521
522// GetSessionId returns the session ID of the transport.
523// If the transport does not support sessions, it returns an empty string.
524func (c *Client) GetSessionId() string {
525 if c.transport == nil {
526 return ""
527 }
528 return c.transport.GetSessionId()
529}
530
531// IsInitialized returns true if the client has been initialized.
532func (c *Client) IsInitialized() bool {
533 return c.initialized
534}