1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "net/http"
12 "os/exec"
13 "strings"
14 "sync"
15 "time"
16
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/csync"
19 "github.com/charmbracelet/crush/internal/home"
20 "github.com/charmbracelet/crush/internal/llm/tools"
21 "github.com/charmbracelet/crush/internal/permission"
22 "github.com/charmbracelet/crush/internal/pubsub"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/modelcontextprotocol/go-sdk/mcp"
25)
26
27// MCPState represents the current state of an MCP client
28type MCPState int
29
30const (
31 MCPStateDisabled MCPState = iota
32 MCPStateStarting
33 MCPStateConnected
34 MCPStateError
35)
36
37func (s MCPState) String() string {
38 switch s {
39 case MCPStateDisabled:
40 return "disabled"
41 case MCPStateStarting:
42 return "starting"
43 case MCPStateConnected:
44 return "connected"
45 case MCPStateError:
46 return "error"
47 default:
48 return "unknown"
49 }
50}
51
52// MCPEventType represents the type of MCP event
53type MCPEventType string
54
55const (
56 MCPEventStateChanged MCPEventType = "state_changed"
57 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
58 MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
59)
60
61// MCPEvent represents an event in the MCP system
62type MCPEvent struct {
63 Type MCPEventType
64 Name string
65 State MCPState
66 Error error
67 ToolCount int
68 PromptCount int
69}
70
71// MCPClientInfo holds information about an MCP client's state
72type MCPClientInfo struct {
73 Name string
74 State MCPState
75 Error error
76 Client *mcp.ClientSession
77 ToolCount int
78 PromptCount int
79 ConnectedAt time.Time
80}
81
82var (
83 mcpToolsOnce sync.Once
84 mcpTools = csync.NewMap[string, tools.BaseTool]()
85 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
86 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
87 mcpStates = csync.NewMap[string, MCPClientInfo]()
88 mcpBroker = pubsub.NewBroker[MCPEvent]()
89 mcpPrompts = csync.NewMap[string, *mcp.Prompt]()
90 mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
91)
92
93type McpTool struct {
94 mcpName string
95 tool *mcp.Tool
96 permissions permission.Service
97 workingDir string
98}
99
100func (b *McpTool) Name() string {
101 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
102}
103
104func (b *McpTool) Info() tools.ToolInfo {
105 input := b.tool.InputSchema.(map[string]any)
106 required, _ := input["required"].([]string)
107 parameters, _ := input["properties"].(map[string]any)
108 return tools.ToolInfo{
109 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
110 Description: b.tool.Description,
111 Parameters: parameters,
112 Required: required,
113 }
114}
115
116func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
117 var args map[string]any
118 if err := json.Unmarshal([]byte(input), &args); err != nil {
119 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
120 }
121
122 c, err := getOrRenewClient(ctx, name)
123 if err != nil {
124 return tools.NewTextErrorResponse(err.Error()), nil
125 }
126 result, err := c.CallTool(ctx, &mcp.CallToolParams{
127 Name: toolName,
128 Arguments: args,
129 })
130 if err != nil {
131 return tools.NewTextErrorResponse(err.Error()), nil
132 }
133
134 output := make([]string, 0, len(result.Content))
135 for _, v := range result.Content {
136 if vv, ok := v.(*mcp.TextContent); ok {
137 output = append(output, vv.Text)
138 } else {
139 output = append(output, fmt.Sprintf("%v", v))
140 }
141 }
142 return tools.NewTextResponse(strings.Join(output, "\n")), nil
143}
144
145func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
146 sess, ok := mcpClients.Get(name)
147 if !ok {
148 return nil, fmt.Errorf("mcp '%s' not available", name)
149 }
150
151 cfg := config.Get()
152 m := cfg.MCP[name]
153 state, _ := mcpStates.Get(name)
154
155 timeout := mcpTimeout(m)
156 pingCtx, cancel := context.WithTimeout(ctx, timeout)
157 defer cancel()
158 err := sess.Ping(pingCtx, nil)
159 if err == nil {
160 return sess, nil
161 }
162 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount, state.PromptCount)
163
164 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
165 if err != nil {
166 return nil, err
167 }
168
169 updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount, state.PromptCount)
170 mcpClients.Set(name, sess)
171 return sess, nil
172}
173
174func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
175 sessionID, messageID := tools.GetContextValues(ctx)
176 if sessionID == "" || messageID == "" {
177 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
178 }
179 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
180 p := b.permissions.Request(
181 permission.CreatePermissionRequest{
182 SessionID: sessionID,
183 ToolCallID: params.ID,
184 Path: b.workingDir,
185 ToolName: b.Info().Name,
186 Action: "execute",
187 Description: permissionDescription,
188 Params: params.Input,
189 },
190 )
191 if !p {
192 return tools.ToolResponse{}, permission.ErrorPermissionDenied
193 }
194
195 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
196}
197
198func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
199 if c.InitializeResult().Capabilities.Tools == nil {
200 return nil, nil
201 }
202 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
203 if err != nil {
204 return nil, err
205 }
206 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
207 for _, tool := range result.Tools {
208 mcpTools = append(mcpTools, &McpTool{
209 mcpName: name,
210 tool: tool,
211 permissions: permissions,
212 workingDir: workingDir,
213 })
214 }
215 return mcpTools, nil
216}
217
218// SubscribeMCPEvents returns a channel for MCP events
219func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
220 return mcpBroker.Subscribe(ctx)
221}
222
223// GetMCPStates returns the current state of all MCP clients
224func GetMCPStates() map[string]MCPClientInfo {
225 return maps.Collect(mcpStates.Seq2())
226}
227
228// GetMCPState returns the state of a specific MCP client
229func GetMCPState(name string) (MCPClientInfo, bool) {
230 return mcpStates.Get(name)
231}
232
233// updateMCPState updates the state of an MCP client and publishes an event
234func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount, promptCount int) {
235 info := MCPClientInfo{
236 Name: name,
237 State: state,
238 Error: err,
239 Client: client,
240 ToolCount: toolCount,
241 PromptCount: promptCount,
242 }
243 switch state {
244 case MCPStateConnected:
245 info.ConnectedAt = time.Now()
246 case MCPStateError:
247 updateMcpTools(name, nil)
248 updateMcpPrompts(name, nil)
249 mcpClients.Del(name)
250 }
251 mcpStates.Set(name, info)
252
253 // Publish state change event
254 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
255 Type: MCPEventStateChanged,
256 Name: name,
257 State: state,
258 Error: err,
259 ToolCount: toolCount,
260 PromptCount: promptCount,
261 })
262}
263
264// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
265func CloseMCPClients() error {
266 var errs []error
267 for name, c := range mcpClients.Seq2() {
268 if err := c.Close(); err != nil {
269 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
270 }
271 }
272 mcpBroker.Shutdown()
273 return errors.Join(errs...)
274}
275
276func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
277 var wg sync.WaitGroup
278 // Initialize states for all configured MCPs
279 for name, m := range cfg.MCP {
280 if m.Disabled {
281 updateMCPState(name, MCPStateDisabled, nil, nil, 0, 0)
282 slog.Debug("skipping disabled mcp", "name", name)
283 continue
284 }
285
286 // Set initial starting state
287 updateMCPState(name, MCPStateStarting, nil, nil, 0, 0)
288
289 wg.Add(1)
290 go func(name string, m config.MCPConfig) {
291 defer func() {
292 wg.Done()
293 if r := recover(); r != nil {
294 var err error
295 switch v := r.(type) {
296 case error:
297 err = v
298 case string:
299 err = fmt.Errorf("panic: %s", v)
300 default:
301 err = fmt.Errorf("panic: %v", v)
302 }
303 updateMCPState(name, MCPStateError, err, nil, 0, 0)
304 slog.Error("panic in mcp client initialization", "error", err, "name", name)
305 }
306 }()
307
308 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
309 defer cancel()
310
311 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
312 if err != nil {
313 return
314 }
315
316 mcpClients.Set(name, c)
317
318 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
319 if err != nil {
320 slog.Error("error listing tools", "error", err)
321 updateMCPState(name, MCPStateError, err, nil, 0, 0)
322 c.Close()
323 return
324 }
325
326 prompts, err := getPrompts(ctx, c)
327 if err != nil {
328 slog.Error("error listing prompts", "error", err)
329 updateMCPState(name, MCPStateError, err, nil, 0, 0)
330 c.Close()
331 return
332 }
333
334 updateMcpTools(name, tools)
335 updateMcpPrompts(name, prompts)
336 mcpClients.Set(name, c)
337 updateMCPState(name, MCPStateConnected, nil, c, len(tools), len(prompts))
338 }(name, m)
339 }
340 wg.Wait()
341}
342
343// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
344func updateMcpTools(mcpName string, tools []tools.BaseTool) {
345 if len(tools) == 0 {
346 mcpClient2Tools.Del(mcpName)
347 } else {
348 mcpClient2Tools.Set(mcpName, tools)
349 }
350 for _, tools := range mcpClient2Tools.Seq2() {
351 for _, t := range tools {
352 mcpTools.Set(t.Name(), t)
353 }
354 }
355}
356
357func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
358 transport, err := createMCPTransport(m, resolver)
359 if err != nil {
360 updateMCPState(name, MCPStateError, err, nil, 0, 0)
361 slog.Error("error creating mcp client", "error", err, "name", name)
362 return nil, err
363 }
364
365 client := mcp.NewClient(
366 &mcp.Implementation{
367 Name: "crush",
368 Version: version.Version,
369 Title: "Crush",
370 },
371 &mcp.ClientOptions{
372 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
373 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
374 Type: MCPEventToolsListChanged,
375 Name: name,
376 })
377 },
378 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
379 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
380 Type: MCPEventPromptsListChanged,
381 Name: name,
382 })
383 },
384 KeepAlive: time.Minute * 10,
385 },
386 )
387
388 timeout := mcpTimeout(m)
389 mcpCtx, cancel := context.WithCancel(ctx)
390 cancelTimer := time.AfterFunc(timeout, cancel)
391
392 session, err := client.Connect(mcpCtx, transport, nil)
393 if err != nil {
394 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0, 0)
395 slog.Error("error starting mcp client", "error", err, "name", name)
396 _ = session.Close()
397 cancel()
398 return nil, err
399 }
400
401 cancelTimer.Stop()
402 slog.Info("Initialized mcp client", "name", name)
403 return session, nil
404}
405
406func maybeTimeoutErr(err error, timeout time.Duration) error {
407 if errors.Is(err, context.Canceled) {
408 return fmt.Errorf("timed out after %s", timeout)
409 }
410 return err
411}
412
413func createMCPTransport(m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
414 switch m.Type {
415 case config.MCPStdio:
416 command, err := resolver.ResolveValue(m.Command)
417 if err != nil {
418 return nil, fmt.Errorf("invalid mcp command: %w", err)
419 }
420 if strings.TrimSpace(command) == "" {
421 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
422 }
423 cmd := exec.Command(home.Long(command), m.Args...)
424 cmd.Env = m.ResolvedEnv()
425 return &mcp.CommandTransport{
426 Command: cmd,
427 }, nil
428 case config.MCPHttp:
429 if strings.TrimSpace(m.URL) == "" {
430 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
431 }
432 client := &http.Client{
433 Transport: &headerRoundTripper{
434 headers: m.ResolvedHeaders(),
435 },
436 }
437 return &mcp.StreamableClientTransport{
438 Endpoint: m.URL,
439 HTTPClient: client,
440 }, nil
441 case config.MCPSSE:
442 if strings.TrimSpace(m.URL) == "" {
443 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
444 }
445 client := &http.Client{
446 Transport: &headerRoundTripper{
447 headers: m.ResolvedHeaders(),
448 },
449 }
450 return &mcp.SSEClientTransport{
451 Endpoint: m.URL,
452 HTTPClient: client,
453 }, nil
454 default:
455 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
456 }
457}
458
459type headerRoundTripper struct {
460 headers map[string]string
461}
462
463func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
464 for k, v := range rt.headers {
465 req.Header.Set(k, v)
466 }
467 return http.DefaultTransport.RoundTrip(req)
468}
469
470func mcpTimeout(m config.MCPConfig) time.Duration {
471 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
472}
473
474func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
475 if c.InitializeResult().Capabilities.Prompts == nil {
476 return nil, nil
477 }
478 result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
479 if err != nil {
480 return nil, err
481 }
482 return result.Prompts, nil
483}
484
485// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
486func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
487 if len(prompts) == 0 {
488 mcpClient2Prompts.Del(mcpName)
489 } else {
490 mcpClient2Prompts.Set(mcpName, prompts)
491 }
492 for clientName, prompts := range mcpClient2Prompts.Seq2() {
493 for _, p := range prompts {
494 key := clientName + ":" + p.Name
495 mcpPrompts.Set(key, p)
496 }
497 }
498}
499
500// GetMCPPrompts returns all available MCP prompts.
501func GetMCPPrompts() map[string]*mcp.Prompt {
502 return maps.Collect(mcpPrompts.Seq2())
503}
504
505// GetMCPPrompt returns a specific MCP prompt by name.
506func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
507 return mcpPrompts.Get(name)
508}
509
510// GetMCPPromptsByClient returns all prompts for a specific MCP client.
511func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
512 return mcpClient2Prompts.Get(clientName)
513}
514
515// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
516func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
517 c, err := getOrRenewClient(ctx, clientName)
518 if err != nil {
519 return nil, err
520 }
521
522 return c.GetPrompt(ctx, &mcp.GetPromptParams{
523 Name: promptName,
524 Arguments: args,
525 })
526}