mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"maps"
 10	"slices"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/llm/tools"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/version"
 21	"github.com/mark3labs/mcp-go/client"
 22	"github.com/mark3labs/mcp-go/client/transport"
 23	"github.com/mark3labs/mcp-go/mcp"
 24)
 25
 26// MCPState represents the current state of an MCP client
 27type MCPState int
 28
 29const (
 30	MCPStateDisabled MCPState = iota
 31	MCPStateStarting
 32	MCPStateConnected
 33	MCPStateError
 34)
 35
 36func (s MCPState) String() string {
 37	switch s {
 38	case MCPStateDisabled:
 39		return "disabled"
 40	case MCPStateStarting:
 41		return "starting"
 42	case MCPStateConnected:
 43		return "connected"
 44	case MCPStateError:
 45		return "error"
 46	default:
 47		return "unknown"
 48	}
 49}
 50
 51// MCPEventType represents the type of MCP event
 52type MCPEventType string
 53
 54const (
 55	MCPEventStateChanged     MCPEventType = "state_changed"
 56	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 57)
 58
 59// MCPEvent represents an event in the MCP system
 60type MCPEvent struct {
 61	Type      MCPEventType
 62	Name      string
 63	State     MCPState
 64	Error     error
 65	ToolCount int
 66}
 67
 68// MCPClientInfo holds information about an MCP client's state
 69type MCPClientInfo struct {
 70	Name        string
 71	State       MCPState
 72	Error       error
 73	Client      *client.Client
 74	ToolCount   int
 75	ConnectedAt time.Time
 76}
 77
 78var (
 79	mcpToolsOnce sync.Once
 80	mcpTools     = csync.NewMap[string, tools.BaseTool]()
 81	// mcpClientTools maps MCP name to tool names
 82	mcpClientTools                                           = csync.NewMap[string, []string]()
 83	mcpClients                                               = csync.NewMap[string, *client.Client]()
 84	mcpStates                                                = csync.NewMap[string, MCPClientInfo]()
 85	mcpBroker                                                = pubsub.NewBroker[MCPEvent]()
 86	toolsMaker     func(string, []mcp.Tool) []tools.BaseTool = nil
 87)
 88
 89type McpTool struct {
 90	mcpName     string
 91	tool        mcp.Tool
 92	permissions permission.Service
 93	workingDir  string
 94}
 95
 96func (b *McpTool) Name() string {
 97	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 98}
 99
100func (b *McpTool) Info() tools.ToolInfo {
101	required := b.tool.InputSchema.Required
102	if required == nil {
103		required = make([]string, 0)
104	}
105	parameters := b.tool.InputSchema.Properties
106	if parameters == nil {
107		parameters = make(map[string]any)
108	}
109	return tools.ToolInfo{
110		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
111		Description: b.tool.Description,
112		Parameters:  parameters,
113		Required:    required,
114	}
115}
116
117func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
118	var args map[string]any
119	if err := json.Unmarshal([]byte(input), &args); err != nil {
120		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
121	}
122
123	c, err := getOrRenewClient(ctx, name)
124	if err != nil {
125		return tools.NewTextErrorResponse(err.Error()), nil
126	}
127	result, err := c.CallTool(ctx, mcp.CallToolRequest{
128		Params: mcp.CallToolParams{
129			Name:      toolName,
130			Arguments: args,
131		},
132	})
133	if err != nil {
134		return tools.NewTextErrorResponse(err.Error()), nil
135	}
136
137	output := make([]string, 0, len(result.Content))
138	for _, v := range result.Content {
139		if v, ok := v.(mcp.TextContent); ok {
140			output = append(output, v.Text)
141		} else {
142			output = append(output, fmt.Sprintf("%v", v))
143		}
144	}
145	return tools.NewTextResponse(strings.Join(output, "\n")), nil
146}
147
148func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
149	c, ok := mcpClients.Get(name)
150	if !ok {
151		return nil, fmt.Errorf("mcp '%s' not available", name)
152	}
153
154	m := config.Get().MCP[name]
155	state, _ := mcpStates.Get(name)
156
157	pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
158	defer cancel()
159	err := c.Ping(pingCtx)
160	if err == nil {
161		return c, nil
162	}
163	updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
164
165	c, err = createAndInitializeClient(ctx, name, m)
166	if err != nil {
167		return nil, err
168	}
169
170	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
171	mcpClients.Set(name, c)
172	return c, nil
173}
174
175func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
176	sessionID, messageID := tools.GetContextValues(ctx)
177	if sessionID == "" || messageID == "" {
178		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
179	}
180	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
181	p := b.permissions.Request(
182		permission.CreatePermissionRequest{
183			SessionID:   sessionID,
184			ToolCallID:  params.ID,
185			Path:        b.workingDir,
186			ToolName:    b.Info().Name,
187			Action:      "execute",
188			Description: permissionDescription,
189			Params:      params.Input,
190		},
191	)
192	if !p {
193		return tools.ToolResponse{}, permission.ErrorPermissionDenied
194	}
195
196	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
197}
198
199func createToolsMaker(ctx context.Context, permissions permission.Service, workingDir string) func(string, []mcp.Tool) []tools.BaseTool {
200	return func(name string, mcpToolsList []mcp.Tool) []tools.BaseTool {
201		mcpTools := make([]tools.BaseTool, 0, len(mcpToolsList))
202		for _, tool := range mcpToolsList {
203			mcpTools = append(mcpTools, &McpTool{
204				mcpName:     name,
205				tool:        tool,
206				permissions: permissions,
207				workingDir:  workingDir,
208			})
209		}
210		return mcpTools
211	}
212}
213
214func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTool {
215	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
216	if err != nil {
217		slog.Error("error listing tools", "error", err)
218		updateMCPState(name, MCPStateError, err, nil, 0)
219		c.Close()
220		mcpClients.Del(name)
221		return nil
222	}
223	return toolsMaker(name, result.Tools)
224}
225
226// SubscribeMCPEvents returns a channel for MCP events
227func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
228	return mcpBroker.Subscribe(ctx)
229}
230
231// GetMCPStates returns the current state of all MCP clients
232func GetMCPStates() map[string]MCPClientInfo {
233	return maps.Collect(mcpStates.Seq2())
234}
235
236// GetMCPState returns the state of a specific MCP client
237func GetMCPState(name string) (MCPClientInfo, bool) {
238	return mcpStates.Get(name)
239}
240
241// updateMCPState updates the state of an MCP client and publishes an event
242func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
243	info := MCPClientInfo{
244		Name:      name,
245		State:     state,
246		Error:     err,
247		Client:    client,
248		ToolCount: toolCount,
249	}
250	if state == MCPStateConnected {
251		info.ConnectedAt = time.Now()
252	}
253	mcpStates.Set(name, info)
254
255	// Publish state change event
256	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
257		Type:      MCPEventStateChanged,
258		Name:      name,
259		State:     state,
260		Error:     err,
261		ToolCount: toolCount,
262	})
263}
264
265// publishMCPEventToolsListChanged publishes a tool list changed event
266func publishMCPEventToolsListChanged(name string) {
267	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
268		Type: MCPEventToolsListChanged,
269		Name: name,
270	})
271}
272
273// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
274func CloseMCPClients() {
275	for c := range mcpClients.Seq() {
276		_ = c.Close()
277	}
278	mcpBroker.Shutdown()
279}
280
281var mcpInitRequest = mcp.InitializeRequest{
282	Params: mcp.InitializeParams{
283		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
284		ClientInfo: mcp.Implementation{
285			Name:    "Crush",
286			Version: version.Version,
287		},
288	},
289}
290
291func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
292	var wg sync.WaitGroup
293	result := csync.NewSlice[tools.BaseTool]()
294
295	toolsMaker = createToolsMaker(ctx, permissions, cfg.WorkingDir())
296
297	// Initialize states for all configured MCPs
298	for name, m := range cfg.MCP {
299		if m.Disabled {
300			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
301			slog.Debug("skipping disabled mcp", "name", name)
302			continue
303		}
304
305		// Set initial starting state
306		updateMCPState(name, MCPStateStarting, nil, nil, 0)
307
308		wg.Add(1)
309		go func(name string, m config.MCPConfig) {
310			defer func() {
311				wg.Done()
312				if r := recover(); r != nil {
313					var err error
314					switch v := r.(type) {
315					case error:
316						err = v
317					case string:
318						err = fmt.Errorf("panic: %s", v)
319					default:
320						err = fmt.Errorf("panic: %v", v)
321					}
322					updateMCPState(name, MCPStateError, err, nil, 0)
323					slog.Error("panic in mcp client initialization", "error", err, "name", name)
324				}
325			}()
326
327			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
328			defer cancel()
329			c, err := createAndInitializeClient(ctx, name, m)
330			if err != nil {
331				return
332			}
333
334			mcpClients.Set(name, c)
335
336			tools := getTools(ctx, name, c)
337			result.Append(tools...)
338			updateMcpTools(name, tools)
339			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
340		}(name, m)
341	}
342	wg.Wait()
343
344	return slices.Collect(result.Seq())
345}
346
347// updateMcpTools updates the global mcpTools and mcpClientTools maps
348func updateMcpTools(mcpName string, tools []tools.BaseTool) {
349	toolNames := make([]string, 0, len(tools))
350	for _, tool := range tools {
351		name := tool.Name()
352		if _, ok := mcpTools.Get(name); !ok {
353			slog.Info("Added MCP tool", "name", name, "mcp", mcpName)
354		}
355		mcpTools.Set(name, tool)
356		toolNames = append(toolNames, name)
357	}
358
359	// remove the tools that are no longer available
360	old, ok := mcpClientTools.Get(mcpName)
361	if ok {
362		slices.Sort(toolNames)
363		for _, name := range old {
364			if _, ok := slices.BinarySearch(toolNames, name); !ok {
365				mcpTools.Del(name)
366				slog.Info("Removed MCP tool", "name", name, "mcp", mcpName)
367			}
368		}
369	}
370	mcpClientTools.Set(mcpName, toolNames)
371}
372
373func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
374	c, err := createMcpClient(m)
375	if err != nil {
376		updateMCPState(name, MCPStateError, err, nil, 0)
377		slog.Error("error creating mcp client", "error", err, "name", name)
378		return nil, err
379	}
380
381	c.OnNotification(func(n mcp.JSONRPCNotification) {
382		slog.Debug("Received MCP notification", "name", name, "notification", n)
383		switch n.Method {
384		case "notifications/tools/list_changed":
385			publishMCPEventToolsListChanged(name)
386		default:
387			slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
388		}
389	})
390
391	if err := c.Start(ctx); err != nil {
392		updateMCPState(name, MCPStateError, err, nil, 0)
393		slog.Error("error starting mcp client", "error", err, "name", name)
394		_ = c.Close()
395		return nil, err
396	}
397
398	if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
399		updateMCPState(name, MCPStateError, err, nil, 0)
400		slog.Error("error initializing mcp client", "error", err, "name", name)
401		_ = c.Close()
402		return nil, err
403	}
404
405	slog.Info("Initialized mcp client", "name", name)
406	return c, nil
407}
408
409func createMcpClient(m config.MCPConfig) (*client.Client, error) {
410	switch m.Type {
411	case config.MCPStdio:
412		if strings.TrimSpace(m.Command) == "" {
413			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
414		}
415		return client.NewStdioMCPClientWithOptions(
416			m.Command,
417			m.ResolvedEnv(),
418			m.Args,
419			transport.WithCommandLogger(mcpLogger{}),
420		)
421	case config.MCPHttp:
422		if strings.TrimSpace(m.URL) == "" {
423			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
424		}
425		return client.NewStreamableHttpClient(
426			m.URL,
427			transport.WithHTTPHeaders(m.ResolvedHeaders()),
428			transport.WithHTTPLogger(mcpLogger{}),
429		)
430	case config.MCPSse:
431		if strings.TrimSpace(m.URL) == "" {
432			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
433		}
434		return client.NewSSEMCPClient(
435			m.URL,
436			client.WithHeaders(m.ResolvedHeaders()),
437			transport.WithSSELogger(mcpLogger{}),
438		)
439	default:
440		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
441	}
442}
443
444// for MCP's clients.
445type mcpLogger struct{}
446
447func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
448func (l mcpLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }
449
450func mcpTimeout(m config.MCPConfig) time.Duration {
451	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
452}