1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "strings"
12 "sync"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/home"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/permission"
20 "github.com/charmbracelet/crush/internal/pubsub"
21 "github.com/charmbracelet/crush/internal/version"
22 "github.com/mark3labs/mcp-go/client"
23 "github.com/mark3labs/mcp-go/client/transport"
24 "github.com/mark3labs/mcp-go/mcp"
25)
26
27// MCPState represents the current state of an MCP client
28type MCPState int
29
30const (
31 MCPStateDisabled MCPState = iota
32 MCPStateStarting
33 MCPStateConnected
34 MCPStateError
35)
36
37func (s MCPState) String() string {
38 switch s {
39 case MCPStateDisabled:
40 return "disabled"
41 case MCPStateStarting:
42 return "starting"
43 case MCPStateConnected:
44 return "connected"
45 case MCPStateError:
46 return "error"
47 default:
48 return "unknown"
49 }
50}
51
52// MCPEventType represents the type of MCP event
53type MCPEventType string
54
55const (
56 MCPEventStateChanged MCPEventType = "state_changed"
57 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
58)
59
60// MCPEvent represents an event in the MCP system
61type MCPEvent struct {
62 Type MCPEventType
63 Name string
64 State MCPState
65 Error error
66 ToolCount int
67}
68
69// MCPClientInfo holds information about an MCP client's state
70type MCPClientInfo struct {
71 Name string
72 State MCPState
73 Error error
74 Client *client.Client
75 ToolCount int
76 ConnectedAt time.Time
77}
78
79var (
80 mcpToolsOnce sync.Once
81 mcpTools = csync.NewMap[string, tools.BaseTool]()
82 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
83 mcpClients = csync.NewMap[string, *client.Client]()
84 mcpStates = csync.NewMap[string, MCPClientInfo]()
85 mcpBroker = pubsub.NewBroker[MCPEvent]()
86 toolsMaker func(string, []mcp.Tool) []tools.BaseTool = nil
87)
88
89type McpTool struct {
90 mcpName string
91 tool mcp.Tool
92 permissions permission.Service
93 workingDir string
94}
95
96func (b *McpTool) Name() string {
97 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
98}
99
100func (b *McpTool) Info() tools.ToolInfo {
101 required := b.tool.InputSchema.Required
102 if required == nil {
103 required = make([]string, 0)
104 }
105 parameters := b.tool.InputSchema.Properties
106 if parameters == nil {
107 parameters = make(map[string]any)
108 }
109 return tools.ToolInfo{
110 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
111 Description: b.tool.Description,
112 Parameters: parameters,
113 Required: required,
114 }
115}
116
117func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
118 var args map[string]any
119 if err := json.Unmarshal([]byte(input), &args); err != nil {
120 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
121 }
122
123 c, err := getOrRenewClient(ctx, name)
124 if err != nil {
125 return tools.NewTextErrorResponse(err.Error()), nil
126 }
127 result, err := c.CallTool(ctx, mcp.CallToolRequest{
128 Params: mcp.CallToolParams{
129 Name: toolName,
130 Arguments: args,
131 },
132 })
133 if err != nil {
134 return tools.NewTextErrorResponse(err.Error()), nil
135 }
136
137 output := make([]string, 0, len(result.Content))
138 for _, v := range result.Content {
139 if v, ok := v.(mcp.TextContent); ok {
140 output = append(output, v.Text)
141 } else {
142 output = append(output, fmt.Sprintf("%v", v))
143 }
144 }
145 return tools.NewTextResponse(strings.Join(output, "\n")), nil
146}
147
148func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
149 c, ok := mcpClients.Get(name)
150 if !ok {
151 return nil, fmt.Errorf("mcp '%s' not available", name)
152 }
153
154 cfg := config.Get()
155 m := cfg.MCP[name]
156 state, _ := mcpStates.Get(name)
157
158 timeout := mcpTimeout(m)
159 pingCtx, cancel := context.WithTimeout(ctx, timeout)
160 defer cancel()
161 err := c.Ping(pingCtx)
162 if err == nil {
163 return c, nil
164 }
165 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
166
167 c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
168 if err != nil {
169 return nil, err
170 }
171
172 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
173 mcpClients.Set(name, c)
174 return c, nil
175}
176
177func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
178 sessionID, messageID := tools.GetContextValues(ctx)
179 if sessionID == "" || messageID == "" {
180 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
181 }
182 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
183 p := b.permissions.Request(
184 permission.CreatePermissionRequest{
185 SessionID: sessionID,
186 ToolCallID: params.ID,
187 Path: b.workingDir,
188 ToolName: b.Info().Name,
189 Action: "execute",
190 Description: permissionDescription,
191 Params: params.Input,
192 },
193 )
194 if !p {
195 return tools.ToolResponse{}, permission.ErrorPermissionDenied
196 }
197
198 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
199}
200
201func createToolsMaker(permissions permission.Service, workingDir string) func(string, []mcp.Tool) []tools.BaseTool {
202 return func(name string, mcpToolsList []mcp.Tool) []tools.BaseTool {
203 mcpTools := make([]tools.BaseTool, 0, len(mcpToolsList))
204 for _, tool := range mcpToolsList {
205 mcpTools = append(mcpTools, &McpTool{
206 mcpName: name,
207 tool: tool,
208 permissions: permissions,
209 workingDir: workingDir,
210 })
211 }
212 return mcpTools
213 }
214}
215
216func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTool {
217 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
218 if err != nil {
219 slog.Error("error listing tools", "error", err)
220 updateMCPState(name, MCPStateError, err, nil, 0)
221 c.Close()
222 return nil
223 }
224 return toolsMaker(name, result.Tools)
225}
226
227// SubscribeMCPEvents returns a channel for MCP events
228func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
229 return mcpBroker.Subscribe(ctx)
230}
231
232// GetMCPStates returns the current state of all MCP clients
233func GetMCPStates() map[string]MCPClientInfo {
234 return maps.Collect(mcpStates.Seq2())
235}
236
237// GetMCPState returns the state of a specific MCP client
238func GetMCPState(name string) (MCPClientInfo, bool) {
239 return mcpStates.Get(name)
240}
241
242// updateMCPState updates the state of an MCP client and publishes an event
243func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
244 info := MCPClientInfo{
245 Name: name,
246 State: state,
247 Error: err,
248 Client: client,
249 ToolCount: toolCount,
250 }
251 switch state {
252 case MCPStateConnected:
253 info.ConnectedAt = time.Now()
254 case MCPStateError:
255 updateMcpTools(name, nil)
256 mcpClients.Del(name)
257 }
258 mcpStates.Set(name, info)
259
260 // Publish state change event
261 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
262 Type: MCPEventStateChanged,
263 Name: name,
264 State: state,
265 Error: err,
266 ToolCount: toolCount,
267 })
268}
269
270// publishMCPEventToolsListChanged publishes a tool list changed event
271func publishMCPEventToolsListChanged(name string) {
272 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
273 Type: MCPEventToolsListChanged,
274 Name: name,
275 })
276}
277
278// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
279func CloseMCPClients() error {
280 var errs []error
281 for name, c := range mcpClients.Seq2() {
282 if err := c.Close(); err != nil {
283 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
284 }
285 }
286 mcpBroker.Shutdown()
287 return errors.Join(errs...)
288}
289
290var mcpInitRequest = mcp.InitializeRequest{
291 Params: mcp.InitializeParams{
292 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
293 ClientInfo: mcp.Implementation{
294 Name: "Crush",
295 Version: version.Version,
296 },
297 },
298}
299
300func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
301 var wg sync.WaitGroup
302
303 toolsMaker = createToolsMaker(permissions, cfg.WorkingDir())
304
305 // Initialize states for all configured MCPs
306 for name, m := range cfg.MCP {
307 if m.Disabled {
308 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
309 slog.Debug("skipping disabled mcp", "name", name)
310 continue
311 }
312
313 // Set initial starting state
314 updateMCPState(name, MCPStateStarting, nil, nil, 0)
315
316 wg.Add(1)
317 go func(name string, m config.MCPConfig) {
318 defer func() {
319 wg.Done()
320 if r := recover(); r != nil {
321 var err error
322 switch v := r.(type) {
323 case error:
324 err = v
325 case string:
326 err = fmt.Errorf("panic: %s", v)
327 default:
328 err = fmt.Errorf("panic: %v", v)
329 }
330 updateMCPState(name, MCPStateError, err, nil, 0)
331 slog.Error("panic in mcp client initialization", "error", err, "name", name)
332 }
333 }()
334
335 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
336 defer cancel()
337 c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
338 if err != nil {
339 return
340 }
341
342 mcpClients.Set(name, c)
343
344 tools := getTools(ctx, name, c)
345 updateMcpTools(name, tools)
346 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
347 }(name, m)
348 }
349 wg.Wait()
350}
351
352// updateMcpTools updates the global mcpTools and mcpClientTools maps
353func updateMcpTools(mcpName string, tools []tools.BaseTool) {
354 if len(tools) == 0 {
355 mcpClient2Tools.Del(mcpName)
356 } else {
357 mcpClient2Tools.Set(mcpName, tools)
358 }
359 for _, tools := range mcpClient2Tools.Seq2() {
360 for _, t := range tools {
361 mcpTools.Set(t.Name(), t)
362 }
363 }
364}
365
366func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
367 c, err := createMcpClient(name, m, resolver)
368 if err != nil {
369 updateMCPState(name, MCPStateError, err, nil, 0)
370 slog.Error("error creating mcp client", "error", err, "name", name)
371 return nil, err
372 }
373
374 c.OnNotification(func(n mcp.JSONRPCNotification) {
375 slog.Debug("Received MCP notification", "name", name, "notification", n)
376 switch n.Method {
377 case "notifications/tools/list_changed":
378 publishMCPEventToolsListChanged(name)
379 default:
380 slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
381 }
382 })
383
384 timeout := mcpTimeout(m)
385 initCtx, cancel := context.WithTimeout(ctx, timeout)
386 defer cancel()
387
388 if err := c.Start(ctx); err != nil {
389 updateMCPState(name, MCPStateError, err, nil, 0)
390 slog.Error("error starting mcp client", "error", err, "name", name)
391 _ = c.Close()
392 return nil, err
393 }
394
395 if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil {
396 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
397 slog.Error("error initializing mcp client", "error", err, "name", name)
398 _ = c.Close()
399 return nil, err
400 }
401
402 slog.Info("Initialized mcp client", "name", name)
403 return c, nil
404}
405
406func maybeTimeoutErr(err error, timeout time.Duration) error {
407 if errors.Is(err, context.DeadlineExceeded) {
408 return fmt.Errorf("timed out after %s", timeout)
409 }
410 return err
411}
412
413func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
414 switch m.Type {
415 case config.MCPStdio:
416 command, err := resolver.ResolveValue(m.Command)
417 if err != nil {
418 return nil, fmt.Errorf("invalid mcp command: %w", err)
419 }
420 if strings.TrimSpace(command) == "" {
421 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
422 }
423 return client.NewStdioMCPClientWithOptions(
424 home.Long(command),
425 m.ResolvedEnv(),
426 m.Args,
427 transport.WithCommandLogger(mcpLogger{name: name}),
428 )
429 case config.MCPHttp:
430 if strings.TrimSpace(m.URL) == "" {
431 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
432 }
433 return client.NewStreamableHttpClient(
434 m.URL,
435 transport.WithHTTPHeaders(m.ResolvedHeaders()),
436 transport.WithHTTPLogger(mcpLogger{name: name}),
437 )
438 case config.MCPSse:
439 if strings.TrimSpace(m.URL) == "" {
440 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
441 }
442 return client.NewSSEMCPClient(
443 m.URL,
444 client.WithHeaders(m.ResolvedHeaders()),
445 transport.WithSSELogger(mcpLogger{name: name}),
446 )
447 default:
448 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
449 }
450}
451
452// for MCP's clients.
453type mcpLogger struct{ name string }
454
455func (l mcpLogger) Errorf(format string, v ...any) {
456 slog.Error(fmt.Sprintf(format, v...), "name", l.name)
457}
458
459func (l mcpLogger) Infof(format string, v ...any) {
460 slog.Info(fmt.Sprintf(format, v...), "name", l.name)
461}
462
463func mcpTimeout(m config.MCPConfig) time.Duration {
464 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
465}