mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"log/slog"
 11	"maps"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"strings"
 16	"sync"
 17	"time"
 18
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/home"
 22	"github.com/charmbracelet/crush/internal/llm/tools"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/pubsub"
 25	"github.com/charmbracelet/crush/internal/version"
 26	"github.com/modelcontextprotocol/go-sdk/mcp"
 27)
 28
 29// MCPState represents the current state of an MCP client
 30type MCPState int
 31
 32const (
 33	MCPStateDisabled MCPState = iota
 34	MCPStateStarting
 35	MCPStateConnected
 36	MCPStateError
 37)
 38
 39func (s MCPState) String() string {
 40	switch s {
 41	case MCPStateDisabled:
 42		return "disabled"
 43	case MCPStateStarting:
 44		return "starting"
 45	case MCPStateConnected:
 46		return "connected"
 47	case MCPStateError:
 48		return "error"
 49	default:
 50		return "unknown"
 51	}
 52}
 53
 54// MCPEventType represents the type of MCP event
 55type MCPEventType string
 56
 57const (
 58	MCPEventToolsListChanged   MCPEventType = "tools_list_changed"
 59	MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
 60)
 61
 62// MCPEvent represents an event in the MCP system
 63type MCPEvent struct {
 64	Type   MCPEventType
 65	Name   string
 66	State  MCPState
 67	Error  error
 68	Counts MCPCounts
 69}
 70
 71// MCPCounts number of available tools, prompts, etc.
 72type MCPCounts struct {
 73	Tools   int
 74	Prompts int
 75}
 76
 77// MCPClientInfo holds information about an MCP client's state
 78type MCPClientInfo struct {
 79	Name        string
 80	State       MCPState
 81	Error       error
 82	Client      *mcp.ClientSession
 83	Counts      MCPCounts
 84	ConnectedAt time.Time
 85}
 86
 87var (
 88	mcpToolsOnce      sync.Once
 89	mcpTools          = csync.NewMap[string, tools.BaseTool]()
 90	mcpClient2Tools   = csync.NewMap[string, []tools.BaseTool]()
 91	mcpClients        = csync.NewMap[string, *mcp.ClientSession]()
 92	mcpStates         = csync.NewMap[string, MCPClientInfo]()
 93	mcpBroker         = pubsub.NewBroker[MCPEvent]()
 94	mcpPrompts        = csync.NewMap[string, *mcp.Prompt]()
 95	mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
 96)
 97
 98type McpTool struct {
 99	mcpName     string
100	tool        *mcp.Tool
101	permissions permission.Service
102	workingDir  string
103}
104
105func (b *McpTool) Name() string {
106	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
107}
108
109func (b *McpTool) Info() tools.ToolInfo {
110	parameters := make(map[string]any)
111	required := make([]string, 0)
112
113	if input, ok := b.tool.InputSchema.(map[string]any); ok {
114		if props, ok := input["properties"].(map[string]any); ok {
115			parameters = props
116		}
117		if req, ok := input["required"].([]any); ok {
118			// Convert []any -> []string when elements are strings
119			for _, v := range req {
120				if s, ok := v.(string); ok {
121					required = append(required, s)
122				}
123			}
124		} else if reqStr, ok := input["required"].([]string); ok {
125			// Handle case where it's already []string
126			required = reqStr
127		}
128	}
129
130	return tools.ToolInfo{
131		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
132		Description: b.tool.Description,
133		Parameters:  parameters,
134		Required:    required,
135	}
136}
137
138func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
139	var args map[string]any
140	if err := json.Unmarshal([]byte(input), &args); err != nil {
141		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
142	}
143
144	c, err := getOrRenewClient(ctx, name)
145	if err != nil {
146		return tools.NewTextErrorResponse(err.Error()), nil
147	}
148	result, err := c.CallTool(ctx, &mcp.CallToolParams{
149		Name:      toolName,
150		Arguments: args,
151	})
152	if err != nil {
153		return tools.NewTextErrorResponse(err.Error()), nil
154	}
155
156	output := make([]string, 0, len(result.Content))
157	for _, v := range result.Content {
158		if vv, ok := v.(*mcp.TextContent); ok {
159			output = append(output, vv.Text)
160		} else {
161			output = append(output, fmt.Sprintf("%v", v))
162		}
163	}
164	return tools.NewTextResponse(strings.Join(output, "\n")), nil
165}
166
167func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
168	sess, ok := mcpClients.Get(name)
169	if !ok {
170		return nil, fmt.Errorf("mcp '%s' not available", name)
171	}
172
173	cfg := config.Get()
174	m := cfg.MCP[name]
175	state, _ := mcpStates.Get(name)
176
177	timeout := mcpTimeout(m)
178	pingCtx, cancel := context.WithTimeout(ctx, timeout)
179	defer cancel()
180	err := sess.Ping(pingCtx, nil)
181	if err == nil {
182		return sess, nil
183	}
184	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
185
186	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
187	if err != nil {
188		return nil, err
189	}
190
191	updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
192	mcpClients.Set(name, sess)
193	return sess, nil
194}
195
196func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
197	sessionID, messageID := tools.GetContextValues(ctx)
198	if sessionID == "" || messageID == "" {
199		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
200	}
201	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
202	p := b.permissions.Request(
203		permission.CreatePermissionRequest{
204			SessionID:   sessionID,
205			ToolCallID:  params.ID,
206			Path:        b.workingDir,
207			ToolName:    b.Info().Name,
208			Action:      "execute",
209			Description: permissionDescription,
210			Params:      params.Input,
211		},
212	)
213	if !p {
214		return tools.ToolResponse{}, permission.ErrorPermissionDenied
215	}
216
217	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
218}
219
220func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
221	if c.InitializeResult().Capabilities.Tools == nil {
222		return nil, nil
223	}
224	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
225	if err != nil {
226		return nil, err
227	}
228	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
229	for _, tool := range result.Tools {
230		mcpTools = append(mcpTools, &McpTool{
231			mcpName:     name,
232			tool:        tool,
233			permissions: permissions,
234			workingDir:  workingDir,
235		})
236	}
237	return mcpTools, nil
238}
239
240// SubscribeMCPEvents returns a channel for MCP events
241func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
242	return mcpBroker.Subscribe(ctx)
243}
244
245// GetMCPStates returns the current state of all MCP clients
246func GetMCPStates() map[string]MCPClientInfo {
247	return maps.Collect(mcpStates.Seq2())
248}
249
250// GetMCPState returns the state of a specific MCP client
251func GetMCPState(name string) (MCPClientInfo, bool) {
252	return mcpStates.Get(name)
253}
254
255// updateMCPState updates the state of an MCP client and publishes an event
256func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
257	info := MCPClientInfo{
258		Name:   name,
259		State:  state,
260		Error:  err,
261		Client: client,
262		Counts: counts,
263	}
264	switch state {
265	case MCPStateConnected:
266		info.ConnectedAt = time.Now()
267	case MCPStateError:
268		updateMcpTools(name, nil)
269		updateMcpPrompts(name, nil)
270		mcpClients.Del(name)
271	}
272	mcpStates.Set(name, info)
273}
274
275// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
276func CloseMCPClients() error {
277	var errs []error
278	for name, c := range mcpClients.Seq2() {
279		if err := c.Close(); err != nil &&
280			!errors.Is(err, io.EOF) &&
281			!errors.Is(err, context.Canceled) &&
282			err.Error() != "signal: killed" {
283			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
284		}
285	}
286	mcpBroker.Shutdown()
287	return errors.Join(errs...)
288}
289
290func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
291	var wg sync.WaitGroup
292	// Initialize states for all configured MCPs
293	for name, m := range cfg.MCP {
294		if m.Disabled {
295			updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
296			slog.Debug("skipping disabled mcp", "name", name)
297			continue
298		}
299
300		// Set initial starting state
301		updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})
302
303		wg.Add(1)
304		go func(name string, m config.MCPConfig) {
305			defer func() {
306				wg.Done()
307				if r := recover(); r != nil {
308					var err error
309					switch v := r.(type) {
310					case error:
311						err = v
312					case string:
313						err = fmt.Errorf("panic: %s", v)
314					default:
315						err = fmt.Errorf("panic: %v", v)
316					}
317					updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
318					slog.Error("panic in mcp client initialization", "error", err, "name", name)
319				}
320			}()
321
322			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
323			defer cancel()
324
325			c, err := createMCPSession(ctx, name, m, cfg.Resolver())
326			if err != nil {
327				return
328			}
329
330			mcpClients.Set(name, c)
331
332			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
333			if err != nil {
334				slog.Error("error listing tools", "error", err)
335				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
336				c.Close()
337				return
338			}
339
340			prompts, err := getPrompts(ctx, c)
341			if err != nil {
342				slog.Error("error listing prompts", "error", err)
343				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
344				c.Close()
345				return
346			}
347
348			updateMcpTools(name, tools)
349			updateMcpPrompts(name, prompts)
350			mcpClients.Set(name, c)
351			counts := MCPCounts{
352				Tools:   len(tools),
353				Prompts: len(prompts),
354			}
355			updateMCPState(name, MCPStateConnected, nil, c, counts)
356		}(name, m)
357	}
358	wg.Wait()
359}
360
361// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
362func updateMcpTools(mcpName string, tools []tools.BaseTool) {
363	if len(tools) == 0 {
364		mcpClient2Tools.Del(mcpName)
365	} else {
366		mcpClient2Tools.Set(mcpName, tools)
367	}
368	for _, tools := range mcpClient2Tools.Seq2() {
369		for _, t := range tools {
370			mcpTools.Set(t.Name(), t)
371		}
372	}
373}
374
375func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
376	timeout := mcpTimeout(m)
377	mcpCtx, cancel := context.WithCancel(ctx)
378	cancelTimer := time.AfterFunc(timeout, cancel)
379
380	transport, err := createMCPTransport(mcpCtx, m, resolver)
381	if err != nil {
382		updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
383		slog.Error("error creating mcp client", "error", err, "name", name)
384		cancel()
385		cancelTimer.Stop()
386		return nil, err
387	}
388
389	client := mcp.NewClient(
390		&mcp.Implementation{
391			Name:    "crush",
392			Version: version.Version,
393			Title:   "Crush",
394		},
395		&mcp.ClientOptions{
396			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
397				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
398					Type: MCPEventToolsListChanged,
399					Name: name,
400				})
401			},
402			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
403				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
404					Type: MCPEventPromptsListChanged,
405					Name: name,
406				})
407			},
408			KeepAlive: time.Minute * 10,
409		},
410	)
411
412	session, err := client.Connect(mcpCtx, transport, nil)
413	if err != nil {
414		err = maybeStdioErr(err, transport)
415		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
416		slog.Error("error starting mcp client", "error", err, "name", name)
417		cancel()
418		cancelTimer.Stop()
419		return nil, err
420	}
421
422	cancelTimer.Stop()
423	slog.Info("Initialized mcp client", "name", name)
424	return session, nil
425}
426
427func maybeTimeoutErr(err error, timeout time.Duration) error {
428	if errors.Is(err, context.Canceled) {
429		return fmt.Errorf("timed out after %s", timeout)
430	}
431	return err
432}
433
434func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
435	switch m.Type {
436	case config.MCPStdio:
437		command, err := resolver.ResolveValue(m.Command)
438		if err != nil {
439			return nil, fmt.Errorf("invalid mcp command: %w", err)
440		}
441		if strings.TrimSpace(command) == "" {
442			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
443		}
444		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
445		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
446		return &mcp.CommandTransport{
447			Command: cmd,
448		}, nil
449	case config.MCPHttp:
450		if strings.TrimSpace(m.URL) == "" {
451			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
452		}
453		client := &http.Client{
454			Transport: &headerRoundTripper{
455				headers: m.ResolvedHeaders(),
456			},
457		}
458		return &mcp.StreamableClientTransport{
459			Endpoint:   m.URL,
460			HTTPClient: client,
461		}, nil
462	case config.MCPSSE:
463		if strings.TrimSpace(m.URL) == "" {
464			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
465		}
466		client := &http.Client{
467			Transport: &headerRoundTripper{
468				headers: m.ResolvedHeaders(),
469			},
470		}
471		return &mcp.SSEClientTransport{
472			Endpoint:   m.URL,
473			HTTPClient: client,
474		}, nil
475	default:
476		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
477	}
478}
479
480type headerRoundTripper struct {
481	headers map[string]string
482}
483
484func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
485	for k, v := range rt.headers {
486		req.Header.Set(k, v)
487	}
488	return http.DefaultTransport.RoundTrip(req)
489}
490
491func mcpTimeout(m config.MCPConfig) time.Duration {
492	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
493}
494
495func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
496	if c.InitializeResult().Capabilities.Prompts == nil {
497		return nil, nil
498	}
499	result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
500	if err != nil {
501		return nil, err
502	}
503	return result.Prompts, nil
504}
505
506// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
507func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
508	if len(prompts) == 0 {
509		mcpClient2Prompts.Del(mcpName)
510	} else {
511		mcpClient2Prompts.Set(mcpName, prompts)
512	}
513	for clientName, prompts := range mcpClient2Prompts.Seq2() {
514		for _, p := range prompts {
515			key := clientName + ":" + p.Name
516			mcpPrompts.Set(key, p)
517		}
518	}
519}
520
521// GetMCPPrompts returns all available MCP prompts.
522func GetMCPPrompts() map[string]*mcp.Prompt {
523	return maps.Collect(mcpPrompts.Seq2())
524}
525
526// GetMCPPrompt returns a specific MCP prompt by name.
527func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
528	return mcpPrompts.Get(name)
529}
530
531// GetMCPPromptsByClient returns all prompts for a specific MCP client.
532func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
533	return mcpClient2Prompts.Get(clientName)
534}
535
536// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
537func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
538	c, err := getOrRenewClient(ctx, clientName)
539	if err != nil {
540		return nil, err
541	}
542
543	return c.GetPrompt(ctx, &mcp.GetPromptParams{
544		Name:      promptName,
545		Arguments: args,
546	})
547}
548
549// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
550// to parse, and the cli will then close it, causing the EOF error.
551// so, if we got an EOF err, and the transport is STDIO, we try to exec it
552// again with a timeout and collect the output so we can add details to the
553// error.
554// this happens particularly when starting things with npx, e.g. if node can't
555// be found or some other error like that.
556func maybeStdioErr(err error, transport mcp.Transport) error {
557	if !errors.Is(err, io.EOF) {
558		return err
559	}
560	ct, ok := transport.(*mcp.CommandTransport)
561	if !ok {
562		return err
563	}
564	if err2 := stdioMCPCheck(ct.Command); err2 != nil {
565		err = errors.Join(err, err2)
566	}
567	return err
568}
569
570func stdioMCPCheck(old *exec.Cmd) error {
571	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
572	defer cancel()
573	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
574	cmd.Env = old.Env
575	out, err := cmd.CombinedOutput()
576	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
577		return nil
578	}
579	return fmt.Errorf("%w: %s", err, string(out))
580}