mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"slices"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/home"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/version"
 23	"github.com/mark3labs/mcp-go/client"
 24	"github.com/mark3labs/mcp-go/client/transport"
 25	"github.com/mark3labs/mcp-go/mcp"
 26)
 27
 28// MCPState represents the current state of an MCP client
 29type MCPState int
 30
 31const (
 32	MCPStateDisabled MCPState = iota
 33	MCPStateStarting
 34	MCPStateConnected
 35	MCPStateError
 36)
 37
 38func (s MCPState) String() string {
 39	switch s {
 40	case MCPStateDisabled:
 41		return "disabled"
 42	case MCPStateStarting:
 43		return "starting"
 44	case MCPStateConnected:
 45		return "connected"
 46	case MCPStateError:
 47		return "error"
 48	default:
 49		return "unknown"
 50	}
 51}
 52
 53// MCPEventType represents the type of MCP event
 54type MCPEventType string
 55
 56const (
 57	MCPEventStateChanged MCPEventType = "state_changed"
 58)
 59
 60// MCPEvent represents an event in the MCP system
 61type MCPEvent struct {
 62	Type      MCPEventType
 63	Name      string
 64	State     MCPState
 65	Error     error
 66	ToolCount int
 67}
 68
 69// MCPClientInfo holds information about an MCP client's state
 70type MCPClientInfo struct {
 71	Name        string
 72	State       MCPState
 73	Error       error
 74	Client      *client.Client
 75	ToolCount   int
 76	ConnectedAt time.Time
 77}
 78
 79var (
 80	mcpToolsOnce sync.Once
 81	mcpTools     []tools.BaseTool
 82	mcpClients   = csync.NewMap[string, *client.Client]()
 83	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 84	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 85)
 86
 87type McpTool struct {
 88	mcpName     string
 89	tool        mcp.Tool
 90	permissions permission.Service
 91	workingDir  string
 92}
 93
 94func (b *McpTool) Name() string {
 95	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 96}
 97
 98func (b *McpTool) Info() tools.ToolInfo {
 99	required := b.tool.InputSchema.Required
100	if required == nil {
101		required = make([]string, 0)
102	}
103	parameters := b.tool.InputSchema.Properties
104	if parameters == nil {
105		parameters = make(map[string]any)
106	}
107	return tools.ToolInfo{
108		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
109		Description: b.tool.Description,
110		Parameters:  parameters,
111		Required:    required,
112	}
113}
114
115func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
116	var args map[string]any
117	if err := json.Unmarshal([]byte(input), &args); err != nil {
118		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
119	}
120
121	c, err := getOrRenewClient(ctx, name)
122	if err != nil {
123		return tools.NewTextErrorResponse(err.Error()), nil
124	}
125	result, err := c.CallTool(ctx, mcp.CallToolRequest{
126		Params: mcp.CallToolParams{
127			Name:      toolName,
128			Arguments: args,
129		},
130	})
131	if err != nil {
132		return tools.NewTextErrorResponse(err.Error()), nil
133	}
134
135	output := make([]string, 0, len(result.Content))
136	for _, v := range result.Content {
137		if v, ok := v.(mcp.TextContent); ok {
138			output = append(output, v.Text)
139		} else {
140			output = append(output, fmt.Sprintf("%v", v))
141		}
142	}
143	return tools.NewTextResponse(strings.Join(output, "\n")), nil
144}
145
146func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
147	c, ok := mcpClients.Get(name)
148	if !ok {
149		return nil, fmt.Errorf("mcp '%s' not available", name)
150	}
151
152	cfg := config.Get()
153	m := cfg.MCP[name]
154	state, _ := mcpStates.Get(name)
155
156	timeout := mcpTimeout(m)
157	pingCtx, cancel := context.WithTimeout(ctx, timeout)
158	defer cancel()
159	err := c.Ping(pingCtx)
160	if err == nil {
161		return c, nil
162	}
163	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
164
165	c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
166	if err != nil {
167		return nil, err
168	}
169
170	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
171	mcpClients.Set(name, c)
172	return c, nil
173}
174
175func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
176	sessionID, messageID := tools.GetContextValues(ctx)
177	if sessionID == "" || messageID == "" {
178		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
179	}
180	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
181	p := b.permissions.Request(
182		permission.CreatePermissionRequest{
183			SessionID:   sessionID,
184			ToolCallID:  params.ID,
185			Path:        b.workingDir,
186			ToolName:    b.Info().Name,
187			Action:      "execute",
188			Description: permissionDescription,
189			Params:      params.Input,
190		},
191	)
192	if !p {
193		return tools.ToolResponse{}, permission.ErrorPermissionDenied
194	}
195
196	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
197}
198
199func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]tools.BaseTool, error) {
200	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
201	if err != nil {
202		return nil, err
203	}
204	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
205	for _, tool := range result.Tools {
206		mcpTools = append(mcpTools, &McpTool{
207			mcpName:     name,
208			tool:        tool,
209			permissions: permissions,
210			workingDir:  workingDir,
211		})
212	}
213	return mcpTools, nil
214}
215
216// SubscribeMCPEvents returns a channel for MCP events
217func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
218	return mcpBroker.Subscribe(ctx)
219}
220
221// GetMCPStates returns the current state of all MCP clients
222func GetMCPStates() map[string]MCPClientInfo {
223	return maps.Collect(mcpStates.Seq2())
224}
225
226// GetMCPState returns the state of a specific MCP client
227func GetMCPState(name string) (MCPClientInfo, bool) {
228	return mcpStates.Get(name)
229}
230
231// updateMCPState updates the state of an MCP client and publishes an event
232func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
233	info := MCPClientInfo{
234		Name:      name,
235		State:     state,
236		Error:     err,
237		Client:    client,
238		ToolCount: toolCount,
239	}
240	if state == MCPStateConnected {
241		info.ConnectedAt = time.Now()
242	}
243	mcpStates.Set(name, info)
244
245	// Publish state change event
246	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
247		Type:      MCPEventStateChanged,
248		Name:      name,
249		State:     state,
250		Error:     err,
251		ToolCount: toolCount,
252	})
253}
254
255// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
256func CloseMCPClients() error {
257	var errs []error
258	for name, c := range mcpClients.Seq2() {
259		if err := c.Close(); err != nil {
260			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
261		}
262	}
263	mcpBroker.Shutdown()
264	return errors.Join(errs...)
265}
266
267var mcpInitRequest = mcp.InitializeRequest{
268	Params: mcp.InitializeParams{
269		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
270		ClientInfo: mcp.Implementation{
271			Name:    "Crush",
272			Version: version.Version,
273		},
274	},
275}
276
277func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
278	var wg sync.WaitGroup
279	result := csync.NewSlice[tools.BaseTool]()
280
281	// Initialize states for all configured MCPs
282	for name, m := range cfg.MCP {
283		if m.Disabled {
284			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
285			slog.Debug("skipping disabled mcp", "name", name)
286			continue
287		}
288
289		// Set initial starting state
290		updateMCPState(name, MCPStateStarting, nil, nil, 0)
291
292		wg.Add(1)
293		go func(name string, m config.MCPConfig) {
294			defer func() {
295				wg.Done()
296				if r := recover(); r != nil {
297					var err error
298					switch v := r.(type) {
299					case error:
300						err = v
301					case string:
302						err = fmt.Errorf("panic: %s", v)
303					default:
304						err = fmt.Errorf("panic: %v", v)
305					}
306					updateMCPState(name, MCPStateError, err, nil, 0)
307					slog.Error("panic in mcp client initialization", "error", err, "name", name)
308				}
309			}()
310
311			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
312			defer cancel()
313
314			c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
315			if err != nil {
316				return
317			}
318
319			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
320			if err != nil {
321				slog.Error("error listing tools", "error", err)
322				updateMCPState(name, MCPStateError, err, nil, 0)
323				c.Close()
324				return
325			}
326
327			mcpClients.Set(name, c)
328			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
329			result.Append(tools...)
330		}(name, m)
331	}
332	wg.Wait()
333	return slices.Collect(result.Seq())
334}
335
336func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
337	c, err := createMcpClient(name, m, resolver)
338	if err != nil {
339		updateMCPState(name, MCPStateError, err, nil, 0)
340		slog.Error("error creating mcp client", "error", err, "name", name)
341		return nil, err
342	}
343
344	// XXX: ideally we should be able to use context.WithTimeout here, but,
345	// the SSE MCP client will start failing once that context is canceled.
346	timeout := mcpTimeout(m)
347	mcpCtx, cancel := context.WithCancel(ctx)
348	cancelTimer := time.AfterFunc(timeout, cancel)
349	if err := c.Start(mcpCtx); err != nil {
350		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
351		slog.Error("error starting mcp client", "error", err, "name", name)
352		_ = c.Close()
353		cancel()
354		return nil, err
355	}
356	if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil {
357		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
358		slog.Error("error initializing mcp client", "error", err, "name", name)
359		_ = c.Close()
360		cancel()
361		return nil, err
362	}
363	cancelTimer.Stop()
364	slog.Info("Initialized mcp client", "name", name)
365	return c, nil
366}
367
368func maybeTimeoutErr(err error, timeout time.Duration) error {
369	if errors.Is(err, context.Canceled) {
370		return fmt.Errorf("timed out after %s", timeout)
371	}
372	return err
373}
374
375func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
376	switch m.Type {
377	case config.MCPStdio:
378		command, err := resolver.ResolveValue(m.Command)
379		if err != nil {
380			return nil, fmt.Errorf("invalid mcp command: %w", err)
381		}
382		if strings.TrimSpace(command) == "" {
383			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
384		}
385		return client.NewStdioMCPClientWithOptions(
386			home.Long(command),
387			m.ResolvedEnv(),
388			m.Args,
389			transport.WithCommandLogger(mcpLogger{name: name}),
390		)
391	case config.MCPHttp:
392		if strings.TrimSpace(m.URL) == "" {
393			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
394		}
395		return client.NewStreamableHttpClient(
396			m.URL,
397			transport.WithHTTPHeaders(m.ResolvedHeaders()),
398			transport.WithHTTPLogger(mcpLogger{name: name}),
399		)
400	case config.MCPSse:
401		if strings.TrimSpace(m.URL) == "" {
402			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
403		}
404		return client.NewSSEMCPClient(
405			m.URL,
406			client.WithHeaders(m.ResolvedHeaders()),
407			transport.WithSSELogger(mcpLogger{name: name}),
408		)
409	default:
410		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
411	}
412}
413
414// for MCP's clients.
415type mcpLogger struct{ name string }
416
417func (l mcpLogger) Errorf(format string, v ...any) {
418	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
419}
420
421func (l mcpLogger) Infof(format string, v ...any) {
422	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
423}
424
425func mcpTimeout(m config.MCPConfig) time.Duration {
426	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
427}