1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "maps"
10 "slices"
11 "strings"
12 "sync"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/llm/tools"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/version"
21 "github.com/mark3labs/mcp-go/client"
22 "github.com/mark3labs/mcp-go/client/transport"
23 "github.com/mark3labs/mcp-go/mcp"
24)
25
26// MCPState represents the current state of an MCP client
27type MCPState int
28
29const (
30 MCPStateDisabled MCPState = iota
31 MCPStateStarting
32 MCPStateConnected
33 MCPStateError
34)
35
36func (s MCPState) String() string {
37 switch s {
38 case MCPStateDisabled:
39 return "disabled"
40 case MCPStateStarting:
41 return "starting"
42 case MCPStateConnected:
43 return "connected"
44 case MCPStateError:
45 return "error"
46 default:
47 return "unknown"
48 }
49}
50
51// MCPEventType represents the type of MCP event
52type MCPEventType string
53
54const (
55 MCPEventStateChanged MCPEventType = "state_changed"
56 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
57)
58
59// MCPEvent represents an event in the MCP system
60type MCPEvent struct {
61 Type MCPEventType
62 Name string
63 State MCPState
64 Error error
65 ToolCount int
66}
67
68// MCPClientInfo holds information about an MCP client's state
69type MCPClientInfo struct {
70 Name string
71 State MCPState
72 Error error
73 Client *client.Client
74 ToolCount int
75 ConnectedAt time.Time
76}
77
78var (
79 mcpToolsOnce sync.Once
80 mcpTools = csync.NewMap[string, tools.BaseTool]()
81 // mcpClientTools maps MCP name to tool names
82 mcpClientTools = csync.NewMap[string, []string]()
83 mcpClients = csync.NewMap[string, *client.Client]()
84 mcpStates = csync.NewMap[string, MCPClientInfo]()
85 mcpBroker = pubsub.NewBroker[MCPEvent]()
86 toolsMaker func(string, []mcp.Tool) []tools.BaseTool = nil
87)
88
89type McpTool struct {
90 mcpName string
91 tool mcp.Tool
92 permissions permission.Service
93 workingDir string
94}
95
96func (b *McpTool) Name() string {
97 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
98}
99
100func (b *McpTool) Info() tools.ToolInfo {
101 required := b.tool.InputSchema.Required
102 if required == nil {
103 required = make([]string, 0)
104 }
105 parameters := b.tool.InputSchema.Properties
106 if parameters == nil {
107 parameters = make(map[string]any)
108 }
109 return tools.ToolInfo{
110 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
111 Description: b.tool.Description,
112 Parameters: parameters,
113 Required: required,
114 }
115}
116
117func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
118 var args map[string]any
119 if err := json.Unmarshal([]byte(input), &args); err != nil {
120 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
121 }
122
123 c, err := getOrRenewClient(ctx, name)
124 if err != nil {
125 return tools.NewTextErrorResponse(err.Error()), nil
126 }
127 result, err := c.CallTool(ctx, mcp.CallToolRequest{
128 Params: mcp.CallToolParams{
129 Name: toolName,
130 Arguments: args,
131 },
132 })
133 if err != nil {
134 return tools.NewTextErrorResponse(err.Error()), nil
135 }
136
137 output := make([]string, 0, len(result.Content))
138 for _, v := range result.Content {
139 if v, ok := v.(mcp.TextContent); ok {
140 output = append(output, v.Text)
141 } else {
142 output = append(output, fmt.Sprintf("%v", v))
143 }
144 }
145 return tools.NewTextResponse(strings.Join(output, "\n")), nil
146}
147
148func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
149 c, ok := mcpClients.Get(name)
150 if !ok {
151 return nil, fmt.Errorf("mcp '%s' not available", name)
152 }
153
154 m := config.Get().MCP[name]
155 state, _ := mcpStates.Get(name)
156
157 pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
158 defer cancel()
159 err := c.Ping(pingCtx)
160 if err == nil {
161 return c, nil
162 }
163 updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
164
165 c, err = createAndInitializeClient(ctx, name, m)
166 if err != nil {
167 return nil, err
168 }
169
170 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
171 mcpClients.Set(name, c)
172 return c, nil
173}
174
175func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
176 sessionID, messageID := tools.GetContextValues(ctx)
177 if sessionID == "" || messageID == "" {
178 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
179 }
180 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
181 p := b.permissions.Request(
182 permission.CreatePermissionRequest{
183 SessionID: sessionID,
184 ToolCallID: params.ID,
185 Path: b.workingDir,
186 ToolName: b.Info().Name,
187 Action: "execute",
188 Description: permissionDescription,
189 Params: params.Input,
190 },
191 )
192 if !p {
193 return tools.ToolResponse{}, permission.ErrorPermissionDenied
194 }
195
196 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
197}
198
199func createToolsMaker(ctx context.Context, permissions permission.Service, workingDir string) func(string, []mcp.Tool) []tools.BaseTool {
200 return func(name string, mcpToolsList []mcp.Tool) []tools.BaseTool {
201 mcpTools := make([]tools.BaseTool, 0, len(mcpToolsList))
202 for _, tool := range mcpToolsList {
203 mcpTools = append(mcpTools, &McpTool{
204 mcpName: name,
205 tool: tool,
206 permissions: permissions,
207 workingDir: workingDir,
208 })
209 }
210 return mcpTools
211 }
212}
213
214func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTool {
215 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
216 if err != nil {
217 slog.Error("error listing tools", "error", err)
218 updateMCPState(name, MCPStateError, err, nil, 0)
219 c.Close()
220 mcpClients.Del(name)
221 return nil
222 }
223 return toolsMaker(name, result.Tools)
224}
225
226// SubscribeMCPEvents returns a channel for MCP events
227func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
228 return mcpBroker.Subscribe(ctx)
229}
230
231// GetMCPStates returns the current state of all MCP clients
232func GetMCPStates() map[string]MCPClientInfo {
233 return maps.Collect(mcpStates.Seq2())
234}
235
236// GetMCPState returns the state of a specific MCP client
237func GetMCPState(name string) (MCPClientInfo, bool) {
238 return mcpStates.Get(name)
239}
240
241// updateMCPState updates the state of an MCP client and publishes an event
242func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
243 info := MCPClientInfo{
244 Name: name,
245 State: state,
246 Error: err,
247 Client: client,
248 ToolCount: toolCount,
249 }
250 if state == MCPStateConnected {
251 info.ConnectedAt = time.Now()
252 }
253 mcpStates.Set(name, info)
254
255 // Publish state change event
256 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
257 Type: MCPEventStateChanged,
258 Name: name,
259 State: state,
260 Error: err,
261 ToolCount: toolCount,
262 })
263}
264
265// publishMCPEventToolsListChanged publishes a tool list changed event
266func publishMCPEventToolsListChanged(name string) {
267 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
268 Type: MCPEventToolsListChanged,
269 Name: name,
270 })
271}
272
273// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
274func CloseMCPClients() {
275 for c := range mcpClients.Seq() {
276 _ = c.Close()
277 }
278 mcpBroker.Shutdown()
279}
280
281var mcpInitRequest = mcp.InitializeRequest{
282 Params: mcp.InitializeParams{
283 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
284 ClientInfo: mcp.Implementation{
285 Name: "Crush",
286 Version: version.Version,
287 },
288 },
289}
290
291func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
292 var wg sync.WaitGroup
293 result := csync.NewSlice[tools.BaseTool]()
294
295 toolsMaker = createToolsMaker(ctx, permissions, cfg.WorkingDir())
296
297 // Initialize states for all configured MCPs
298 for name, m := range cfg.MCP {
299 if m.Disabled {
300 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
301 slog.Debug("skipping disabled mcp", "name", name)
302 continue
303 }
304
305 // Set initial starting state
306 updateMCPState(name, MCPStateStarting, nil, nil, 0)
307
308 wg.Add(1)
309 go func(name string, m config.MCPConfig) {
310 defer func() {
311 wg.Done()
312 if r := recover(); r != nil {
313 var err error
314 switch v := r.(type) {
315 case error:
316 err = v
317 case string:
318 err = fmt.Errorf("panic: %s", v)
319 default:
320 err = fmt.Errorf("panic: %v", v)
321 }
322 updateMCPState(name, MCPStateError, err, nil, 0)
323 slog.Error("panic in mcp client initialization", "error", err, "name", name)
324 }
325 }()
326
327 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
328 defer cancel()
329 c, err := createAndInitializeClient(ctx, name, m)
330 if err != nil {
331 return
332 }
333
334 mcpClients.Set(name, c)
335
336 tools := getTools(ctx, name, c)
337 result.Append(tools...)
338 updateMcpTools(name, tools)
339 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
340 }(name, m)
341 }
342 wg.Wait()
343
344 return slices.Collect(result.Seq())
345}
346
347// updateMcpTools updates the global mcpTools and mcpClientTools maps
348func updateMcpTools(mcpName string, tools []tools.BaseTool) {
349 toolNames := make([]string, 0, len(tools))
350 for _, tool := range tools {
351 name := tool.Name()
352 if _, ok := mcpTools.Get(name); !ok {
353 slog.Info("Added MCP tool", "name", name, "mcp", mcpName)
354 }
355 mcpTools.Set(name, tool)
356 toolNames = append(toolNames, name)
357 }
358
359 // remove the tools that are no longer available
360 old, ok := mcpClientTools.Get(mcpName)
361 if ok {
362 slices.Sort(toolNames)
363 for _, name := range old {
364 if _, ok := slices.BinarySearch(toolNames, name); !ok {
365 mcpTools.Del(name)
366 slog.Info("Removed MCP tool", "name", name, "mcp", mcpName)
367 }
368 }
369 }
370 mcpClientTools.Set(mcpName, toolNames)
371}
372
373func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
374 c, err := createMcpClient(m)
375 if err != nil {
376 updateMCPState(name, MCPStateError, err, nil, 0)
377 slog.Error("error creating mcp client", "error", err, "name", name)
378 return nil, err
379 }
380
381 c.OnNotification(func(n mcp.JSONRPCNotification) {
382 slog.Debug("Received MCP notification", "name", name, "notification", n)
383 switch n.Method {
384 case "notifications/tools/list_changed":
385 publishMCPEventToolsListChanged(name)
386 default:
387 slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
388 }
389 })
390
391 if err := c.Start(ctx); err != nil {
392 updateMCPState(name, MCPStateError, err, nil, 0)
393 slog.Error("error starting mcp client", "error", err, "name", name)
394 _ = c.Close()
395 return nil, err
396 }
397
398 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
399 updateMCPState(name, MCPStateError, err, nil, 0)
400 slog.Error("error initializing mcp client", "error", err, "name", name)
401 _ = c.Close()
402 return nil, err
403 }
404
405 slog.Info("Initialized mcp client", "name", name)
406 return c, nil
407}
408
409func createMcpClient(m config.MCPConfig) (*client.Client, error) {
410 switch m.Type {
411 case config.MCPStdio:
412 if strings.TrimSpace(m.Command) == "" {
413 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
414 }
415 return client.NewStdioMCPClientWithOptions(
416 m.Command,
417 m.ResolvedEnv(),
418 m.Args,
419 transport.WithCommandLogger(mcpLogger{}),
420 )
421 case config.MCPHttp:
422 if strings.TrimSpace(m.URL) == "" {
423 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
424 }
425 return client.NewStreamableHttpClient(
426 m.URL,
427 transport.WithHTTPHeaders(m.ResolvedHeaders()),
428 transport.WithHTTPLogger(mcpLogger{}),
429 )
430 case config.MCPSse:
431 if strings.TrimSpace(m.URL) == "" {
432 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
433 }
434 return client.NewSSEMCPClient(
435 m.URL,
436 client.WithHeaders(m.ResolvedHeaders()),
437 transport.WithSSELogger(mcpLogger{}),
438 )
439 default:
440 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
441 }
442}
443
444// for MCP's clients.
445type mcpLogger struct{}
446
447func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
448func (l mcpLogger) Infof(format string, v ...any) { slog.Info(fmt.Sprintf(format, v...)) }
449
450func mcpTimeout(m config.MCPConfig) time.Duration {
451 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
452}