client.go

  1package client
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"sync"
  9	"sync/atomic"
 10
 11	"github.com/mark3labs/mcp-go/client/transport"
 12	"github.com/mark3labs/mcp-go/mcp"
 13)
 14
 15// Client implements the MCP client.
 16type Client struct {
 17	transport transport.Interface
 18
 19	initialized        bool
 20	notifications      []func(mcp.JSONRPCNotification)
 21	notifyMu           sync.RWMutex
 22	requestID          atomic.Int64
 23	clientCapabilities mcp.ClientCapabilities
 24	serverCapabilities mcp.ServerCapabilities
 25}
 26
 27type ClientOption func(*Client)
 28
 29// WithClientCapabilities sets the client capabilities for the client.
 30func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
 31	return func(c *Client) {
 32		c.clientCapabilities = capabilities
 33	}
 34}
 35
 36// NewClient creates a new MCP client with the given transport.
 37// Usage:
 38//
 39//	stdio := transport.NewStdio("./mcp_server", nil, "xxx")
 40//	client, err := NewClient(stdio)
 41//	if err != nil {
 42//	    log.Fatalf("Failed to create client: %v", err)
 43//	}
 44func NewClient(transport transport.Interface, options ...ClientOption) *Client {
 45	client := &Client{
 46		transport: transport,
 47	}
 48
 49	for _, opt := range options {
 50		opt(client)
 51	}
 52
 53	return client
 54}
 55
 56// Start initiates the connection to the server.
 57// Must be called before using the client.
 58func (c *Client) Start(ctx context.Context) error {
 59	if c.transport == nil {
 60		return fmt.Errorf("transport is nil")
 61	}
 62	err := c.transport.Start(ctx)
 63	if err != nil {
 64		return err
 65	}
 66
 67	c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) {
 68		c.notifyMu.RLock()
 69		defer c.notifyMu.RUnlock()
 70		for _, handler := range c.notifications {
 71			handler(notification)
 72		}
 73	})
 74	return nil
 75}
 76
 77// Close shuts down the client and closes the transport.
 78func (c *Client) Close() error {
 79	return c.transport.Close()
 80}
 81
 82// OnNotification registers a handler function to be called when notifications are received.
 83// Multiple handlers can be registered and will be called in the order they were added.
 84func (c *Client) OnNotification(
 85	handler func(notification mcp.JSONRPCNotification),
 86) {
 87	c.notifyMu.Lock()
 88	defer c.notifyMu.Unlock()
 89	c.notifications = append(c.notifications, handler)
 90}
 91
 92// sendRequest sends a JSON-RPC request to the server and waits for a response.
 93// Returns the raw JSON response message or an error if the request fails.
 94func (c *Client) sendRequest(
 95	ctx context.Context,
 96	method string,
 97	params any,
 98) (*json.RawMessage, error) {
 99	if !c.initialized && method != "initialize" {
100		return nil, fmt.Errorf("client not initialized")
101	}
102
103	id := c.requestID.Add(1)
104
105	request := transport.JSONRPCRequest{
106		JSONRPC: mcp.JSONRPC_VERSION,
107		ID:      mcp.NewRequestId(id),
108		Method:  method,
109		Params:  params,
110	}
111
112	response, err := c.transport.SendRequest(ctx, request)
113	if err != nil {
114		return nil, fmt.Errorf("transport error: %w", err)
115	}
116
117	if response.Error != nil {
118		return nil, errors.New(response.Error.Message)
119	}
120
121	return &response.Result, nil
122}
123
124// Initialize negotiates with the server.
125// Must be called after Start, and before any request methods.
126func (c *Client) Initialize(
127	ctx context.Context,
128	request mcp.InitializeRequest,
129) (*mcp.InitializeResult, error) {
130	// Ensure we send a params object with all required fields
131	params := struct {
132		ProtocolVersion string                 `json:"protocolVersion"`
133		ClientInfo      mcp.Implementation     `json:"clientInfo"`
134		Capabilities    mcp.ClientCapabilities `json:"capabilities"`
135	}{
136		ProtocolVersion: request.Params.ProtocolVersion,
137		ClientInfo:      request.Params.ClientInfo,
138		Capabilities:    request.Params.Capabilities, // Will be empty struct if not set
139	}
140
141	response, err := c.sendRequest(ctx, "initialize", params)
142	if err != nil {
143		return nil, err
144	}
145
146	var result mcp.InitializeResult
147	if err := json.Unmarshal(*response, &result); err != nil {
148		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
149	}
150
151	// Store serverCapabilities
152	c.serverCapabilities = result.Capabilities
153
154	// Send initialized notification
155	notification := mcp.JSONRPCNotification{
156		JSONRPC: mcp.JSONRPC_VERSION,
157		Notification: mcp.Notification{
158			Method: "notifications/initialized",
159		},
160	}
161
162	err = c.transport.SendNotification(ctx, notification)
163	if err != nil {
164		return nil, fmt.Errorf(
165			"failed to send initialized notification: %w",
166			err,
167		)
168	}
169
170	c.initialized = true
171	return &result, nil
172}
173
174func (c *Client) Ping(ctx context.Context) error {
175	_, err := c.sendRequest(ctx, "ping", nil)
176	return err
177}
178
179// ListResourcesByPage manually list resources by page.
180func (c *Client) ListResourcesByPage(
181	ctx context.Context,
182	request mcp.ListResourcesRequest,
183) (*mcp.ListResourcesResult, error) {
184	result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list")
185	if err != nil {
186		return nil, err
187	}
188	return result, nil
189}
190
191func (c *Client) ListResources(
192	ctx context.Context,
193	request mcp.ListResourcesRequest,
194) (*mcp.ListResourcesResult, error) {
195	result, err := c.ListResourcesByPage(ctx, request)
196	if err != nil {
197		return nil, err
198	}
199	for result.NextCursor != "" {
200		select {
201		case <-ctx.Done():
202			return nil, ctx.Err()
203		default:
204			request.Params.Cursor = result.NextCursor
205			newPageRes, err := c.ListResourcesByPage(ctx, request)
206			if err != nil {
207				return nil, err
208			}
209			result.Resources = append(result.Resources, newPageRes.Resources...)
210			result.NextCursor = newPageRes.NextCursor
211		}
212	}
213	return result, nil
214}
215
216func (c *Client) ListResourceTemplatesByPage(
217	ctx context.Context,
218	request mcp.ListResourceTemplatesRequest,
219) (*mcp.ListResourceTemplatesResult, error) {
220	result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list")
221	if err != nil {
222		return nil, err
223	}
224	return result, nil
225}
226
227func (c *Client) ListResourceTemplates(
228	ctx context.Context,
229	request mcp.ListResourceTemplatesRequest,
230) (*mcp.ListResourceTemplatesResult, error) {
231	result, err := c.ListResourceTemplatesByPage(ctx, request)
232	if err != nil {
233		return nil, err
234	}
235	for result.NextCursor != "" {
236		select {
237		case <-ctx.Done():
238			return nil, ctx.Err()
239		default:
240			request.Params.Cursor = result.NextCursor
241			newPageRes, err := c.ListResourceTemplatesByPage(ctx, request)
242			if err != nil {
243				return nil, err
244			}
245			result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...)
246			result.NextCursor = newPageRes.NextCursor
247		}
248	}
249	return result, nil
250}
251
252func (c *Client) ReadResource(
253	ctx context.Context,
254	request mcp.ReadResourceRequest,
255) (*mcp.ReadResourceResult, error) {
256	response, err := c.sendRequest(ctx, "resources/read", request.Params)
257	if err != nil {
258		return nil, err
259	}
260
261	return mcp.ParseReadResourceResult(response)
262}
263
264func (c *Client) Subscribe(
265	ctx context.Context,
266	request mcp.SubscribeRequest,
267) error {
268	_, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
269	return err
270}
271
272func (c *Client) Unsubscribe(
273	ctx context.Context,
274	request mcp.UnsubscribeRequest,
275) error {
276	_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
277	return err
278}
279
280func (c *Client) ListPromptsByPage(
281	ctx context.Context,
282	request mcp.ListPromptsRequest,
283) (*mcp.ListPromptsResult, error) {
284	result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list")
285	if err != nil {
286		return nil, err
287	}
288	return result, nil
289}
290
291func (c *Client) ListPrompts(
292	ctx context.Context,
293	request mcp.ListPromptsRequest,
294) (*mcp.ListPromptsResult, error) {
295	result, err := c.ListPromptsByPage(ctx, request)
296	if err != nil {
297		return nil, err
298	}
299	for result.NextCursor != "" {
300		select {
301		case <-ctx.Done():
302			return nil, ctx.Err()
303		default:
304			request.Params.Cursor = result.NextCursor
305			newPageRes, err := c.ListPromptsByPage(ctx, request)
306			if err != nil {
307				return nil, err
308			}
309			result.Prompts = append(result.Prompts, newPageRes.Prompts...)
310			result.NextCursor = newPageRes.NextCursor
311		}
312	}
313	return result, nil
314}
315
316func (c *Client) GetPrompt(
317	ctx context.Context,
318	request mcp.GetPromptRequest,
319) (*mcp.GetPromptResult, error) {
320	response, err := c.sendRequest(ctx, "prompts/get", request.Params)
321	if err != nil {
322		return nil, err
323	}
324
325	return mcp.ParseGetPromptResult(response)
326}
327
328func (c *Client) ListToolsByPage(
329	ctx context.Context,
330	request mcp.ListToolsRequest,
331) (*mcp.ListToolsResult, error) {
332	result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list")
333	if err != nil {
334		return nil, err
335	}
336	return result, nil
337}
338
339func (c *Client) ListTools(
340	ctx context.Context,
341	request mcp.ListToolsRequest,
342) (*mcp.ListToolsResult, error) {
343	result, err := c.ListToolsByPage(ctx, request)
344	if err != nil {
345		return nil, err
346	}
347	for result.NextCursor != "" {
348		select {
349		case <-ctx.Done():
350			return nil, ctx.Err()
351		default:
352			request.Params.Cursor = result.NextCursor
353			newPageRes, err := c.ListToolsByPage(ctx, request)
354			if err != nil {
355				return nil, err
356			}
357			result.Tools = append(result.Tools, newPageRes.Tools...)
358			result.NextCursor = newPageRes.NextCursor
359		}
360	}
361	return result, nil
362}
363
364func (c *Client) CallTool(
365	ctx context.Context,
366	request mcp.CallToolRequest,
367) (*mcp.CallToolResult, error) {
368	response, err := c.sendRequest(ctx, "tools/call", request.Params)
369	if err != nil {
370		return nil, err
371	}
372
373	return mcp.ParseCallToolResult(response)
374}
375
376func (c *Client) SetLevel(
377	ctx context.Context,
378	request mcp.SetLevelRequest,
379) error {
380	_, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
381	return err
382}
383
384func (c *Client) Complete(
385	ctx context.Context,
386	request mcp.CompleteRequest,
387) (*mcp.CompleteResult, error) {
388	response, err := c.sendRequest(ctx, "completion/complete", request.Params)
389	if err != nil {
390		return nil, err
391	}
392
393	var result mcp.CompleteResult
394	if err := json.Unmarshal(*response, &result); err != nil {
395		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
396	}
397
398	return &result, nil
399}
400
401func listByPage[T any](
402	ctx context.Context,
403	client *Client,
404	request mcp.PaginatedRequest,
405	method string,
406) (*T, error) {
407	response, err := client.sendRequest(ctx, method, request.Params)
408	if err != nil {
409		return nil, err
410	}
411	var result T
412	if err := json.Unmarshal(*response, &result); err != nil {
413		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
414	}
415	return &result, nil
416}
417
418// Helper methods
419
420// GetTransport gives access to the underlying transport layer.
421// Cast it to the specific transport type and obtain the other helper methods.
422func (c *Client) GetTransport() transport.Interface {
423	return c.transport
424}
425
426// GetServerCapabilities returns the server capabilities.
427func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
428	return c.serverCapabilities
429}
430
431// GetClientCapabilities returns the client capabilities.
432func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
433	return c.clientCapabilities
434}