mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/home"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/permission"
 20	"github.com/charmbracelet/crush/internal/pubsub"
 21	"github.com/charmbracelet/crush/internal/version"
 22	"github.com/mark3labs/mcp-go/client"
 23	"github.com/mark3labs/mcp-go/client/transport"
 24	"github.com/mark3labs/mcp-go/mcp"
 25)
 26
 27// MCPState represents the current state of an MCP client
 28type MCPState int
 29
 30const (
 31	MCPStateDisabled MCPState = iota
 32	MCPStateStarting
 33	MCPStateConnected
 34	MCPStateError
 35)
 36
 37func (s MCPState) String() string {
 38	switch s {
 39	case MCPStateDisabled:
 40		return "disabled"
 41	case MCPStateStarting:
 42		return "starting"
 43	case MCPStateConnected:
 44		return "connected"
 45	case MCPStateError:
 46		return "error"
 47	default:
 48		return "unknown"
 49	}
 50}
 51
 52// MCPEventType represents the type of MCP event
 53type MCPEventType string
 54
 55const (
 56	MCPEventStateChanged     MCPEventType = "state_changed"
 57	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 58)
 59
 60// MCPEvent represents an event in the MCP system
 61type MCPEvent struct {
 62	Type      MCPEventType
 63	Name      string
 64	State     MCPState
 65	Error     error
 66	ToolCount int
 67}
 68
 69// MCPClientInfo holds information about an MCP client's state
 70type MCPClientInfo struct {
 71	Name        string
 72	State       MCPState
 73	Error       error
 74	Client      *client.Client
 75	ToolCount   int
 76	ConnectedAt time.Time
 77}
 78
 79var (
 80	mcpToolsOnce    sync.Once
 81	mcpTools        = csync.NewMap[string, tools.BaseTool]()
 82	mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
 83	mcpClients      = csync.NewMap[string, *client.Client]()
 84	mcpStates       = csync.NewMap[string, MCPClientInfo]()
 85	mcpBroker       = pubsub.NewBroker[MCPEvent]()
 86)
 87
 88type McpTool struct {
 89	mcpName     string
 90	tool        mcp.Tool
 91	permissions permission.Service
 92	workingDir  string
 93}
 94
 95func (b *McpTool) Name() string {
 96	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 97}
 98
 99func (b *McpTool) Info() tools.ToolInfo {
100	required := b.tool.InputSchema.Required
101	if required == nil {
102		required = make([]string, 0)
103	}
104	parameters := b.tool.InputSchema.Properties
105	if parameters == nil {
106		parameters = make(map[string]any)
107	}
108	return tools.ToolInfo{
109		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
110		Description: b.tool.Description,
111		Parameters:  parameters,
112		Required:    required,
113	}
114}
115
116func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
117	var args map[string]any
118	if err := json.Unmarshal([]byte(input), &args); err != nil {
119		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
120	}
121
122	c, err := getOrRenewClient(ctx, name)
123	if err != nil {
124		return tools.NewTextErrorResponse(err.Error()), nil
125	}
126	result, err := c.CallTool(ctx, mcp.CallToolRequest{
127		Params: mcp.CallToolParams{
128			Name:      toolName,
129			Arguments: args,
130		},
131	})
132	if err != nil {
133		return tools.NewTextErrorResponse(err.Error()), nil
134	}
135
136	output := make([]string, 0, len(result.Content))
137	for _, v := range result.Content {
138		if v, ok := v.(mcp.TextContent); ok {
139			output = append(output, v.Text)
140		} else {
141			output = append(output, fmt.Sprintf("%v", v))
142		}
143	}
144	return tools.NewTextResponse(strings.Join(output, "\n")), nil
145}
146
147func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
148	c, ok := mcpClients.Get(name)
149	if !ok {
150		return nil, fmt.Errorf("mcp '%s' not available", name)
151	}
152
153	cfg := config.Get()
154	m := cfg.MCP[name]
155	state, _ := mcpStates.Get(name)
156
157	timeout := mcpTimeout(m)
158	pingCtx, cancel := context.WithTimeout(ctx, timeout)
159	defer cancel()
160	err := c.Ping(pingCtx)
161	if err == nil {
162		return c, nil
163	}
164	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
165
166	c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
167	if err != nil {
168		return nil, err
169	}
170
171	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
172	mcpClients.Set(name, c)
173	return c, nil
174}
175
176func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
177	sessionID, messageID := tools.GetContextValues(ctx)
178	if sessionID == "" || messageID == "" {
179		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
180	}
181	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
182	p := b.permissions.Request(
183		permission.CreatePermissionRequest{
184			SessionID:   sessionID,
185			ToolCallID:  params.ID,
186			Path:        b.workingDir,
187			ToolName:    b.Info().Name,
188			Action:      "execute",
189			Description: permissionDescription,
190			Params:      params.Input,
191		},
192	)
193	if !p {
194		return tools.ToolResponse{}, permission.ErrorPermissionDenied
195	}
196
197	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
198}
199
200func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]tools.BaseTool, error) {
201	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
202	if err != nil {
203		return nil, err
204	}
205	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
206	for _, tool := range result.Tools {
207		mcpTools = append(mcpTools, &McpTool{
208			mcpName:     name,
209			tool:        tool,
210			permissions: permissions,
211			workingDir:  workingDir,
212		})
213	}
214	return mcpTools, nil
215}
216
217// SubscribeMCPEvents returns a channel for MCP events
218func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
219	return mcpBroker.Subscribe(ctx)
220}
221
222// GetMCPStates returns the current state of all MCP clients
223func GetMCPStates() map[string]MCPClientInfo {
224	return maps.Collect(mcpStates.Seq2())
225}
226
227// GetMCPState returns the state of a specific MCP client
228func GetMCPState(name string) (MCPClientInfo, bool) {
229	return mcpStates.Get(name)
230}
231
232// updateMCPState updates the state of an MCP client and publishes an event
233func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
234	info := MCPClientInfo{
235		Name:      name,
236		State:     state,
237		Error:     err,
238		Client:    client,
239		ToolCount: toolCount,
240	}
241	switch state {
242	case MCPStateConnected:
243		info.ConnectedAt = time.Now()
244	case MCPStateError:
245		updateMcpTools(name, nil)
246		mcpClients.Del(name)
247	}
248	mcpStates.Set(name, info)
249
250	// Publish state change event
251	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
252		Type:      MCPEventStateChanged,
253		Name:      name,
254		State:     state,
255		Error:     err,
256		ToolCount: toolCount,
257	})
258}
259
260// publishMCPEventToolsListChanged publishes a tool list changed event
261func publishMCPEventToolsListChanged(name string) {
262	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
263		Type: MCPEventToolsListChanged,
264		Name: name,
265	})
266}
267
268// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
269func CloseMCPClients() error {
270	var errs []error
271	for name, c := range mcpClients.Seq2() {
272		if err := c.Close(); err != nil {
273			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
274		}
275	}
276	mcpBroker.Shutdown()
277	return errors.Join(errs...)
278}
279
280var mcpInitRequest = mcp.InitializeRequest{
281	Params: mcp.InitializeParams{
282		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
283		ClientInfo: mcp.Implementation{
284			Name:    "Crush",
285			Version: version.Version,
286		},
287	},
288}
289
290func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
291	var wg sync.WaitGroup
292	// Initialize states for all configured MCPs
293	for name, m := range cfg.MCP {
294		if m.Disabled {
295			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
296			slog.Debug("skipping disabled mcp", "name", name)
297			continue
298		}
299
300		// Set initial starting state
301		updateMCPState(name, MCPStateStarting, nil, nil, 0)
302
303		wg.Add(1)
304		go func(name string, m config.MCPConfig) {
305			defer func() {
306				wg.Done()
307				if r := recover(); r != nil {
308					var err error
309					switch v := r.(type) {
310					case error:
311						err = v
312					case string:
313						err = fmt.Errorf("panic: %s", v)
314					default:
315						err = fmt.Errorf("panic: %v", v)
316					}
317					updateMCPState(name, MCPStateError, err, nil, 0)
318					slog.Error("panic in mcp client initialization", "error", err, "name", name)
319				}
320			}()
321
322			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
323			defer cancel()
324
325			c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
326			if err != nil {
327				return
328			}
329
330			mcpClients.Set(name, c)
331
332			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
333			if err != nil {
334				slog.Error("error listing tools", "error", err)
335				updateMCPState(name, MCPStateError, err, nil, 0)
336				c.Close()
337				return
338			}
339
340			updateMcpTools(name, tools)
341			mcpClients.Set(name, c)
342			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
343		}(name, m)
344	}
345	wg.Wait()
346}
347
348// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
349func updateMcpTools(mcpName string, tools []tools.BaseTool) {
350	if len(tools) == 0 {
351		mcpClient2Tools.Del(mcpName)
352	} else {
353		mcpClient2Tools.Set(mcpName, tools)
354	}
355	for _, tools := range mcpClient2Tools.Seq2() {
356		for _, t := range tools {
357			mcpTools.Set(t.Name(), t)
358		}
359	}
360}
361
362func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
363	c, err := createMcpClient(name, m, resolver)
364	if err != nil {
365		updateMCPState(name, MCPStateError, err, nil, 0)
366		slog.Error("error creating mcp client", "error", err, "name", name)
367		return nil, err
368	}
369
370	c.OnNotification(func(n mcp.JSONRPCNotification) {
371		slog.Debug("Received MCP notification", "name", name, "notification", n)
372		switch n.Method {
373		case "notifications/tools/list_changed":
374			publishMCPEventToolsListChanged(name)
375		default:
376			slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
377		}
378	})
379
380	timeout := mcpTimeout(m)
381	initCtx, cancel := context.WithTimeout(ctx, timeout)
382	defer cancel()
383
384	if err := c.Start(initCtx); err != nil {
385		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
386		slog.Error("error starting mcp client", "error", err, "name", name)
387		_ = c.Close()
388		return nil, err
389	}
390
391	if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil {
392		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
393		slog.Error("error initializing mcp client", "error", err, "name", name)
394		_ = c.Close()
395		return nil, err
396	}
397
398	slog.Info("Initialized mcp client", "name", name)
399	return c, nil
400}
401
402func maybeTimeoutErr(err error, timeout time.Duration) error {
403	if errors.Is(err, context.DeadlineExceeded) {
404		return fmt.Errorf("timed out after %s", timeout)
405	}
406	return err
407}
408
409func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
410	switch m.Type {
411	case config.MCPStdio:
412		command, err := resolver.ResolveValue(m.Command)
413		if err != nil {
414			return nil, fmt.Errorf("invalid mcp command: %w", err)
415		}
416		if strings.TrimSpace(command) == "" {
417			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
418		}
419		return client.NewStdioMCPClientWithOptions(
420			home.Long(command),
421			m.ResolvedEnv(),
422			m.Args,
423			transport.WithCommandLogger(mcpLogger{name: name}),
424		)
425	case config.MCPHttp:
426		if strings.TrimSpace(m.URL) == "" {
427			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
428		}
429		return client.NewStreamableHttpClient(
430			m.URL,
431			transport.WithHTTPHeaders(m.ResolvedHeaders()),
432			transport.WithHTTPLogger(mcpLogger{name: name}),
433		)
434	case config.MCPSse:
435		if strings.TrimSpace(m.URL) == "" {
436			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
437		}
438		return client.NewSSEMCPClient(
439			m.URL,
440			client.WithHeaders(m.ResolvedHeaders()),
441			transport.WithSSELogger(mcpLogger{name: name}),
442		)
443	default:
444		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
445	}
446}
447
448// for MCP's clients.
449type mcpLogger struct{ name string }
450
451func (l mcpLogger) Errorf(format string, v ...any) {
452	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
453}
454
455func (l mcpLogger) Infof(format string, v ...any) {
456	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
457}
458
459func mcpTimeout(m config.MCPConfig) time.Duration {
460	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
461}