1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "net/http"
13 "os/exec"
14 "strings"
15 "sync"
16 "time"
17
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/home"
21 "github.com/charmbracelet/crush/internal/llm/tools"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/modelcontextprotocol/go-sdk/mcp"
26)
27
28// MCPState represents the current state of an MCP client
29type MCPState int
30
31const (
32 MCPStateDisabled MCPState = iota
33 MCPStateStarting
34 MCPStateConnected
35 MCPStateError
36)
37
38func (s MCPState) String() string {
39 switch s {
40 case MCPStateDisabled:
41 return "disabled"
42 case MCPStateStarting:
43 return "starting"
44 case MCPStateConnected:
45 return "connected"
46 case MCPStateError:
47 return "error"
48 default:
49 return "unknown"
50 }
51}
52
53// MCPEventType represents the type of MCP event
54type MCPEventType string
55
56const (
57 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
58 MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
59)
60
61// MCPEvent represents an event in the MCP system
62type MCPEvent struct {
63 Type MCPEventType
64 Name string
65 State MCPState
66 Error error
67 Counts MCPCounts
68}
69
70// MCPCounts number of available tools, prompts, etc.
71type MCPCounts struct {
72 Tools int
73 Prompts int
74}
75
76// MCPClientInfo holds information about an MCP client's state
77type MCPClientInfo struct {
78 Name string
79 State MCPState
80 Error error
81 Client *mcp.ClientSession
82 Counts MCPCounts
83 ConnectedAt time.Time
84}
85
86var (
87 mcpToolsOnce sync.Once
88 mcpTools = csync.NewMap[string, tools.BaseTool]()
89 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
90 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
91 mcpStates = csync.NewMap[string, MCPClientInfo]()
92 mcpBroker = pubsub.NewBroker[MCPEvent]()
93 mcpPrompts = csync.NewMap[string, *mcp.Prompt]()
94 mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
95)
96
97type McpTool struct {
98 mcpName string
99 tool *mcp.Tool
100 permissions permission.Service
101 workingDir string
102}
103
104func (b *McpTool) Name() string {
105 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
106}
107
108func (b *McpTool) Info() tools.ToolInfo {
109 var parameters map[string]any
110 var required []string
111
112 if input, ok := b.tool.InputSchema.(map[string]any); ok {
113 if props, ok := input["properties"].(map[string]any); ok {
114 parameters = props
115 }
116 if req, ok := input["required"].([]any); ok {
117 // Convert []any -> []string when elements are strings
118 for _, v := range req {
119 if s, ok := v.(string); ok {
120 required = append(required, s)
121 }
122 }
123 } else if reqStr, ok := input["required"].([]string); ok {
124 // Handle case where it's already []string
125 required = reqStr
126 }
127 }
128
129 return tools.ToolInfo{
130 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
131 Description: b.tool.Description,
132 Parameters: parameters,
133 Required: required,
134 }
135}
136
137func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
138 var args map[string]any
139 if err := json.Unmarshal([]byte(input), &args); err != nil {
140 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
141 }
142
143 c, err := getOrRenewClient(ctx, name)
144 if err != nil {
145 return tools.NewTextErrorResponse(err.Error()), nil
146 }
147 result, err := c.CallTool(ctx, &mcp.CallToolParams{
148 Name: toolName,
149 Arguments: args,
150 })
151 if err != nil {
152 return tools.NewTextErrorResponse(err.Error()), nil
153 }
154
155 output := make([]string, 0, len(result.Content))
156 for _, v := range result.Content {
157 if vv, ok := v.(*mcp.TextContent); ok {
158 output = append(output, vv.Text)
159 } else {
160 output = append(output, fmt.Sprintf("%v", v))
161 }
162 }
163 return tools.NewTextResponse(strings.Join(output, "\n")), nil
164}
165
166func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
167 sess, ok := mcpClients.Get(name)
168 if !ok {
169 return nil, fmt.Errorf("mcp '%s' not available", name)
170 }
171
172 cfg := config.Get()
173 m := cfg.MCP[name]
174 state, _ := mcpStates.Get(name)
175
176 timeout := mcpTimeout(m)
177 pingCtx, cancel := context.WithTimeout(ctx, timeout)
178 defer cancel()
179 err := sess.Ping(pingCtx, nil)
180 if err == nil {
181 return sess, nil
182 }
183 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
184
185 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
186 if err != nil {
187 return nil, err
188 }
189
190 updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
191 mcpClients.Set(name, sess)
192 return sess, nil
193}
194
195func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
196 sessionID, messageID := tools.GetContextValues(ctx)
197 if sessionID == "" || messageID == "" {
198 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
199 }
200 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
201 p := b.permissions.Request(
202 permission.CreatePermissionRequest{
203 SessionID: sessionID,
204 ToolCallID: params.ID,
205 Path: b.workingDir,
206 ToolName: b.Info().Name,
207 Action: "execute",
208 Description: permissionDescription,
209 Params: params.Input,
210 },
211 )
212 if !p {
213 return tools.ToolResponse{}, permission.ErrorPermissionDenied
214 }
215
216 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
217}
218
219func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
220 if c.InitializeResult().Capabilities.Tools == nil {
221 return nil, nil
222 }
223 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
224 if err != nil {
225 return nil, err
226 }
227 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
228 for _, tool := range result.Tools {
229 mcpTools = append(mcpTools, &McpTool{
230 mcpName: name,
231 tool: tool,
232 permissions: permissions,
233 workingDir: workingDir,
234 })
235 }
236 return mcpTools, nil
237}
238
239// SubscribeMCPEvents returns a channel for MCP events
240func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
241 return mcpBroker.Subscribe(ctx)
242}
243
244// GetMCPStates returns the current state of all MCP clients
245func GetMCPStates() map[string]MCPClientInfo {
246 return maps.Collect(mcpStates.Seq2())
247}
248
249// GetMCPState returns the state of a specific MCP client
250func GetMCPState(name string) (MCPClientInfo, bool) {
251 return mcpStates.Get(name)
252}
253
254// updateMCPState updates the state of an MCP client and publishes an event
255func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
256 info := MCPClientInfo{
257 Name: name,
258 State: state,
259 Error: err,
260 Client: client,
261 Counts: counts,
262 }
263 switch state {
264 case MCPStateConnected:
265 info.ConnectedAt = time.Now()
266 case MCPStateError:
267 updateMcpTools(name, nil)
268 updateMcpPrompts(name, nil)
269 mcpClients.Del(name)
270 }
271 mcpStates.Set(name, info)
272}
273
274// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
275func CloseMCPClients() error {
276 var errs []error
277 for name, c := range mcpClients.Seq2() {
278 if err := c.Close(); err != nil &&
279 !errors.Is(err, io.EOF) &&
280 !errors.Is(err, context.Canceled) &&
281 err.Error() != "signal: killed" {
282 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
283 }
284 }
285 mcpBroker.Shutdown()
286 return errors.Join(errs...)
287}
288
289func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
290 var wg sync.WaitGroup
291 // Initialize states for all configured MCPs
292 for name, m := range cfg.MCP {
293 if m.Disabled {
294 updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
295 slog.Debug("skipping disabled mcp", "name", name)
296 continue
297 }
298
299 // Set initial starting state
300 updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})
301
302 wg.Add(1)
303 go func(name string, m config.MCPConfig) {
304 defer func() {
305 wg.Done()
306 if r := recover(); r != nil {
307 var err error
308 switch v := r.(type) {
309 case error:
310 err = v
311 case string:
312 err = fmt.Errorf("panic: %s", v)
313 default:
314 err = fmt.Errorf("panic: %v", v)
315 }
316 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
317 slog.Error("panic in mcp client initialization", "error", err, "name", name)
318 }
319 }()
320
321 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
322 defer cancel()
323
324 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
325 if err != nil {
326 return
327 }
328
329 mcpClients.Set(name, c)
330
331 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
332 if err != nil {
333 slog.Error("error listing tools", "error", err)
334 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
335 c.Close()
336 return
337 }
338
339 prompts, err := getPrompts(ctx, c)
340 if err != nil {
341 slog.Error("error listing prompts", "error", err)
342 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
343 c.Close()
344 return
345 }
346
347 updateMcpTools(name, tools)
348 updateMcpPrompts(name, prompts)
349 mcpClients.Set(name, c)
350 counts := MCPCounts{
351 Tools: len(tools),
352 Prompts: len(prompts),
353 }
354 updateMCPState(name, MCPStateConnected, nil, c, counts)
355 }(name, m)
356 }
357 wg.Wait()
358}
359
360// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
361func updateMcpTools(mcpName string, tools []tools.BaseTool) {
362 if len(tools) == 0 {
363 mcpClient2Tools.Del(mcpName)
364 } else {
365 mcpClient2Tools.Set(mcpName, tools)
366 }
367 for _, tools := range mcpClient2Tools.Seq2() {
368 for _, t := range tools {
369 mcpTools.Set(t.Name(), t)
370 }
371 }
372}
373
374func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
375 timeout := mcpTimeout(m)
376 mcpCtx, cancel := context.WithCancel(ctx)
377 cancelTimer := time.AfterFunc(timeout, cancel)
378
379 transport, err := createMCPTransport(mcpCtx, m, resolver)
380 if err != nil {
381 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
382 slog.Error("error creating mcp client", "error", err, "name", name)
383 return nil, err
384 }
385
386 client := mcp.NewClient(
387 &mcp.Implementation{
388 Name: "crush",
389 Version: version.Version,
390 Title: "Crush",
391 },
392 &mcp.ClientOptions{
393 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
394 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
395 Type: MCPEventToolsListChanged,
396 Name: name,
397 })
398 },
399 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
400 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
401 Type: MCPEventPromptsListChanged,
402 Name: name,
403 })
404 },
405 KeepAlive: time.Minute * 10,
406 },
407 )
408
409 session, err := client.Connect(mcpCtx, transport, nil)
410 if err != nil {
411 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
412 slog.Error("error starting mcp client", "error", err, "name", name)
413 cancel()
414 return nil, err
415 }
416
417 cancelTimer.Stop()
418 slog.Info("Initialized mcp client", "name", name)
419 return session, nil
420}
421
422func maybeTimeoutErr(err error, timeout time.Duration) error {
423 if errors.Is(err, context.Canceled) {
424 return fmt.Errorf("timed out after %s", timeout)
425 }
426 return err
427}
428
429func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
430 switch m.Type {
431 case config.MCPStdio:
432 command, err := resolver.ResolveValue(m.Command)
433 if err != nil {
434 return nil, fmt.Errorf("invalid mcp command: %w", err)
435 }
436 if strings.TrimSpace(command) == "" {
437 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
438 }
439 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
440 cmd.Env = m.ResolvedEnv()
441 return &mcp.CommandTransport{
442 Command: cmd,
443 }, nil
444 case config.MCPHttp:
445 if strings.TrimSpace(m.URL) == "" {
446 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
447 }
448 client := &http.Client{
449 Transport: &headerRoundTripper{
450 headers: m.ResolvedHeaders(),
451 },
452 }
453 return &mcp.StreamableClientTransport{
454 Endpoint: m.URL,
455 HTTPClient: client,
456 }, nil
457 case config.MCPSSE:
458 if strings.TrimSpace(m.URL) == "" {
459 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
460 }
461 client := &http.Client{
462 Transport: &headerRoundTripper{
463 headers: m.ResolvedHeaders(),
464 },
465 }
466 return &mcp.SSEClientTransport{
467 Endpoint: m.URL,
468 HTTPClient: client,
469 }, nil
470 default:
471 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
472 }
473}
474
475type headerRoundTripper struct {
476 headers map[string]string
477}
478
479func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
480 for k, v := range rt.headers {
481 req.Header.Set(k, v)
482 }
483 return http.DefaultTransport.RoundTrip(req)
484}
485
486func mcpTimeout(m config.MCPConfig) time.Duration {
487 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
488}
489
490func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
491 if c.InitializeResult().Capabilities.Prompts == nil {
492 return nil, nil
493 }
494 result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
495 if err != nil {
496 return nil, err
497 }
498 return result.Prompts, nil
499}
500
501// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
502func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
503 if len(prompts) == 0 {
504 mcpClient2Prompts.Del(mcpName)
505 } else {
506 mcpClient2Prompts.Set(mcpName, prompts)
507 }
508 for clientName, prompts := range mcpClient2Prompts.Seq2() {
509 for _, p := range prompts {
510 key := clientName + ":" + p.Name
511 mcpPrompts.Set(key, p)
512 }
513 }
514}
515
516// GetMCPPrompts returns all available MCP prompts.
517func GetMCPPrompts() map[string]*mcp.Prompt {
518 return maps.Collect(mcpPrompts.Seq2())
519}
520
521// GetMCPPrompt returns a specific MCP prompt by name.
522func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
523 return mcpPrompts.Get(name)
524}
525
526// GetMCPPromptsByClient returns all prompts for a specific MCP client.
527func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
528 return mcpClient2Prompts.Get(clientName)
529}
530
531// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
532func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
533 c, err := getOrRenewClient(ctx, clientName)
534 if err != nil {
535 return nil, err
536 }
537
538 return c.GetPrompt(ctx, &mcp.GetPromptParams{
539 Name: promptName,
540 Arguments: args,
541 })
542}