mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"net/http"
 12	"os/exec"
 13	"strings"
 14	"sync"
 15	"time"
 16
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/csync"
 19	"github.com/charmbracelet/crush/internal/home"
 20	"github.com/charmbracelet/crush/internal/llm/tools"
 21	"github.com/charmbracelet/crush/internal/permission"
 22	"github.com/charmbracelet/crush/internal/pubsub"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/modelcontextprotocol/go-sdk/mcp"
 25)
 26
 27// MCPState represents the current state of an MCP client
 28type MCPState int
 29
 30const (
 31	MCPStateDisabled MCPState = iota
 32	MCPStateStarting
 33	MCPStateConnected
 34	MCPStateError
 35)
 36
 37func (s MCPState) String() string {
 38	switch s {
 39	case MCPStateDisabled:
 40		return "disabled"
 41	case MCPStateStarting:
 42		return "starting"
 43	case MCPStateConnected:
 44		return "connected"
 45	case MCPStateError:
 46		return "error"
 47	default:
 48		return "unknown"
 49	}
 50}
 51
 52// MCPEventType represents the type of MCP event
 53type MCPEventType string
 54
 55const (
 56	MCPEventStateChanged       MCPEventType = "state_changed"
 57	MCPEventToolsListChanged   MCPEventType = "tools_list_changed"
 58	MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
 59)
 60
 61// MCPEvent represents an event in the MCP system
 62type MCPEvent struct {
 63	Type        MCPEventType
 64	Name        string
 65	State       MCPState
 66	Error       error
 67	ToolCount   int
 68	PromptCount int
 69}
 70
 71// MCPClientInfo holds information about an MCP client's state
 72type MCPClientInfo struct {
 73	Name        string
 74	State       MCPState
 75	Error       error
 76	Client      *mcp.ClientSession
 77	ToolCount   int
 78	PromptCount int
 79	ConnectedAt time.Time
 80}
 81
 82var (
 83	mcpToolsOnce      sync.Once
 84	mcpTools          = csync.NewMap[string, tools.BaseTool]()
 85	mcpClient2Tools   = csync.NewMap[string, []tools.BaseTool]()
 86	mcpClients        = csync.NewMap[string, *mcp.ClientSession]()
 87	mcpStates         = csync.NewMap[string, MCPClientInfo]()
 88	mcpBroker         = pubsub.NewBroker[MCPEvent]()
 89	mcpPrompts        = csync.NewMap[string, *mcp.Prompt]()
 90	mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
 91)
 92
 93type McpTool struct {
 94	mcpName     string
 95	tool        *mcp.Tool
 96	permissions permission.Service
 97	workingDir  string
 98}
 99
100func (b *McpTool) Name() string {
101	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
102}
103
104func (b *McpTool) Info() tools.ToolInfo {
105	input := b.tool.InputSchema.(map[string]any)
106	required, _ := input["required"].([]string)
107	parameters, _ := input["properties"].(map[string]any)
108	return tools.ToolInfo{
109		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
110		Description: b.tool.Description,
111		Parameters:  parameters,
112		Required:    required,
113	}
114}
115
116func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
117	var args map[string]any
118	if err := json.Unmarshal([]byte(input), &args); err != nil {
119		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
120	}
121
122	c, err := getOrRenewClient(ctx, name)
123	if err != nil {
124		return tools.NewTextErrorResponse(err.Error()), nil
125	}
126	result, err := c.CallTool(ctx, &mcp.CallToolParams{
127		Name:      toolName,
128		Arguments: args,
129	})
130	if err != nil {
131		return tools.NewTextErrorResponse(err.Error()), nil
132	}
133
134	output := make([]string, 0, len(result.Content))
135	for _, v := range result.Content {
136		if vv, ok := v.(*mcp.TextContent); ok {
137			output = append(output, vv.Text)
138		} else {
139			output = append(output, fmt.Sprintf("%v", v))
140		}
141	}
142	return tools.NewTextResponse(strings.Join(output, "\n")), nil
143}
144
145func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
146	sess, ok := mcpClients.Get(name)
147	if !ok {
148		return nil, fmt.Errorf("mcp '%s' not available", name)
149	}
150
151	cfg := config.Get()
152	m := cfg.MCP[name]
153	state, _ := mcpStates.Get(name)
154
155	timeout := mcpTimeout(m)
156	pingCtx, cancel := context.WithTimeout(ctx, timeout)
157	defer cancel()
158	err := sess.Ping(pingCtx, nil)
159	if err == nil {
160		return sess, nil
161	}
162	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount, state.PromptCount)
163
164	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
165	if err != nil {
166		return nil, err
167	}
168
169	updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount, state.PromptCount)
170	mcpClients.Set(name, sess)
171	return sess, nil
172}
173
174func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
175	sessionID, messageID := tools.GetContextValues(ctx)
176	if sessionID == "" || messageID == "" {
177		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
178	}
179	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
180	p := b.permissions.Request(
181		permission.CreatePermissionRequest{
182			SessionID:   sessionID,
183			ToolCallID:  params.ID,
184			Path:        b.workingDir,
185			ToolName:    b.Info().Name,
186			Action:      "execute",
187			Description: permissionDescription,
188			Params:      params.Input,
189		},
190	)
191	if !p {
192		return tools.ToolResponse{}, permission.ErrorPermissionDenied
193	}
194
195	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
196}
197
198func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
199	if c.InitializeResult().Capabilities.Tools == nil {
200		return nil, nil
201	}
202	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
203	if err != nil {
204		return nil, err
205	}
206	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
207	for _, tool := range result.Tools {
208		mcpTools = append(mcpTools, &McpTool{
209			mcpName:     name,
210			tool:        tool,
211			permissions: permissions,
212			workingDir:  workingDir,
213		})
214	}
215	return mcpTools, nil
216}
217
218// SubscribeMCPEvents returns a channel for MCP events
219func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
220	return mcpBroker.Subscribe(ctx)
221}
222
223// GetMCPStates returns the current state of all MCP clients
224func GetMCPStates() map[string]MCPClientInfo {
225	return maps.Collect(mcpStates.Seq2())
226}
227
228// GetMCPState returns the state of a specific MCP client
229func GetMCPState(name string) (MCPClientInfo, bool) {
230	return mcpStates.Get(name)
231}
232
233// updateMCPState updates the state of an MCP client and publishes an event
234func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount, promptCount int) {
235	info := MCPClientInfo{
236		Name:        name,
237		State:       state,
238		Error:       err,
239		Client:      client,
240		ToolCount:   toolCount,
241		PromptCount: promptCount,
242	}
243	switch state {
244	case MCPStateConnected:
245		info.ConnectedAt = time.Now()
246	case MCPStateError:
247		updateMcpTools(name, nil)
248		updateMcpPrompts(name, nil)
249		mcpClients.Del(name)
250	}
251	mcpStates.Set(name, info)
252
253	// Publish state change event
254	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
255		Type:        MCPEventStateChanged,
256		Name:        name,
257		State:       state,
258		Error:       err,
259		ToolCount:   toolCount,
260		PromptCount: promptCount,
261	})
262}
263
264// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
265func CloseMCPClients() error {
266	var errs []error
267	for name, c := range mcpClients.Seq2() {
268		if err := c.Close(); err != nil {
269			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
270		}
271	}
272	mcpBroker.Shutdown()
273	return errors.Join(errs...)
274}
275
276func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
277	var wg sync.WaitGroup
278	// Initialize states for all configured MCPs
279	for name, m := range cfg.MCP {
280		if m.Disabled {
281			updateMCPState(name, MCPStateDisabled, nil, nil, 0, 0)
282			slog.Debug("skipping disabled mcp", "name", name)
283			continue
284		}
285
286		// Set initial starting state
287		updateMCPState(name, MCPStateStarting, nil, nil, 0, 0)
288
289		wg.Add(1)
290		go func(name string, m config.MCPConfig) {
291			defer func() {
292				wg.Done()
293				if r := recover(); r != nil {
294					var err error
295					switch v := r.(type) {
296					case error:
297						err = v
298					case string:
299						err = fmt.Errorf("panic: %s", v)
300					default:
301						err = fmt.Errorf("panic: %v", v)
302					}
303					updateMCPState(name, MCPStateError, err, nil, 0, 0)
304					slog.Error("panic in mcp client initialization", "error", err, "name", name)
305				}
306			}()
307
308			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
309			defer cancel()
310
311			c, err := createMCPSession(ctx, name, m, cfg.Resolver())
312			if err != nil {
313				return
314			}
315
316			mcpClients.Set(name, c)
317
318			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
319			if err != nil {
320				slog.Error("error listing tools", "error", err)
321				updateMCPState(name, MCPStateError, err, nil, 0, 0)
322				c.Close()
323				return
324			}
325
326			prompts, err := getPrompts(ctx, c)
327			if err != nil {
328				slog.Error("error listing prompts", "error", err)
329				updateMCPState(name, MCPStateError, err, nil, 0, 0)
330				c.Close()
331				return
332			}
333
334			updateMcpTools(name, tools)
335			updateMcpPrompts(name, prompts)
336			mcpClients.Set(name, c)
337			updateMCPState(name, MCPStateConnected, nil, c, len(tools), len(prompts))
338		}(name, m)
339	}
340	wg.Wait()
341}
342
343// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
344func updateMcpTools(mcpName string, tools []tools.BaseTool) {
345	if len(tools) == 0 {
346		mcpClient2Tools.Del(mcpName)
347	} else {
348		mcpClient2Tools.Set(mcpName, tools)
349	}
350	for _, tools := range mcpClient2Tools.Seq2() {
351		for _, t := range tools {
352			mcpTools.Set(t.Name(), t)
353		}
354	}
355}
356
357func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
358	timeout := mcpTimeout(m)
359	mcpCtx, cancel := context.WithCancel(ctx)
360	cancelTimer := time.AfterFunc(timeout, cancel)
361
362	transport, err := createMCPTransport(mcpCtx, m, resolver)
363	if err != nil {
364		updateMCPState(name, MCPStateError, err, nil, 0, 0)
365		slog.Error("error creating mcp client", "error", err, "name", name)
366		return nil, err
367	}
368
369	client := mcp.NewClient(
370		&mcp.Implementation{
371			Name:    "crush",
372			Version: version.Version,
373			Title:   "Crush",
374		},
375		&mcp.ClientOptions{
376			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
377				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
378					Type: MCPEventToolsListChanged,
379					Name: name,
380				})
381			},
382			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
383				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
384					Type: MCPEventPromptsListChanged,
385					Name: name,
386				})
387			},
388			KeepAlive: time.Minute * 10,
389		},
390	)
391
392	session, err := client.Connect(mcpCtx, transport, nil)
393	if err != nil {
394		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0, 0)
395		slog.Error("error starting mcp client", "error", err, "name", name)
396		_ = session.Close()
397		cancel()
398		return nil, err
399	}
400
401	cancelTimer.Stop()
402	slog.Info("Initialized mcp client", "name", name)
403	return session, nil
404}
405
406func maybeTimeoutErr(err error, timeout time.Duration) error {
407	if errors.Is(err, context.Canceled) {
408		return fmt.Errorf("timed out after %s", timeout)
409	}
410	return err
411}
412
413func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
414	switch m.Type {
415	case config.MCPStdio:
416		command, err := resolver.ResolveValue(m.Command)
417		if err != nil {
418			return nil, fmt.Errorf("invalid mcp command: %w", err)
419		}
420		if strings.TrimSpace(command) == "" {
421			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
422		}
423		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
424		cmd.Env = m.ResolvedEnv()
425		return &mcp.CommandTransport{
426			Command: cmd,
427		}, nil
428	case config.MCPHttp:
429		if strings.TrimSpace(m.URL) == "" {
430			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
431		}
432		client := &http.Client{
433			Transport: &headerRoundTripper{
434				headers: m.ResolvedHeaders(),
435			},
436		}
437		return &mcp.StreamableClientTransport{
438			Endpoint:   m.URL,
439			HTTPClient: client,
440		}, nil
441	case config.MCPSSE:
442		if strings.TrimSpace(m.URL) == "" {
443			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
444		}
445		client := &http.Client{
446			Transport: &headerRoundTripper{
447				headers: m.ResolvedHeaders(),
448			},
449		}
450		return &mcp.SSEClientTransport{
451			Endpoint:   m.URL,
452			HTTPClient: client,
453		}, nil
454	default:
455		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
456	}
457}
458
459type headerRoundTripper struct {
460	headers map[string]string
461}
462
463func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
464	for k, v := range rt.headers {
465		req.Header.Set(k, v)
466	}
467	return http.DefaultTransport.RoundTrip(req)
468}
469
470func mcpTimeout(m config.MCPConfig) time.Duration {
471	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
472}
473
474func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
475	if c.InitializeResult().Capabilities.Prompts == nil {
476		return nil, nil
477	}
478	result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
479	if err != nil {
480		return nil, err
481	}
482	return result.Prompts, nil
483}
484
485// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
486func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
487	if len(prompts) == 0 {
488		mcpClient2Prompts.Del(mcpName)
489	} else {
490		mcpClient2Prompts.Set(mcpName, prompts)
491	}
492	for clientName, prompts := range mcpClient2Prompts.Seq2() {
493		for _, p := range prompts {
494			key := clientName + ":" + p.Name
495			mcpPrompts.Set(key, p)
496		}
497	}
498}
499
500// GetMCPPrompts returns all available MCP prompts.
501func GetMCPPrompts() map[string]*mcp.Prompt {
502	return maps.Collect(mcpPrompts.Seq2())
503}
504
505// GetMCPPrompt returns a specific MCP prompt by name.
506func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
507	return mcpPrompts.Get(name)
508}
509
510// GetMCPPromptsByClient returns all prompts for a specific MCP client.
511func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
512	return mcpClient2Prompts.Get(clientName)
513}
514
515// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
516func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
517	c, err := getOrRenewClient(ctx, clientName)
518	if err != nil {
519		return nil, err
520	}
521
522	return c.GetPrompt(ctx, &mcp.GetPromptParams{
523		Name:      promptName,
524		Arguments: args,
525	})
526}