1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "net/http"
13 "os/exec"
14 "strings"
15 "sync"
16 "time"
17
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/home"
21 "github.com/charmbracelet/crush/internal/llm/tools"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/modelcontextprotocol/go-sdk/mcp"
26)
27
28// MCPState represents the current state of an MCP client
29type MCPState int
30
31const (
32 MCPStateDisabled MCPState = iota
33 MCPStateStarting
34 MCPStateConnected
35 MCPStateError
36)
37
38func (s MCPState) String() string {
39 switch s {
40 case MCPStateDisabled:
41 return "disabled"
42 case MCPStateStarting:
43 return "starting"
44 case MCPStateConnected:
45 return "connected"
46 case MCPStateError:
47 return "error"
48 default:
49 return "unknown"
50 }
51}
52
53// MCPEventType represents the type of MCP event
54type MCPEventType string
55
56const (
57 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
58 MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
59)
60
61// MCPEvent represents an event in the MCP system
62type MCPEvent struct {
63 Type MCPEventType
64 Name string
65 State MCPState
66 Error error
67 Counts MCPCounts
68}
69
70// MCPCounts number of available tools, prompts, etc.
71type MCPCounts struct {
72 Tools int
73 Prompts int
74 Resources int
75}
76
77// MCPClientInfo holds information about an MCP client's state
78type MCPClientInfo struct {
79 Name string
80 State MCPState
81 Error error
82 Client *mcp.ClientSession
83 Counts MCPCounts
84 ConnectedAt time.Time
85}
86
87var (
88 mcpToolsOnce sync.Once
89 mcpTools = csync.NewMap[string, tools.BaseTool]()
90 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
91 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
92 mcpStates = csync.NewMap[string, MCPClientInfo]()
93 mcpBroker = pubsub.NewBroker[MCPEvent]()
94 mcpPrompts = csync.NewMap[string, *mcp.Prompt]()
95 mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
96 mcpResources = csync.NewMap[string, *mcp.Resource]()
97 mcpClient2Resources = csync.NewMap[string, []*mcp.Resource]()
98)
99
100type McpTool struct {
101 mcpName string
102 tool *mcp.Tool
103 permissions permission.Service
104 workingDir string
105}
106
107func (b *McpTool) Name() string {
108 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
109}
110
111func (b *McpTool) Info() tools.ToolInfo {
112 input := b.tool.InputSchema.(map[string]any)
113 required, _ := input["required"].([]string)
114 parameters, _ := input["properties"].(map[string]any)
115 return tools.ToolInfo{
116 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
117 Description: b.tool.Description,
118 Parameters: parameters,
119 Required: required,
120 }
121}
122
123func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
124 var args map[string]any
125 if err := json.Unmarshal([]byte(input), &args); err != nil {
126 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
127 }
128
129 c, err := getOrRenewClient(ctx, name)
130 if err != nil {
131 return tools.NewTextErrorResponse(err.Error()), nil
132 }
133 result, err := c.CallTool(ctx, &mcp.CallToolParams{
134 Name: toolName,
135 Arguments: args,
136 })
137 if err != nil {
138 return tools.NewTextErrorResponse(err.Error()), nil
139 }
140
141 output := make([]string, 0, len(result.Content))
142 for _, v := range result.Content {
143 if vv, ok := v.(*mcp.TextContent); ok {
144 output = append(output, vv.Text)
145 } else {
146 output = append(output, fmt.Sprintf("%v", v))
147 }
148 }
149 return tools.NewTextResponse(strings.Join(output, "\n")), nil
150}
151
152func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
153 sess, ok := mcpClients.Get(name)
154 if !ok {
155 return nil, fmt.Errorf("mcp '%s' not available", name)
156 }
157
158 cfg := config.Get()
159 m := cfg.MCP[name]
160 state, _ := mcpStates.Get(name)
161
162 timeout := mcpTimeout(m)
163 pingCtx, cancel := context.WithTimeout(ctx, timeout)
164 defer cancel()
165 err := sess.Ping(pingCtx, nil)
166 if err == nil {
167 return sess, nil
168 }
169 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
170
171 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
172 if err != nil {
173 return nil, err
174 }
175
176 updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
177 mcpClients.Set(name, sess)
178 return sess, nil
179}
180
181func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
182 sessionID, messageID := tools.GetContextValues(ctx)
183 if sessionID == "" || messageID == "" {
184 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
185 }
186 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
187 p := b.permissions.Request(
188 permission.CreatePermissionRequest{
189 SessionID: sessionID,
190 ToolCallID: params.ID,
191 Path: b.workingDir,
192 ToolName: b.Info().Name,
193 Action: "execute",
194 Description: permissionDescription,
195 Params: params.Input,
196 },
197 )
198 if !p {
199 return tools.ToolResponse{}, permission.ErrorPermissionDenied
200 }
201
202 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
203}
204
205func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
206 if c.InitializeResult().Capabilities.Tools == nil {
207 return nil, nil
208 }
209 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
210 if err != nil {
211 return nil, err
212 }
213 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
214 for _, tool := range result.Tools {
215 mcpTools = append(mcpTools, &McpTool{
216 mcpName: name,
217 tool: tool,
218 permissions: permissions,
219 workingDir: workingDir,
220 })
221 }
222 return mcpTools, nil
223}
224
225// SubscribeMCPEvents returns a channel for MCP events
226func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
227 return mcpBroker.Subscribe(ctx)
228}
229
230// GetMCPStates returns the current state of all MCP clients
231func GetMCPStates() map[string]MCPClientInfo {
232 return maps.Collect(mcpStates.Seq2())
233}
234
235// GetMCPState returns the state of a specific MCP client
236func GetMCPState(name string) (MCPClientInfo, bool) {
237 return mcpStates.Get(name)
238}
239
240// updateMCPState updates the state of an MCP client and publishes an event
241func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
242 info := MCPClientInfo{
243 Name: name,
244 State: state,
245 Error: err,
246 Client: client,
247 Counts: counts,
248 }
249 switch state {
250 case MCPStateConnected:
251 info.ConnectedAt = time.Now()
252 case MCPStateError:
253 updateMcpTools(name, nil)
254 updateMcpPrompts(name, nil)
255 mcpClients.Del(name)
256 }
257 mcpStates.Set(name, info)
258}
259
260// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
261func CloseMCPClients() error {
262 var errs []error
263 for name, c := range mcpClients.Seq2() {
264 if err := c.Close(); err != nil &&
265 !errors.Is(err, io.EOF) &&
266 !errors.Is(err, context.Canceled) &&
267 err.Error() != "signal: killed" {
268 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
269 }
270 }
271 mcpBroker.Shutdown()
272 return errors.Join(errs...)
273}
274
275func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
276 var wg sync.WaitGroup
277 // Initialize states for all configured MCPs
278 for name, m := range cfg.MCP {
279 if m.Disabled {
280 updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
281 slog.Debug("skipping disabled mcp", "name", name)
282 continue
283 }
284
285 // Set initial starting state
286 updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})
287
288 wg.Add(1)
289 go func(name string, m config.MCPConfig) {
290 defer func() {
291 wg.Done()
292 if r := recover(); r != nil {
293 var err error
294 switch v := r.(type) {
295 case error:
296 err = v
297 case string:
298 err = fmt.Errorf("panic: %s", v)
299 default:
300 err = fmt.Errorf("panic: %v", v)
301 }
302 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
303 slog.Error("panic in mcp client initialization", "error", err, "name", name)
304 }
305 }()
306
307 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
308 defer cancel()
309
310 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
311 if err != nil {
312 return
313 }
314
315 mcpClients.Set(name, c)
316
317 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
318 if err != nil {
319 slog.Error("error listing tools", "error", err)
320 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
321 c.Close()
322 return
323 }
324
325 prompts, err := getPrompts(ctx, c)
326 if err != nil {
327 slog.Error("error listing prompts", "error", err)
328 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
329 c.Close()
330 return
331 }
332
333 resources, err := getResources(ctx, c)
334 if err != nil {
335 slog.Error("error listing resources", "error", err)
336 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
337 c.Close()
338 return
339 }
340
341 updateMcpTools(name, tools)
342 updateMcpPrompts(name, prompts)
343 updateMcpResources(name, resources)
344 mcpClients.Set(name, c)
345 counts := MCPCounts{
346 Tools: len(tools),
347 Prompts: len(prompts),
348 Resources: len(resources),
349 }
350 updateMCPState(name, MCPStateConnected, nil, c, counts)
351 }(name, m)
352 }
353 wg.Wait()
354}
355
356// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
357func updateMcpTools(mcpName string, tools []tools.BaseTool) {
358 if len(tools) == 0 {
359 mcpClient2Tools.Del(mcpName)
360 } else {
361 mcpClient2Tools.Set(mcpName, tools)
362 }
363 for _, tools := range mcpClient2Tools.Seq2() {
364 for _, t := range tools {
365 mcpTools.Set(t.Name(), t)
366 }
367 }
368}
369
370func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
371 timeout := mcpTimeout(m)
372 mcpCtx, cancel := context.WithCancel(ctx)
373 cancelTimer := time.AfterFunc(timeout, cancel)
374
375 transport, err := createMCPTransport(mcpCtx, m, resolver)
376 if err != nil {
377 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
378 slog.Error("error creating mcp client", "error", err, "name", name)
379 return nil, err
380 }
381
382 client := mcp.NewClient(
383 &mcp.Implementation{
384 Name: "crush",
385 Version: version.Version,
386 Title: "Crush",
387 },
388 &mcp.ClientOptions{
389 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
390 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
391 Type: MCPEventToolsListChanged,
392 Name: name,
393 })
394 },
395 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
396 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
397 Type: MCPEventPromptsListChanged,
398 Name: name,
399 })
400 },
401 KeepAlive: time.Minute * 10,
402 },
403 )
404
405 session, err := client.Connect(mcpCtx, transport, nil)
406 if err != nil {
407 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
408 slog.Error("error starting mcp client", "error", err, "name", name)
409 cancel()
410 return nil, err
411 }
412
413 cancelTimer.Stop()
414 slog.Info("Initialized mcp client", "name", name)
415 return session, nil
416}
417
418func maybeTimeoutErr(err error, timeout time.Duration) error {
419 if errors.Is(err, context.Canceled) {
420 return fmt.Errorf("timed out after %s", timeout)
421 }
422 return err
423}
424
425func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
426 switch m.Type {
427 case config.MCPStdio:
428 command, err := resolver.ResolveValue(m.Command)
429 if err != nil {
430 return nil, fmt.Errorf("invalid mcp command: %w", err)
431 }
432 if strings.TrimSpace(command) == "" {
433 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
434 }
435 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
436 cmd.Env = m.ResolvedEnv()
437 return &mcp.CommandTransport{
438 Command: cmd,
439 }, nil
440 case config.MCPHttp:
441 if strings.TrimSpace(m.URL) == "" {
442 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
443 }
444 client := &http.Client{
445 Transport: &headerRoundTripper{
446 headers: m.ResolvedHeaders(),
447 },
448 }
449 return &mcp.StreamableClientTransport{
450 Endpoint: m.URL,
451 HTTPClient: client,
452 }, nil
453 case config.MCPSSE:
454 if strings.TrimSpace(m.URL) == "" {
455 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
456 }
457 client := &http.Client{
458 Transport: &headerRoundTripper{
459 headers: m.ResolvedHeaders(),
460 },
461 }
462 return &mcp.SSEClientTransport{
463 Endpoint: m.URL,
464 HTTPClient: client,
465 }, nil
466 default:
467 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
468 }
469}
470
471type headerRoundTripper struct {
472 headers map[string]string
473}
474
475func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
476 for k, v := range rt.headers {
477 req.Header.Set(k, v)
478 }
479 return http.DefaultTransport.RoundTrip(req)
480}
481
482func mcpTimeout(m config.MCPConfig) time.Duration {
483 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
484}
485
486func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
487 if c.InitializeResult().Capabilities.Prompts == nil {
488 return nil, nil
489 }
490 result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
491 if err != nil {
492 return nil, err
493 }
494 return result.Prompts, nil
495}
496
497// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
498func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
499 if len(prompts) == 0 {
500 mcpClient2Prompts.Del(mcpName)
501 } else {
502 mcpClient2Prompts.Set(mcpName, prompts)
503 }
504 for clientName, prompts := range mcpClient2Prompts.Seq2() {
505 for _, p := range prompts {
506 key := clientName + ":" + p.Name
507 mcpPrompts.Set(key, p)
508 }
509 }
510}
511
512func getResources(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Resource, error) {
513 if c.InitializeResult().Capabilities.Resources == nil {
514 return nil, nil
515 }
516 result, err := c.ListResources(ctx, &mcp.ListResourcesParams{})
517 if err != nil {
518 return nil, err
519 }
520 return result.Resources, nil
521}
522
523// updateMcpResources updates the global mcpResources and mcpClient2Resources maps.
524func updateMcpResources(mcpName string, resources []*mcp.Resource) {
525 if len(resources) == 0 {
526 mcpClient2Resources.Del(mcpName)
527 } else {
528 mcpClient2Resources.Set(mcpName, resources)
529 }
530 for clientName, resources := range mcpClient2Resources.Seq2() {
531 for _, p := range resources {
532 key := clientName + ":" + p.Name
533 mcpResources.Set(key, p)
534 }
535 }
536}
537
538// GetMCPPrompts returns all available MCP prompts.
539func GetMCPPrompts() map[string]*mcp.Prompt {
540 return maps.Collect(mcpPrompts.Seq2())
541}
542
543// GetMCPPrompt returns a specific MCP prompt by name.
544func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
545 return mcpPrompts.Get(name)
546}
547
548// GetMCPPromptsByClient returns all prompts for a specific MCP client.
549func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
550 return mcpClient2Prompts.Get(clientName)
551}
552
553// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
554func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
555 c, err := getOrRenewClient(ctx, clientName)
556 if err != nil {
557 return nil, err
558 }
559
560 return c.GetPrompt(ctx, &mcp.GetPromptParams{
561 Name: promptName,
562 Arguments: args,
563 })
564}