1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "net/http"
13 "os"
14 "os/exec"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/home"
22 "github.com/charmbracelet/crush/internal/llm/tools"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/version"
26 "github.com/modelcontextprotocol/go-sdk/mcp"
27)
28
29// MCPState represents the current state of an MCP client
30type MCPState int
31
32const (
33 MCPStateDisabled MCPState = iota
34 MCPStateStarting
35 MCPStateConnected
36 MCPStateError
37)
38
39func (s MCPState) String() string {
40 switch s {
41 case MCPStateDisabled:
42 return "disabled"
43 case MCPStateStarting:
44 return "starting"
45 case MCPStateConnected:
46 return "connected"
47 case MCPStateError:
48 return "error"
49 default:
50 return "unknown"
51 }
52}
53
54// MCPEventType represents the type of MCP event
55type MCPEventType string
56
57const (
58 MCPEventStateChanged MCPEventType = "state_changed"
59 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
60)
61
62// MCPEvent represents an event in the MCP system
63type MCPEvent struct {
64 Type MCPEventType
65 Name string
66 State MCPState
67 Error error
68 ToolCount int
69}
70
71// MCPClientInfo holds information about an MCP client's state
72type MCPClientInfo struct {
73 Name string
74 State MCPState
75 Error error
76 Client *mcp.ClientSession
77 ToolCount int
78 ConnectedAt time.Time
79}
80
81var (
82 mcpToolsOnce sync.Once
83 mcpTools = csync.NewMap[string, tools.BaseTool]()
84 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
85 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
86 mcpStates = csync.NewMap[string, MCPClientInfo]()
87 mcpBroker = pubsub.NewBroker[MCPEvent]()
88)
89
90type McpTool struct {
91 mcpName string
92 tool *mcp.Tool
93 permissions permission.Service
94 workingDir string
95}
96
97func (b *McpTool) Name() string {
98 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
99}
100
101func (b *McpTool) Info() tools.ToolInfo {
102 var parameters map[string]any
103 var required []string
104
105 if input, ok := b.tool.InputSchema.(map[string]any); ok {
106 if props, ok := input["properties"].(map[string]any); ok {
107 parameters = props
108 }
109 if req, ok := input["required"].([]any); ok {
110 // Convert []any -> []string when elements are strings
111 for _, v := range req {
112 if s, ok := v.(string); ok {
113 required = append(required, s)
114 }
115 }
116 } else if reqStr, ok := input["required"].([]string); ok {
117 // Handle case where it's already []string
118 required = reqStr
119 }
120 }
121
122 return tools.ToolInfo{
123 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
124 Description: b.tool.Description,
125 Parameters: parameters,
126 Required: required,
127 }
128}
129
130func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
131 var args map[string]any
132 if err := json.Unmarshal([]byte(input), &args); err != nil {
133 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
134 }
135
136 c, err := getOrRenewClient(ctx, name)
137 if err != nil {
138 return tools.NewTextErrorResponse(err.Error()), nil
139 }
140 result, err := c.CallTool(ctx, &mcp.CallToolParams{
141 Name: toolName,
142 Arguments: args,
143 })
144 if err != nil {
145 return tools.NewTextErrorResponse(err.Error()), nil
146 }
147
148 output := make([]string, 0, len(result.Content))
149 for _, v := range result.Content {
150 if vv, ok := v.(*mcp.TextContent); ok {
151 output = append(output, vv.Text)
152 } else {
153 output = append(output, fmt.Sprintf("%v", v))
154 }
155 }
156 return tools.NewTextResponse(strings.Join(output, "\n")), nil
157}
158
159func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
160 sess, ok := mcpClients.Get(name)
161 if !ok {
162 return nil, fmt.Errorf("mcp '%s' not available", name)
163 }
164
165 cfg := config.Get()
166 m := cfg.MCP[name]
167 state, _ := mcpStates.Get(name)
168
169 timeout := mcpTimeout(m)
170 pingCtx, cancel := context.WithTimeout(ctx, timeout)
171 defer cancel()
172 err := sess.Ping(pingCtx, nil)
173 if err == nil {
174 return sess, nil
175 }
176 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
177
178 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
179 if err != nil {
180 return nil, err
181 }
182
183 updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
184 mcpClients.Set(name, sess)
185 return sess, nil
186}
187
188func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
189 sessionID, messageID := tools.GetContextValues(ctx)
190 if sessionID == "" || messageID == "" {
191 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
192 }
193 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
194 p := b.permissions.Request(
195 permission.CreatePermissionRequest{
196 SessionID: sessionID,
197 ToolCallID: params.ID,
198 Path: b.workingDir,
199 ToolName: b.Info().Name,
200 Action: "execute",
201 Description: permissionDescription,
202 Params: params.Input,
203 },
204 )
205 if !p {
206 return tools.ToolResponse{}, permission.ErrorPermissionDenied
207 }
208
209 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
210}
211
212func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
213 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
214 if err != nil {
215 return nil, err
216 }
217 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
218 for _, tool := range result.Tools {
219 mcpTools = append(mcpTools, &McpTool{
220 mcpName: name,
221 tool: tool,
222 permissions: permissions,
223 workingDir: workingDir,
224 })
225 }
226 return mcpTools, nil
227}
228
229// SubscribeMCPEvents returns a channel for MCP events
230func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
231 return mcpBroker.Subscribe(ctx)
232}
233
234// GetMCPStates returns the current state of all MCP clients
235func GetMCPStates() map[string]MCPClientInfo {
236 return maps.Collect(mcpStates.Seq2())
237}
238
239// GetMCPState returns the state of a specific MCP client
240func GetMCPState(name string) (MCPClientInfo, bool) {
241 return mcpStates.Get(name)
242}
243
244// updateMCPState updates the state of an MCP client and publishes an event
245func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
246 info := MCPClientInfo{
247 Name: name,
248 State: state,
249 Error: err,
250 Client: client,
251 ToolCount: toolCount,
252 }
253 switch state {
254 case MCPStateConnected:
255 info.ConnectedAt = time.Now()
256 case MCPStateError:
257 updateMcpTools(name, nil)
258 mcpClients.Del(name)
259 }
260 mcpStates.Set(name, info)
261
262 // Publish state change event
263 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
264 Type: MCPEventStateChanged,
265 Name: name,
266 State: state,
267 Error: err,
268 ToolCount: toolCount,
269 })
270}
271
272// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
273func CloseMCPClients() error {
274 var errs []error
275 for name, c := range mcpClients.Seq2() {
276 if err := c.Close(); err != nil &&
277 !errors.Is(err, io.EOF) &&
278 !errors.Is(err, context.Canceled) &&
279 err.Error() != "signal: killed" {
280 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
281 }
282 }
283 mcpBroker.Shutdown()
284 return errors.Join(errs...)
285}
286
287func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
288 var wg sync.WaitGroup
289 // Initialize states for all configured MCPs
290 for name, m := range cfg.MCP {
291 if m.Disabled {
292 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
293 slog.Debug("skipping disabled mcp", "name", name)
294 continue
295 }
296
297 // Set initial starting state
298 updateMCPState(name, MCPStateStarting, nil, nil, 0)
299
300 wg.Add(1)
301 go func(name string, m config.MCPConfig) {
302 defer func() {
303 wg.Done()
304 if r := recover(); r != nil {
305 var err error
306 switch v := r.(type) {
307 case error:
308 err = v
309 case string:
310 err = fmt.Errorf("panic: %s", v)
311 default:
312 err = fmt.Errorf("panic: %v", v)
313 }
314 updateMCPState(name, MCPStateError, err, nil, 0)
315 slog.Error("panic in mcp client initialization", "error", err, "name", name)
316 }
317 }()
318
319 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
320 defer cancel()
321
322 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
323 if err != nil {
324 return
325 }
326
327 mcpClients.Set(name, c)
328
329 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
330 if err != nil {
331 slog.Error("error listing tools", "error", err)
332 updateMCPState(name, MCPStateError, err, nil, 0)
333 c.Close()
334 return
335 }
336
337 updateMcpTools(name, tools)
338 mcpClients.Set(name, c)
339 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
340 }(name, m)
341 }
342 wg.Wait()
343}
344
345// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
346func updateMcpTools(mcpName string, tools []tools.BaseTool) {
347 if len(tools) == 0 {
348 mcpClient2Tools.Del(mcpName)
349 } else {
350 mcpClient2Tools.Set(mcpName, tools)
351 }
352 for _, tools := range mcpClient2Tools.Seq2() {
353 for _, t := range tools {
354 mcpTools.Set(t.Name(), t)
355 }
356 }
357}
358
359func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
360 timeout := mcpTimeout(m)
361 mcpCtx, cancel := context.WithCancel(ctx)
362 cancelTimer := time.AfterFunc(timeout, cancel)
363
364 transport, err := createMCPTransport(mcpCtx, m, resolver)
365 if err != nil {
366 updateMCPState(name, MCPStateError, err, nil, 0)
367 slog.Error("error creating mcp client", "error", err, "name", name)
368 return nil, err
369 }
370
371 client := mcp.NewClient(
372 &mcp.Implementation{
373 Name: "crush",
374 Version: version.Version,
375 Title: "Crush",
376 },
377 &mcp.ClientOptions{
378 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
379 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
380 Type: MCPEventToolsListChanged,
381 Name: name,
382 })
383 },
384 KeepAlive: time.Minute * 10,
385 },
386 )
387
388 session, err := client.Connect(mcpCtx, transport, nil)
389 if err != nil {
390 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
391 slog.Error("error starting mcp client", "error", err, "name", name)
392 cancel()
393 return nil, err
394 }
395
396 cancelTimer.Stop()
397 slog.Info("Initialized mcp client", "name", name)
398 return session, nil
399}
400
401func maybeTimeoutErr(err error, timeout time.Duration) error {
402 if errors.Is(err, context.Canceled) {
403 return fmt.Errorf("timed out after %s", timeout)
404 }
405 return err
406}
407
408func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
409 switch m.Type {
410 case config.MCPStdio:
411 command, err := resolver.ResolveValue(m.Command)
412 if err != nil {
413 return nil, fmt.Errorf("invalid mcp command: %w", err)
414 }
415 if strings.TrimSpace(command) == "" {
416 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
417 }
418 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
419 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
420 return &mcp.CommandTransport{
421 Command: cmd,
422 }, nil
423 case config.MCPHttp:
424 if strings.TrimSpace(m.URL) == "" {
425 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
426 }
427 client := &http.Client{
428 Transport: &headerRoundTripper{
429 headers: m.ResolvedHeaders(),
430 },
431 }
432 return &mcp.StreamableClientTransport{
433 Endpoint: m.URL,
434 HTTPClient: client,
435 }, nil
436 case config.MCPSSE:
437 if strings.TrimSpace(m.URL) == "" {
438 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
439 }
440 client := &http.Client{
441 Transport: &headerRoundTripper{
442 headers: m.ResolvedHeaders(),
443 },
444 }
445 return &mcp.SSEClientTransport{
446 Endpoint: m.URL,
447 HTTPClient: client,
448 }, nil
449 default:
450 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
451 }
452}
453
454type headerRoundTripper struct {
455 headers map[string]string
456}
457
458func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
459 for k, v := range rt.headers {
460 req.Header.Set(k, v)
461 }
462 return http.DefaultTransport.RoundTrip(req)
463}
464
465func mcpTimeout(m config.MCPConfig) time.Duration {
466 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
467}