1package tools
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/home"
19 "github.com/charmbracelet/crush/internal/permission"
20 "github.com/charmbracelet/crush/internal/pubsub"
21 "github.com/charmbracelet/crush/internal/version"
22 "github.com/charmbracelet/fantasy/ai"
23 "github.com/mark3labs/mcp-go/client"
24 "github.com/mark3labs/mcp-go/client/transport"
25 "github.com/mark3labs/mcp-go/mcp"
26)
27
28// MCPState represents the current state of an MCP client
29type MCPState int
30
31const (
32 MCPStateDisabled MCPState = iota
33 MCPStateStarting
34 MCPStateConnected
35 MCPStateError
36)
37
38func (s MCPState) String() string {
39 switch s {
40 case MCPStateDisabled:
41 return "disabled"
42 case MCPStateStarting:
43 return "starting"
44 case MCPStateConnected:
45 return "connected"
46 case MCPStateError:
47 return "error"
48 default:
49 return "unknown"
50 }
51}
52
53// MCPEventType represents the type of MCP event
54type MCPEventType string
55
56const (
57 MCPEventStateChanged MCPEventType = "state_changed"
58 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
59)
60
61// MCPEvent represents an event in the MCP system
62type MCPEvent struct {
63 Type MCPEventType
64 Name string
65 State MCPState
66 Error error
67 ToolCount int
68}
69
70// MCPClientInfo holds information about an MCP client's state
71type MCPClientInfo struct {
72 Name string
73 State MCPState
74 Error error
75 Client *client.Client
76 ToolCount int
77 ConnectedAt time.Time
78}
79
80var (
81 mcpToolsOnce sync.Once
82 mcpTools = csync.NewMap[string, *McpTool]()
83 mcpClient2Tools = csync.NewMap[string, []*McpTool]()
84 mcpClients = csync.NewMap[string, *client.Client]()
85 mcpStates = csync.NewMap[string, MCPClientInfo]()
86 mcpBroker = pubsub.NewBroker[MCPEvent]()
87)
88
89type McpTool struct {
90 mcpName string
91 tool mcp.Tool
92 permissions permission.Service
93 workingDir string
94 providerOptions ai.ProviderOptions
95}
96
97func (m *McpTool) SetProviderOptions(opts ai.ProviderOptions) {
98 m.providerOptions = opts
99}
100
101func (m *McpTool) ProviderOptions() ai.ProviderOptions {
102 return m.providerOptions
103}
104
105func (m *McpTool) Name() string {
106 return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
107}
108
109func (m *McpTool) MCP() string {
110 return m.mcpName
111}
112
113func (m *McpTool) MCPToolName() string {
114 return m.tool.Name
115}
116
117func (m *McpTool) Info() ai.ToolInfo {
118 required := m.tool.InputSchema.Required
119 if required == nil {
120 required = make([]string, 0)
121 }
122 parameters := m.tool.InputSchema.Properties
123 if parameters == nil {
124 parameters = make(map[string]any)
125 }
126 return ai.ToolInfo{
127 Name: fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name),
128 Description: m.tool.Description,
129 Parameters: parameters,
130 Required: required,
131 }
132}
133
134func runTool(ctx context.Context, name, toolName string, input string) (ai.ToolResponse, error) {
135 var args map[string]any
136 if err := json.Unmarshal([]byte(input), &args); err != nil {
137 return ai.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
138 }
139
140 c, err := getOrRenewClient(ctx, name)
141 if err != nil {
142 return ai.NewTextErrorResponse(err.Error()), nil
143 }
144 result, err := c.CallTool(ctx, mcp.CallToolRequest{
145 Params: mcp.CallToolParams{
146 Name: toolName,
147 Arguments: args,
148 },
149 })
150 if err != nil {
151 return ai.NewTextErrorResponse(err.Error()), nil
152 }
153
154 output := make([]string, 0, len(result.Content))
155 for _, v := range result.Content {
156 if v, ok := v.(mcp.TextContent); ok {
157 output = append(output, v.Text)
158 } else {
159 output = append(output, fmt.Sprintf("%v", v))
160 }
161 }
162 return ai.NewTextResponse(strings.Join(output, "\n")), nil
163}
164
165func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
166 c, ok := mcpClients.Get(name)
167 if !ok {
168 return nil, fmt.Errorf("mcp '%s' not available", name)
169 }
170
171 cfg := config.Get()
172 m := cfg.MCP[name]
173 state, _ := mcpStates.Get(name)
174
175 timeout := mcpTimeout(m)
176 pingCtx, cancel := context.WithTimeout(ctx, timeout)
177 defer cancel()
178 err := c.Ping(pingCtx)
179 if err == nil {
180 return c, nil
181 }
182 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
183
184 c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
185 if err != nil {
186 return nil, err
187 }
188
189 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
190 mcpClients.Set(name, c)
191 return c, nil
192}
193
194func (m *McpTool) Run(ctx context.Context, params ai.ToolCall) (ai.ToolResponse, error) {
195 sessionID := GetSessionFromContext(ctx)
196 if sessionID == "" {
197 return ai.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
198 }
199 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
200 p := m.permissions.Request(
201 permission.CreatePermissionRequest{
202 SessionID: sessionID,
203 ToolCallID: params.ID,
204 Path: m.workingDir,
205 ToolName: m.Info().Name,
206 Action: "execute",
207 Description: permissionDescription,
208 Params: params.Input,
209 },
210 )
211 if !p {
212 return ai.ToolResponse{}, permission.ErrorPermissionDenied
213 }
214
215 return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
216}
217
218func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]*McpTool, error) {
219 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
220 if err != nil {
221 return nil, err
222 }
223 mcpTools := make([]*McpTool, 0, len(result.Tools))
224 for _, tool := range result.Tools {
225 mcpTools = append(mcpTools, &McpTool{
226 mcpName: name,
227 tool: tool,
228 permissions: permissions,
229 workingDir: workingDir,
230 })
231 }
232 return mcpTools, nil
233}
234
235// SubscribeMCPEvents returns a channel for MCP events
236func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
237 return mcpBroker.Subscribe(ctx)
238}
239
240// GetMCPStates returns the current state of all MCP clients
241func GetMCPStates() map[string]MCPClientInfo {
242 return maps.Collect(mcpStates.Seq2())
243}
244
245// GetMCPState returns the state of a specific MCP client
246func GetMCPState(name string) (MCPClientInfo, bool) {
247 return mcpStates.Get(name)
248}
249
250// updateMCPState updates the state of an MCP client and publishes an event
251func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
252 info := MCPClientInfo{
253 Name: name,
254 State: state,
255 Error: err,
256 Client: client,
257 ToolCount: toolCount,
258 }
259 switch state {
260 case MCPStateConnected:
261 info.ConnectedAt = time.Now()
262 case MCPStateError:
263 updateMcpTools(name, nil)
264 mcpClients.Del(name)
265 }
266 mcpStates.Set(name, info)
267
268 // Publish state change event
269 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
270 Type: MCPEventStateChanged,
271 Name: name,
272 State: state,
273 Error: err,
274 ToolCount: toolCount,
275 })
276}
277
278// publishMCPEventToolsListChanged publishes a tool list changed event
279func publishMCPEventToolsListChanged(name string) {
280 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
281 Type: MCPEventToolsListChanged,
282 Name: name,
283 })
284}
285
286// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
287func CloseMCPClients() error {
288 var errs []error
289 for name, c := range mcpClients.Seq2() {
290 if err := c.Close(); err != nil {
291 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
292 }
293 }
294 mcpBroker.Shutdown()
295 return errors.Join(errs...)
296}
297
298var mcpInitRequest = mcp.InitializeRequest{
299 Params: mcp.InitializeParams{
300 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
301 ClientInfo: mcp.Implementation{
302 Name: "Crush",
303 Version: version.Version,
304 },
305 },
306}
307
308func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []*McpTool {
309 mcpToolsOnce.Do(func() {
310 var wg sync.WaitGroup
311 // Initialize states for all configured MCPs
312 for name, m := range cfg.MCP {
313 if m.Disabled {
314 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
315 slog.Debug("skipping disabled mcp", "name", name)
316 continue
317 }
318
319 // Set initial starting state
320 updateMCPState(name, MCPStateStarting, nil, nil, 0)
321
322 wg.Add(1)
323 go func(name string, m config.MCPConfig) {
324 defer func() {
325 wg.Done()
326 if r := recover(); r != nil {
327 var err error
328 switch v := r.(type) {
329 case error:
330 err = v
331 case string:
332 err = fmt.Errorf("panic: %s", v)
333 default:
334 err = fmt.Errorf("panic: %v", v)
335 }
336 updateMCPState(name, MCPStateError, err, nil, 0)
337 slog.Error("panic in mcp client initialization", "error", err, "name", name)
338 }
339 }()
340
341 mcpCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
342 defer cancel()
343
344 c, err := createAndInitializeClient(mcpCtx, name, m, cfg.Resolver())
345 if err != nil {
346 return
347 }
348
349 mcpClients.Set(name, c)
350
351 tools, err := getTools(mcpCtx, name, permissions, c, cfg.WorkingDir())
352 if err != nil {
353 slog.Error("error listing tools", "error", err)
354 updateMCPState(name, MCPStateError, err, nil, 0)
355 c.Close()
356 return
357 }
358
359 updateMcpTools(name, tools)
360 mcpClients.Set(name, c)
361 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
362 }(name, m)
363 }
364 wg.Wait()
365 })
366 return slices.Collect(mcpTools.Seq())
367}
368
369// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
370func updateMcpTools(mcpName string, tools []*McpTool) {
371 if len(tools) == 0 {
372 mcpClient2Tools.Del(mcpName)
373 } else {
374 mcpClient2Tools.Set(mcpName, tools)
375 }
376 for _, tools := range mcpClient2Tools.Seq2() {
377 for _, t := range tools {
378 mcpTools.Set(t.Info().Name, t)
379 }
380 }
381}
382
383func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
384 c, err := createMcpClient(name, m, resolver)
385 if err != nil {
386 updateMCPState(name, MCPStateError, err, nil, 0)
387 slog.Error("error creating mcp client", "error", err, "name", name)
388 return nil, err
389 }
390
391 c.OnNotification(func(n mcp.JSONRPCNotification) {
392 slog.Debug("Received MCP notification", "name", name, "notification", n)
393 switch n.Method {
394 case "notifications/tools/list_changed":
395 publishMCPEventToolsListChanged(name)
396 default:
397 slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
398 }
399 })
400
401 // XXX: ideally we should be able to use context.WithTimeout here, but,
402 // the SSE MCP client will start failing once that context is canceled.
403 timeout := mcpTimeout(m)
404 mcpCtx, cancel := context.WithCancel(ctx)
405 cancelTimer := time.AfterFunc(timeout, cancel)
406
407 if err := c.Start(mcpCtx); err != nil {
408 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
409 slog.Error("error starting mcp client", "error", err, "name", name)
410 _ = c.Close()
411 cancel()
412 return nil, err
413 }
414
415 if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil {
416 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
417 slog.Error("error initializing mcp client", "error", err, "name", name)
418 _ = c.Close()
419 cancel()
420 return nil, err
421 }
422
423 cancelTimer.Stop()
424 slog.Info("Initialized mcp client", "name", name)
425 return c, nil
426}
427
428func maybeTimeoutErr(err error, timeout time.Duration) error {
429 if errors.Is(err, context.Canceled) {
430 return fmt.Errorf("timed out after %s", timeout)
431 }
432 return err
433}
434
435func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
436 switch m.Type {
437 case config.MCPStdio:
438 command, err := resolver.ResolveValue(m.Command)
439 if err != nil {
440 return nil, fmt.Errorf("invalid mcp command: %w", err)
441 }
442 if strings.TrimSpace(command) == "" {
443 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
444 }
445 return client.NewStdioMCPClientWithOptions(
446 home.Long(command),
447 m.ResolvedEnv(),
448 m.Args,
449 transport.WithCommandLogger(mcpLogger{name: name}),
450 )
451 case config.MCPHttp:
452 if strings.TrimSpace(m.URL) == "" {
453 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
454 }
455 return client.NewStreamableHttpClient(
456 m.URL,
457 transport.WithHTTPHeaders(m.ResolvedHeaders()),
458 transport.WithHTTPLogger(mcpLogger{name: name}),
459 )
460 case config.MCPSse:
461 if strings.TrimSpace(m.URL) == "" {
462 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
463 }
464 return client.NewSSEMCPClient(
465 m.URL,
466 client.WithHeaders(m.ResolvedHeaders()),
467 transport.WithSSELogger(mcpLogger{name: name}),
468 )
469 default:
470 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
471 }
472}
473
474// for MCP's clients.
475type mcpLogger struct{ name string }
476
477func (l mcpLogger) Errorf(format string, v ...any) {
478 slog.Error(fmt.Sprintf(format, v...), "name", l.name)
479}
480
481func (l mcpLogger) Infof(format string, v ...any) {
482 slog.Info(fmt.Sprintf(format, v...), "name", l.name)
483}
484
485func mcpTimeout(m config.MCPConfig) time.Duration {
486 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
487}