mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"slices"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/home"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/version"
 23	"github.com/mark3labs/mcp-go/client"
 24	"github.com/mark3labs/mcp-go/client/transport"
 25	"github.com/mark3labs/mcp-go/mcp"
 26)
 27
 28// MCPState represents the current state of an MCP client
 29type MCPState int
 30
 31const (
 32	MCPStateDisabled MCPState = iota
 33	MCPStateStarting
 34	MCPStateConnected
 35	MCPStateError
 36)
 37
 38func (s MCPState) String() string {
 39	switch s {
 40	case MCPStateDisabled:
 41		return "disabled"
 42	case MCPStateStarting:
 43		return "starting"
 44	case MCPStateConnected:
 45		return "connected"
 46	case MCPStateError:
 47		return "error"
 48	default:
 49		return "unknown"
 50	}
 51}
 52
 53// MCPEventType represents the type of MCP event
 54type MCPEventType string
 55
 56const (
 57	MCPEventStateChanged MCPEventType = "state_changed"
 58)
 59
 60// MCPEvent represents an event in the MCP system
 61type MCPEvent struct {
 62	Type      MCPEventType
 63	Name      string
 64	State     MCPState
 65	Error     error
 66	ToolCount int
 67}
 68
 69// MCPClientInfo holds information about an MCP client's state
 70type MCPClientInfo struct {
 71	Name        string
 72	State       MCPState
 73	Error       error
 74	Client      *client.Client
 75	ToolCount   int
 76	ConnectedAt time.Time
 77}
 78
 79var (
 80	mcpToolsOnce sync.Once
 81	mcpTools     []tools.BaseTool
 82	mcpClients   = csync.NewMap[string, *client.Client]()
 83	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 84	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 85)
 86
 87type McpTool struct {
 88	mcpName     string
 89	tool        mcp.Tool
 90	permissions permission.Service
 91	workingDir  string
 92}
 93
 94func (b *McpTool) Name() string {
 95	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 96}
 97
 98func (b *McpTool) Info() tools.ToolInfo {
 99	required := b.tool.InputSchema.Required
100	if required == nil {
101		required = make([]string, 0)
102	}
103	parameters := b.tool.InputSchema.Properties
104	if parameters == nil {
105		parameters = make(map[string]any)
106	}
107	return tools.ToolInfo{
108		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
109		Description: b.tool.Description,
110		Parameters:  parameters,
111		Required:    required,
112	}
113}
114
115func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
116	var args map[string]any
117	if err := json.Unmarshal([]byte(input), &args); err != nil {
118		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
119	}
120
121	c, err := getOrRenewClient(ctx, name)
122	if err != nil {
123		return tools.NewTextErrorResponse(err.Error()), nil
124	}
125	result, err := c.CallTool(ctx, mcp.CallToolRequest{
126		Params: mcp.CallToolParams{
127			Name:      toolName,
128			Arguments: args,
129		},
130	})
131	if err != nil {
132		return tools.NewTextErrorResponse(err.Error()), nil
133	}
134
135	output := make([]string, 0, len(result.Content))
136	for _, v := range result.Content {
137		if v, ok := v.(mcp.TextContent); ok {
138			output = append(output, v.Text)
139		} else {
140			output = append(output, fmt.Sprintf("%v", v))
141		}
142	}
143	return tools.NewTextResponse(strings.Join(output, "\n")), nil
144}
145
146func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
147	c, ok := mcpClients.Get(name)
148	if !ok {
149		return nil, fmt.Errorf("mcp '%s' not available", name)
150	}
151
152	cfg := config.Get()
153	m := cfg.MCP[name]
154	state, _ := mcpStates.Get(name)
155
156	timeout := mcpTimeout(m)
157	pingCtx, cancel := context.WithTimeout(ctx, timeout)
158	defer cancel()
159	err := c.Ping(pingCtx)
160	if err == nil {
161		return c, nil
162	}
163	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
164
165	c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
166	if err != nil {
167		return nil, err
168	}
169
170	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
171	mcpClients.Set(name, c)
172	return c, nil
173}
174
175func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
176	sessionID, messageID := tools.GetContextValues(ctx)
177	if sessionID == "" || messageID == "" {
178		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
179	}
180	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
181	p := b.permissions.Request(
182		permission.CreatePermissionRequest{
183			SessionID:   sessionID,
184			ToolCallID:  params.ID,
185			Path:        b.workingDir,
186			ToolName:    b.Info().Name,
187			Action:      "execute",
188			Description: permissionDescription,
189			Params:      params.Input,
190		},
191	)
192	if !p {
193		return tools.ToolResponse{}, permission.ErrorPermissionDenied
194	}
195
196	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
197}
198
199func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
200	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
201	if err != nil {
202		slog.Error("error listing tools", "error", err)
203		updateMCPState(name, MCPStateError, err, nil, 0)
204		c.Close()
205		mcpClients.Del(name)
206		return nil
207	}
208	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
209	for _, tool := range result.Tools {
210		mcpTools = append(mcpTools, &McpTool{
211			mcpName:     name,
212			tool:        tool,
213			permissions: permissions,
214			workingDir:  workingDir,
215		})
216	}
217	return mcpTools
218}
219
220// SubscribeMCPEvents returns a channel for MCP events
221func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
222	return mcpBroker.Subscribe(ctx)
223}
224
225// GetMCPStates returns the current state of all MCP clients
226func GetMCPStates() map[string]MCPClientInfo {
227	return maps.Collect(mcpStates.Seq2())
228}
229
230// GetMCPState returns the state of a specific MCP client
231func GetMCPState(name string) (MCPClientInfo, bool) {
232	return mcpStates.Get(name)
233}
234
235// updateMCPState updates the state of an MCP client and publishes an event
236func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
237	info := MCPClientInfo{
238		Name:      name,
239		State:     state,
240		Error:     err,
241		Client:    client,
242		ToolCount: toolCount,
243	}
244	if state == MCPStateConnected {
245		info.ConnectedAt = time.Now()
246	}
247	mcpStates.Set(name, info)
248
249	// Publish state change event
250	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
251		Type:      MCPEventStateChanged,
252		Name:      name,
253		State:     state,
254		Error:     err,
255		ToolCount: toolCount,
256	})
257}
258
259// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
260func CloseMCPClients() error {
261	var errs []error
262	for name, c := range mcpClients.Seq2() {
263		if err := c.Close(); err != nil {
264			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
265		}
266	}
267	mcpBroker.Shutdown()
268	return errors.Join(errs...)
269}
270
271var mcpInitRequest = mcp.InitializeRequest{
272	Params: mcp.InitializeParams{
273		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
274		ClientInfo: mcp.Implementation{
275			Name:    "Crush",
276			Version: version.Version,
277		},
278	},
279}
280
281func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
282	var wg sync.WaitGroup
283	result := csync.NewSlice[tools.BaseTool]()
284
285	// Initialize states for all configured MCPs
286	for name, m := range cfg.MCP {
287		if m.Disabled {
288			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
289			slog.Debug("skipping disabled mcp", "name", name)
290			continue
291		}
292
293		// Set initial starting state
294		updateMCPState(name, MCPStateStarting, nil, nil, 0)
295
296		wg.Add(1)
297		go func(name string, m config.MCPConfig) {
298			defer func() {
299				wg.Done()
300				if r := recover(); r != nil {
301					var err error
302					switch v := r.(type) {
303					case error:
304						err = v
305					case string:
306						err = fmt.Errorf("panic: %s", v)
307					default:
308						err = fmt.Errorf("panic: %v", v)
309					}
310					updateMCPState(name, MCPStateError, err, nil, 0)
311					slog.Error("panic in mcp client initialization", "error", err, "name", name)
312				}
313			}()
314
315			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
316			defer cancel()
317			c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
318			if err != nil {
319				return
320			}
321			mcpClients.Set(name, c)
322
323			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
324			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
325			result.Append(tools...)
326		}(name, m)
327	}
328	wg.Wait()
329	return slices.Collect(result.Seq())
330}
331
332func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
333	c, err := createMcpClient(name, m, resolver)
334	if err != nil {
335		updateMCPState(name, MCPStateError, err, nil, 0)
336		slog.Error("error creating mcp client", "error", err, "name", name)
337		return nil, err
338	}
339
340	timeout := mcpTimeout(m)
341	initCtx, cancel := context.WithTimeout(ctx, timeout)
342	defer cancel()
343
344	// Only call Start() for non-stdio clients, as stdio clients auto-start
345	if m.Type != config.MCPStdio {
346		if err := c.Start(initCtx); err != nil {
347			updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
348			slog.Error("error starting mcp client", "error", err, "name", name)
349			_ = c.Close()
350			return nil, err
351		}
352	}
353	if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil {
354		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
355		slog.Error("error initializing mcp client", "error", err, "name", name)
356		_ = c.Close()
357		return nil, err
358	}
359
360	slog.Info("Initialized mcp client", "name", name)
361	return c, nil
362}
363
364func maybeTimeoutErr(err error, timeout time.Duration) error {
365	if errors.Is(err, context.DeadlineExceeded) {
366		return fmt.Errorf("timed out after %s", timeout)
367	}
368	return err
369}
370
371func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
372	switch m.Type {
373	case config.MCPStdio:
374		command, err := resolver.ResolveValue(m.Command)
375		if err != nil {
376			return nil, fmt.Errorf("invalid mcp command: %w", err)
377		}
378		if strings.TrimSpace(command) == "" {
379			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
380		}
381		return client.NewStdioMCPClientWithOptions(
382			home.Long(command),
383			m.ResolvedEnv(),
384			m.Args,
385			transport.WithCommandLogger(mcpLogger{name: name}),
386		)
387	case config.MCPHttp:
388		if strings.TrimSpace(m.URL) == "" {
389			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
390		}
391		return client.NewStreamableHttpClient(
392			m.URL,
393			transport.WithHTTPHeaders(m.ResolvedHeaders()),
394			transport.WithHTTPLogger(mcpLogger{name: name}),
395		)
396	case config.MCPSse:
397		if strings.TrimSpace(m.URL) == "" {
398			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
399		}
400		return client.NewSSEMCPClient(
401			m.URL,
402			client.WithHeaders(m.ResolvedHeaders()),
403			transport.WithSSELogger(mcpLogger{name: name}),
404		)
405	default:
406		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
407	}
408}
409
410// for MCP's clients.
411type mcpLogger struct{ name string }
412
413func (l mcpLogger) Errorf(format string, v ...any) {
414	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
415}
416
417func (l mcpLogger) Infof(format string, v ...any) {
418	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
419}
420
421func mcpTimeout(m config.MCPConfig) time.Duration {
422	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
423}