1package client
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "sync"
9 "sync/atomic"
10
11 "github.com/mark3labs/mcp-go/client/transport"
12 "github.com/mark3labs/mcp-go/mcp"
13)
14
15// Client implements the MCP client.
16type Client struct {
17 transport transport.Interface
18
19 initialized bool
20 notifications []func(mcp.JSONRPCNotification)
21 notifyMu sync.RWMutex
22 requestID atomic.Int64
23 clientCapabilities mcp.ClientCapabilities
24 serverCapabilities mcp.ServerCapabilities
25}
26
27type ClientOption func(*Client)
28
29// WithClientCapabilities sets the client capabilities for the client.
30func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
31 return func(c *Client) {
32 c.clientCapabilities = capabilities
33 }
34}
35
36// NewClient creates a new MCP client with the given transport.
37// Usage:
38//
39// stdio := transport.NewStdio("./mcp_server", nil, "xxx")
40// client, err := NewClient(stdio)
41// if err != nil {
42// log.Fatalf("Failed to create client: %v", err)
43// }
44func NewClient(transport transport.Interface, options ...ClientOption) *Client {
45 client := &Client{
46 transport: transport,
47 }
48
49 for _, opt := range options {
50 opt(client)
51 }
52
53 return client
54}
55
56// Start initiates the connection to the server.
57// Must be called before using the client.
58func (c *Client) Start(ctx context.Context) error {
59 if c.transport == nil {
60 return fmt.Errorf("transport is nil")
61 }
62 err := c.transport.Start(ctx)
63 if err != nil {
64 return err
65 }
66
67 c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) {
68 c.notifyMu.RLock()
69 defer c.notifyMu.RUnlock()
70 for _, handler := range c.notifications {
71 handler(notification)
72 }
73 })
74 return nil
75}
76
77// Close shuts down the client and closes the transport.
78func (c *Client) Close() error {
79 return c.transport.Close()
80}
81
82// OnNotification registers a handler function to be called when notifications are received.
83// Multiple handlers can be registered and will be called in the order they were added.
84func (c *Client) OnNotification(
85 handler func(notification mcp.JSONRPCNotification),
86) {
87 c.notifyMu.Lock()
88 defer c.notifyMu.Unlock()
89 c.notifications = append(c.notifications, handler)
90}
91
92// sendRequest sends a JSON-RPC request to the server and waits for a response.
93// Returns the raw JSON response message or an error if the request fails.
94func (c *Client) sendRequest(
95 ctx context.Context,
96 method string,
97 params any,
98) (*json.RawMessage, error) {
99 if !c.initialized && method != "initialize" {
100 return nil, fmt.Errorf("client not initialized")
101 }
102
103 id := c.requestID.Add(1)
104
105 request := transport.JSONRPCRequest{
106 JSONRPC: mcp.JSONRPC_VERSION,
107 ID: mcp.NewRequestId(id),
108 Method: method,
109 Params: params,
110 }
111
112 response, err := c.transport.SendRequest(ctx, request)
113 if err != nil {
114 return nil, fmt.Errorf("transport error: %w", err)
115 }
116
117 if response.Error != nil {
118 return nil, errors.New(response.Error.Message)
119 }
120
121 return &response.Result, nil
122}
123
124// Initialize negotiates with the server.
125// Must be called after Start, and before any request methods.
126func (c *Client) Initialize(
127 ctx context.Context,
128 request mcp.InitializeRequest,
129) (*mcp.InitializeResult, error) {
130 // Ensure we send a params object with all required fields
131 params := struct {
132 ProtocolVersion string `json:"protocolVersion"`
133 ClientInfo mcp.Implementation `json:"clientInfo"`
134 Capabilities mcp.ClientCapabilities `json:"capabilities"`
135 }{
136 ProtocolVersion: request.Params.ProtocolVersion,
137 ClientInfo: request.Params.ClientInfo,
138 Capabilities: request.Params.Capabilities, // Will be empty struct if not set
139 }
140
141 response, err := c.sendRequest(ctx, "initialize", params)
142 if err != nil {
143 return nil, err
144 }
145
146 var result mcp.InitializeResult
147 if err := json.Unmarshal(*response, &result); err != nil {
148 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
149 }
150
151 // Store serverCapabilities
152 c.serverCapabilities = result.Capabilities
153
154 // Send initialized notification
155 notification := mcp.JSONRPCNotification{
156 JSONRPC: mcp.JSONRPC_VERSION,
157 Notification: mcp.Notification{
158 Method: "notifications/initialized",
159 },
160 }
161
162 err = c.transport.SendNotification(ctx, notification)
163 if err != nil {
164 return nil, fmt.Errorf(
165 "failed to send initialized notification: %w",
166 err,
167 )
168 }
169
170 c.initialized = true
171 return &result, nil
172}
173
174func (c *Client) Ping(ctx context.Context) error {
175 _, err := c.sendRequest(ctx, "ping", nil)
176 return err
177}
178
179// ListResourcesByPage manually list resources by page.
180func (c *Client) ListResourcesByPage(
181 ctx context.Context,
182 request mcp.ListResourcesRequest,
183) (*mcp.ListResourcesResult, error) {
184 result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list")
185 if err != nil {
186 return nil, err
187 }
188 return result, nil
189}
190
191func (c *Client) ListResources(
192 ctx context.Context,
193 request mcp.ListResourcesRequest,
194) (*mcp.ListResourcesResult, error) {
195 result, err := c.ListResourcesByPage(ctx, request)
196 if err != nil {
197 return nil, err
198 }
199 for result.NextCursor != "" {
200 select {
201 case <-ctx.Done():
202 return nil, ctx.Err()
203 default:
204 request.Params.Cursor = result.NextCursor
205 newPageRes, err := c.ListResourcesByPage(ctx, request)
206 if err != nil {
207 return nil, err
208 }
209 result.Resources = append(result.Resources, newPageRes.Resources...)
210 result.NextCursor = newPageRes.NextCursor
211 }
212 }
213 return result, nil
214}
215
216func (c *Client) ListResourceTemplatesByPage(
217 ctx context.Context,
218 request mcp.ListResourceTemplatesRequest,
219) (*mcp.ListResourceTemplatesResult, error) {
220 result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list")
221 if err != nil {
222 return nil, err
223 }
224 return result, nil
225}
226
227func (c *Client) ListResourceTemplates(
228 ctx context.Context,
229 request mcp.ListResourceTemplatesRequest,
230) (*mcp.ListResourceTemplatesResult, error) {
231 result, err := c.ListResourceTemplatesByPage(ctx, request)
232 if err != nil {
233 return nil, err
234 }
235 for result.NextCursor != "" {
236 select {
237 case <-ctx.Done():
238 return nil, ctx.Err()
239 default:
240 request.Params.Cursor = result.NextCursor
241 newPageRes, err := c.ListResourceTemplatesByPage(ctx, request)
242 if err != nil {
243 return nil, err
244 }
245 result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...)
246 result.NextCursor = newPageRes.NextCursor
247 }
248 }
249 return result, nil
250}
251
252func (c *Client) ReadResource(
253 ctx context.Context,
254 request mcp.ReadResourceRequest,
255) (*mcp.ReadResourceResult, error) {
256 response, err := c.sendRequest(ctx, "resources/read", request.Params)
257 if err != nil {
258 return nil, err
259 }
260
261 return mcp.ParseReadResourceResult(response)
262}
263
264func (c *Client) Subscribe(
265 ctx context.Context,
266 request mcp.SubscribeRequest,
267) error {
268 _, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
269 return err
270}
271
272func (c *Client) Unsubscribe(
273 ctx context.Context,
274 request mcp.UnsubscribeRequest,
275) error {
276 _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
277 return err
278}
279
280func (c *Client) ListPromptsByPage(
281 ctx context.Context,
282 request mcp.ListPromptsRequest,
283) (*mcp.ListPromptsResult, error) {
284 result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list")
285 if err != nil {
286 return nil, err
287 }
288 return result, nil
289}
290
291func (c *Client) ListPrompts(
292 ctx context.Context,
293 request mcp.ListPromptsRequest,
294) (*mcp.ListPromptsResult, error) {
295 result, err := c.ListPromptsByPage(ctx, request)
296 if err != nil {
297 return nil, err
298 }
299 for result.NextCursor != "" {
300 select {
301 case <-ctx.Done():
302 return nil, ctx.Err()
303 default:
304 request.Params.Cursor = result.NextCursor
305 newPageRes, err := c.ListPromptsByPage(ctx, request)
306 if err != nil {
307 return nil, err
308 }
309 result.Prompts = append(result.Prompts, newPageRes.Prompts...)
310 result.NextCursor = newPageRes.NextCursor
311 }
312 }
313 return result, nil
314}
315
316func (c *Client) GetPrompt(
317 ctx context.Context,
318 request mcp.GetPromptRequest,
319) (*mcp.GetPromptResult, error) {
320 response, err := c.sendRequest(ctx, "prompts/get", request.Params)
321 if err != nil {
322 return nil, err
323 }
324
325 return mcp.ParseGetPromptResult(response)
326}
327
328func (c *Client) ListToolsByPage(
329 ctx context.Context,
330 request mcp.ListToolsRequest,
331) (*mcp.ListToolsResult, error) {
332 result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list")
333 if err != nil {
334 return nil, err
335 }
336 return result, nil
337}
338
339func (c *Client) ListTools(
340 ctx context.Context,
341 request mcp.ListToolsRequest,
342) (*mcp.ListToolsResult, error) {
343 result, err := c.ListToolsByPage(ctx, request)
344 if err != nil {
345 return nil, err
346 }
347 for result.NextCursor != "" {
348 select {
349 case <-ctx.Done():
350 return nil, ctx.Err()
351 default:
352 request.Params.Cursor = result.NextCursor
353 newPageRes, err := c.ListToolsByPage(ctx, request)
354 if err != nil {
355 return nil, err
356 }
357 result.Tools = append(result.Tools, newPageRes.Tools...)
358 result.NextCursor = newPageRes.NextCursor
359 }
360 }
361 return result, nil
362}
363
364func (c *Client) CallTool(
365 ctx context.Context,
366 request mcp.CallToolRequest,
367) (*mcp.CallToolResult, error) {
368 response, err := c.sendRequest(ctx, "tools/call", request.Params)
369 if err != nil {
370 return nil, err
371 }
372
373 return mcp.ParseCallToolResult(response)
374}
375
376func (c *Client) SetLevel(
377 ctx context.Context,
378 request mcp.SetLevelRequest,
379) error {
380 _, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
381 return err
382}
383
384func (c *Client) Complete(
385 ctx context.Context,
386 request mcp.CompleteRequest,
387) (*mcp.CompleteResult, error) {
388 response, err := c.sendRequest(ctx, "completion/complete", request.Params)
389 if err != nil {
390 return nil, err
391 }
392
393 var result mcp.CompleteResult
394 if err := json.Unmarshal(*response, &result); err != nil {
395 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
396 }
397
398 return &result, nil
399}
400
401func listByPage[T any](
402 ctx context.Context,
403 client *Client,
404 request mcp.PaginatedRequest,
405 method string,
406) (*T, error) {
407 response, err := client.sendRequest(ctx, method, request.Params)
408 if err != nil {
409 return nil, err
410 }
411 var result T
412 if err := json.Unmarshal(*response, &result); err != nil {
413 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
414 }
415 return &result, nil
416}
417
418// Helper methods
419
420// GetTransport gives access to the underlying transport layer.
421// Cast it to the specific transport type and obtain the other helper methods.
422func (c *Client) GetTransport() transport.Interface {
423 return c.transport
424}
425
426// GetServerCapabilities returns the server capabilities.
427func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
428 return c.serverCapabilities
429}
430
431// GetClientCapabilities returns the client capabilities.
432func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
433 return c.clientCapabilities
434}