mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"net/http"
 12	"os/exec"
 13	"strings"
 14	"sync"
 15	"time"
 16
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/csync"
 19	"github.com/charmbracelet/crush/internal/home"
 20	"github.com/charmbracelet/crush/internal/llm/tools"
 21	"github.com/charmbracelet/crush/internal/permission"
 22	"github.com/charmbracelet/crush/internal/pubsub"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/modelcontextprotocol/go-sdk/mcp"
 25)
 26
 27// MCPState represents the current state of an MCP client
 28type MCPState int
 29
 30const (
 31	MCPStateDisabled MCPState = iota
 32	MCPStateStarting
 33	MCPStateConnected
 34	MCPStateError
 35)
 36
 37func (s MCPState) String() string {
 38	switch s {
 39	case MCPStateDisabled:
 40		return "disabled"
 41	case MCPStateStarting:
 42		return "starting"
 43	case MCPStateConnected:
 44		return "connected"
 45	case MCPStateError:
 46		return "error"
 47	default:
 48		return "unknown"
 49	}
 50}
 51
 52// MCPEventType represents the type of MCP event
 53type MCPEventType string
 54
 55const (
 56	MCPEventStateChanged       MCPEventType = "state_changed"
 57	MCPEventToolsListChanged   MCPEventType = "tools_list_changed"
 58	MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
 59)
 60
 61// MCPEvent represents an event in the MCP system
 62type MCPEvent struct {
 63	Type   MCPEventType
 64	Name   string
 65	State  MCPState
 66	Error  error
 67	Counts MCPCounts
 68}
 69
 70// MCPCounts number of available tools, prompts, etc.
 71type MCPCounts struct {
 72	Tools   int
 73	Prompts int
 74}
 75
 76// MCPClientInfo holds information about an MCP client's state
 77type MCPClientInfo struct {
 78	Name        string
 79	State       MCPState
 80	Error       error
 81	Client      *mcp.ClientSession
 82	Counts      MCPCounts
 83	ConnectedAt time.Time
 84}
 85
 86var (
 87	mcpToolsOnce      sync.Once
 88	mcpTools          = csync.NewMap[string, tools.BaseTool]()
 89	mcpClient2Tools   = csync.NewMap[string, []tools.BaseTool]()
 90	mcpClients        = csync.NewMap[string, *mcp.ClientSession]()
 91	mcpStates         = csync.NewMap[string, MCPClientInfo]()
 92	mcpBroker         = pubsub.NewBroker[MCPEvent]()
 93	mcpPrompts        = csync.NewMap[string, *mcp.Prompt]()
 94	mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
 95)
 96
 97type McpTool struct {
 98	mcpName     string
 99	tool        *mcp.Tool
100	permissions permission.Service
101	workingDir  string
102}
103
104func (b *McpTool) Name() string {
105	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
106}
107
108func (b *McpTool) Info() tools.ToolInfo {
109	input := b.tool.InputSchema.(map[string]any)
110	required, _ := input["required"].([]string)
111	parameters, _ := input["properties"].(map[string]any)
112	return tools.ToolInfo{
113		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
114		Description: b.tool.Description,
115		Parameters:  parameters,
116		Required:    required,
117	}
118}
119
120func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
121	var args map[string]any
122	if err := json.Unmarshal([]byte(input), &args); err != nil {
123		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
124	}
125
126	c, err := getOrRenewClient(ctx, name)
127	if err != nil {
128		return tools.NewTextErrorResponse(err.Error()), nil
129	}
130	result, err := c.CallTool(ctx, &mcp.CallToolParams{
131		Name:      toolName,
132		Arguments: args,
133	})
134	if err != nil {
135		return tools.NewTextErrorResponse(err.Error()), nil
136	}
137
138	output := make([]string, 0, len(result.Content))
139	for _, v := range result.Content {
140		if vv, ok := v.(*mcp.TextContent); ok {
141			output = append(output, vv.Text)
142		} else {
143			output = append(output, fmt.Sprintf("%v", v))
144		}
145	}
146	return tools.NewTextResponse(strings.Join(output, "\n")), nil
147}
148
149func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
150	sess, ok := mcpClients.Get(name)
151	if !ok {
152		return nil, fmt.Errorf("mcp '%s' not available", name)
153	}
154
155	cfg := config.Get()
156	m := cfg.MCP[name]
157	state, _ := mcpStates.Get(name)
158
159	timeout := mcpTimeout(m)
160	pingCtx, cancel := context.WithTimeout(ctx, timeout)
161	defer cancel()
162	err := sess.Ping(pingCtx, nil)
163	if err == nil {
164		return sess, nil
165	}
166	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
167
168	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
169	if err != nil {
170		return nil, err
171	}
172
173	updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
174	mcpClients.Set(name, sess)
175	return sess, nil
176}
177
178func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
179	sessionID, messageID := tools.GetContextValues(ctx)
180	if sessionID == "" || messageID == "" {
181		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
182	}
183	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
184	p := b.permissions.Request(
185		permission.CreatePermissionRequest{
186			SessionID:   sessionID,
187			ToolCallID:  params.ID,
188			Path:        b.workingDir,
189			ToolName:    b.Info().Name,
190			Action:      "execute",
191			Description: permissionDescription,
192			Params:      params.Input,
193		},
194	)
195	if !p {
196		return tools.ToolResponse{}, permission.ErrorPermissionDenied
197	}
198
199	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
200}
201
202func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
203	if c.InitializeResult().Capabilities.Tools == nil {
204		return nil, nil
205	}
206	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
207	if err != nil {
208		return nil, err
209	}
210	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
211	for _, tool := range result.Tools {
212		mcpTools = append(mcpTools, &McpTool{
213			mcpName:     name,
214			tool:        tool,
215			permissions: permissions,
216			workingDir:  workingDir,
217		})
218	}
219	return mcpTools, nil
220}
221
222// SubscribeMCPEvents returns a channel for MCP events
223func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
224	return mcpBroker.Subscribe(ctx)
225}
226
227// GetMCPStates returns the current state of all MCP clients
228func GetMCPStates() map[string]MCPClientInfo {
229	return maps.Collect(mcpStates.Seq2())
230}
231
232// GetMCPState returns the state of a specific MCP client
233func GetMCPState(name string) (MCPClientInfo, bool) {
234	return mcpStates.Get(name)
235}
236
237// updateMCPState updates the state of an MCP client and publishes an event
238func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
239	info := MCPClientInfo{
240		Name:   name,
241		State:  state,
242		Error:  err,
243		Client: client,
244		Counts: counts,
245	}
246	switch state {
247	case MCPStateConnected:
248		info.ConnectedAt = time.Now()
249	case MCPStateError:
250		updateMcpTools(name, nil)
251		updateMcpPrompts(name, nil)
252		mcpClients.Del(name)
253	}
254	mcpStates.Set(name, info)
255
256	// Publish state change event
257	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
258		Type:   MCPEventStateChanged,
259		Name:   name,
260		State:  state,
261		Error:  err,
262		Counts: counts,
263	})
264}
265
266// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
267func CloseMCPClients() error {
268	var errs []error
269	for name, c := range mcpClients.Seq2() {
270		if err := c.Close(); err != nil {
271			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
272		}
273	}
274	mcpBroker.Shutdown()
275	return errors.Join(errs...)
276}
277
278func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
279	var wg sync.WaitGroup
280	// Initialize states for all configured MCPs
281	for name, m := range cfg.MCP {
282		if m.Disabled {
283			updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
284			slog.Debug("skipping disabled mcp", "name", name)
285			continue
286		}
287
288		// Set initial starting state
289		updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})
290
291		wg.Add(1)
292		go func(name string, m config.MCPConfig) {
293			defer func() {
294				wg.Done()
295				if r := recover(); r != nil {
296					var err error
297					switch v := r.(type) {
298					case error:
299						err = v
300					case string:
301						err = fmt.Errorf("panic: %s", v)
302					default:
303						err = fmt.Errorf("panic: %v", v)
304					}
305					updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
306					slog.Error("panic in mcp client initialization", "error", err, "name", name)
307				}
308			}()
309
310			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
311			defer cancel()
312
313			c, err := createMCPSession(ctx, name, m, cfg.Resolver())
314			if err != nil {
315				return
316			}
317
318			mcpClients.Set(name, c)
319
320			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
321			if err != nil {
322				slog.Error("error listing tools", "error", err)
323				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
324				c.Close()
325				return
326			}
327
328			prompts, err := getPrompts(ctx, c)
329			if err != nil {
330				slog.Error("error listing prompts", "error", err)
331				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
332				c.Close()
333				return
334			}
335
336			updateMcpTools(name, tools)
337			updateMcpPrompts(name, prompts)
338			mcpClients.Set(name, c)
339			updateMCPState(
340				name, MCPStateConnected, nil, c,
341				MCPCounts{
342					Tools:   len(tools),
343					Prompts: len(prompts),
344				},
345			)
346		}(name, m)
347	}
348	wg.Wait()
349}
350
351// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
352func updateMcpTools(mcpName string, tools []tools.BaseTool) {
353	if len(tools) == 0 {
354		mcpClient2Tools.Del(mcpName)
355	} else {
356		mcpClient2Tools.Set(mcpName, tools)
357	}
358	for _, tools := range mcpClient2Tools.Seq2() {
359		for _, t := range tools {
360			mcpTools.Set(t.Name(), t)
361		}
362	}
363}
364
365func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
366	timeout := mcpTimeout(m)
367	mcpCtx, cancel := context.WithCancel(ctx)
368	cancelTimer := time.AfterFunc(timeout, cancel)
369
370	transport, err := createMCPTransport(mcpCtx, m, resolver)
371	if err != nil {
372		updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
373		slog.Error("error creating mcp client", "error", err, "name", name)
374		return nil, err
375	}
376
377	client := mcp.NewClient(
378		&mcp.Implementation{
379			Name:    "crush",
380			Version: version.Version,
381			Title:   "Crush",
382		},
383		&mcp.ClientOptions{
384			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
385				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
386					Type: MCPEventToolsListChanged,
387					Name: name,
388				})
389			},
390			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
391				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
392					Type: MCPEventPromptsListChanged,
393					Name: name,
394				})
395			},
396			KeepAlive: time.Minute * 10,
397		},
398	)
399
400	session, err := client.Connect(mcpCtx, transport, nil)
401	if err != nil {
402		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
403		slog.Error("error starting mcp client", "error", err, "name", name)
404		_ = session.Close()
405		cancel()
406		return nil, err
407	}
408
409	cancelTimer.Stop()
410	slog.Info("Initialized mcp client", "name", name)
411	return session, nil
412}
413
414func maybeTimeoutErr(err error, timeout time.Duration) error {
415	if errors.Is(err, context.Canceled) {
416		return fmt.Errorf("timed out after %s", timeout)
417	}
418	return err
419}
420
421func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
422	switch m.Type {
423	case config.MCPStdio:
424		command, err := resolver.ResolveValue(m.Command)
425		if err != nil {
426			return nil, fmt.Errorf("invalid mcp command: %w", err)
427		}
428		if strings.TrimSpace(command) == "" {
429			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
430		}
431		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
432		cmd.Env = m.ResolvedEnv()
433		return &mcp.CommandTransport{
434			Command: cmd,
435		}, nil
436	case config.MCPHttp:
437		if strings.TrimSpace(m.URL) == "" {
438			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
439		}
440		client := &http.Client{
441			Transport: &headerRoundTripper{
442				headers: m.ResolvedHeaders(),
443			},
444		}
445		return &mcp.StreamableClientTransport{
446			Endpoint:   m.URL,
447			HTTPClient: client,
448		}, nil
449	case config.MCPSSE:
450		if strings.TrimSpace(m.URL) == "" {
451			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
452		}
453		client := &http.Client{
454			Transport: &headerRoundTripper{
455				headers: m.ResolvedHeaders(),
456			},
457		}
458		return &mcp.SSEClientTransport{
459			Endpoint:   m.URL,
460			HTTPClient: client,
461		}, nil
462	default:
463		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
464	}
465}
466
467type headerRoundTripper struct {
468	headers map[string]string
469}
470
471func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
472	for k, v := range rt.headers {
473		req.Header.Set(k, v)
474	}
475	return http.DefaultTransport.RoundTrip(req)
476}
477
478func mcpTimeout(m config.MCPConfig) time.Duration {
479	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
480}
481
482func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
483	if c.InitializeResult().Capabilities.Prompts == nil {
484		return nil, nil
485	}
486	result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
487	if err != nil {
488		return nil, err
489	}
490	return result.Prompts, nil
491}
492
493// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
494func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
495	if len(prompts) == 0 {
496		mcpClient2Prompts.Del(mcpName)
497	} else {
498		mcpClient2Prompts.Set(mcpName, prompts)
499	}
500	for clientName, prompts := range mcpClient2Prompts.Seq2() {
501		for _, p := range prompts {
502			key := clientName + ":" + p.Name
503			mcpPrompts.Set(key, p)
504		}
505	}
506}
507
508// GetMCPPrompts returns all available MCP prompts.
509func GetMCPPrompts() map[string]*mcp.Prompt {
510	return maps.Collect(mcpPrompts.Seq2())
511}
512
513// GetMCPPrompt returns a specific MCP prompt by name.
514func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
515	return mcpPrompts.Get(name)
516}
517
518// GetMCPPromptsByClient returns all prompts for a specific MCP client.
519func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
520	return mcpClient2Prompts.Get(clientName)
521}
522
523// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
524func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
525	c, err := getOrRenewClient(ctx, clientName)
526	if err != nil {
527		return nil, err
528	}
529
530	return c.GetPrompt(ctx, &mcp.GetPromptParams{
531		Name:      promptName,
532		Arguments: args,
533	})
534}