1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "net/http"
13 "os"
14 "os/exec"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/home"
22 "github.com/charmbracelet/crush/internal/llm/tools"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/version"
26 "github.com/modelcontextprotocol/go-sdk/mcp"
27)
28
29// MCPState represents the current state of an MCP client
30type MCPState int
31
32const (
33 MCPStateDisabled MCPState = iota
34 MCPStateStarting
35 MCPStateConnected
36 MCPStateError
37)
38
39func (s MCPState) String() string {
40 switch s {
41 case MCPStateDisabled:
42 return "disabled"
43 case MCPStateStarting:
44 return "starting"
45 case MCPStateConnected:
46 return "connected"
47 case MCPStateError:
48 return "error"
49 default:
50 return "unknown"
51 }
52}
53
54// MCPEventType represents the type of MCP event
55type MCPEventType string
56
57const (
58 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
59 MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
60)
61
62// MCPEvent represents an event in the MCP system
63type MCPEvent struct {
64 Type MCPEventType
65 Name string
66 State MCPState
67 Error error
68 Counts MCPCounts
69}
70
71// MCPCounts number of available tools, prompts, etc.
72type MCPCounts struct {
73 Tools int
74 Prompts int
75}
76
77// MCPClientInfo holds information about an MCP client's state
78type MCPClientInfo struct {
79 Name string
80 State MCPState
81 Error error
82 Client *mcp.ClientSession
83 Counts MCPCounts
84 ConnectedAt time.Time
85}
86
87var (
88 mcpToolsOnce sync.Once
89 mcpTools = csync.NewMap[string, tools.BaseTool]()
90 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
91 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
92 mcpStates = csync.NewMap[string, MCPClientInfo]()
93 mcpBroker = pubsub.NewBroker[MCPEvent]()
94 mcpPrompts = csync.NewMap[string, *mcp.Prompt]()
95 mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
96)
97
98type McpTool struct {
99 mcpName string
100 tool *mcp.Tool
101 permissions permission.Service
102 workingDir string
103}
104
105func (b *McpTool) Name() string {
106 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
107}
108
109func (b *McpTool) Info() tools.ToolInfo {
110 parameters := make(map[string]any)
111 required := make([]string, 0)
112
113 if input, ok := b.tool.InputSchema.(map[string]any); ok {
114 if props, ok := input["properties"].(map[string]any); ok {
115 parameters = props
116 }
117 if req, ok := input["required"].([]any); ok {
118 // Convert []any -> []string when elements are strings
119 for _, v := range req {
120 if s, ok := v.(string); ok {
121 required = append(required, s)
122 }
123 }
124 } else if reqStr, ok := input["required"].([]string); ok {
125 // Handle case where it's already []string
126 required = reqStr
127 }
128 }
129
130 return tools.ToolInfo{
131 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
132 Description: b.tool.Description,
133 Parameters: parameters,
134 Required: required,
135 }
136}
137
138func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
139 var args map[string]any
140 if err := json.Unmarshal([]byte(input), &args); err != nil {
141 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
142 }
143
144 c, err := getOrRenewClient(ctx, name)
145 if err != nil {
146 return tools.NewTextErrorResponse(err.Error()), nil
147 }
148 result, err := c.CallTool(ctx, &mcp.CallToolParams{
149 Name: toolName,
150 Arguments: args,
151 })
152 if err != nil {
153 return tools.NewTextErrorResponse(err.Error()), nil
154 }
155
156 output := make([]string, 0, len(result.Content))
157 for _, v := range result.Content {
158 if vv, ok := v.(*mcp.TextContent); ok {
159 output = append(output, vv.Text)
160 } else {
161 output = append(output, fmt.Sprintf("%v", v))
162 }
163 }
164 return tools.NewTextResponse(strings.Join(output, "\n")), nil
165}
166
167func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
168 sess, ok := mcpClients.Get(name)
169 if !ok {
170 return nil, fmt.Errorf("mcp '%s' not available", name)
171 }
172
173 cfg := config.Get()
174 m := cfg.MCP[name]
175 state, _ := mcpStates.Get(name)
176
177 timeout := mcpTimeout(m)
178 pingCtx, cancel := context.WithTimeout(ctx, timeout)
179 defer cancel()
180 err := sess.Ping(pingCtx, nil)
181 if err == nil {
182 return sess, nil
183 }
184 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
185
186 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
187 if err != nil {
188 return nil, err
189 }
190
191 updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
192 mcpClients.Set(name, sess)
193 return sess, nil
194}
195
196func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
197 sessionID, messageID := tools.GetContextValues(ctx)
198 if sessionID == "" || messageID == "" {
199 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
200 }
201 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
202 p := b.permissions.Request(
203 permission.CreatePermissionRequest{
204 SessionID: sessionID,
205 ToolCallID: params.ID,
206 Path: b.workingDir,
207 ToolName: b.Info().Name,
208 Action: "execute",
209 Description: permissionDescription,
210 Params: params.Input,
211 },
212 )
213 if !p {
214 return tools.ToolResponse{}, permission.ErrorPermissionDenied
215 }
216
217 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
218}
219
220func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
221 if c.InitializeResult().Capabilities.Tools == nil {
222 return nil, nil
223 }
224 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
225 if err != nil {
226 return nil, err
227 }
228 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
229 for _, tool := range result.Tools {
230 mcpTools = append(mcpTools, &McpTool{
231 mcpName: name,
232 tool: tool,
233 permissions: permissions,
234 workingDir: workingDir,
235 })
236 }
237 return mcpTools, nil
238}
239
240// SubscribeMCPEvents returns a channel for MCP events
241func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
242 return mcpBroker.Subscribe(ctx)
243}
244
245// GetMCPStates returns the current state of all MCP clients
246func GetMCPStates() map[string]MCPClientInfo {
247 return maps.Collect(mcpStates.Seq2())
248}
249
250// GetMCPState returns the state of a specific MCP client
251func GetMCPState(name string) (MCPClientInfo, bool) {
252 return mcpStates.Get(name)
253}
254
255// updateMCPState updates the state of an MCP client and publishes an event
256func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
257 info := MCPClientInfo{
258 Name: name,
259 State: state,
260 Error: err,
261 Client: client,
262 Counts: counts,
263 }
264 switch state {
265 case MCPStateConnected:
266 info.ConnectedAt = time.Now()
267 case MCPStateError:
268 updateMcpTools(name, nil)
269 updateMcpPrompts(name, nil)
270 mcpClients.Del(name)
271 }
272 mcpStates.Set(name, info)
273}
274
275// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
276func CloseMCPClients() error {
277 var errs []error
278 for name, c := range mcpClients.Seq2() {
279 if err := c.Close(); err != nil &&
280 !errors.Is(err, io.EOF) &&
281 !errors.Is(err, context.Canceled) &&
282 err.Error() != "signal: killed" {
283 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
284 }
285 }
286 mcpBroker.Shutdown()
287 return errors.Join(errs...)
288}
289
290func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
291 var wg sync.WaitGroup
292 // Initialize states for all configured MCPs
293 for name, m := range cfg.MCP {
294 if m.Disabled {
295 updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
296 slog.Debug("skipping disabled mcp", "name", name)
297 continue
298 }
299
300 // Set initial starting state
301 updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})
302
303 wg.Add(1)
304 go func(name string, m config.MCPConfig) {
305 defer func() {
306 wg.Done()
307 if r := recover(); r != nil {
308 var err error
309 switch v := r.(type) {
310 case error:
311 err = v
312 case string:
313 err = fmt.Errorf("panic: %s", v)
314 default:
315 err = fmt.Errorf("panic: %v", v)
316 }
317 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
318 slog.Error("panic in mcp client initialization", "error", err, "name", name)
319 }
320 }()
321
322 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
323 defer cancel()
324
325 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
326 if err != nil {
327 return
328 }
329
330 mcpClients.Set(name, c)
331
332 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
333 if err != nil {
334 slog.Error("error listing tools", "error", err)
335 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
336 c.Close()
337 return
338 }
339
340 prompts, err := getPrompts(ctx, c)
341 if err != nil {
342 slog.Error("error listing prompts", "error", err)
343 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
344 c.Close()
345 return
346 }
347
348 updateMcpTools(name, tools)
349 updateMcpPrompts(name, prompts)
350 mcpClients.Set(name, c)
351 counts := MCPCounts{
352 Tools: len(tools),
353 Prompts: len(prompts),
354 }
355 updateMCPState(name, MCPStateConnected, nil, c, counts)
356 }(name, m)
357 }
358 wg.Wait()
359}
360
361// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
362func updateMcpTools(mcpName string, tools []tools.BaseTool) {
363 if len(tools) == 0 {
364 mcpClient2Tools.Del(mcpName)
365 } else {
366 mcpClient2Tools.Set(mcpName, tools)
367 }
368 for _, tools := range mcpClient2Tools.Seq2() {
369 for _, t := range tools {
370 mcpTools.Set(t.Name(), t)
371 }
372 }
373}
374
375func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
376 timeout := mcpTimeout(m)
377 mcpCtx, cancel := context.WithCancel(ctx)
378 cancelTimer := time.AfterFunc(timeout, cancel)
379
380 transport, err := createMCPTransport(mcpCtx, m, resolver)
381 if err != nil {
382 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
383 slog.Error("error creating mcp client", "error", err, "name", name)
384 cancel()
385 cancelTimer.Stop()
386 return nil, err
387 }
388
389 client := mcp.NewClient(
390 &mcp.Implementation{
391 Name: "crush",
392 Version: version.Version,
393 Title: "Crush",
394 },
395 &mcp.ClientOptions{
396 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
397 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
398 Type: MCPEventToolsListChanged,
399 Name: name,
400 })
401 },
402 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
403 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
404 Type: MCPEventPromptsListChanged,
405 Name: name,
406 })
407 },
408 KeepAlive: time.Minute * 10,
409 },
410 )
411
412 session, err := client.Connect(mcpCtx, transport, nil)
413 if err != nil {
414 err = maybeStdioErr(err, transport)
415 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
416 slog.Error("error starting mcp client", "error", err, "name", name)
417 cancel()
418 cancelTimer.Stop()
419 return nil, err
420 }
421
422 cancelTimer.Stop()
423 slog.Info("Initialized mcp client", "name", name)
424 return session, nil
425}
426
427func maybeTimeoutErr(err error, timeout time.Duration) error {
428 if errors.Is(err, context.Canceled) {
429 return fmt.Errorf("timed out after %s", timeout)
430 }
431 return err
432}
433
434func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
435 switch m.Type {
436 case config.MCPStdio:
437 command, err := resolver.ResolveValue(m.Command)
438 if err != nil {
439 return nil, fmt.Errorf("invalid mcp command: %w", err)
440 }
441 if strings.TrimSpace(command) == "" {
442 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
443 }
444 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
445 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
446 return &mcp.CommandTransport{
447 Command: cmd,
448 }, nil
449 case config.MCPHttp:
450 if strings.TrimSpace(m.URL) == "" {
451 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
452 }
453 client := &http.Client{
454 Transport: &headerRoundTripper{
455 headers: m.ResolvedHeaders(),
456 },
457 }
458 return &mcp.StreamableClientTransport{
459 Endpoint: m.URL,
460 HTTPClient: client,
461 }, nil
462 case config.MCPSSE:
463 if strings.TrimSpace(m.URL) == "" {
464 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
465 }
466 client := &http.Client{
467 Transport: &headerRoundTripper{
468 headers: m.ResolvedHeaders(),
469 },
470 }
471 return &mcp.SSEClientTransport{
472 Endpoint: m.URL,
473 HTTPClient: client,
474 }, nil
475 default:
476 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
477 }
478}
479
480type headerRoundTripper struct {
481 headers map[string]string
482}
483
484func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
485 for k, v := range rt.headers {
486 req.Header.Set(k, v)
487 }
488 return http.DefaultTransport.RoundTrip(req)
489}
490
491func mcpTimeout(m config.MCPConfig) time.Duration {
492 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
493}
494
495func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
496 if c.InitializeResult().Capabilities.Prompts == nil {
497 return nil, nil
498 }
499 result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
500 if err != nil {
501 return nil, err
502 }
503 return result.Prompts, nil
504}
505
506// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
507func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
508 if len(prompts) == 0 {
509 mcpClient2Prompts.Del(mcpName)
510 } else {
511 mcpClient2Prompts.Set(mcpName, prompts)
512 }
513 for clientName, prompts := range mcpClient2Prompts.Seq2() {
514 for _, p := range prompts {
515 key := clientName + ":" + p.Name
516 mcpPrompts.Set(key, p)
517 }
518 }
519}
520
521// GetMCPPrompts returns all available MCP prompts.
522func GetMCPPrompts() map[string]*mcp.Prompt {
523 return maps.Collect(mcpPrompts.Seq2())
524}
525
526// GetMCPPrompt returns a specific MCP prompt by name.
527func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
528 return mcpPrompts.Get(name)
529}
530
531// GetMCPPromptsByClient returns all prompts for a specific MCP client.
532func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
533 return mcpClient2Prompts.Get(clientName)
534}
535
536// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
537func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
538 c, err := getOrRenewClient(ctx, clientName)
539 if err != nil {
540 return nil, err
541 }
542
543 return c.GetPrompt(ctx, &mcp.GetPromptParams{
544 Name: promptName,
545 Arguments: args,
546 })
547}
548
549// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
550// to parse, and the cli will then close it, causing the EOF error.
551// so, if we got an EOF err, and the transport is STDIO, we try to exec it
552// again with a timeout and collect the output so we can add details to the
553// error.
554// this happens particularly when starting things with npx, e.g. if node can't
555// be found or some other error like that.
556func maybeStdioErr(err error, transport mcp.Transport) error {
557 if !errors.Is(err, io.EOF) {
558 return err
559 }
560 ct, ok := transport.(*mcp.CommandTransport)
561 if !ok {
562 return err
563 }
564 if err2 := stdioMCPCheck(ct.Command); err2 != nil {
565 err = errors.Join(err, err2)
566 }
567 return err
568}
569
570func stdioMCPCheck(old *exec.Cmd) error {
571 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
572 defer cancel()
573 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
574 cmd.Env = old.Env
575 out, err := cmd.CombinedOutput()
576 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
577 return nil
578 }
579 return fmt.Errorf("%w: %s", err, string(out))
580}