1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "net/http"
12 "os/exec"
13 "strings"
14 "sync"
15 "time"
16
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/csync"
19 "github.com/charmbracelet/crush/internal/home"
20 "github.com/charmbracelet/crush/internal/llm/tools"
21 "github.com/charmbracelet/crush/internal/permission"
22 "github.com/charmbracelet/crush/internal/pubsub"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/modelcontextprotocol/go-sdk/mcp"
25)
26
27// MCPState represents the current state of an MCP client
28type MCPState int
29
30const (
31 MCPStateDisabled MCPState = iota
32 MCPStateStarting
33 MCPStateConnected
34 MCPStateError
35)
36
37func (s MCPState) String() string {
38 switch s {
39 case MCPStateDisabled:
40 return "disabled"
41 case MCPStateStarting:
42 return "starting"
43 case MCPStateConnected:
44 return "connected"
45 case MCPStateError:
46 return "error"
47 default:
48 return "unknown"
49 }
50}
51
52// MCPEventType represents the type of MCP event
53type MCPEventType string
54
55const (
56 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
57 MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
58)
59
60// MCPEvent represents an event in the MCP system
61type MCPEvent struct {
62 Type MCPEventType
63 Name string
64 State MCPState
65 Error error
66 Counts MCPCounts
67}
68
69// MCPCounts number of available tools, prompts, etc.
70type MCPCounts struct {
71 Tools int
72 Prompts int
73}
74
75// MCPClientInfo holds information about an MCP client's state
76type MCPClientInfo struct {
77 Name string
78 State MCPState
79 Error error
80 Client *mcp.ClientSession
81 Counts MCPCounts
82 ConnectedAt time.Time
83}
84
85var (
86 mcpToolsOnce sync.Once
87 mcpTools = csync.NewMap[string, tools.BaseTool]()
88 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
89 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
90 mcpStates = csync.NewMap[string, MCPClientInfo]()
91 mcpBroker = pubsub.NewBroker[MCPEvent]()
92 mcpPrompts = csync.NewMap[string, *mcp.Prompt]()
93 mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
94)
95
96type McpTool struct {
97 mcpName string
98 tool *mcp.Tool
99 permissions permission.Service
100 workingDir string
101}
102
103func (b *McpTool) Name() string {
104 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
105}
106
107func (b *McpTool) Info() tools.ToolInfo {
108 input := b.tool.InputSchema.(map[string]any)
109 required, _ := input["required"].([]string)
110 parameters, _ := input["properties"].(map[string]any)
111 return tools.ToolInfo{
112 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
113 Description: b.tool.Description,
114 Parameters: parameters,
115 Required: required,
116 }
117}
118
119func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
120 var args map[string]any
121 if err := json.Unmarshal([]byte(input), &args); err != nil {
122 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
123 }
124
125 c, err := getOrRenewClient(ctx, name)
126 if err != nil {
127 return tools.NewTextErrorResponse(err.Error()), nil
128 }
129 result, err := c.CallTool(ctx, &mcp.CallToolParams{
130 Name: toolName,
131 Arguments: args,
132 })
133 if err != nil {
134 return tools.NewTextErrorResponse(err.Error()), nil
135 }
136
137 output := make([]string, 0, len(result.Content))
138 for _, v := range result.Content {
139 if vv, ok := v.(*mcp.TextContent); ok {
140 output = append(output, vv.Text)
141 } else {
142 output = append(output, fmt.Sprintf("%v", v))
143 }
144 }
145 return tools.NewTextResponse(strings.Join(output, "\n")), nil
146}
147
148func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
149 sess, ok := mcpClients.Get(name)
150 if !ok {
151 return nil, fmt.Errorf("mcp '%s' not available", name)
152 }
153
154 cfg := config.Get()
155 m := cfg.MCP[name]
156 state, _ := mcpStates.Get(name)
157
158 timeout := mcpTimeout(m)
159 pingCtx, cancel := context.WithTimeout(ctx, timeout)
160 defer cancel()
161 err := sess.Ping(pingCtx, nil)
162 if err == nil {
163 return sess, nil
164 }
165 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
166
167 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
168 if err != nil {
169 return nil, err
170 }
171
172 updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
173 mcpClients.Set(name, sess)
174 return sess, nil
175}
176
177func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
178 sessionID, messageID := tools.GetContextValues(ctx)
179 if sessionID == "" || messageID == "" {
180 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
181 }
182 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
183 p := b.permissions.Request(
184 permission.CreatePermissionRequest{
185 SessionID: sessionID,
186 ToolCallID: params.ID,
187 Path: b.workingDir,
188 ToolName: b.Info().Name,
189 Action: "execute",
190 Description: permissionDescription,
191 Params: params.Input,
192 },
193 )
194 if !p {
195 return tools.ToolResponse{}, permission.ErrorPermissionDenied
196 }
197
198 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
199}
200
201func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
202 if c.InitializeResult().Capabilities.Tools == nil {
203 return nil, nil
204 }
205 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
206 if err != nil {
207 return nil, err
208 }
209 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
210 for _, tool := range result.Tools {
211 mcpTools = append(mcpTools, &McpTool{
212 mcpName: name,
213 tool: tool,
214 permissions: permissions,
215 workingDir: workingDir,
216 })
217 }
218 return mcpTools, nil
219}
220
221// SubscribeMCPEvents returns a channel for MCP events
222func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
223 return mcpBroker.Subscribe(ctx)
224}
225
226// GetMCPStates returns the current state of all MCP clients
227func GetMCPStates() map[string]MCPClientInfo {
228 return maps.Collect(mcpStates.Seq2())
229}
230
231// GetMCPState returns the state of a specific MCP client
232func GetMCPState(name string) (MCPClientInfo, bool) {
233 return mcpStates.Get(name)
234}
235
236// updateMCPState updates the state of an MCP client and publishes an event
237func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
238 info := MCPClientInfo{
239 Name: name,
240 State: state,
241 Error: err,
242 Client: client,
243 Counts: counts,
244 }
245 switch state {
246 case MCPStateConnected:
247 info.ConnectedAt = time.Now()
248 case MCPStateError:
249 updateMcpTools(name, nil)
250 updateMcpPrompts(name, nil)
251 mcpClients.Del(name)
252 }
253 mcpStates.Set(name, info)
254}
255
256// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
257func CloseMCPClients() error {
258 var errs []error
259 for name, c := range mcpClients.Seq2() {
260 if err := c.Close(); err != nil {
261 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
262 }
263 }
264 mcpBroker.Shutdown()
265 return errors.Join(errs...)
266}
267
268func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
269 var wg sync.WaitGroup
270 // Initialize states for all configured MCPs
271 for name, m := range cfg.MCP {
272 if m.Disabled {
273 updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
274 slog.Debug("skipping disabled mcp", "name", name)
275 continue
276 }
277
278 // Set initial starting state
279 updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})
280
281 wg.Add(1)
282 go func(name string, m config.MCPConfig) {
283 defer func() {
284 wg.Done()
285 if r := recover(); r != nil {
286 var err error
287 switch v := r.(type) {
288 case error:
289 err = v
290 case string:
291 err = fmt.Errorf("panic: %s", v)
292 default:
293 err = fmt.Errorf("panic: %v", v)
294 }
295 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
296 slog.Error("panic in mcp client initialization", "error", err, "name", name)
297 }
298 }()
299
300 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
301 defer cancel()
302
303 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
304 if err != nil {
305 return
306 }
307
308 mcpClients.Set(name, c)
309
310 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
311 if err != nil {
312 slog.Error("error listing tools", "error", err)
313 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
314 c.Close()
315 return
316 }
317
318 prompts, err := getPrompts(ctx, c)
319 if err != nil {
320 slog.Error("error listing prompts", "error", err)
321 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
322 c.Close()
323 return
324 }
325
326 updateMcpTools(name, tools)
327 updateMcpPrompts(name, prompts)
328 mcpClients.Set(name, c)
329 counts := MCPCounts{
330 Tools: len(tools),
331 Prompts: len(prompts),
332 }
333 updateMCPState(name, MCPStateConnected, nil, c, counts)
334 }(name, m)
335 }
336 wg.Wait()
337}
338
339// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
340func updateMcpTools(mcpName string, tools []tools.BaseTool) {
341 if len(tools) == 0 {
342 mcpClient2Tools.Del(mcpName)
343 } else {
344 mcpClient2Tools.Set(mcpName, tools)
345 }
346 for _, tools := range mcpClient2Tools.Seq2() {
347 for _, t := range tools {
348 mcpTools.Set(t.Name(), t)
349 }
350 }
351}
352
353func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
354 timeout := mcpTimeout(m)
355 mcpCtx, cancel := context.WithCancel(ctx)
356 cancelTimer := time.AfterFunc(timeout, cancel)
357
358 transport, err := createMCPTransport(mcpCtx, m, resolver)
359 if err != nil {
360 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
361 slog.Error("error creating mcp client", "error", err, "name", name)
362 return nil, err
363 }
364
365 client := mcp.NewClient(
366 &mcp.Implementation{
367 Name: "crush",
368 Version: version.Version,
369 Title: "Crush",
370 },
371 &mcp.ClientOptions{
372 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
373 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
374 Type: MCPEventToolsListChanged,
375 Name: name,
376 })
377 },
378 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
379 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
380 Type: MCPEventPromptsListChanged,
381 Name: name,
382 })
383 },
384 KeepAlive: time.Minute * 10,
385 },
386 )
387
388 session, err := client.Connect(mcpCtx, transport, nil)
389 if err != nil {
390 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
391 slog.Error("error starting mcp client", "error", err, "name", name)
392 _ = session.Close()
393 cancel()
394 return nil, err
395 }
396
397 cancelTimer.Stop()
398 slog.Info("Initialized mcp client", "name", name)
399 return session, nil
400}
401
402func maybeTimeoutErr(err error, timeout time.Duration) error {
403 if errors.Is(err, context.Canceled) {
404 return fmt.Errorf("timed out after %s", timeout)
405 }
406 return err
407}
408
409func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
410 switch m.Type {
411 case config.MCPStdio:
412 command, err := resolver.ResolveValue(m.Command)
413 if err != nil {
414 return nil, fmt.Errorf("invalid mcp command: %w", err)
415 }
416 if strings.TrimSpace(command) == "" {
417 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
418 }
419 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
420 cmd.Env = m.ResolvedEnv()
421 return &mcp.CommandTransport{
422 Command: cmd,
423 }, nil
424 case config.MCPHttp:
425 if strings.TrimSpace(m.URL) == "" {
426 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
427 }
428 client := &http.Client{
429 Transport: &headerRoundTripper{
430 headers: m.ResolvedHeaders(),
431 },
432 }
433 return &mcp.StreamableClientTransport{
434 Endpoint: m.URL,
435 HTTPClient: client,
436 }, nil
437 case config.MCPSSE:
438 if strings.TrimSpace(m.URL) == "" {
439 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
440 }
441 client := &http.Client{
442 Transport: &headerRoundTripper{
443 headers: m.ResolvedHeaders(),
444 },
445 }
446 return &mcp.SSEClientTransport{
447 Endpoint: m.URL,
448 HTTPClient: client,
449 }, nil
450 default:
451 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
452 }
453}
454
455type headerRoundTripper struct {
456 headers map[string]string
457}
458
459func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
460 for k, v := range rt.headers {
461 req.Header.Set(k, v)
462 }
463 return http.DefaultTransport.RoundTrip(req)
464}
465
466func mcpTimeout(m config.MCPConfig) time.Duration {
467 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
468}
469
470func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
471 if c.InitializeResult().Capabilities.Prompts == nil {
472 return nil, nil
473 }
474 result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
475 if err != nil {
476 return nil, err
477 }
478 return result.Prompts, nil
479}
480
481// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
482func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
483 if len(prompts) == 0 {
484 mcpClient2Prompts.Del(mcpName)
485 } else {
486 mcpClient2Prompts.Set(mcpName, prompts)
487 }
488 for clientName, prompts := range mcpClient2Prompts.Seq2() {
489 for _, p := range prompts {
490 key := clientName + ":" + p.Name
491 mcpPrompts.Set(key, p)
492 }
493 }
494}
495
496// GetMCPPrompts returns all available MCP prompts.
497func GetMCPPrompts() map[string]*mcp.Prompt {
498 return maps.Collect(mcpPrompts.Seq2())
499}
500
501// GetMCPPrompt returns a specific MCP prompt by name.
502func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
503 return mcpPrompts.Get(name)
504}
505
506// GetMCPPromptsByClient returns all prompts for a specific MCP client.
507func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
508 return mcpClient2Prompts.Get(clientName)
509}
510
511// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
512func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
513 c, err := getOrRenewClient(ctx, clientName)
514 if err != nil {
515 return nil, err
516 }
517
518 return c.GetPrompt(ctx, &mcp.GetPromptParams{
519 Name: promptName,
520 Arguments: args,
521 })
522}