mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"slices"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/home"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/version"
 23	"github.com/mark3labs/mcp-go/client"
 24	"github.com/mark3labs/mcp-go/client/transport"
 25	"github.com/mark3labs/mcp-go/mcp"
 26)
 27
 28// MCPState represents the current state of an MCP client
 29type MCPState int
 30
 31const (
 32	MCPStateDisabled MCPState = iota
 33	MCPStateStarting
 34	MCPStateConnected
 35	MCPStateError
 36)
 37
 38func (s MCPState) String() string {
 39	switch s {
 40	case MCPStateDisabled:
 41		return "disabled"
 42	case MCPStateStarting:
 43		return "starting"
 44	case MCPStateConnected:
 45		return "connected"
 46	case MCPStateError:
 47		return "error"
 48	default:
 49		return "unknown"
 50	}
 51}
 52
 53// MCPEventType represents the type of MCP event
 54type MCPEventType string
 55
 56const (
 57	MCPEventStateChanged MCPEventType = "state_changed"
 58)
 59
 60// MCPEvent represents an event in the MCP system
 61type MCPEvent struct {
 62	Type      MCPEventType
 63	Name      string
 64	State     MCPState
 65	Error     error
 66	ToolCount int
 67}
 68
 69// MCPClientInfo holds information about an MCP client's state
 70type MCPClientInfo struct {
 71	Name        string
 72	State       MCPState
 73	Error       error
 74	Client      *client.Client
 75	ToolCount   int
 76	ConnectedAt time.Time
 77}
 78
 79var (
 80	mcpToolsOnce sync.Once
 81	mcpTools     []tools.BaseTool
 82	mcpClients   = csync.NewMap[string, *client.Client]()
 83	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 84	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 85)
 86
 87type McpTool struct {
 88	mcpName     string
 89	tool        mcp.Tool
 90	permissions permission.Service
 91	workingDir  string
 92}
 93
 94func (b *McpTool) Name() string {
 95	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 96}
 97
 98func (b *McpTool) Info() tools.ToolInfo {
 99	required := b.tool.InputSchema.Required
100	if required == nil {
101		required = make([]string, 0)
102	}
103	parameters := b.tool.InputSchema.Properties
104	if parameters == nil {
105		parameters = make(map[string]any)
106	}
107	return tools.ToolInfo{
108		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
109		Description: b.tool.Description,
110		Parameters:  parameters,
111		Required:    required,
112	}
113}
114
115func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
116	var args map[string]any
117	if err := json.Unmarshal([]byte(input), &args); err != nil {
118		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
119	}
120
121	c, err := getOrRenewClient(ctx, name)
122	if err != nil {
123		return tools.NewTextErrorResponse(err.Error()), nil
124	}
125	result, err := c.CallTool(ctx, mcp.CallToolRequest{
126		Params: mcp.CallToolParams{
127			Name:      toolName,
128			Arguments: args,
129		},
130	})
131	if err != nil {
132		return tools.NewTextErrorResponse(err.Error()), nil
133	}
134
135	output := make([]string, 0, len(result.Content))
136	for _, v := range result.Content {
137		if v, ok := v.(mcp.TextContent); ok {
138			output = append(output, v.Text)
139		} else {
140			output = append(output, fmt.Sprintf("%v", v))
141		}
142	}
143	return tools.NewTextResponse(strings.Join(output, "\n")), nil
144}
145
146func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
147	c, ok := mcpClients.Get(name)
148	if !ok {
149		return nil, fmt.Errorf("mcp '%s' not available", name)
150	}
151
152	m := config.Get().MCP[name]
153	state, _ := mcpStates.Get(name)
154
155	timeout := mcpTimeout(m)
156	pingCtx, cancel := context.WithTimeout(ctx, timeout)
157	defer cancel()
158	err := c.Ping(pingCtx)
159	if err == nil {
160		return c, nil
161	}
162	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
163
164	c, err = createAndInitializeClient(ctx, name, m)
165	if err != nil {
166		return nil, err
167	}
168
169	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
170	mcpClients.Set(name, c)
171	return c, nil
172}
173
174func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
175	sessionID, messageID := tools.GetContextValues(ctx)
176	if sessionID == "" || messageID == "" {
177		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
178	}
179	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
180	p := b.permissions.Request(
181		permission.CreatePermissionRequest{
182			SessionID:   sessionID,
183			ToolCallID:  params.ID,
184			Path:        b.workingDir,
185			ToolName:    b.Info().Name,
186			Action:      "execute",
187			Description: permissionDescription,
188			Params:      params.Input,
189		},
190	)
191	if !p {
192		return tools.ToolResponse{}, permission.ErrorPermissionDenied
193	}
194
195	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
196}
197
198func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
199	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
200	if err != nil {
201		slog.Error("error listing tools", "error", err)
202		updateMCPState(name, MCPStateError, err, nil, 0)
203		c.Close()
204		mcpClients.Del(name)
205		return nil
206	}
207	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
208	for _, tool := range result.Tools {
209		mcpTools = append(mcpTools, &McpTool{
210			mcpName:     name,
211			tool:        tool,
212			permissions: permissions,
213			workingDir:  workingDir,
214		})
215	}
216	return mcpTools
217}
218
219// SubscribeMCPEvents returns a channel for MCP events
220func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
221	return mcpBroker.Subscribe(ctx)
222}
223
224// GetMCPStates returns the current state of all MCP clients
225func GetMCPStates() map[string]MCPClientInfo {
226	return maps.Collect(mcpStates.Seq2())
227}
228
229// GetMCPState returns the state of a specific MCP client
230func GetMCPState(name string) (MCPClientInfo, bool) {
231	return mcpStates.Get(name)
232}
233
234// updateMCPState updates the state of an MCP client and publishes an event
235func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
236	info := MCPClientInfo{
237		Name:      name,
238		State:     state,
239		Error:     err,
240		Client:    client,
241		ToolCount: toolCount,
242	}
243	if state == MCPStateConnected {
244		info.ConnectedAt = time.Now()
245	}
246	mcpStates.Set(name, info)
247
248	// Publish state change event
249	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
250		Type:      MCPEventStateChanged,
251		Name:      name,
252		State:     state,
253		Error:     err,
254		ToolCount: toolCount,
255	})
256}
257
258// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
259func CloseMCPClients() error {
260	var errs []error
261	for c := range mcpClients.Seq() {
262		if err := c.Close(); err != nil {
263			errs = append(errs, err)
264		}
265	}
266	mcpBroker.Shutdown()
267	return errors.Join(errs...)
268}
269
270var mcpInitRequest = mcp.InitializeRequest{
271	Params: mcp.InitializeParams{
272		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
273		ClientInfo: mcp.Implementation{
274			Name:    "Crush",
275			Version: version.Version,
276		},
277	},
278}
279
280func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
281	var wg sync.WaitGroup
282	result := csync.NewSlice[tools.BaseTool]()
283
284	// Initialize states for all configured MCPs
285	for name, m := range cfg.MCP {
286		if m.Disabled {
287			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
288			slog.Debug("skipping disabled mcp", "name", name)
289			continue
290		}
291
292		// Set initial starting state
293		updateMCPState(name, MCPStateStarting, nil, nil, 0)
294
295		wg.Add(1)
296		go func(name string, m config.MCPConfig) {
297			defer func() {
298				wg.Done()
299				if r := recover(); r != nil {
300					var err error
301					switch v := r.(type) {
302					case error:
303						err = v
304					case string:
305						err = fmt.Errorf("panic: %s", v)
306					default:
307						err = fmt.Errorf("panic: %v", v)
308					}
309					updateMCPState(name, MCPStateError, err, nil, 0)
310					slog.Error("panic in mcp client initialization", "error", err, "name", name)
311				}
312			}()
313
314			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
315			defer cancel()
316			c, err := createAndInitializeClient(ctx, name, m)
317			if err != nil {
318				return
319			}
320			mcpClients.Set(name, c)
321
322			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
323			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
324			result.Append(tools...)
325		}(name, m)
326	}
327	wg.Wait()
328	return slices.Collect(result.Seq())
329}
330
331func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
332	c, err := createMcpClient(name, m)
333	if err != nil {
334		updateMCPState(name, MCPStateError, err, nil, 0)
335		slog.Error("error creating mcp client", "error", err, "name", name)
336		return nil, err
337	}
338
339	timeout := mcpTimeout(m)
340	initCtx, cancel := context.WithTimeout(ctx, timeout)
341	defer cancel()
342
343	// Only call Start() for non-stdio clients, as stdio clients auto-start
344	if m.Type != config.MCPStdio {
345		if err := c.Start(initCtx); err != nil {
346			updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
347			slog.Error("error starting mcp client", "error", err, "name", name)
348			_ = c.Close()
349			return nil, err
350		}
351	}
352	if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil {
353		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
354		slog.Error("error initializing mcp client", "error", err, "name", name)
355		_ = c.Close()
356		return nil, err
357	}
358
359	slog.Info("Initialized mcp client", "name", name)
360	return c, nil
361}
362
363func maybeTimeoutErr(err error, timeout time.Duration) error {
364	if errors.Is(err, context.DeadlineExceeded) {
365		return fmt.Errorf("timed out after %s", timeout)
366	}
367	return err
368}
369
370func createMcpClient(name string, m config.MCPConfig) (*client.Client, error) {
371	switch m.Type {
372	case config.MCPStdio:
373		if strings.TrimSpace(m.Command) == "" {
374			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
375		}
376		return client.NewStdioMCPClientWithOptions(
377			home.Long(m.Command),
378			m.ResolvedEnv(),
379			m.Args,
380			transport.WithCommandLogger(mcpLogger{name: name}),
381		)
382	case config.MCPHttp:
383		if strings.TrimSpace(m.URL) == "" {
384			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
385		}
386		return client.NewStreamableHttpClient(
387			m.URL,
388			transport.WithHTTPHeaders(m.ResolvedHeaders()),
389			transport.WithHTTPLogger(mcpLogger{name: name}),
390		)
391	case config.MCPSse:
392		if strings.TrimSpace(m.URL) == "" {
393			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
394		}
395		return client.NewSSEMCPClient(
396			m.URL,
397			client.WithHeaders(m.ResolvedHeaders()),
398			transport.WithSSELogger(mcpLogger{name: name}),
399		)
400	default:
401		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
402	}
403}
404
405// for MCP's clients.
406type mcpLogger struct{ name string }
407
408func (l mcpLogger) Errorf(format string, v ...any) {
409	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
410}
411
412func (l mcpLogger) Infof(format string, v ...any) {
413	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
414}
415
416func mcpTimeout(m config.MCPConfig) time.Duration {
417	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
418}