1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "maps"
10 "slices"
11 "strings"
12 "sync"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/llm/tools"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/version"
21 "github.com/mark3labs/mcp-go/client"
22 "github.com/mark3labs/mcp-go/client/transport"
23 "github.com/mark3labs/mcp-go/mcp"
24)
25
26// MCPState represents the current state of an MCP client
27type MCPState int
28
29const (
30 MCPStateDisabled MCPState = iota
31 MCPStateStarting
32 MCPStateConnected
33 MCPStateError
34)
35
36func (s MCPState) String() string {
37 switch s {
38 case MCPStateDisabled:
39 return "disabled"
40 case MCPStateStarting:
41 return "starting"
42 case MCPStateConnected:
43 return "connected"
44 case MCPStateError:
45 return "error"
46 default:
47 return "unknown"
48 }
49}
50
51// MCPEventType represents the type of MCP event
52type MCPEventType string
53
54const (
55 MCPEventStateChanged MCPEventType = "state_changed"
56 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
57)
58
59// MCPEvent represents an event in the MCP system
60type MCPEvent struct {
61 Type MCPEventType
62 Name string
63 State MCPState
64 Error error
65 ToolCount int
66}
67
68// MCPClientInfo holds information about an MCP client's state
69type MCPClientInfo struct {
70 Name string
71 State MCPState
72 Error error
73 Client *client.Client
74 ToolCount int
75 ConnectedAt time.Time
76}
77
78var (
79 mcpToolsOnce sync.Once
80 mcpTools = csync.NewMap[string, tools.BaseTool]()
81 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
82 mcpClients = csync.NewMap[string, *client.Client]()
83 mcpStates = csync.NewMap[string, MCPClientInfo]()
84 mcpBroker = pubsub.NewBroker[MCPEvent]()
85 toolsMaker func(string, []mcp.Tool) []tools.BaseTool = nil
86)
87
88type McpTool struct {
89 mcpName string
90 tool mcp.Tool
91 permissions permission.Service
92 workingDir string
93}
94
95func (b *McpTool) Name() string {
96 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
97}
98
99func (b *McpTool) Info() tools.ToolInfo {
100 required := b.tool.InputSchema.Required
101 if required == nil {
102 required = make([]string, 0)
103 }
104 parameters := b.tool.InputSchema.Properties
105 if parameters == nil {
106 parameters = make(map[string]any)
107 }
108 return tools.ToolInfo{
109 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
110 Description: b.tool.Description,
111 Parameters: parameters,
112 Required: required,
113 }
114}
115
116func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
117 var args map[string]any
118 if err := json.Unmarshal([]byte(input), &args); err != nil {
119 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
120 }
121
122 c, err := getOrRenewClient(ctx, name)
123 if err != nil {
124 return tools.NewTextErrorResponse(err.Error()), nil
125 }
126 result, err := c.CallTool(ctx, mcp.CallToolRequest{
127 Params: mcp.CallToolParams{
128 Name: toolName,
129 Arguments: args,
130 },
131 })
132 if err != nil {
133 return tools.NewTextErrorResponse(err.Error()), nil
134 }
135
136 output := make([]string, 0, len(result.Content))
137 for _, v := range result.Content {
138 if v, ok := v.(mcp.TextContent); ok {
139 output = append(output, v.Text)
140 } else {
141 output = append(output, fmt.Sprintf("%v", v))
142 }
143 }
144 return tools.NewTextResponse(strings.Join(output, "\n")), nil
145}
146
147func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
148 c, ok := mcpClients.Get(name)
149 if !ok {
150 return nil, fmt.Errorf("mcp '%s' not available", name)
151 }
152
153 m := config.Get().MCP[name]
154 state, _ := mcpStates.Get(name)
155
156 pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
157 defer cancel()
158 err := c.Ping(pingCtx)
159 if err == nil {
160 return c, nil
161 }
162 updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
163
164 c, err = createAndInitializeClient(ctx, name, m)
165 if err != nil {
166 return nil, err
167 }
168
169 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
170 mcpClients.Set(name, c)
171 return c, nil
172}
173
174func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
175 sessionID, messageID := tools.GetContextValues(ctx)
176 if sessionID == "" || messageID == "" {
177 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
178 }
179 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
180 p := b.permissions.Request(
181 permission.CreatePermissionRequest{
182 SessionID: sessionID,
183 ToolCallID: params.ID,
184 Path: b.workingDir,
185 ToolName: b.Info().Name,
186 Action: "execute",
187 Description: permissionDescription,
188 Params: params.Input,
189 },
190 )
191 if !p {
192 return tools.ToolResponse{}, permission.ErrorPermissionDenied
193 }
194
195 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
196}
197
198func createToolsMaker(ctx context.Context, permissions permission.Service, workingDir string) func(string, []mcp.Tool) []tools.BaseTool {
199 return func(name string, mcpToolsList []mcp.Tool) []tools.BaseTool {
200 mcpTools := make([]tools.BaseTool, 0, len(mcpToolsList))
201 for _, tool := range mcpToolsList {
202 mcpTools = append(mcpTools, &McpTool{
203 mcpName: name,
204 tool: tool,
205 permissions: permissions,
206 workingDir: workingDir,
207 })
208 }
209 return mcpTools
210 }
211}
212
213func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTool {
214 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
215 if err != nil {
216 slog.Error("error listing tools", "error", err)
217 updateMCPState(name, MCPStateError, err, nil, 0)
218 c.Close()
219 return nil
220 }
221 return toolsMaker(name, result.Tools)
222}
223
224// SubscribeMCPEvents returns a channel for MCP events
225func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
226 return mcpBroker.Subscribe(ctx)
227}
228
229// GetMCPStates returns the current state of all MCP clients
230func GetMCPStates() map[string]MCPClientInfo {
231 return maps.Collect(mcpStates.Seq2())
232}
233
234// GetMCPState returns the state of a specific MCP client
235func GetMCPState(name string) (MCPClientInfo, bool) {
236 return mcpStates.Get(name)
237}
238
239// updateMCPState updates the state of an MCP client and publishes an event
240func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
241 info := MCPClientInfo{
242 Name: name,
243 State: state,
244 Error: err,
245 Client: client,
246 ToolCount: toolCount,
247 }
248 switch state {
249 case MCPStateConnected:
250 info.ConnectedAt = time.Now()
251 case MCPStateError:
252 updateMcpTools(name, nil)
253 mcpClients.Del(name)
254 }
255 mcpStates.Set(name, info)
256
257 // Publish state change event
258 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
259 Type: MCPEventStateChanged,
260 Name: name,
261 State: state,
262 Error: err,
263 ToolCount: toolCount,
264 })
265}
266
267// publishMCPEventToolsListChanged publishes a tool list changed event
268func publishMCPEventToolsListChanged(name string) {
269 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
270 Type: MCPEventToolsListChanged,
271 Name: name,
272 })
273}
274
275// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
276func CloseMCPClients() {
277 for c := range mcpClients.Seq() {
278 _ = c.Close()
279 }
280 mcpBroker.Shutdown()
281}
282
283var mcpInitRequest = mcp.InitializeRequest{
284 Params: mcp.InitializeParams{
285 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
286 ClientInfo: mcp.Implementation{
287 Name: "Crush",
288 Version: version.Version,
289 },
290 },
291}
292
293func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
294 var wg sync.WaitGroup
295 result := csync.NewSlice[tools.BaseTool]()
296
297 toolsMaker = createToolsMaker(ctx, permissions, cfg.WorkingDir())
298
299 // Initialize states for all configured MCPs
300 for name, m := range cfg.MCP {
301 if m.Disabled {
302 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
303 slog.Debug("skipping disabled mcp", "name", name)
304 continue
305 }
306
307 // Set initial starting state
308 updateMCPState(name, MCPStateStarting, nil, nil, 0)
309
310 wg.Add(1)
311 go func(name string, m config.MCPConfig) {
312 defer func() {
313 wg.Done()
314 if r := recover(); r != nil {
315 var err error
316 switch v := r.(type) {
317 case error:
318 err = v
319 case string:
320 err = fmt.Errorf("panic: %s", v)
321 default:
322 err = fmt.Errorf("panic: %v", v)
323 }
324 updateMCPState(name, MCPStateError, err, nil, 0)
325 slog.Error("panic in mcp client initialization", "error", err, "name", name)
326 }
327 }()
328
329 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
330 defer cancel()
331 c, err := createAndInitializeClient(ctx, name, m)
332 if err != nil {
333 return
334 }
335
336 mcpClients.Set(name, c)
337
338 tools := getTools(ctx, name, c)
339 result.Append(tools...)
340 updateMcpTools(name, tools)
341 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
342 }(name, m)
343 }
344 wg.Wait()
345
346 return slices.Collect(result.Seq())
347}
348
349// updateMcpTools updates the global mcpTools and mcpClientTools maps
350func updateMcpTools(mcpName string, tools []tools.BaseTool) {
351 if len(tools) == 0 {
352 mcpClient2Tools.Del(mcpName)
353 } else {
354 mcpClient2Tools.Set(mcpName, tools)
355 }
356 for _, tools := range mcpClient2Tools.Seq2() {
357 for _, t := range tools {
358 mcpTools.Set(t.Name(), t)
359 }
360 }
361}
362
363func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
364 c, err := createMcpClient(m)
365 if err != nil {
366 updateMCPState(name, MCPStateError, err, nil, 0)
367 slog.Error("error creating mcp client", "error", err, "name", name)
368 return nil, err
369 }
370
371 c.OnNotification(func(n mcp.JSONRPCNotification) {
372 slog.Debug("Received MCP notification", "name", name, "notification", n)
373 switch n.Method {
374 case "notifications/tools/list_changed":
375 publishMCPEventToolsListChanged(name)
376 default:
377 slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
378 }
379 })
380
381 if err := c.Start(ctx); err != nil {
382 updateMCPState(name, MCPStateError, err, nil, 0)
383 slog.Error("error starting mcp client", "error", err, "name", name)
384 _ = c.Close()
385 return nil, err
386 }
387
388 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
389 updateMCPState(name, MCPStateError, err, nil, 0)
390 slog.Error("error initializing mcp client", "error", err, "name", name)
391 _ = c.Close()
392 return nil, err
393 }
394
395 slog.Info("Initialized mcp client", "name", name)
396 return c, nil
397}
398
399func createMcpClient(m config.MCPConfig) (*client.Client, error) {
400 switch m.Type {
401 case config.MCPStdio:
402 if strings.TrimSpace(m.Command) == "" {
403 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
404 }
405 return client.NewStdioMCPClientWithOptions(
406 m.Command,
407 m.ResolvedEnv(),
408 m.Args,
409 transport.WithCommandLogger(mcpLogger{}),
410 )
411 case config.MCPHttp:
412 if strings.TrimSpace(m.URL) == "" {
413 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
414 }
415 return client.NewStreamableHttpClient(
416 m.URL,
417 transport.WithHTTPHeaders(m.ResolvedHeaders()),
418 transport.WithHTTPLogger(mcpLogger{}),
419 )
420 case config.MCPSse:
421 if strings.TrimSpace(m.URL) == "" {
422 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
423 }
424 return client.NewSSEMCPClient(
425 m.URL,
426 client.WithHeaders(m.ResolvedHeaders()),
427 transport.WithSSELogger(mcpLogger{}),
428 )
429 default:
430 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
431 }
432}
433
434// for MCP's clients.
435type mcpLogger struct{}
436
437func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
438func (l mcpLogger) Infof(format string, v ...any) { slog.Info(fmt.Sprintf(format, v...)) }
439
440func mcpTimeout(m config.MCPConfig) time.Duration {
441 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
442}