mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/home"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/permission"
 20	"github.com/charmbracelet/crush/internal/pubsub"
 21	"github.com/charmbracelet/crush/internal/version"
 22	"github.com/mark3labs/mcp-go/client"
 23	"github.com/mark3labs/mcp-go/client/transport"
 24	"github.com/mark3labs/mcp-go/mcp"
 25)
 26
 27// MCPState represents the current state of an MCP client
 28type MCPState int
 29
 30const (
 31	MCPStateDisabled MCPState = iota
 32	MCPStateStarting
 33	MCPStateConnected
 34	MCPStateError
 35)
 36
 37func (s MCPState) String() string {
 38	switch s {
 39	case MCPStateDisabled:
 40		return "disabled"
 41	case MCPStateStarting:
 42		return "starting"
 43	case MCPStateConnected:
 44		return "connected"
 45	case MCPStateError:
 46		return "error"
 47	default:
 48		return "unknown"
 49	}
 50}
 51
 52// MCPEventType represents the type of MCP event
 53type MCPEventType string
 54
 55const (
 56	MCPEventStateChanged     MCPEventType = "state_changed"
 57	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 58)
 59
 60// MCPEvent represents an event in the MCP system
 61type MCPEvent struct {
 62	Type      MCPEventType
 63	Name      string
 64	State     MCPState
 65	Error     error
 66	ToolCount int
 67}
 68
 69// MCPClientInfo holds information about an MCP client's state
 70type MCPClientInfo struct {
 71	Name        string
 72	State       MCPState
 73	Error       error
 74	Client      *client.Client
 75	ToolCount   int
 76	ConnectedAt time.Time
 77}
 78
 79var (
 80	mcpToolsOnce    sync.Once
 81	mcpTools                                                  = csync.NewMap[string, tools.BaseTool]()
 82	mcpClient2Tools                                           = csync.NewMap[string, []tools.BaseTool]()
 83	mcpClients                                                = csync.NewMap[string, *client.Client]()
 84	mcpStates                                                 = csync.NewMap[string, MCPClientInfo]()
 85	mcpBroker                                                 = pubsub.NewBroker[MCPEvent]()
 86	toolsMaker      func(string, []mcp.Tool) []tools.BaseTool = nil
 87)
 88
 89type McpTool struct {
 90	mcpName     string
 91	tool        mcp.Tool
 92	permissions permission.Service
 93	workingDir  string
 94}
 95
 96func (b *McpTool) Name() string {
 97	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 98}
 99
100func (b *McpTool) Info() tools.ToolInfo {
101	required := b.tool.InputSchema.Required
102	if required == nil {
103		required = make([]string, 0)
104	}
105	parameters := b.tool.InputSchema.Properties
106	if parameters == nil {
107		parameters = make(map[string]any)
108	}
109	return tools.ToolInfo{
110		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
111		Description: b.tool.Description,
112		Parameters:  parameters,
113		Required:    required,
114	}
115}
116
117func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
118	var args map[string]any
119	if err := json.Unmarshal([]byte(input), &args); err != nil {
120		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
121	}
122
123	c, err := getOrRenewClient(ctx, name)
124	if err != nil {
125		return tools.NewTextErrorResponse(err.Error()), nil
126	}
127	result, err := c.CallTool(ctx, mcp.CallToolRequest{
128		Params: mcp.CallToolParams{
129			Name:      toolName,
130			Arguments: args,
131		},
132	})
133	if err != nil {
134		return tools.NewTextErrorResponse(err.Error()), nil
135	}
136
137	output := make([]string, 0, len(result.Content))
138	for _, v := range result.Content {
139		if v, ok := v.(mcp.TextContent); ok {
140			output = append(output, v.Text)
141		} else {
142			output = append(output, fmt.Sprintf("%v", v))
143		}
144	}
145	return tools.NewTextResponse(strings.Join(output, "\n")), nil
146}
147
148func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
149	c, ok := mcpClients.Get(name)
150	if !ok {
151		return nil, fmt.Errorf("mcp '%s' not available", name)
152	}
153
154	cfg := config.Get()
155	m := cfg.MCP[name]
156	state, _ := mcpStates.Get(name)
157
158	timeout := mcpTimeout(m)
159	pingCtx, cancel := context.WithTimeout(ctx, timeout)
160	defer cancel()
161	err := c.Ping(pingCtx)
162	if err == nil {
163		return c, nil
164	}
165	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
166
167	c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
168	if err != nil {
169		return nil, err
170	}
171
172	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
173	mcpClients.Set(name, c)
174	return c, nil
175}
176
177func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
178	sessionID, messageID := tools.GetContextValues(ctx)
179	if sessionID == "" || messageID == "" {
180		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
181	}
182	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
183	p := b.permissions.Request(
184		permission.CreatePermissionRequest{
185			SessionID:   sessionID,
186			ToolCallID:  params.ID,
187			Path:        b.workingDir,
188			ToolName:    b.Info().Name,
189			Action:      "execute",
190			Description: permissionDescription,
191			Params:      params.Input,
192		},
193	)
194	if !p {
195		return tools.ToolResponse{}, permission.ErrorPermissionDenied
196	}
197
198	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
199}
200
201func createToolsMaker(permissions permission.Service, workingDir string) func(string, []mcp.Tool) []tools.BaseTool {
202	return func(name string, mcpToolsList []mcp.Tool) []tools.BaseTool {
203		mcpTools := make([]tools.BaseTool, 0, len(mcpToolsList))
204		for _, tool := range mcpToolsList {
205			mcpTools = append(mcpTools, &McpTool{
206				mcpName:     name,
207				tool:        tool,
208				permissions: permissions,
209				workingDir:  workingDir,
210			})
211		}
212		return mcpTools
213	}
214}
215
216func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTool {
217	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
218	if err != nil {
219		slog.Error("error listing tools", "error", err)
220		updateMCPState(name, MCPStateError, err, nil, 0)
221		c.Close()
222		return nil
223	}
224	return toolsMaker(name, result.Tools)
225}
226
227// SubscribeMCPEvents returns a channel for MCP events
228func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
229	return mcpBroker.Subscribe(ctx)
230}
231
232// GetMCPStates returns the current state of all MCP clients
233func GetMCPStates() map[string]MCPClientInfo {
234	return maps.Collect(mcpStates.Seq2())
235}
236
237// GetMCPState returns the state of a specific MCP client
238func GetMCPState(name string) (MCPClientInfo, bool) {
239	return mcpStates.Get(name)
240}
241
242// updateMCPState updates the state of an MCP client and publishes an event
243func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
244	info := MCPClientInfo{
245		Name:      name,
246		State:     state,
247		Error:     err,
248		Client:    client,
249		ToolCount: toolCount,
250	}
251	switch state {
252	case MCPStateConnected:
253		info.ConnectedAt = time.Now()
254	case MCPStateError:
255		updateMcpTools(name, nil)
256		mcpClients.Del(name)
257	}
258	mcpStates.Set(name, info)
259
260	// Publish state change event
261	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
262		Type:      MCPEventStateChanged,
263		Name:      name,
264		State:     state,
265		Error:     err,
266		ToolCount: toolCount,
267	})
268}
269
270// publishMCPEventToolsListChanged publishes a tool list changed event
271func publishMCPEventToolsListChanged(name string) {
272	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
273		Type: MCPEventToolsListChanged,
274		Name: name,
275	})
276}
277
278// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
279func CloseMCPClients() error {
280	var errs []error
281	for name, c := range mcpClients.Seq2() {
282		if err := c.Close(); err != nil {
283			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
284		}
285	}
286	mcpBroker.Shutdown()
287	return errors.Join(errs...)
288}
289
290var mcpInitRequest = mcp.InitializeRequest{
291	Params: mcp.InitializeParams{
292		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
293		ClientInfo: mcp.Implementation{
294			Name:    "Crush",
295			Version: version.Version,
296		},
297	},
298}
299
300func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
301	var wg sync.WaitGroup
302
303	toolsMaker = createToolsMaker(permissions, cfg.WorkingDir())
304
305	// Initialize states for all configured MCPs
306	for name, m := range cfg.MCP {
307		if m.Disabled {
308			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
309			slog.Debug("skipping disabled mcp", "name", name)
310			continue
311		}
312
313		// Set initial starting state
314		updateMCPState(name, MCPStateStarting, nil, nil, 0)
315
316		wg.Add(1)
317		go func(name string, m config.MCPConfig) {
318			defer func() {
319				wg.Done()
320				if r := recover(); r != nil {
321					var err error
322					switch v := r.(type) {
323					case error:
324						err = v
325					case string:
326						err = fmt.Errorf("panic: %s", v)
327					default:
328						err = fmt.Errorf("panic: %v", v)
329					}
330					updateMCPState(name, MCPStateError, err, nil, 0)
331					slog.Error("panic in mcp client initialization", "error", err, "name", name)
332				}
333			}()
334
335			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
336			defer cancel()
337			c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
338			if err != nil {
339				return
340			}
341
342			mcpClients.Set(name, c)
343
344			tools := getTools(ctx, name, c)
345			updateMcpTools(name, tools)
346			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
347		}(name, m)
348	}
349	wg.Wait()
350}
351
352// updateMcpTools updates the global mcpTools and mcpClientTools maps
353func updateMcpTools(mcpName string, tools []tools.BaseTool) {
354	if len(tools) == 0 {
355		mcpClient2Tools.Del(mcpName)
356	} else {
357		mcpClient2Tools.Set(mcpName, tools)
358	}
359	for _, tools := range mcpClient2Tools.Seq2() {
360		for _, t := range tools {
361			mcpTools.Set(t.Name(), t)
362		}
363	}
364}
365
366func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
367	c, err := createMcpClient(name, m, resolver)
368	if err != nil {
369		updateMCPState(name, MCPStateError, err, nil, 0)
370		slog.Error("error creating mcp client", "error", err, "name", name)
371		return nil, err
372	}
373
374	c.OnNotification(func(n mcp.JSONRPCNotification) {
375		slog.Debug("Received MCP notification", "name", name, "notification", n)
376		switch n.Method {
377		case "notifications/tools/list_changed":
378			publishMCPEventToolsListChanged(name)
379		default:
380			slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
381		}
382	})
383
384	timeout := mcpTimeout(m)
385	initCtx, cancel := context.WithTimeout(ctx, timeout)
386	defer cancel()
387
388	if err := c.Start(ctx); err != nil {
389		updateMCPState(name, MCPStateError, err, nil, 0)
390		slog.Error("error starting mcp client", "error", err, "name", name)
391		_ = c.Close()
392		return nil, err
393	}
394
395	if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil {
396		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
397		slog.Error("error initializing mcp client", "error", err, "name", name)
398		_ = c.Close()
399		return nil, err
400	}
401
402	slog.Info("Initialized mcp client", "name", name)
403	return c, nil
404}
405
406func maybeTimeoutErr(err error, timeout time.Duration) error {
407	if errors.Is(err, context.DeadlineExceeded) {
408		return fmt.Errorf("timed out after %s", timeout)
409	}
410	return err
411}
412
413func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
414	switch m.Type {
415	case config.MCPStdio:
416		command, err := resolver.ResolveValue(m.Command)
417		if err != nil {
418			return nil, fmt.Errorf("invalid mcp command: %w", err)
419		}
420		if strings.TrimSpace(command) == "" {
421			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
422		}
423		return client.NewStdioMCPClientWithOptions(
424			home.Long(command),
425			m.ResolvedEnv(),
426			m.Args,
427			transport.WithCommandLogger(mcpLogger{name: name}),
428		)
429	case config.MCPHttp:
430		if strings.TrimSpace(m.URL) == "" {
431			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
432		}
433		return client.NewStreamableHttpClient(
434			m.URL,
435			transport.WithHTTPHeaders(m.ResolvedHeaders()),
436			transport.WithHTTPLogger(mcpLogger{name: name}),
437		)
438	case config.MCPSse:
439		if strings.TrimSpace(m.URL) == "" {
440			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
441		}
442		return client.NewSSEMCPClient(
443			m.URL,
444			client.WithHeaders(m.ResolvedHeaders()),
445			transport.WithSSELogger(mcpLogger{name: name}),
446		)
447	default:
448		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
449	}
450}
451
452// for MCP's clients.
453type mcpLogger struct{ name string }
454
455func (l mcpLogger) Errorf(format string, v ...any) {
456	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
457}
458
459func (l mcpLogger) Infof(format string, v ...any) {
460	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
461}
462
463func mcpTimeout(m config.MCPConfig) time.Duration {
464	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
465}