mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"maps"
 10	"slices"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/llm/tools"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/version"
 21	"github.com/mark3labs/mcp-go/client"
 22	"github.com/mark3labs/mcp-go/client/transport"
 23	"github.com/mark3labs/mcp-go/mcp"
 24)
 25
 26// MCPState represents the current state of an MCP client
 27type MCPState int
 28
 29const (
 30	MCPStateDisabled MCPState = iota
 31	MCPStateStarting
 32	MCPStateConnected
 33	MCPStateError
 34)
 35
 36func (s MCPState) String() string {
 37	switch s {
 38	case MCPStateDisabled:
 39		return "disabled"
 40	case MCPStateStarting:
 41		return "starting"
 42	case MCPStateConnected:
 43		return "connected"
 44	case MCPStateError:
 45		return "error"
 46	default:
 47		return "unknown"
 48	}
 49}
 50
 51// MCPEventType represents the type of MCP event
 52type MCPEventType string
 53
 54const (
 55	MCPEventStateChanged     MCPEventType = "state_changed"
 56	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 57)
 58
 59// MCPEvent represents an event in the MCP system
 60type MCPEvent struct {
 61	Type      MCPEventType
 62	Name      string
 63	State     MCPState
 64	Error     error
 65	ToolCount int
 66}
 67
 68// MCPClientInfo holds information about an MCP client's state
 69type MCPClientInfo struct {
 70	Name        string
 71	State       MCPState
 72	Error       error
 73	Client      *client.Client
 74	ToolCount   int
 75	ConnectedAt time.Time
 76}
 77
 78var (
 79	mcpToolsOnce    sync.Once
 80	mcpTools                                                  = csync.NewMap[string, tools.BaseTool]()
 81	mcpClient2Tools                                           = csync.NewMap[string, []tools.BaseTool]()
 82	mcpClients                                                = csync.NewMap[string, *client.Client]()
 83	mcpStates                                                 = csync.NewMap[string, MCPClientInfo]()
 84	mcpBroker                                                 = pubsub.NewBroker[MCPEvent]()
 85	toolsMaker      func(string, []mcp.Tool) []tools.BaseTool = nil
 86)
 87
 88type McpTool struct {
 89	mcpName     string
 90	tool        mcp.Tool
 91	permissions permission.Service
 92	workingDir  string
 93}
 94
 95func (b *McpTool) Name() string {
 96	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 97}
 98
 99func (b *McpTool) Info() tools.ToolInfo {
100	required := b.tool.InputSchema.Required
101	if required == nil {
102		required = make([]string, 0)
103	}
104	parameters := b.tool.InputSchema.Properties
105	if parameters == nil {
106		parameters = make(map[string]any)
107	}
108	return tools.ToolInfo{
109		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
110		Description: b.tool.Description,
111		Parameters:  parameters,
112		Required:    required,
113	}
114}
115
116func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
117	var args map[string]any
118	if err := json.Unmarshal([]byte(input), &args); err != nil {
119		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
120	}
121
122	c, err := getOrRenewClient(ctx, name)
123	if err != nil {
124		return tools.NewTextErrorResponse(err.Error()), nil
125	}
126	result, err := c.CallTool(ctx, mcp.CallToolRequest{
127		Params: mcp.CallToolParams{
128			Name:      toolName,
129			Arguments: args,
130		},
131	})
132	if err != nil {
133		return tools.NewTextErrorResponse(err.Error()), nil
134	}
135
136	output := make([]string, 0, len(result.Content))
137	for _, v := range result.Content {
138		if v, ok := v.(mcp.TextContent); ok {
139			output = append(output, v.Text)
140		} else {
141			output = append(output, fmt.Sprintf("%v", v))
142		}
143	}
144	return tools.NewTextResponse(strings.Join(output, "\n")), nil
145}
146
147func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
148	c, ok := mcpClients.Get(name)
149	if !ok {
150		return nil, fmt.Errorf("mcp '%s' not available", name)
151	}
152
153	m := config.Get().MCP[name]
154	state, _ := mcpStates.Get(name)
155
156	pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
157	defer cancel()
158	err := c.Ping(pingCtx)
159	if err == nil {
160		return c, nil
161	}
162	updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
163
164	c, err = createAndInitializeClient(ctx, name, m)
165	if err != nil {
166		return nil, err
167	}
168
169	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
170	mcpClients.Set(name, c)
171	return c, nil
172}
173
174func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
175	sessionID, messageID := tools.GetContextValues(ctx)
176	if sessionID == "" || messageID == "" {
177		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
178	}
179	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
180	p := b.permissions.Request(
181		permission.CreatePermissionRequest{
182			SessionID:   sessionID,
183			ToolCallID:  params.ID,
184			Path:        b.workingDir,
185			ToolName:    b.Info().Name,
186			Action:      "execute",
187			Description: permissionDescription,
188			Params:      params.Input,
189		},
190	)
191	if !p {
192		return tools.ToolResponse{}, permission.ErrorPermissionDenied
193	}
194
195	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
196}
197
198func createToolsMaker(permissions permission.Service, workingDir string) func(string, []mcp.Tool) []tools.BaseTool {
199	return func(name string, mcpToolsList []mcp.Tool) []tools.BaseTool {
200		mcpTools := make([]tools.BaseTool, 0, len(mcpToolsList))
201		for _, tool := range mcpToolsList {
202			mcpTools = append(mcpTools, &McpTool{
203				mcpName:     name,
204				tool:        tool,
205				permissions: permissions,
206				workingDir:  workingDir,
207			})
208		}
209		return mcpTools
210	}
211}
212
213func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTool {
214	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
215	if err != nil {
216		slog.Error("error listing tools", "error", err)
217		updateMCPState(name, MCPStateError, err, nil, 0)
218		c.Close()
219		return nil
220	}
221	return toolsMaker(name, result.Tools)
222}
223
224// SubscribeMCPEvents returns a channel for MCP events
225func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
226	return mcpBroker.Subscribe(ctx)
227}
228
229// GetMCPStates returns the current state of all MCP clients
230func GetMCPStates() map[string]MCPClientInfo {
231	return maps.Collect(mcpStates.Seq2())
232}
233
234// GetMCPState returns the state of a specific MCP client
235func GetMCPState(name string) (MCPClientInfo, bool) {
236	return mcpStates.Get(name)
237}
238
239// updateMCPState updates the state of an MCP client and publishes an event
240func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
241	info := MCPClientInfo{
242		Name:      name,
243		State:     state,
244		Error:     err,
245		Client:    client,
246		ToolCount: toolCount,
247	}
248	switch state {
249	case MCPStateConnected:
250		info.ConnectedAt = time.Now()
251	case MCPStateError:
252		updateMcpTools(name, nil)
253		mcpClients.Del(name)
254	}
255	mcpStates.Set(name, info)
256
257	// Publish state change event
258	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
259		Type:      MCPEventStateChanged,
260		Name:      name,
261		State:     state,
262		Error:     err,
263		ToolCount: toolCount,
264	})
265}
266
267// publishMCPEventToolsListChanged publishes a tool list changed event
268func publishMCPEventToolsListChanged(name string) {
269	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
270		Type: MCPEventToolsListChanged,
271		Name: name,
272	})
273}
274
275// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
276func CloseMCPClients() {
277	for c := range mcpClients.Seq() {
278		_ = c.Close()
279	}
280	mcpBroker.Shutdown()
281}
282
283var mcpInitRequest = mcp.InitializeRequest{
284	Params: mcp.InitializeParams{
285		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
286		ClientInfo: mcp.Implementation{
287			Name:    "Crush",
288			Version: version.Version,
289		},
290	},
291}
292
293func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
294	var wg sync.WaitGroup
295	result := csync.NewSlice[tools.BaseTool]()
296
297	toolsMaker = createToolsMaker(permissions, cfg.WorkingDir())
298
299	// Initialize states for all configured MCPs
300	for name, m := range cfg.MCP {
301		if m.Disabled {
302			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
303			slog.Debug("skipping disabled mcp", "name", name)
304			continue
305		}
306
307		// Set initial starting state
308		updateMCPState(name, MCPStateStarting, nil, nil, 0)
309
310		wg.Add(1)
311		go func(name string, m config.MCPConfig) {
312			defer func() {
313				wg.Done()
314				if r := recover(); r != nil {
315					var err error
316					switch v := r.(type) {
317					case error:
318						err = v
319					case string:
320						err = fmt.Errorf("panic: %s", v)
321					default:
322						err = fmt.Errorf("panic: %v", v)
323					}
324					updateMCPState(name, MCPStateError, err, nil, 0)
325					slog.Error("panic in mcp client initialization", "error", err, "name", name)
326				}
327			}()
328
329			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
330			defer cancel()
331			c, err := createAndInitializeClient(ctx, name, m)
332			if err != nil {
333				return
334			}
335
336			mcpClients.Set(name, c)
337
338			tools := getTools(ctx, name, c)
339			result.Append(tools...)
340			updateMcpTools(name, tools)
341			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
342		}(name, m)
343	}
344	wg.Wait()
345
346	return slices.Collect(result.Seq())
347}
348
349// updateMcpTools updates the global mcpTools and mcpClientTools maps
350func updateMcpTools(mcpName string, tools []tools.BaseTool) {
351	if len(tools) == 0 {
352		mcpClient2Tools.Del(mcpName)
353	} else {
354		mcpClient2Tools.Set(mcpName, tools)
355	}
356	for _, tools := range mcpClient2Tools.Seq2() {
357		for _, t := range tools {
358			mcpTools.Set(t.Name(), t)
359		}
360	}
361}
362
363func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
364	c, err := createMcpClient(m)
365	if err != nil {
366		updateMCPState(name, MCPStateError, err, nil, 0)
367		slog.Error("error creating mcp client", "error", err, "name", name)
368		return nil, err
369	}
370
371	c.OnNotification(func(n mcp.JSONRPCNotification) {
372		slog.Debug("Received MCP notification", "name", name, "notification", n)
373		switch n.Method {
374		case "notifications/tools/list_changed":
375			publishMCPEventToolsListChanged(name)
376		default:
377			slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
378		}
379	})
380
381	if err := c.Start(ctx); err != nil {
382		updateMCPState(name, MCPStateError, err, nil, 0)
383		slog.Error("error starting mcp client", "error", err, "name", name)
384		_ = c.Close()
385		return nil, err
386	}
387
388	if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
389		updateMCPState(name, MCPStateError, err, nil, 0)
390		slog.Error("error initializing mcp client", "error", err, "name", name)
391		_ = c.Close()
392		return nil, err
393	}
394
395	slog.Info("Initialized mcp client", "name", name)
396	return c, nil
397}
398
399func createMcpClient(m config.MCPConfig) (*client.Client, error) {
400	switch m.Type {
401	case config.MCPStdio:
402		if strings.TrimSpace(m.Command) == "" {
403			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
404		}
405		return client.NewStdioMCPClientWithOptions(
406			m.Command,
407			m.ResolvedEnv(),
408			m.Args,
409			transport.WithCommandLogger(mcpLogger{}),
410		)
411	case config.MCPHttp:
412		if strings.TrimSpace(m.URL) == "" {
413			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
414		}
415		return client.NewStreamableHttpClient(
416			m.URL,
417			transport.WithHTTPHeaders(m.ResolvedHeaders()),
418			transport.WithHTTPLogger(mcpLogger{}),
419		)
420	case config.MCPSse:
421		if strings.TrimSpace(m.URL) == "" {
422			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
423		}
424		return client.NewSSEMCPClient(
425			m.URL,
426			client.WithHeaders(m.ResolvedHeaders()),
427			transport.WithSSELogger(mcpLogger{}),
428		)
429	default:
430		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
431	}
432}
433
434// for MCP's clients.
435type mcpLogger struct{}
436
437func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
438func (l mcpLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }
439
440func mcpTimeout(m config.MCPConfig) time.Duration {
441	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
442}