client.go

  1package client
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"sync"
  9	"sync/atomic"
 10
 11	"github.com/mark3labs/mcp-go/client/transport"
 12	"github.com/mark3labs/mcp-go/mcp"
 13)
 14
 15// Client implements the MCP client.
 16type Client struct {
 17	transport transport.Interface
 18
 19	initialized        bool
 20	notifications      []func(mcp.JSONRPCNotification)
 21	notifyMu           sync.RWMutex
 22	requestID          atomic.Int64
 23	clientCapabilities mcp.ClientCapabilities
 24	serverCapabilities mcp.ServerCapabilities
 25	samplingHandler    SamplingHandler
 26}
 27
 28type ClientOption func(*Client)
 29
 30// WithClientCapabilities sets the client capabilities for the client.
 31func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
 32	return func(c *Client) {
 33		c.clientCapabilities = capabilities
 34	}
 35}
 36
 37// WithSamplingHandler sets the sampling handler for the client.
 38// When set, the client will declare sampling capability during initialization.
 39func WithSamplingHandler(handler SamplingHandler) ClientOption {
 40	return func(c *Client) {
 41		c.samplingHandler = handler
 42  }
 43}
 44
 45// WithSession assumes a MCP Session has already been initialized
 46func WithSession() ClientOption {
 47	return func(c *Client) {
 48		c.initialized = true
 49	}
 50}
 51
 52// NewClient creates a new MCP client with the given transport.
 53// Usage:
 54//
 55//	stdio := transport.NewStdio("./mcp_server", nil, "xxx")
 56//	client, err := NewClient(stdio)
 57//	if err != nil {
 58//	    log.Fatalf("Failed to create client: %v", err)
 59//	}
 60func NewClient(transport transport.Interface, options ...ClientOption) *Client {
 61	client := &Client{
 62		transport: transport,
 63	}
 64
 65	for _, opt := range options {
 66		opt(client)
 67	}
 68
 69	return client
 70}
 71
 72// Start initiates the connection to the server.
 73// Must be called before using the client.
 74func (c *Client) Start(ctx context.Context) error {
 75	if c.transport == nil {
 76		return fmt.Errorf("transport is nil")
 77	}
 78	err := c.transport.Start(ctx)
 79	if err != nil {
 80		return err
 81	}
 82
 83	c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) {
 84		c.notifyMu.RLock()
 85		defer c.notifyMu.RUnlock()
 86		for _, handler := range c.notifications {
 87			handler(notification)
 88		}
 89	})
 90
 91	// Set up request handler for bidirectional communication (e.g., sampling)
 92	if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok {
 93		bidirectional.SetRequestHandler(c.handleIncomingRequest)
 94	}
 95
 96	return nil
 97}
 98
 99// Close shuts down the client and closes the transport.
100func (c *Client) Close() error {
101	return c.transport.Close()
102}
103
104// OnNotification registers a handler function to be called when notifications are received.
105// Multiple handlers can be registered and will be called in the order they were added.
106func (c *Client) OnNotification(
107	handler func(notification mcp.JSONRPCNotification),
108) {
109	c.notifyMu.Lock()
110	defer c.notifyMu.Unlock()
111	c.notifications = append(c.notifications, handler)
112}
113
114// sendRequest sends a JSON-RPC request to the server and waits for a response.
115// Returns the raw JSON response message or an error if the request fails.
116func (c *Client) sendRequest(
117	ctx context.Context,
118	method string,
119	params any,
120) (*json.RawMessage, error) {
121	if !c.initialized && method != "initialize" {
122		return nil, fmt.Errorf("client not initialized")
123	}
124
125	id := c.requestID.Add(1)
126
127	request := transport.JSONRPCRequest{
128		JSONRPC: mcp.JSONRPC_VERSION,
129		ID:      mcp.NewRequestId(id),
130		Method:  method,
131		Params:  params,
132	}
133
134	response, err := c.transport.SendRequest(ctx, request)
135	if err != nil {
136		return nil, fmt.Errorf("transport error: %w", err)
137	}
138
139	if response.Error != nil {
140		return nil, errors.New(response.Error.Message)
141	}
142
143	return &response.Result, nil
144}
145
146// Initialize negotiates with the server.
147// Must be called after Start, and before any request methods.
148func (c *Client) Initialize(
149	ctx context.Context,
150	request mcp.InitializeRequest,
151) (*mcp.InitializeResult, error) {
152	// Merge client capabilities with sampling capability if handler is configured
153	capabilities := request.Params.Capabilities
154	if c.samplingHandler != nil {
155		capabilities.Sampling = &struct{}{}
156	}
157
158	// Ensure we send a params object with all required fields
159	params := struct {
160		ProtocolVersion string                 `json:"protocolVersion"`
161		ClientInfo      mcp.Implementation     `json:"clientInfo"`
162		Capabilities    mcp.ClientCapabilities `json:"capabilities"`
163	}{
164		ProtocolVersion: request.Params.ProtocolVersion,
165		ClientInfo:      request.Params.ClientInfo,
166		Capabilities:    capabilities,
167	}
168
169	response, err := c.sendRequest(ctx, "initialize", params)
170	if err != nil {
171		return nil, err
172	}
173
174	var result mcp.InitializeResult
175	if err := json.Unmarshal(*response, &result); err != nil {
176		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
177	}
178
179	// Store serverCapabilities
180	c.serverCapabilities = result.Capabilities
181
182	// Send initialized notification
183	notification := mcp.JSONRPCNotification{
184		JSONRPC: mcp.JSONRPC_VERSION,
185		Notification: mcp.Notification{
186			Method: "notifications/initialized",
187		},
188	}
189
190	err = c.transport.SendNotification(ctx, notification)
191	if err != nil {
192		return nil, fmt.Errorf(
193			"failed to send initialized notification: %w",
194			err,
195		)
196	}
197
198	c.initialized = true
199	return &result, nil
200}
201
202func (c *Client) Ping(ctx context.Context) error {
203	_, err := c.sendRequest(ctx, "ping", nil)
204	return err
205}
206
207// ListResourcesByPage manually list resources by page.
208func (c *Client) ListResourcesByPage(
209	ctx context.Context,
210	request mcp.ListResourcesRequest,
211) (*mcp.ListResourcesResult, error) {
212	result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list")
213	if err != nil {
214		return nil, err
215	}
216	return result, nil
217}
218
219func (c *Client) ListResources(
220	ctx context.Context,
221	request mcp.ListResourcesRequest,
222) (*mcp.ListResourcesResult, error) {
223	result, err := c.ListResourcesByPage(ctx, request)
224	if err != nil {
225		return nil, err
226	}
227	for result.NextCursor != "" {
228		select {
229		case <-ctx.Done():
230			return nil, ctx.Err()
231		default:
232			request.Params.Cursor = result.NextCursor
233			newPageRes, err := c.ListResourcesByPage(ctx, request)
234			if err != nil {
235				return nil, err
236			}
237			result.Resources = append(result.Resources, newPageRes.Resources...)
238			result.NextCursor = newPageRes.NextCursor
239		}
240	}
241	return result, nil
242}
243
244func (c *Client) ListResourceTemplatesByPage(
245	ctx context.Context,
246	request mcp.ListResourceTemplatesRequest,
247) (*mcp.ListResourceTemplatesResult, error) {
248	result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list")
249	if err != nil {
250		return nil, err
251	}
252	return result, nil
253}
254
255func (c *Client) ListResourceTemplates(
256	ctx context.Context,
257	request mcp.ListResourceTemplatesRequest,
258) (*mcp.ListResourceTemplatesResult, error) {
259	result, err := c.ListResourceTemplatesByPage(ctx, request)
260	if err != nil {
261		return nil, err
262	}
263	for result.NextCursor != "" {
264		select {
265		case <-ctx.Done():
266			return nil, ctx.Err()
267		default:
268			request.Params.Cursor = result.NextCursor
269			newPageRes, err := c.ListResourceTemplatesByPage(ctx, request)
270			if err != nil {
271				return nil, err
272			}
273			result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...)
274			result.NextCursor = newPageRes.NextCursor
275		}
276	}
277	return result, nil
278}
279
280func (c *Client) ReadResource(
281	ctx context.Context,
282	request mcp.ReadResourceRequest,
283) (*mcp.ReadResourceResult, error) {
284	response, err := c.sendRequest(ctx, "resources/read", request.Params)
285	if err != nil {
286		return nil, err
287	}
288
289	return mcp.ParseReadResourceResult(response)
290}
291
292func (c *Client) Subscribe(
293	ctx context.Context,
294	request mcp.SubscribeRequest,
295) error {
296	_, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
297	return err
298}
299
300func (c *Client) Unsubscribe(
301	ctx context.Context,
302	request mcp.UnsubscribeRequest,
303) error {
304	_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
305	return err
306}
307
308func (c *Client) ListPromptsByPage(
309	ctx context.Context,
310	request mcp.ListPromptsRequest,
311) (*mcp.ListPromptsResult, error) {
312	result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list")
313	if err != nil {
314		return nil, err
315	}
316	return result, nil
317}
318
319func (c *Client) ListPrompts(
320	ctx context.Context,
321	request mcp.ListPromptsRequest,
322) (*mcp.ListPromptsResult, error) {
323	result, err := c.ListPromptsByPage(ctx, request)
324	if err != nil {
325		return nil, err
326	}
327	for result.NextCursor != "" {
328		select {
329		case <-ctx.Done():
330			return nil, ctx.Err()
331		default:
332			request.Params.Cursor = result.NextCursor
333			newPageRes, err := c.ListPromptsByPage(ctx, request)
334			if err != nil {
335				return nil, err
336			}
337			result.Prompts = append(result.Prompts, newPageRes.Prompts...)
338			result.NextCursor = newPageRes.NextCursor
339		}
340	}
341	return result, nil
342}
343
344func (c *Client) GetPrompt(
345	ctx context.Context,
346	request mcp.GetPromptRequest,
347) (*mcp.GetPromptResult, error) {
348	response, err := c.sendRequest(ctx, "prompts/get", request.Params)
349	if err != nil {
350		return nil, err
351	}
352
353	return mcp.ParseGetPromptResult(response)
354}
355
356func (c *Client) ListToolsByPage(
357	ctx context.Context,
358	request mcp.ListToolsRequest,
359) (*mcp.ListToolsResult, error) {
360	result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list")
361	if err != nil {
362		return nil, err
363	}
364	return result, nil
365}
366
367func (c *Client) ListTools(
368	ctx context.Context,
369	request mcp.ListToolsRequest,
370) (*mcp.ListToolsResult, error) {
371	result, err := c.ListToolsByPage(ctx, request)
372	if err != nil {
373		return nil, err
374	}
375	for result.NextCursor != "" {
376		select {
377		case <-ctx.Done():
378			return nil, ctx.Err()
379		default:
380			request.Params.Cursor = result.NextCursor
381			newPageRes, err := c.ListToolsByPage(ctx, request)
382			if err != nil {
383				return nil, err
384			}
385			result.Tools = append(result.Tools, newPageRes.Tools...)
386			result.NextCursor = newPageRes.NextCursor
387		}
388	}
389	return result, nil
390}
391
392func (c *Client) CallTool(
393	ctx context.Context,
394	request mcp.CallToolRequest,
395) (*mcp.CallToolResult, error) {
396	response, err := c.sendRequest(ctx, "tools/call", request.Params)
397	if err != nil {
398		return nil, err
399	}
400
401	return mcp.ParseCallToolResult(response)
402}
403
404func (c *Client) SetLevel(
405	ctx context.Context,
406	request mcp.SetLevelRequest,
407) error {
408	_, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
409	return err
410}
411
412func (c *Client) Complete(
413	ctx context.Context,
414	request mcp.CompleteRequest,
415) (*mcp.CompleteResult, error) {
416	response, err := c.sendRequest(ctx, "completion/complete", request.Params)
417	if err != nil {
418		return nil, err
419	}
420
421	var result mcp.CompleteResult
422	if err := json.Unmarshal(*response, &result); err != nil {
423		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
424	}
425
426	return &result, nil
427}
428
429// handleIncomingRequest processes incoming requests from the server.
430// This is the main entry point for server-to-client requests like sampling.
431func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
432	switch request.Method {
433	case string(mcp.MethodSamplingCreateMessage):
434		return c.handleSamplingRequestTransport(ctx, request)
435	default:
436		return nil, fmt.Errorf("unsupported request method: %s", request.Method)
437	}
438}
439
440// handleSamplingRequestTransport handles sampling requests at the transport level.
441func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
442	if c.samplingHandler == nil {
443		return nil, fmt.Errorf("no sampling handler configured")
444	}
445
446	// Parse the request parameters
447	var params mcp.CreateMessageParams
448	if request.Params != nil {
449		paramsBytes, err := json.Marshal(request.Params)
450		if err != nil {
451			return nil, fmt.Errorf("failed to marshal params: %w", err)
452		}
453		if err := json.Unmarshal(paramsBytes, &params); err != nil {
454			return nil, fmt.Errorf("failed to unmarshal params: %w", err)
455		}
456	}
457
458	// Create the MCP request
459	mcpRequest := mcp.CreateMessageRequest{
460		Request: mcp.Request{
461			Method: string(mcp.MethodSamplingCreateMessage),
462		},
463		CreateMessageParams: params,
464	}
465
466	// Call the sampling handler
467	result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest)
468	if err != nil {
469		return nil, err
470	}
471
472	// Marshal the result
473	resultBytes, err := json.Marshal(result)
474	if err != nil {
475		return nil, fmt.Errorf("failed to marshal result: %w", err)
476	}
477
478	// Create the transport response
479	response := &transport.JSONRPCResponse{
480		JSONRPC: mcp.JSONRPC_VERSION,
481		ID:      request.ID,
482		Result:  json.RawMessage(resultBytes),
483	}
484
485	return response, nil
486}
487func listByPage[T any](
488	ctx context.Context,
489	client *Client,
490	request mcp.PaginatedRequest,
491	method string,
492) (*T, error) {
493	response, err := client.sendRequest(ctx, method, request.Params)
494	if err != nil {
495		return nil, err
496	}
497	var result T
498	if err := json.Unmarshal(*response, &result); err != nil {
499		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
500	}
501	return &result, nil
502}
503
504// Helper methods
505
506// GetTransport gives access to the underlying transport layer.
507// Cast it to the specific transport type and obtain the other helper methods.
508func (c *Client) GetTransport() transport.Interface {
509	return c.transport
510}
511
512// GetServerCapabilities returns the server capabilities.
513func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
514	return c.serverCapabilities
515}
516
517// GetClientCapabilities returns the client capabilities.
518func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
519	return c.clientCapabilities
520}
521
522// GetSessionId returns the session ID of the transport.
523// If the transport does not support sessions, it returns an empty string.
524func (c *Client) GetSessionId() string {
525	if c.transport == nil {
526		return ""
527	}
528	return c.transport.GetSessionId()
529}
530
531// IsInitialized returns true if the client has been initialized.
532func (c *Client) IsInitialized() bool {
533	return c.initialized
534}