1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "strings"
12 "sync"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/home"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/permission"
20 "github.com/charmbracelet/crush/internal/pubsub"
21 "github.com/charmbracelet/crush/internal/version"
22 "github.com/mark3labs/mcp-go/client"
23 "github.com/mark3labs/mcp-go/client/transport"
24 "github.com/mark3labs/mcp-go/mcp"
25)
26
27// MCPState represents the current state of an MCP client
28type MCPState int
29
30const (
31 MCPStateDisabled MCPState = iota
32 MCPStateStarting
33 MCPStateConnected
34 MCPStateError
35)
36
37func (s MCPState) String() string {
38 switch s {
39 case MCPStateDisabled:
40 return "disabled"
41 case MCPStateStarting:
42 return "starting"
43 case MCPStateConnected:
44 return "connected"
45 case MCPStateError:
46 return "error"
47 default:
48 return "unknown"
49 }
50}
51
52// MCPEventType represents the type of MCP event
53type MCPEventType string
54
55const (
56 MCPEventStateChanged MCPEventType = "state_changed"
57 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
58)
59
60// MCPEvent represents an event in the MCP system
61type MCPEvent struct {
62 Type MCPEventType
63 Name string
64 State MCPState
65 Error error
66 ToolCount int
67}
68
69// MCPClientInfo holds information about an MCP client's state
70type MCPClientInfo struct {
71 Name string
72 State MCPState
73 Error error
74 Client *client.Client
75 ToolCount int
76 ConnectedAt time.Time
77}
78
79var (
80 mcpToolsOnce sync.Once
81 mcpTools = csync.NewMap[string, tools.BaseTool]()
82 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
83 mcpClients = csync.NewMap[string, *client.Client]()
84 mcpStates = csync.NewMap[string, MCPClientInfo]()
85 mcpBroker = pubsub.NewBroker[MCPEvent]()
86)
87
88type McpTool struct {
89 mcpName string
90 tool mcp.Tool
91 permissions permission.Service
92 workingDir string
93}
94
95func (b *McpTool) Name() string {
96 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
97}
98
99func (b *McpTool) Info() tools.ToolInfo {
100 required := b.tool.InputSchema.Required
101 if required == nil {
102 required = make([]string, 0)
103 }
104 parameters := b.tool.InputSchema.Properties
105 if parameters == nil {
106 parameters = make(map[string]any)
107 }
108 return tools.ToolInfo{
109 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
110 Description: b.tool.Description,
111 Parameters: parameters,
112 Required: required,
113 }
114}
115
116func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
117 var args map[string]any
118 if err := json.Unmarshal([]byte(input), &args); err != nil {
119 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
120 }
121
122 c, err := getOrRenewClient(ctx, name)
123 if err != nil {
124 return tools.NewTextErrorResponse(err.Error()), nil
125 }
126 result, err := c.CallTool(ctx, mcp.CallToolRequest{
127 Params: mcp.CallToolParams{
128 Name: toolName,
129 Arguments: args,
130 },
131 })
132 if err != nil {
133 return tools.NewTextErrorResponse(err.Error()), nil
134 }
135
136 output := make([]string, 0, len(result.Content))
137 for _, v := range result.Content {
138 if v, ok := v.(mcp.TextContent); ok {
139 output = append(output, v.Text)
140 } else {
141 output = append(output, fmt.Sprintf("%v", v))
142 }
143 }
144 return tools.NewTextResponse(strings.Join(output, "\n")), nil
145}
146
147func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
148 c, ok := mcpClients.Get(name)
149 if !ok {
150 return nil, fmt.Errorf("mcp '%s' not available", name)
151 }
152
153 cfg := config.Get()
154 m := cfg.MCP[name]
155 state, _ := mcpStates.Get(name)
156
157 timeout := mcpTimeout(m)
158 pingCtx, cancel := context.WithTimeout(ctx, timeout)
159 defer cancel()
160 err := c.Ping(pingCtx)
161 if err == nil {
162 return c, nil
163 }
164 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
165
166 c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
167 if err != nil {
168 return nil, err
169 }
170
171 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
172 mcpClients.Set(name, c)
173 return c, nil
174}
175
176func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
177 sessionID, messageID := tools.GetContextValues(ctx)
178 if sessionID == "" || messageID == "" {
179 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
180 }
181 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
182 p := b.permissions.Request(
183 permission.CreatePermissionRequest{
184 SessionID: sessionID,
185 ToolCallID: params.ID,
186 Path: b.workingDir,
187 ToolName: b.Info().Name,
188 Action: "execute",
189 Description: permissionDescription,
190 Params: params.Input,
191 },
192 )
193 if !p {
194 return tools.ToolResponse{}, permission.ErrorPermissionDenied
195 }
196
197 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
198}
199
200func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]tools.BaseTool, error) {
201 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
202 if err != nil {
203 return nil, err
204 }
205 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
206 for _, tool := range result.Tools {
207 mcpTools = append(mcpTools, &McpTool{
208 mcpName: name,
209 tool: tool,
210 permissions: permissions,
211 workingDir: workingDir,
212 })
213 }
214 return mcpTools, nil
215}
216
217// SubscribeMCPEvents returns a channel for MCP events
218func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
219 return mcpBroker.Subscribe(ctx)
220}
221
222// GetMCPStates returns the current state of all MCP clients
223func GetMCPStates() map[string]MCPClientInfo {
224 return maps.Collect(mcpStates.Seq2())
225}
226
227// GetMCPState returns the state of a specific MCP client
228func GetMCPState(name string) (MCPClientInfo, bool) {
229 return mcpStates.Get(name)
230}
231
232// updateMCPState updates the state of an MCP client and publishes an event
233func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
234 info := MCPClientInfo{
235 Name: name,
236 State: state,
237 Error: err,
238 Client: client,
239 ToolCount: toolCount,
240 }
241 switch state {
242 case MCPStateConnected:
243 info.ConnectedAt = time.Now()
244 case MCPStateError:
245 updateMcpTools(name, nil)
246 mcpClients.Del(name)
247 }
248 mcpStates.Set(name, info)
249
250 // Publish state change event
251 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
252 Type: MCPEventStateChanged,
253 Name: name,
254 State: state,
255 Error: err,
256 ToolCount: toolCount,
257 })
258}
259
260// publishMCPEventToolsListChanged publishes a tool list changed event
261func publishMCPEventToolsListChanged(name string) {
262 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
263 Type: MCPEventToolsListChanged,
264 Name: name,
265 })
266}
267
268// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
269func CloseMCPClients() error {
270 var errs []error
271 for name, c := range mcpClients.Seq2() {
272 if err := c.Close(); err != nil {
273 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
274 }
275 }
276 mcpBroker.Shutdown()
277 return errors.Join(errs...)
278}
279
280var mcpInitRequest = mcp.InitializeRequest{
281 Params: mcp.InitializeParams{
282 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
283 ClientInfo: mcp.Implementation{
284 Name: "Crush",
285 Version: version.Version,
286 },
287 },
288}
289
290func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
291 var wg sync.WaitGroup
292 // Initialize states for all configured MCPs
293 for name, m := range cfg.MCP {
294 if m.Disabled {
295 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
296 slog.Debug("skipping disabled mcp", "name", name)
297 continue
298 }
299
300 // Set initial starting state
301 updateMCPState(name, MCPStateStarting, nil, nil, 0)
302
303 wg.Add(1)
304 go func(name string, m config.MCPConfig) {
305 defer func() {
306 wg.Done()
307 if r := recover(); r != nil {
308 var err error
309 switch v := r.(type) {
310 case error:
311 err = v
312 case string:
313 err = fmt.Errorf("panic: %s", v)
314 default:
315 err = fmt.Errorf("panic: %v", v)
316 }
317 updateMCPState(name, MCPStateError, err, nil, 0)
318 slog.Error("panic in mcp client initialization", "error", err, "name", name)
319 }
320 }()
321
322 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
323 defer cancel()
324
325 c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
326 if err != nil {
327 return
328 }
329
330 mcpClients.Set(name, c)
331
332 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
333 if err != nil {
334 slog.Error("error listing tools", "error", err)
335 updateMCPState(name, MCPStateError, err, nil, 0)
336 c.Close()
337 return
338 }
339
340 updateMcpTools(name, tools)
341 mcpClients.Set(name, c)
342 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
343 }(name, m)
344 }
345 wg.Wait()
346}
347
348// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
349func updateMcpTools(mcpName string, tools []tools.BaseTool) {
350 if len(tools) == 0 {
351 mcpClient2Tools.Del(mcpName)
352 } else {
353 mcpClient2Tools.Set(mcpName, tools)
354 }
355 for _, tools := range mcpClient2Tools.Seq2() {
356 for _, t := range tools {
357 mcpTools.Set(t.Name(), t)
358 }
359 }
360}
361
362func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
363 c, err := createMcpClient(name, m, resolver)
364 if err != nil {
365 updateMCPState(name, MCPStateError, err, nil, 0)
366 slog.Error("error creating mcp client", "error", err, "name", name)
367 return nil, err
368 }
369
370 c.OnNotification(func(n mcp.JSONRPCNotification) {
371 slog.Debug("Received MCP notification", "name", name, "notification", n)
372 switch n.Method {
373 case "notifications/tools/list_changed":
374 publishMCPEventToolsListChanged(name)
375 default:
376 slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
377 }
378 })
379
380 // XXX: ideally we should be able to use context.WithTimeout here, but,
381 // the SSE MCP client will start failing once that context is canceled.
382 timeout := mcpTimeout(m)
383 mcpCtx, cancel := context.WithCancel(ctx)
384 cancelTimer := time.AfterFunc(timeout, cancel)
385
386 if err := c.Start(mcpCtx); err != nil {
387 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
388 slog.Error("error starting mcp client", "error", err, "name", name)
389 _ = c.Close()
390 cancel()
391 return nil, err
392 }
393
394 if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil {
395 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
396 slog.Error("error initializing mcp client", "error", err, "name", name)
397 _ = c.Close()
398 cancel()
399 return nil, err
400 }
401
402 cancelTimer.Stop()
403 slog.Info("Initialized mcp client", "name", name)
404 return c, nil
405}
406
407func maybeTimeoutErr(err error, timeout time.Duration) error {
408 if errors.Is(err, context.Canceled) {
409 return fmt.Errorf("timed out after %s", timeout)
410 }
411 return err
412}
413
414func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
415 switch m.Type {
416 case config.MCPStdio:
417 command, err := resolver.ResolveValue(m.Command)
418 if err != nil {
419 return nil, fmt.Errorf("invalid mcp command: %w", err)
420 }
421 if strings.TrimSpace(command) == "" {
422 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
423 }
424 return client.NewStdioMCPClientWithOptions(
425 home.Long(command),
426 m.ResolvedEnv(),
427 m.Args,
428 transport.WithCommandLogger(mcpLogger{name: name}),
429 )
430 case config.MCPHttp:
431 if strings.TrimSpace(m.URL) == "" {
432 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
433 }
434 return client.NewStreamableHttpClient(
435 m.URL,
436 transport.WithHTTPHeaders(m.ResolvedHeaders()),
437 transport.WithHTTPLogger(mcpLogger{name: name}),
438 )
439 case config.MCPSse:
440 if strings.TrimSpace(m.URL) == "" {
441 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
442 }
443 return client.NewSSEMCPClient(
444 m.URL,
445 client.WithHeaders(m.ResolvedHeaders()),
446 transport.WithSSELogger(mcpLogger{name: name}),
447 )
448 default:
449 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
450 }
451}
452
453// for MCP's clients.
454type mcpLogger struct{ name string }
455
456func (l mcpLogger) Errorf(format string, v ...any) {
457 slog.Error(fmt.Sprintf(format, v...), "name", l.name)
458}
459
460func (l mcpLogger) Infof(format string, v ...any) {
461 slog.Info(fmt.Sprintf(format, v...), "name", l.name)
462}
463
464func mcpTimeout(m config.MCPConfig) time.Duration {
465 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
466}