1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "net/http"
13 "os"
14 "os/exec"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/home"
22 "github.com/charmbracelet/crush/internal/llm/tools"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/version"
26 "github.com/modelcontextprotocol/go-sdk/mcp"
27)
28
29// MCPState represents the current state of an MCP client
30type MCPState int
31
32const (
33 MCPStateDisabled MCPState = iota
34 MCPStateStarting
35 MCPStateConnected
36 MCPStateError
37)
38
39func (s MCPState) String() string {
40 switch s {
41 case MCPStateDisabled:
42 return "disabled"
43 case MCPStateStarting:
44 return "starting"
45 case MCPStateConnected:
46 return "connected"
47 case MCPStateError:
48 return "error"
49 default:
50 return "unknown"
51 }
52}
53
54// MCPEventType represents the type of MCP event
55type MCPEventType string
56
57const (
58 MCPEventStateChanged MCPEventType = "state_changed"
59 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
60)
61
62// MCPEvent represents an event in the MCP system
63type MCPEvent struct {
64 Type MCPEventType
65 Name string
66 State MCPState
67 Error error
68 ToolCount int
69}
70
71// MCPClientInfo holds information about an MCP client's state
72type MCPClientInfo struct {
73 Name string
74 State MCPState
75 Error error
76 Client *mcp.ClientSession
77 ToolCount int
78 ConnectedAt time.Time
79}
80
81var (
82 mcpToolsOnce sync.Once
83 mcpTools = csync.NewMap[string, tools.BaseTool]()
84 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
85 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
86 mcpStates = csync.NewMap[string, MCPClientInfo]()
87 mcpBroker = pubsub.NewBroker[MCPEvent]()
88)
89
90type McpTool struct {
91 mcpName string
92 tool *mcp.Tool
93 permissions permission.Service
94 workingDir string
95}
96
97func (b *McpTool) Name() string {
98 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
99}
100
101func (b *McpTool) Info() tools.ToolInfo {
102 parameters := make(map[string]any)
103 required := make([]string, 0)
104
105 if input, ok := b.tool.InputSchema.(map[string]any); ok {
106 if props, ok := input["properties"].(map[string]any); ok {
107 parameters = props
108 }
109 if req, ok := input["required"].([]any); ok {
110 // Convert []any -> []string when elements are strings
111 for _, v := range req {
112 if s, ok := v.(string); ok {
113 required = append(required, s)
114 }
115 }
116 } else if reqStr, ok := input["required"].([]string); ok {
117 // Handle case where it's already []string
118 required = reqStr
119 }
120 }
121
122 return tools.ToolInfo{
123 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
124 Description: b.tool.Description,
125 Parameters: parameters,
126 Required: required,
127 }
128}
129
130func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
131 var args map[string]any
132 if err := json.Unmarshal([]byte(input), &args); err != nil {
133 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
134 }
135
136 c, err := getOrRenewClient(ctx, name)
137 if err != nil {
138 return tools.NewTextErrorResponse(err.Error()), nil
139 }
140 result, err := c.CallTool(ctx, &mcp.CallToolParams{
141 Name: toolName,
142 Arguments: args,
143 })
144 if err != nil {
145 return tools.NewTextErrorResponse(err.Error()), nil
146 }
147
148 output := make([]string, 0, len(result.Content))
149 for _, v := range result.Content {
150 if vv, ok := v.(*mcp.TextContent); ok {
151 output = append(output, vv.Text)
152 } else {
153 output = append(output, fmt.Sprintf("%v", v))
154 }
155 }
156 return tools.NewTextResponse(strings.Join(output, "\n")), nil
157}
158
159func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
160 sess, ok := mcpClients.Get(name)
161 if !ok {
162 return nil, fmt.Errorf("mcp '%s' not available", name)
163 }
164
165 cfg := config.Get()
166 m := cfg.MCP[name]
167 state, _ := mcpStates.Get(name)
168
169 timeout := mcpTimeout(m)
170 pingCtx, cancel := context.WithTimeout(ctx, timeout)
171 defer cancel()
172 err := sess.Ping(pingCtx, nil)
173 if err == nil {
174 return sess, nil
175 }
176 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
177
178 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
179 if err != nil {
180 return nil, err
181 }
182
183 updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
184 mcpClients.Set(name, sess)
185 return sess, nil
186}
187
188func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
189 sessionID, messageID := tools.GetContextValues(ctx)
190 if sessionID == "" || messageID == "" {
191 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
192 }
193 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
194 p := b.permissions.Request(
195 permission.CreatePermissionRequest{
196 SessionID: sessionID,
197 ToolCallID: params.ID,
198 Path: b.workingDir,
199 ToolName: b.Info().Name,
200 Action: "execute",
201 Description: permissionDescription,
202 Params: params.Input,
203 },
204 )
205 if !p {
206 return tools.ToolResponse{}, permission.ErrorPermissionDenied
207 }
208
209 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
210}
211
212func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
213 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
214 if err != nil {
215 return nil, err
216 }
217 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
218 for _, tool := range result.Tools {
219 mcpTools = append(mcpTools, &McpTool{
220 mcpName: name,
221 tool: tool,
222 permissions: permissions,
223 workingDir: workingDir,
224 })
225 }
226 return mcpTools, nil
227}
228
229// SubscribeMCPEvents returns a channel for MCP events
230func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
231 return mcpBroker.Subscribe(ctx)
232}
233
234// GetMCPStates returns the current state of all MCP clients
235func GetMCPStates() map[string]MCPClientInfo {
236 return maps.Collect(mcpStates.Seq2())
237}
238
239// GetMCPState returns the state of a specific MCP client
240func GetMCPState(name string) (MCPClientInfo, bool) {
241 return mcpStates.Get(name)
242}
243
244// updateMCPState updates the state of an MCP client and publishes an event
245func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
246 info := MCPClientInfo{
247 Name: name,
248 State: state,
249 Error: err,
250 Client: client,
251 ToolCount: toolCount,
252 }
253 switch state {
254 case MCPStateConnected:
255 info.ConnectedAt = time.Now()
256 case MCPStateError:
257 updateMcpTools(name, nil)
258 mcpClients.Del(name)
259 }
260 mcpStates.Set(name, info)
261
262 // Publish state change event
263 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
264 Type: MCPEventStateChanged,
265 Name: name,
266 State: state,
267 Error: err,
268 ToolCount: toolCount,
269 })
270}
271
272// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
273func CloseMCPClients() error {
274 var errs []error
275 for name, c := range mcpClients.Seq2() {
276 if err := c.Close(); err != nil &&
277 !errors.Is(err, io.EOF) &&
278 !errors.Is(err, context.Canceled) &&
279 err.Error() != "signal: killed" {
280 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
281 }
282 }
283 mcpBroker.Shutdown()
284 return errors.Join(errs...)
285}
286
287func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
288 var wg sync.WaitGroup
289 // Initialize states for all configured MCPs
290 for name, m := range cfg.MCP {
291 if m.Disabled {
292 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
293 slog.Debug("skipping disabled mcp", "name", name)
294 continue
295 }
296
297 // Set initial starting state
298 updateMCPState(name, MCPStateStarting, nil, nil, 0)
299
300 wg.Add(1)
301 go func(name string, m config.MCPConfig) {
302 defer func() {
303 wg.Done()
304 if r := recover(); r != nil {
305 var err error
306 switch v := r.(type) {
307 case error:
308 err = v
309 case string:
310 err = fmt.Errorf("panic: %s", v)
311 default:
312 err = fmt.Errorf("panic: %v", v)
313 }
314 updateMCPState(name, MCPStateError, err, nil, 0)
315 slog.Error("panic in mcp client initialization", "error", err, "name", name)
316 }
317 }()
318
319 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
320 defer cancel()
321
322 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
323 if err != nil {
324 return
325 }
326
327 mcpClients.Set(name, c)
328
329 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
330 if err != nil {
331 slog.Error("error listing tools", "error", err)
332 updateMCPState(name, MCPStateError, err, nil, 0)
333 c.Close()
334 return
335 }
336
337 updateMcpTools(name, tools)
338 mcpClients.Set(name, c)
339 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
340 }(name, m)
341 }
342 wg.Wait()
343}
344
345// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
346func updateMcpTools(mcpName string, tools []tools.BaseTool) {
347 if len(tools) == 0 {
348 mcpClient2Tools.Del(mcpName)
349 } else {
350 mcpClient2Tools.Set(mcpName, tools)
351 }
352 for _, tools := range mcpClient2Tools.Seq2() {
353 for _, t := range tools {
354 mcpTools.Set(t.Name(), t)
355 }
356 }
357}
358
359func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
360 timeout := mcpTimeout(m)
361 mcpCtx, cancel := context.WithCancel(ctx)
362 cancelTimer := time.AfterFunc(timeout, cancel)
363
364 transport, err := createMCPTransport(mcpCtx, m, resolver)
365 if err != nil {
366 updateMCPState(name, MCPStateError, err, nil, 0)
367 slog.Error("error creating mcp client", "error", err, "name", name)
368 return nil, err
369 }
370
371 client := mcp.NewClient(
372 &mcp.Implementation{
373 Name: "crush",
374 Version: version.Version,
375 Title: "Crush",
376 },
377 &mcp.ClientOptions{
378 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
379 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
380 Type: MCPEventToolsListChanged,
381 Name: name,
382 })
383 },
384 KeepAlive: time.Minute * 10,
385 },
386 )
387
388 session, err := client.Connect(mcpCtx, transport, nil)
389 if err != nil {
390 err = maybeStdioErr(err, transport)
391 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
392 slog.Error("error starting mcp client", "error", err, "name", name)
393 cancel()
394 return nil, err
395 }
396
397 cancelTimer.Stop()
398 slog.Info("Initialized mcp client", "name", name)
399 return session, nil
400}
401
402// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
403// to parse, and the cli will then close it, causing the EOF error.
404// so, if we got an EOF err, and the transport is STDIO, we try to exec it
405// again with a timeout and collect the output so we can add details to the
406// error.
407// this happens particularly when starting things with npx, e.g. if node can't
408// be found or some other error like that.
409func maybeStdioErr(err error, transport mcp.Transport) error {
410 if !errors.Is(err, io.EOF) {
411 return err
412 }
413 ct, ok := transport.(*mcp.CommandTransport)
414 if !ok {
415 return err
416 }
417 if err2 := stdioMCPCheck(ct.Command); err2 != nil {
418 err = errors.Join(err, err2)
419 }
420 return err
421}
422
423func maybeTimeoutErr(err error, timeout time.Duration) error {
424 if errors.Is(err, context.Canceled) {
425 return fmt.Errorf("timed out after %s", timeout)
426 }
427 return err
428}
429
430func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
431 switch m.Type {
432 case config.MCPStdio:
433 command, err := resolver.ResolveValue(m.Command)
434 if err != nil {
435 return nil, fmt.Errorf("invalid mcp command: %w", err)
436 }
437 if strings.TrimSpace(command) == "" {
438 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
439 }
440 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
441 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
442 return &mcp.CommandTransport{
443 Command: cmd,
444 }, nil
445 case config.MCPHttp:
446 if strings.TrimSpace(m.URL) == "" {
447 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
448 }
449 client := &http.Client{
450 Transport: &headerRoundTripper{
451 headers: m.ResolvedHeaders(),
452 },
453 }
454 return &mcp.StreamableClientTransport{
455 Endpoint: m.URL,
456 HTTPClient: client,
457 }, nil
458 case config.MCPSSE:
459 if strings.TrimSpace(m.URL) == "" {
460 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
461 }
462 client := &http.Client{
463 Transport: &headerRoundTripper{
464 headers: m.ResolvedHeaders(),
465 },
466 }
467 return &mcp.SSEClientTransport{
468 Endpoint: m.URL,
469 HTTPClient: client,
470 }, nil
471 default:
472 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
473 }
474}
475
476type headerRoundTripper struct {
477 headers map[string]string
478}
479
480func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
481 for k, v := range rt.headers {
482 req.Header.Set(k, v)
483 }
484 return http.DefaultTransport.RoundTrip(req)
485}
486
487func mcpTimeout(m config.MCPConfig) time.Duration {
488 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
489}
490
491func stdioMCPCheck(old *exec.Cmd) error {
492 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
493 defer cancel()
494 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
495 cmd.Env = old.Env
496 out, err := cmd.CombinedOutput()
497 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
498 return nil
499 }
500 return fmt.Errorf("%w: %s", err, string(out))
501}