1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "net/http"
12 "os/exec"
13 "strings"
14 "sync"
15 "time"
16
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/csync"
19 "github.com/charmbracelet/crush/internal/home"
20 "github.com/charmbracelet/crush/internal/llm/tools"
21 "github.com/charmbracelet/crush/internal/permission"
22 "github.com/charmbracelet/crush/internal/pubsub"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/modelcontextprotocol/go-sdk/mcp"
25)
26
27// MCPState represents the current state of an MCP client
28type MCPState int
29
30const (
31 MCPStateDisabled MCPState = iota
32 MCPStateStarting
33 MCPStateConnected
34 MCPStateError
35)
36
37func (s MCPState) String() string {
38 switch s {
39 case MCPStateDisabled:
40 return "disabled"
41 case MCPStateStarting:
42 return "starting"
43 case MCPStateConnected:
44 return "connected"
45 case MCPStateError:
46 return "error"
47 default:
48 return "unknown"
49 }
50}
51
52// MCPEventType represents the type of MCP event
53type MCPEventType string
54
55const (
56 MCPEventStateChanged MCPEventType = "state_changed"
57 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
58 MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
59)
60
61// MCPEvent represents an event in the MCP system
62type MCPEvent struct {
63 Type MCPEventType
64 Name string
65 State MCPState
66 Error error
67 Counts MCPCounts
68}
69
70// MCPCounts number of available tools, prompts, etc.
71type MCPCounts struct {
72 Tools int
73 Prompts int
74}
75
76// MCPClientInfo holds information about an MCP client's state
77type MCPClientInfo struct {
78 Name string
79 State MCPState
80 Error error
81 Client *mcp.ClientSession
82 Counts MCPCounts
83 ConnectedAt time.Time
84}
85
86var (
87 mcpToolsOnce sync.Once
88 mcpTools = csync.NewMap[string, tools.BaseTool]()
89 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
90 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
91 mcpStates = csync.NewMap[string, MCPClientInfo]()
92 mcpBroker = pubsub.NewBroker[MCPEvent]()
93 mcpPrompts = csync.NewMap[string, *mcp.Prompt]()
94 mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
95)
96
97type McpTool struct {
98 mcpName string
99 tool *mcp.Tool
100 permissions permission.Service
101 workingDir string
102}
103
104func (b *McpTool) Name() string {
105 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
106}
107
108func (b *McpTool) Info() tools.ToolInfo {
109 input := b.tool.InputSchema.(map[string]any)
110 required, _ := input["required"].([]string)
111 parameters, _ := input["properties"].(map[string]any)
112 return tools.ToolInfo{
113 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
114 Description: b.tool.Description,
115 Parameters: parameters,
116 Required: required,
117 }
118}
119
120func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
121 var args map[string]any
122 if err := json.Unmarshal([]byte(input), &args); err != nil {
123 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
124 }
125
126 c, err := getOrRenewClient(ctx, name)
127 if err != nil {
128 return tools.NewTextErrorResponse(err.Error()), nil
129 }
130 result, err := c.CallTool(ctx, &mcp.CallToolParams{
131 Name: toolName,
132 Arguments: args,
133 })
134 if err != nil {
135 return tools.NewTextErrorResponse(err.Error()), nil
136 }
137
138 output := make([]string, 0, len(result.Content))
139 for _, v := range result.Content {
140 if vv, ok := v.(*mcp.TextContent); ok {
141 output = append(output, vv.Text)
142 } else {
143 output = append(output, fmt.Sprintf("%v", v))
144 }
145 }
146 return tools.NewTextResponse(strings.Join(output, "\n")), nil
147}
148
149func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
150 sess, ok := mcpClients.Get(name)
151 if !ok {
152 return nil, fmt.Errorf("mcp '%s' not available", name)
153 }
154
155 cfg := config.Get()
156 m := cfg.MCP[name]
157 state, _ := mcpStates.Get(name)
158
159 timeout := mcpTimeout(m)
160 pingCtx, cancel := context.WithTimeout(ctx, timeout)
161 defer cancel()
162 err := sess.Ping(pingCtx, nil)
163 if err == nil {
164 return sess, nil
165 }
166 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
167
168 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
169 if err != nil {
170 return nil, err
171 }
172
173 updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
174 mcpClients.Set(name, sess)
175 return sess, nil
176}
177
178func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
179 sessionID, messageID := tools.GetContextValues(ctx)
180 if sessionID == "" || messageID == "" {
181 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
182 }
183 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
184 p := b.permissions.Request(
185 permission.CreatePermissionRequest{
186 SessionID: sessionID,
187 ToolCallID: params.ID,
188 Path: b.workingDir,
189 ToolName: b.Info().Name,
190 Action: "execute",
191 Description: permissionDescription,
192 Params: params.Input,
193 },
194 )
195 if !p {
196 return tools.ToolResponse{}, permission.ErrorPermissionDenied
197 }
198
199 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
200}
201
202func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
203 if c.InitializeResult().Capabilities.Tools == nil {
204 return nil, nil
205 }
206 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
207 if err != nil {
208 return nil, err
209 }
210 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
211 for _, tool := range result.Tools {
212 mcpTools = append(mcpTools, &McpTool{
213 mcpName: name,
214 tool: tool,
215 permissions: permissions,
216 workingDir: workingDir,
217 })
218 }
219 return mcpTools, nil
220}
221
222// SubscribeMCPEvents returns a channel for MCP events
223func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
224 return mcpBroker.Subscribe(ctx)
225}
226
227// GetMCPStates returns the current state of all MCP clients
228func GetMCPStates() map[string]MCPClientInfo {
229 return maps.Collect(mcpStates.Seq2())
230}
231
232// GetMCPState returns the state of a specific MCP client
233func GetMCPState(name string) (MCPClientInfo, bool) {
234 return mcpStates.Get(name)
235}
236
237// updateMCPState updates the state of an MCP client and publishes an event
238func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
239 info := MCPClientInfo{
240 Name: name,
241 State: state,
242 Error: err,
243 Client: client,
244 Counts: counts,
245 }
246 switch state {
247 case MCPStateConnected:
248 info.ConnectedAt = time.Now()
249 case MCPStateError:
250 updateMcpTools(name, nil)
251 updateMcpPrompts(name, nil)
252 mcpClients.Del(name)
253 }
254 mcpStates.Set(name, info)
255
256 // Publish state change event
257 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
258 Type: MCPEventStateChanged,
259 Name: name,
260 State: state,
261 Error: err,
262 Counts: counts,
263 })
264}
265
266// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
267func CloseMCPClients() error {
268 var errs []error
269 for name, c := range mcpClients.Seq2() {
270 if err := c.Close(); err != nil {
271 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
272 }
273 }
274 mcpBroker.Shutdown()
275 return errors.Join(errs...)
276}
277
278func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
279 var wg sync.WaitGroup
280 // Initialize states for all configured MCPs
281 for name, m := range cfg.MCP {
282 if m.Disabled {
283 updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
284 slog.Debug("skipping disabled mcp", "name", name)
285 continue
286 }
287
288 // Set initial starting state
289 updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})
290
291 wg.Add(1)
292 go func(name string, m config.MCPConfig) {
293 defer func() {
294 wg.Done()
295 if r := recover(); r != nil {
296 var err error
297 switch v := r.(type) {
298 case error:
299 err = v
300 case string:
301 err = fmt.Errorf("panic: %s", v)
302 default:
303 err = fmt.Errorf("panic: %v", v)
304 }
305 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
306 slog.Error("panic in mcp client initialization", "error", err, "name", name)
307 }
308 }()
309
310 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
311 defer cancel()
312
313 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
314 if err != nil {
315 return
316 }
317
318 mcpClients.Set(name, c)
319
320 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
321 if err != nil {
322 slog.Error("error listing tools", "error", err)
323 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
324 c.Close()
325 return
326 }
327
328 prompts, err := getPrompts(ctx, c)
329 if err != nil {
330 slog.Error("error listing prompts", "error", err)
331 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
332 c.Close()
333 return
334 }
335
336 updateMcpTools(name, tools)
337 updateMcpPrompts(name, prompts)
338 mcpClients.Set(name, c)
339 updateMCPState(
340 name, MCPStateConnected, nil, c,
341 MCPCounts{
342 Tools: len(tools),
343 Prompts: len(prompts),
344 },
345 )
346 }(name, m)
347 }
348 wg.Wait()
349}
350
351// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
352func updateMcpTools(mcpName string, tools []tools.BaseTool) {
353 if len(tools) == 0 {
354 mcpClient2Tools.Del(mcpName)
355 } else {
356 mcpClient2Tools.Set(mcpName, tools)
357 }
358 for _, tools := range mcpClient2Tools.Seq2() {
359 for _, t := range tools {
360 mcpTools.Set(t.Name(), t)
361 }
362 }
363}
364
365func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
366 timeout := mcpTimeout(m)
367 mcpCtx, cancel := context.WithCancel(ctx)
368 cancelTimer := time.AfterFunc(timeout, cancel)
369
370 transport, err := createMCPTransport(mcpCtx, m, resolver)
371 if err != nil {
372 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
373 slog.Error("error creating mcp client", "error", err, "name", name)
374 return nil, err
375 }
376
377 client := mcp.NewClient(
378 &mcp.Implementation{
379 Name: "crush",
380 Version: version.Version,
381 Title: "Crush",
382 },
383 &mcp.ClientOptions{
384 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
385 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
386 Type: MCPEventToolsListChanged,
387 Name: name,
388 })
389 },
390 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
391 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
392 Type: MCPEventPromptsListChanged,
393 Name: name,
394 })
395 },
396 KeepAlive: time.Minute * 10,
397 },
398 )
399
400 session, err := client.Connect(mcpCtx, transport, nil)
401 if err != nil {
402 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
403 slog.Error("error starting mcp client", "error", err, "name", name)
404 _ = session.Close()
405 cancel()
406 return nil, err
407 }
408
409 cancelTimer.Stop()
410 slog.Info("Initialized mcp client", "name", name)
411 return session, nil
412}
413
414func maybeTimeoutErr(err error, timeout time.Duration) error {
415 if errors.Is(err, context.Canceled) {
416 return fmt.Errorf("timed out after %s", timeout)
417 }
418 return err
419}
420
421func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
422 switch m.Type {
423 case config.MCPStdio:
424 command, err := resolver.ResolveValue(m.Command)
425 if err != nil {
426 return nil, fmt.Errorf("invalid mcp command: %w", err)
427 }
428 if strings.TrimSpace(command) == "" {
429 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
430 }
431 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
432 cmd.Env = m.ResolvedEnv()
433 return &mcp.CommandTransport{
434 Command: cmd,
435 }, nil
436 case config.MCPHttp:
437 if strings.TrimSpace(m.URL) == "" {
438 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
439 }
440 client := &http.Client{
441 Transport: &headerRoundTripper{
442 headers: m.ResolvedHeaders(),
443 },
444 }
445 return &mcp.StreamableClientTransport{
446 Endpoint: m.URL,
447 HTTPClient: client,
448 }, nil
449 case config.MCPSSE:
450 if strings.TrimSpace(m.URL) == "" {
451 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
452 }
453 client := &http.Client{
454 Transport: &headerRoundTripper{
455 headers: m.ResolvedHeaders(),
456 },
457 }
458 return &mcp.SSEClientTransport{
459 Endpoint: m.URL,
460 HTTPClient: client,
461 }, nil
462 default:
463 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
464 }
465}
466
467type headerRoundTripper struct {
468 headers map[string]string
469}
470
471func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
472 for k, v := range rt.headers {
473 req.Header.Set(k, v)
474 }
475 return http.DefaultTransport.RoundTrip(req)
476}
477
478func mcpTimeout(m config.MCPConfig) time.Duration {
479 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
480}
481
482func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
483 if c.InitializeResult().Capabilities.Prompts == nil {
484 return nil, nil
485 }
486 result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
487 if err != nil {
488 return nil, err
489 }
490 return result.Prompts, nil
491}
492
493// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
494func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
495 if len(prompts) == 0 {
496 mcpClient2Prompts.Del(mcpName)
497 } else {
498 mcpClient2Prompts.Set(mcpName, prompts)
499 }
500 for clientName, prompts := range mcpClient2Prompts.Seq2() {
501 for _, p := range prompts {
502 key := clientName + ":" + p.Name
503 mcpPrompts.Set(key, p)
504 }
505 }
506}
507
508// GetMCPPrompts returns all available MCP prompts.
509func GetMCPPrompts() map[string]*mcp.Prompt {
510 return maps.Collect(mcpPrompts.Seq2())
511}
512
513// GetMCPPrompt returns a specific MCP prompt by name.
514func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
515 return mcpPrompts.Get(name)
516}
517
518// GetMCPPromptsByClient returns all prompts for a specific MCP client.
519func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
520 return mcpClient2Prompts.Get(clientName)
521}
522
523// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
524func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
525 c, err := getOrRenewClient(ctx, clientName)
526 if err != nil {
527 return nil, err
528 }
529
530 return c.GetPrompt(ctx, &mcp.GetPromptParams{
531 Name: promptName,
532 Arguments: args,
533 })
534}