mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"log/slog"
 11	"maps"
 12	"net/http"
 13	"os/exec"
 14	"strings"
 15	"sync"
 16	"time"
 17
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20	"github.com/charmbracelet/crush/internal/home"
 21	"github.com/charmbracelet/crush/internal/llm/tools"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/version"
 25	"github.com/modelcontextprotocol/go-sdk/mcp"
 26)
 27
 28// MCPState represents the current state of an MCP client
 29type MCPState int
 30
 31const (
 32	MCPStateDisabled MCPState = iota
 33	MCPStateStarting
 34	MCPStateConnected
 35	MCPStateError
 36)
 37
 38func (s MCPState) String() string {
 39	switch s {
 40	case MCPStateDisabled:
 41		return "disabled"
 42	case MCPStateStarting:
 43		return "starting"
 44	case MCPStateConnected:
 45		return "connected"
 46	case MCPStateError:
 47		return "error"
 48	default:
 49		return "unknown"
 50	}
 51}
 52
 53// MCPEventType represents the type of MCP event
 54type MCPEventType string
 55
 56const (
 57	MCPEventToolsListChanged   MCPEventType = "tools_list_changed"
 58	MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
 59)
 60
 61// MCPEvent represents an event in the MCP system
 62type MCPEvent struct {
 63	Type   MCPEventType
 64	Name   string
 65	State  MCPState
 66	Error  error
 67	Counts MCPCounts
 68}
 69
 70// MCPCounts number of available tools, prompts, etc.
 71type MCPCounts struct {
 72	Tools   int
 73	Prompts int
 74}
 75
 76// MCPClientInfo holds information about an MCP client's state
 77type MCPClientInfo struct {
 78	Name        string
 79	State       MCPState
 80	Error       error
 81	Client      *mcp.ClientSession
 82	Counts      MCPCounts
 83	ConnectedAt time.Time
 84}
 85
 86var (
 87	mcpToolsOnce      sync.Once
 88	mcpTools          = csync.NewMap[string, tools.BaseTool]()
 89	mcpClient2Tools   = csync.NewMap[string, []tools.BaseTool]()
 90	mcpClients        = csync.NewMap[string, *mcp.ClientSession]()
 91	mcpStates         = csync.NewMap[string, MCPClientInfo]()
 92	mcpBroker         = pubsub.NewBroker[MCPEvent]()
 93	mcpPrompts        = csync.NewMap[string, *mcp.Prompt]()
 94	mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
 95)
 96
 97type McpTool struct {
 98	mcpName     string
 99	tool        *mcp.Tool
100	permissions permission.Service
101	workingDir  string
102}
103
104func (b *McpTool) Name() string {
105	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
106}
107
108func (b *McpTool) Info() tools.ToolInfo {
109	var parameters map[string]any
110	var required []string
111
112	if input, ok := b.tool.InputSchema.(map[string]any); ok {
113		if props, ok := input["properties"].(map[string]any); ok {
114			parameters = props
115		}
116		if req, ok := input["required"].([]any); ok {
117			// Convert []any -> []string when elements are strings
118			for _, v := range req {
119				if s, ok := v.(string); ok {
120					required = append(required, s)
121				}
122			}
123		} else if reqStr, ok := input["required"].([]string); ok {
124			// Handle case where it's already []string
125			required = reqStr
126		}
127	}
128
129	return tools.ToolInfo{
130		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
131		Description: b.tool.Description,
132		Parameters:  parameters,
133		Required:    required,
134	}
135}
136
137func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
138	var args map[string]any
139	if err := json.Unmarshal([]byte(input), &args); err != nil {
140		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
141	}
142
143	c, err := getOrRenewClient(ctx, name)
144	if err != nil {
145		return tools.NewTextErrorResponse(err.Error()), nil
146	}
147	result, err := c.CallTool(ctx, &mcp.CallToolParams{
148		Name:      toolName,
149		Arguments: args,
150	})
151	if err != nil {
152		return tools.NewTextErrorResponse(err.Error()), nil
153	}
154
155	output := make([]string, 0, len(result.Content))
156	for _, v := range result.Content {
157		if vv, ok := v.(*mcp.TextContent); ok {
158			output = append(output, vv.Text)
159		} else {
160			output = append(output, fmt.Sprintf("%v", v))
161		}
162	}
163	return tools.NewTextResponse(strings.Join(output, "\n")), nil
164}
165
166func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
167	sess, ok := mcpClients.Get(name)
168	if !ok {
169		return nil, fmt.Errorf("mcp '%s' not available", name)
170	}
171
172	cfg := config.Get()
173	m := cfg.MCP[name]
174	state, _ := mcpStates.Get(name)
175
176	timeout := mcpTimeout(m)
177	pingCtx, cancel := context.WithTimeout(ctx, timeout)
178	defer cancel()
179	err := sess.Ping(pingCtx, nil)
180	if err == nil {
181		return sess, nil
182	}
183	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
184
185	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
186	if err != nil {
187		return nil, err
188	}
189
190	updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
191	mcpClients.Set(name, sess)
192	return sess, nil
193}
194
195func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
196	sessionID, messageID := tools.GetContextValues(ctx)
197	if sessionID == "" || messageID == "" {
198		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
199	}
200	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
201	p := b.permissions.Request(
202		permission.CreatePermissionRequest{
203			SessionID:   sessionID,
204			ToolCallID:  params.ID,
205			Path:        b.workingDir,
206			ToolName:    b.Info().Name,
207			Action:      "execute",
208			Description: permissionDescription,
209			Params:      params.Input,
210		},
211	)
212	if !p {
213		return tools.ToolResponse{}, permission.ErrorPermissionDenied
214	}
215
216	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
217}
218
219func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
220	if c.InitializeResult().Capabilities.Tools == nil {
221		return nil, nil
222	}
223	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
224	if err != nil {
225		return nil, err
226	}
227	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
228	for _, tool := range result.Tools {
229		mcpTools = append(mcpTools, &McpTool{
230			mcpName:     name,
231			tool:        tool,
232			permissions: permissions,
233			workingDir:  workingDir,
234		})
235	}
236	return mcpTools, nil
237}
238
239// SubscribeMCPEvents returns a channel for MCP events
240func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
241	return mcpBroker.Subscribe(ctx)
242}
243
244// GetMCPStates returns the current state of all MCP clients
245func GetMCPStates() map[string]MCPClientInfo {
246	return maps.Collect(mcpStates.Seq2())
247}
248
249// GetMCPState returns the state of a specific MCP client
250func GetMCPState(name string) (MCPClientInfo, bool) {
251	return mcpStates.Get(name)
252}
253
254// updateMCPState updates the state of an MCP client and publishes an event
255func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
256	info := MCPClientInfo{
257		Name:   name,
258		State:  state,
259		Error:  err,
260		Client: client,
261		Counts: counts,
262	}
263	switch state {
264	case MCPStateConnected:
265		info.ConnectedAt = time.Now()
266	case MCPStateError:
267		updateMcpTools(name, nil)
268		updateMcpPrompts(name, nil)
269		mcpClients.Del(name)
270	}
271	mcpStates.Set(name, info)
272}
273
274// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
275func CloseMCPClients() error {
276	var errs []error
277	for name, c := range mcpClients.Seq2() {
278		if err := c.Close(); err != nil &&
279			!errors.Is(err, io.EOF) &&
280			!errors.Is(err, context.Canceled) &&
281			err.Error() != "signal: killed" {
282			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
283		}
284	}
285	mcpBroker.Shutdown()
286	return errors.Join(errs...)
287}
288
289func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
290	var wg sync.WaitGroup
291	// Initialize states for all configured MCPs
292	for name, m := range cfg.MCP {
293		if m.Disabled {
294			updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
295			slog.Debug("skipping disabled mcp", "name", name)
296			continue
297		}
298
299		// Set initial starting state
300		updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})
301
302		wg.Add(1)
303		go func(name string, m config.MCPConfig) {
304			defer func() {
305				wg.Done()
306				if r := recover(); r != nil {
307					var err error
308					switch v := r.(type) {
309					case error:
310						err = v
311					case string:
312						err = fmt.Errorf("panic: %s", v)
313					default:
314						err = fmt.Errorf("panic: %v", v)
315					}
316					updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
317					slog.Error("panic in mcp client initialization", "error", err, "name", name)
318				}
319			}()
320
321			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
322			defer cancel()
323
324			c, err := createMCPSession(ctx, name, m, cfg.Resolver())
325			if err != nil {
326				return
327			}
328
329			mcpClients.Set(name, c)
330
331			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
332			if err != nil {
333				slog.Error("error listing tools", "error", err)
334				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
335				c.Close()
336				return
337			}
338
339			prompts, err := getPrompts(ctx, c)
340			if err != nil {
341				slog.Error("error listing prompts", "error", err)
342				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
343				c.Close()
344				return
345			}
346
347			updateMcpTools(name, tools)
348			updateMcpPrompts(name, prompts)
349			mcpClients.Set(name, c)
350			counts := MCPCounts{
351				Tools:   len(tools),
352				Prompts: len(prompts),
353			}
354			updateMCPState(name, MCPStateConnected, nil, c, counts)
355		}(name, m)
356	}
357	wg.Wait()
358}
359
360// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
361func updateMcpTools(mcpName string, tools []tools.BaseTool) {
362	if len(tools) == 0 {
363		mcpClient2Tools.Del(mcpName)
364	} else {
365		mcpClient2Tools.Set(mcpName, tools)
366	}
367	for _, tools := range mcpClient2Tools.Seq2() {
368		for _, t := range tools {
369			mcpTools.Set(t.Name(), t)
370		}
371	}
372}
373
374func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
375	timeout := mcpTimeout(m)
376	mcpCtx, cancel := context.WithCancel(ctx)
377	cancelTimer := time.AfterFunc(timeout, cancel)
378
379	transport, err := createMCPTransport(mcpCtx, m, resolver)
380	if err != nil {
381		updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
382		slog.Error("error creating mcp client", "error", err, "name", name)
383		return nil, err
384	}
385
386	client := mcp.NewClient(
387		&mcp.Implementation{
388			Name:    "crush",
389			Version: version.Version,
390			Title:   "Crush",
391		},
392		&mcp.ClientOptions{
393			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
394				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
395					Type: MCPEventToolsListChanged,
396					Name: name,
397				})
398			},
399			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
400				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
401					Type: MCPEventPromptsListChanged,
402					Name: name,
403				})
404			},
405			KeepAlive: time.Minute * 10,
406		},
407	)
408
409	session, err := client.Connect(mcpCtx, transport, nil)
410	if err != nil {
411		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
412		slog.Error("error starting mcp client", "error", err, "name", name)
413		cancel()
414		return nil, err
415	}
416
417	cancelTimer.Stop()
418	slog.Info("Initialized mcp client", "name", name)
419	return session, nil
420}
421
422func maybeTimeoutErr(err error, timeout time.Duration) error {
423	if errors.Is(err, context.Canceled) {
424		return fmt.Errorf("timed out after %s", timeout)
425	}
426	return err
427}
428
429func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
430	switch m.Type {
431	case config.MCPStdio:
432		command, err := resolver.ResolveValue(m.Command)
433		if err != nil {
434			return nil, fmt.Errorf("invalid mcp command: %w", err)
435		}
436		if strings.TrimSpace(command) == "" {
437			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
438		}
439		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
440		cmd.Env = m.ResolvedEnv()
441		return &mcp.CommandTransport{
442			Command: cmd,
443		}, nil
444	case config.MCPHttp:
445		if strings.TrimSpace(m.URL) == "" {
446			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
447		}
448		client := &http.Client{
449			Transport: &headerRoundTripper{
450				headers: m.ResolvedHeaders(),
451			},
452		}
453		return &mcp.StreamableClientTransport{
454			Endpoint:   m.URL,
455			HTTPClient: client,
456		}, nil
457	case config.MCPSSE:
458		if strings.TrimSpace(m.URL) == "" {
459			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
460		}
461		client := &http.Client{
462			Transport: &headerRoundTripper{
463				headers: m.ResolvedHeaders(),
464			},
465		}
466		return &mcp.SSEClientTransport{
467			Endpoint:   m.URL,
468			HTTPClient: client,
469		}, nil
470	default:
471		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
472	}
473}
474
475type headerRoundTripper struct {
476	headers map[string]string
477}
478
479func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
480	for k, v := range rt.headers {
481		req.Header.Set(k, v)
482	}
483	return http.DefaultTransport.RoundTrip(req)
484}
485
486func mcpTimeout(m config.MCPConfig) time.Duration {
487	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
488}
489
490func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
491	if c.InitializeResult().Capabilities.Prompts == nil {
492		return nil, nil
493	}
494	result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
495	if err != nil {
496		return nil, err
497	}
498	return result.Prompts, nil
499}
500
501// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
502func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
503	if len(prompts) == 0 {
504		mcpClient2Prompts.Del(mcpName)
505	} else {
506		mcpClient2Prompts.Set(mcpName, prompts)
507	}
508	for clientName, prompts := range mcpClient2Prompts.Seq2() {
509		for _, p := range prompts {
510			key := clientName + ":" + p.Name
511			mcpPrompts.Set(key, p)
512		}
513	}
514}
515
516// GetMCPPrompts returns all available MCP prompts.
517func GetMCPPrompts() map[string]*mcp.Prompt {
518	return maps.Collect(mcpPrompts.Seq2())
519}
520
521// GetMCPPrompt returns a specific MCP prompt by name.
522func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
523	return mcpPrompts.Get(name)
524}
525
526// GetMCPPromptsByClient returns all prompts for a specific MCP client.
527func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
528	return mcpClient2Prompts.Get(clientName)
529}
530
531// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
532func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
533	c, err := getOrRenewClient(ctx, clientName)
534	if err != nil {
535		return nil, err
536	}
537
538	return c.GetPrompt(ctx, &mcp.GetPromptParams{
539		Name:      promptName,
540		Arguments: args,
541	})
542}