1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "net/http"
13 "os/exec"
14 "strings"
15 "sync"
16 "time"
17
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/home"
21 "github.com/charmbracelet/crush/internal/llm/tools"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/modelcontextprotocol/go-sdk/mcp"
26)
27
28// MCPState represents the current state of an MCP client
29type MCPState int
30
31const (
32 MCPStateDisabled MCPState = iota
33 MCPStateStarting
34 MCPStateConnected
35 MCPStateError
36)
37
38func (s MCPState) String() string {
39 switch s {
40 case MCPStateDisabled:
41 return "disabled"
42 case MCPStateStarting:
43 return "starting"
44 case MCPStateConnected:
45 return "connected"
46 case MCPStateError:
47 return "error"
48 default:
49 return "unknown"
50 }
51}
52
53// MCPEventType represents the type of MCP event
54type MCPEventType string
55
56const (
57 MCPEventStateChanged MCPEventType = "state_changed"
58 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
59)
60
61// MCPEvent represents an event in the MCP system
62type MCPEvent struct {
63 Type MCPEventType
64 Name string
65 State MCPState
66 Error error
67 ToolCount int
68}
69
70// MCPClientInfo holds information about an MCP client's state
71type MCPClientInfo struct {
72 Name string
73 State MCPState
74 Error error
75 Client *mcp.ClientSession
76 ToolCount int
77 ConnectedAt time.Time
78}
79
80var (
81 mcpToolsOnce sync.Once
82 mcpTools = csync.NewMap[string, tools.BaseTool]()
83 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
84 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
85 mcpStates = csync.NewMap[string, MCPClientInfo]()
86 mcpBroker = pubsub.NewBroker[MCPEvent]()
87)
88
89type McpTool struct {
90 mcpName string
91 tool *mcp.Tool
92 permissions permission.Service
93 workingDir string
94}
95
96func (b *McpTool) Name() string {
97 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
98}
99
100func (b *McpTool) Info() tools.ToolInfo {
101 var parameters map[string]any
102 var required []string
103
104 if input, ok := b.tool.InputSchema.(map[string]any); ok {
105 if props, ok := input["properties"].(map[string]any); ok {
106 parameters = props
107 }
108 if req, ok := input["required"].([]any); ok {
109 // Convert []any -> []string when elements are strings
110 for _, v := range req {
111 if s, ok := v.(string); ok {
112 required = append(required, s)
113 }
114 }
115 } else if reqStr, ok := input["required"].([]string); ok {
116 // Handle case where it's already []string
117 required = reqStr
118 }
119 }
120
121 return tools.ToolInfo{
122 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
123 Description: b.tool.Description,
124 Parameters: parameters,
125 Required: required,
126 }
127}
128
129func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
130 var args map[string]any
131 if err := json.Unmarshal([]byte(input), &args); err != nil {
132 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
133 }
134
135 c, err := getOrRenewClient(ctx, name)
136 if err != nil {
137 return tools.NewTextErrorResponse(err.Error()), nil
138 }
139 result, err := c.CallTool(ctx, &mcp.CallToolParams{
140 Name: toolName,
141 Arguments: args,
142 })
143 if err != nil {
144 return tools.NewTextErrorResponse(err.Error()), nil
145 }
146
147 output := make([]string, 0, len(result.Content))
148 for _, v := range result.Content {
149 if vv, ok := v.(*mcp.TextContent); ok {
150 output = append(output, vv.Text)
151 } else {
152 output = append(output, fmt.Sprintf("%v", v))
153 }
154 }
155 return tools.NewTextResponse(strings.Join(output, "\n")), nil
156}
157
158func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
159 sess, ok := mcpClients.Get(name)
160 if !ok {
161 return nil, fmt.Errorf("mcp '%s' not available", name)
162 }
163
164 cfg := config.Get()
165 m := cfg.MCP[name]
166 state, _ := mcpStates.Get(name)
167
168 timeout := mcpTimeout(m)
169 pingCtx, cancel := context.WithTimeout(ctx, timeout)
170 defer cancel()
171 err := sess.Ping(pingCtx, nil)
172 if err == nil {
173 return sess, nil
174 }
175 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
176
177 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
178 if err != nil {
179 return nil, err
180 }
181
182 updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
183 mcpClients.Set(name, sess)
184 return sess, nil
185}
186
187func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
188 sessionID, messageID := tools.GetContextValues(ctx)
189 if sessionID == "" || messageID == "" {
190 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
191 }
192 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
193 p := b.permissions.Request(
194 permission.CreatePermissionRequest{
195 SessionID: sessionID,
196 ToolCallID: params.ID,
197 Path: b.workingDir,
198 ToolName: b.Info().Name,
199 Action: "execute",
200 Description: permissionDescription,
201 Params: params.Input,
202 },
203 )
204 if !p {
205 return tools.ToolResponse{}, permission.ErrorPermissionDenied
206 }
207
208 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
209}
210
211func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
212 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
213 if err != nil {
214 return nil, err
215 }
216 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
217 for _, tool := range result.Tools {
218 mcpTools = append(mcpTools, &McpTool{
219 mcpName: name,
220 tool: tool,
221 permissions: permissions,
222 workingDir: workingDir,
223 })
224 }
225 return mcpTools, nil
226}
227
228// SubscribeMCPEvents returns a channel for MCP events
229func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
230 return mcpBroker.Subscribe(ctx)
231}
232
233// GetMCPStates returns the current state of all MCP clients
234func GetMCPStates() map[string]MCPClientInfo {
235 return maps.Collect(mcpStates.Seq2())
236}
237
238// GetMCPState returns the state of a specific MCP client
239func GetMCPState(name string) (MCPClientInfo, bool) {
240 return mcpStates.Get(name)
241}
242
243// updateMCPState updates the state of an MCP client and publishes an event
244func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
245 info := MCPClientInfo{
246 Name: name,
247 State: state,
248 Error: err,
249 Client: client,
250 ToolCount: toolCount,
251 }
252 switch state {
253 case MCPStateConnected:
254 info.ConnectedAt = time.Now()
255 case MCPStateError:
256 updateMcpTools(name, nil)
257 mcpClients.Del(name)
258 }
259 mcpStates.Set(name, info)
260
261 // Publish state change event
262 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
263 Type: MCPEventStateChanged,
264 Name: name,
265 State: state,
266 Error: err,
267 ToolCount: toolCount,
268 })
269}
270
271// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
272func CloseMCPClients() error {
273 var errs []error
274 for name, c := range mcpClients.Seq2() {
275 if err := c.Close(); err != nil &&
276 !errors.Is(err, io.EOF) &&
277 !errors.Is(err, context.Canceled) &&
278 err.Error() != "signal: killed" {
279 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
280 }
281 }
282 mcpBroker.Shutdown()
283 return errors.Join(errs...)
284}
285
286func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
287 var wg sync.WaitGroup
288 // Initialize states for all configured MCPs
289 for name, m := range cfg.MCP {
290 if m.Disabled {
291 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
292 slog.Debug("skipping disabled mcp", "name", name)
293 continue
294 }
295
296 // Set initial starting state
297 updateMCPState(name, MCPStateStarting, nil, nil, 0)
298
299 wg.Add(1)
300 go func(name string, m config.MCPConfig) {
301 defer func() {
302 wg.Done()
303 if r := recover(); r != nil {
304 var err error
305 switch v := r.(type) {
306 case error:
307 err = v
308 case string:
309 err = fmt.Errorf("panic: %s", v)
310 default:
311 err = fmt.Errorf("panic: %v", v)
312 }
313 updateMCPState(name, MCPStateError, err, nil, 0)
314 slog.Error("panic in mcp client initialization", "error", err, "name", name)
315 }
316 }()
317
318 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
319 defer cancel()
320
321 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
322 if err != nil {
323 return
324 }
325
326 mcpClients.Set(name, c)
327
328 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
329 if err != nil {
330 slog.Error("error listing tools", "error", err)
331 updateMCPState(name, MCPStateError, err, nil, 0)
332 c.Close()
333 return
334 }
335
336 updateMcpTools(name, tools)
337 mcpClients.Set(name, c)
338 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
339 }(name, m)
340 }
341 wg.Wait()
342}
343
344// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
345func updateMcpTools(mcpName string, tools []tools.BaseTool) {
346 if len(tools) == 0 {
347 mcpClient2Tools.Del(mcpName)
348 } else {
349 mcpClient2Tools.Set(mcpName, tools)
350 }
351 for _, tools := range mcpClient2Tools.Seq2() {
352 for _, t := range tools {
353 mcpTools.Set(t.Name(), t)
354 }
355 }
356}
357
358func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
359 timeout := mcpTimeout(m)
360 mcpCtx, cancel := context.WithCancel(ctx)
361 cancelTimer := time.AfterFunc(timeout, cancel)
362
363 transport, err := createMCPTransport(mcpCtx, m, resolver)
364 if err != nil {
365 updateMCPState(name, MCPStateError, err, nil, 0)
366 slog.Error("error creating mcp client", "error", err, "name", name)
367 return nil, err
368 }
369
370 client := mcp.NewClient(
371 &mcp.Implementation{
372 Name: "crush",
373 Version: version.Version,
374 Title: "Crush",
375 },
376 &mcp.ClientOptions{
377 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
378 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
379 Type: MCPEventToolsListChanged,
380 Name: name,
381 })
382 },
383 KeepAlive: time.Minute * 10,
384 },
385 )
386
387 session, err := client.Connect(mcpCtx, transport, nil)
388 if err != nil {
389 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
390 slog.Error("error starting mcp client", "error", err, "name", name)
391 cancel()
392 return nil, err
393 }
394
395 cancelTimer.Stop()
396 slog.Info("Initialized mcp client", "name", name)
397 return session, nil
398}
399
400func maybeTimeoutErr(err error, timeout time.Duration) error {
401 if errors.Is(err, context.Canceled) {
402 return fmt.Errorf("timed out after %s", timeout)
403 }
404 return err
405}
406
407func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
408 switch m.Type {
409 case config.MCPStdio:
410 command, err := resolver.ResolveValue(m.Command)
411 if err != nil {
412 return nil, fmt.Errorf("invalid mcp command: %w", err)
413 }
414 if strings.TrimSpace(command) == "" {
415 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
416 }
417 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
418 cmd.Env = m.ResolvedEnv()
419 return &mcp.CommandTransport{
420 Command: cmd,
421 }, nil
422 case config.MCPHttp:
423 if strings.TrimSpace(m.URL) == "" {
424 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
425 }
426 client := &http.Client{
427 Transport: &headerRoundTripper{
428 headers: m.ResolvedHeaders(),
429 },
430 }
431 return &mcp.StreamableClientTransport{
432 Endpoint: m.URL,
433 HTTPClient: client,
434 }, nil
435 case config.MCPSSE:
436 if strings.TrimSpace(m.URL) == "" {
437 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
438 }
439 client := &http.Client{
440 Transport: &headerRoundTripper{
441 headers: m.ResolvedHeaders(),
442 },
443 }
444 return &mcp.SSEClientTransport{
445 Endpoint: m.URL,
446 HTTPClient: client,
447 }, nil
448 default:
449 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
450 }
451}
452
453type headerRoundTripper struct {
454 headers map[string]string
455}
456
457func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
458 for k, v := range rt.headers {
459 req.Header.Set(k, v)
460 }
461 return http.DefaultTransport.RoundTrip(req)
462}
463
464func mcpTimeout(m config.MCPConfig) time.Duration {
465 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
466}