mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"log/slog"
 11	"maps"
 12	"net/http"
 13	"os/exec"
 14	"strings"
 15	"sync"
 16	"time"
 17
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20	"github.com/charmbracelet/crush/internal/home"
 21	"github.com/charmbracelet/crush/internal/llm/tools"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/version"
 25	"github.com/modelcontextprotocol/go-sdk/mcp"
 26)
 27
 28// MCPState represents the current state of an MCP client
 29type MCPState int
 30
 31const (
 32	MCPStateDisabled MCPState = iota
 33	MCPStateStarting
 34	MCPStateConnected
 35	MCPStateError
 36)
 37
 38func (s MCPState) String() string {
 39	switch s {
 40	case MCPStateDisabled:
 41		return "disabled"
 42	case MCPStateStarting:
 43		return "starting"
 44	case MCPStateConnected:
 45		return "connected"
 46	case MCPStateError:
 47		return "error"
 48	default:
 49		return "unknown"
 50	}
 51}
 52
 53// MCPEventType represents the type of MCP event
 54type MCPEventType string
 55
 56const (
 57	MCPEventToolsListChanged   MCPEventType = "tools_list_changed"
 58	MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
 59)
 60
 61// MCPEvent represents an event in the MCP system
 62type MCPEvent struct {
 63	Type   MCPEventType
 64	Name   string
 65	State  MCPState
 66	Error  error
 67	Counts MCPCounts
 68}
 69
 70// MCPCounts number of available tools, prompts, etc.
 71type MCPCounts struct {
 72	Tools     int
 73	Prompts   int
 74	Resources int
 75}
 76
 77// MCPClientInfo holds information about an MCP client's state
 78type MCPClientInfo struct {
 79	Name        string
 80	State       MCPState
 81	Error       error
 82	Client      *mcp.ClientSession
 83	Counts      MCPCounts
 84	ConnectedAt time.Time
 85}
 86
 87var (
 88	mcpToolsOnce        sync.Once
 89	mcpTools            = csync.NewMap[string, tools.BaseTool]()
 90	mcpClient2Tools     = csync.NewMap[string, []tools.BaseTool]()
 91	mcpClients          = csync.NewMap[string, *mcp.ClientSession]()
 92	mcpStates           = csync.NewMap[string, MCPClientInfo]()
 93	mcpBroker           = pubsub.NewBroker[MCPEvent]()
 94	mcpPrompts          = csync.NewMap[string, *mcp.Prompt]()
 95	mcpClient2Prompts   = csync.NewMap[string, []*mcp.Prompt]()
 96	mcpResources        = csync.NewMap[string, *mcp.Resource]()
 97	mcpClient2Resources = csync.NewMap[string, []*mcp.Resource]()
 98)
 99
100type McpTool struct {
101	mcpName     string
102	tool        *mcp.Tool
103	permissions permission.Service
104	workingDir  string
105}
106
107func (b *McpTool) Name() string {
108	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
109}
110
111func (b *McpTool) Info() tools.ToolInfo {
112	input := b.tool.InputSchema.(map[string]any)
113	required, _ := input["required"].([]string)
114	parameters, _ := input["properties"].(map[string]any)
115	return tools.ToolInfo{
116		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
117		Description: b.tool.Description,
118		Parameters:  parameters,
119		Required:    required,
120	}
121}
122
123func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
124	var args map[string]any
125	if err := json.Unmarshal([]byte(input), &args); err != nil {
126		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
127	}
128
129	c, err := getOrRenewClient(ctx, name)
130	if err != nil {
131		return tools.NewTextErrorResponse(err.Error()), nil
132	}
133	result, err := c.CallTool(ctx, &mcp.CallToolParams{
134		Name:      toolName,
135		Arguments: args,
136	})
137	if err != nil {
138		return tools.NewTextErrorResponse(err.Error()), nil
139	}
140
141	output := make([]string, 0, len(result.Content))
142	for _, v := range result.Content {
143		if vv, ok := v.(*mcp.TextContent); ok {
144			output = append(output, vv.Text)
145		} else {
146			output = append(output, fmt.Sprintf("%v", v))
147		}
148	}
149	return tools.NewTextResponse(strings.Join(output, "\n")), nil
150}
151
152func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
153	sess, ok := mcpClients.Get(name)
154	if !ok {
155		return nil, fmt.Errorf("mcp '%s' not available", name)
156	}
157
158	cfg := config.Get()
159	m := cfg.MCP[name]
160	state, _ := mcpStates.Get(name)
161
162	timeout := mcpTimeout(m)
163	pingCtx, cancel := context.WithTimeout(ctx, timeout)
164	defer cancel()
165	err := sess.Ping(pingCtx, nil)
166	if err == nil {
167		return sess, nil
168	}
169	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
170
171	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
172	if err != nil {
173		return nil, err
174	}
175
176	updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
177	mcpClients.Set(name, sess)
178	return sess, nil
179}
180
181func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
182	sessionID, messageID := tools.GetContextValues(ctx)
183	if sessionID == "" || messageID == "" {
184		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
185	}
186	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
187	p := b.permissions.Request(
188		permission.CreatePermissionRequest{
189			SessionID:   sessionID,
190			ToolCallID:  params.ID,
191			Path:        b.workingDir,
192			ToolName:    b.Info().Name,
193			Action:      "execute",
194			Description: permissionDescription,
195			Params:      params.Input,
196		},
197	)
198	if !p {
199		return tools.ToolResponse{}, permission.ErrorPermissionDenied
200	}
201
202	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
203}
204
205func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
206	if c.InitializeResult().Capabilities.Tools == nil {
207		return nil, nil
208	}
209	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
210	if err != nil {
211		return nil, err
212	}
213	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
214	for _, tool := range result.Tools {
215		mcpTools = append(mcpTools, &McpTool{
216			mcpName:     name,
217			tool:        tool,
218			permissions: permissions,
219			workingDir:  workingDir,
220		})
221	}
222	return mcpTools, nil
223}
224
225// SubscribeMCPEvents returns a channel for MCP events
226func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
227	return mcpBroker.Subscribe(ctx)
228}
229
230// GetMCPStates returns the current state of all MCP clients
231func GetMCPStates() map[string]MCPClientInfo {
232	return maps.Collect(mcpStates.Seq2())
233}
234
235// GetMCPState returns the state of a specific MCP client
236func GetMCPState(name string) (MCPClientInfo, bool) {
237	return mcpStates.Get(name)
238}
239
240// updateMCPState updates the state of an MCP client and publishes an event
241func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
242	info := MCPClientInfo{
243		Name:   name,
244		State:  state,
245		Error:  err,
246		Client: client,
247		Counts: counts,
248	}
249	switch state {
250	case MCPStateConnected:
251		info.ConnectedAt = time.Now()
252	case MCPStateError:
253		updateMcpTools(name, nil)
254		updateMcpPrompts(name, nil)
255		mcpClients.Del(name)
256	}
257	mcpStates.Set(name, info)
258}
259
260// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
261func CloseMCPClients() error {
262	var errs []error
263	for name, c := range mcpClients.Seq2() {
264		if err := c.Close(); err != nil &&
265			!errors.Is(err, io.EOF) &&
266			!errors.Is(err, context.Canceled) &&
267			err.Error() != "signal: killed" {
268			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
269		}
270	}
271	mcpBroker.Shutdown()
272	return errors.Join(errs...)
273}
274
275func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
276	var wg sync.WaitGroup
277	// Initialize states for all configured MCPs
278	for name, m := range cfg.MCP {
279		if m.Disabled {
280			updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
281			slog.Debug("skipping disabled mcp", "name", name)
282			continue
283		}
284
285		// Set initial starting state
286		updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})
287
288		wg.Add(1)
289		go func(name string, m config.MCPConfig) {
290			defer func() {
291				wg.Done()
292				if r := recover(); r != nil {
293					var err error
294					switch v := r.(type) {
295					case error:
296						err = v
297					case string:
298						err = fmt.Errorf("panic: %s", v)
299					default:
300						err = fmt.Errorf("panic: %v", v)
301					}
302					updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
303					slog.Error("panic in mcp client initialization", "error", err, "name", name)
304				}
305			}()
306
307			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
308			defer cancel()
309
310			c, err := createMCPSession(ctx, name, m, cfg.Resolver())
311			if err != nil {
312				return
313			}
314
315			mcpClients.Set(name, c)
316
317			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
318			if err != nil {
319				slog.Error("error listing tools", "error", err)
320				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
321				c.Close()
322				return
323			}
324
325			prompts, err := getPrompts(ctx, c)
326			if err != nil {
327				slog.Error("error listing prompts", "error", err)
328				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
329				c.Close()
330				return
331			}
332
333			resources, err := getResources(ctx, c)
334			if err != nil {
335				slog.Error("error listing resources", "error", err)
336				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
337				c.Close()
338				return
339			}
340
341			updateMcpTools(name, tools)
342			updateMcpPrompts(name, prompts)
343			updateMcpResources(name, resources)
344			mcpClients.Set(name, c)
345			counts := MCPCounts{
346				Tools:     len(tools),
347				Prompts:   len(prompts),
348				Resources: len(resources),
349			}
350			updateMCPState(name, MCPStateConnected, nil, c, counts)
351		}(name, m)
352	}
353	wg.Wait()
354}
355
356// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
357func updateMcpTools(mcpName string, tools []tools.BaseTool) {
358	if len(tools) == 0 {
359		mcpClient2Tools.Del(mcpName)
360	} else {
361		mcpClient2Tools.Set(mcpName, tools)
362	}
363	for _, tools := range mcpClient2Tools.Seq2() {
364		for _, t := range tools {
365			mcpTools.Set(t.Name(), t)
366		}
367	}
368}
369
370func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
371	timeout := mcpTimeout(m)
372	mcpCtx, cancel := context.WithCancel(ctx)
373	cancelTimer := time.AfterFunc(timeout, cancel)
374
375	transport, err := createMCPTransport(mcpCtx, m, resolver)
376	if err != nil {
377		updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
378		slog.Error("error creating mcp client", "error", err, "name", name)
379		return nil, err
380	}
381
382	client := mcp.NewClient(
383		&mcp.Implementation{
384			Name:    "crush",
385			Version: version.Version,
386			Title:   "Crush",
387		},
388		&mcp.ClientOptions{
389			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
390				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
391					Type: MCPEventToolsListChanged,
392					Name: name,
393				})
394			},
395			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
396				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
397					Type: MCPEventPromptsListChanged,
398					Name: name,
399				})
400			},
401			KeepAlive: time.Minute * 10,
402		},
403	)
404
405	session, err := client.Connect(mcpCtx, transport, nil)
406	if err != nil {
407		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
408		slog.Error("error starting mcp client", "error", err, "name", name)
409		cancel()
410		return nil, err
411	}
412
413	cancelTimer.Stop()
414	slog.Info("Initialized mcp client", "name", name)
415	return session, nil
416}
417
418func maybeTimeoutErr(err error, timeout time.Duration) error {
419	if errors.Is(err, context.Canceled) {
420		return fmt.Errorf("timed out after %s", timeout)
421	}
422	return err
423}
424
425func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
426	switch m.Type {
427	case config.MCPStdio:
428		command, err := resolver.ResolveValue(m.Command)
429		if err != nil {
430			return nil, fmt.Errorf("invalid mcp command: %w", err)
431		}
432		if strings.TrimSpace(command) == "" {
433			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
434		}
435		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
436		cmd.Env = m.ResolvedEnv()
437		return &mcp.CommandTransport{
438			Command: cmd,
439		}, nil
440	case config.MCPHttp:
441		if strings.TrimSpace(m.URL) == "" {
442			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
443		}
444		client := &http.Client{
445			Transport: &headerRoundTripper{
446				headers: m.ResolvedHeaders(),
447			},
448		}
449		return &mcp.StreamableClientTransport{
450			Endpoint:   m.URL,
451			HTTPClient: client,
452		}, nil
453	case config.MCPSSE:
454		if strings.TrimSpace(m.URL) == "" {
455			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
456		}
457		client := &http.Client{
458			Transport: &headerRoundTripper{
459				headers: m.ResolvedHeaders(),
460			},
461		}
462		return &mcp.SSEClientTransport{
463			Endpoint:   m.URL,
464			HTTPClient: client,
465		}, nil
466	default:
467		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
468	}
469}
470
471type headerRoundTripper struct {
472	headers map[string]string
473}
474
475func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
476	for k, v := range rt.headers {
477		req.Header.Set(k, v)
478	}
479	return http.DefaultTransport.RoundTrip(req)
480}
481
482func mcpTimeout(m config.MCPConfig) time.Duration {
483	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
484}
485
486func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
487	if c.InitializeResult().Capabilities.Prompts == nil {
488		return nil, nil
489	}
490	result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
491	if err != nil {
492		return nil, err
493	}
494	return result.Prompts, nil
495}
496
497// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
498func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
499	if len(prompts) == 0 {
500		mcpClient2Prompts.Del(mcpName)
501	} else {
502		mcpClient2Prompts.Set(mcpName, prompts)
503	}
504	for clientName, prompts := range mcpClient2Prompts.Seq2() {
505		for _, p := range prompts {
506			key := clientName + ":" + p.Name
507			mcpPrompts.Set(key, p)
508		}
509	}
510}
511
512func getResources(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Resource, error) {
513	if c.InitializeResult().Capabilities.Resources == nil {
514		return nil, nil
515	}
516	result, err := c.ListResources(ctx, &mcp.ListResourcesParams{})
517	if err != nil {
518		return nil, err
519	}
520	return result.Resources, nil
521}
522
523// updateMcpResources updates the global mcpResources and mcpClient2Resources maps.
524func updateMcpResources(mcpName string, resources []*mcp.Resource) {
525	if len(resources) == 0 {
526		mcpClient2Resources.Del(mcpName)
527	} else {
528		mcpClient2Resources.Set(mcpName, resources)
529	}
530	for clientName, resources := range mcpClient2Resources.Seq2() {
531		for _, p := range resources {
532			key := clientName + ":" + p.Name
533			mcpResources.Set(key, p)
534		}
535	}
536}
537
538// GetMCPPrompts returns all available MCP prompts.
539func GetMCPPrompts() map[string]*mcp.Prompt {
540	return maps.Collect(mcpPrompts.Seq2())
541}
542
543// GetMCPPrompt returns a specific MCP prompt by name.
544func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
545	return mcpPrompts.Get(name)
546}
547
548// GetMCPPromptsByClient returns all prompts for a specific MCP client.
549func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
550	return mcpClient2Prompts.Get(clientName)
551}
552
553// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
554func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
555	c, err := getOrRenewClient(ctx, clientName)
556	if err != nil {
557		return nil, err
558	}
559
560	return c.GetPrompt(ctx, &mcp.GetPromptParams{
561		Name:      promptName,
562		Arguments: args,
563	})
564}