1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/home"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/version"
23 "github.com/mark3labs/mcp-go/client"
24 "github.com/mark3labs/mcp-go/client/transport"
25 "github.com/mark3labs/mcp-go/mcp"
26)
27
28// MCPState represents the current state of an MCP client
29type MCPState int
30
31const (
32 MCPStateDisabled MCPState = iota
33 MCPStateStarting
34 MCPStateConnected
35 MCPStateError
36)
37
38func (s MCPState) MarshalText() ([]byte, error) {
39 return []byte(s.String()), nil
40}
41
42func (s *MCPState) UnmarshalText(data []byte) error {
43 switch string(data) {
44 case "disabled":
45 *s = MCPStateDisabled
46 case "starting":
47 *s = MCPStateStarting
48 case "connected":
49 *s = MCPStateConnected
50 case "error":
51 *s = MCPStateError
52 default:
53 return fmt.Errorf("unknown mcp state: %s", data)
54 }
55 return nil
56}
57
58func (s MCPState) String() string {
59 switch s {
60 case MCPStateDisabled:
61 return "disabled"
62 case MCPStateStarting:
63 return "starting"
64 case MCPStateConnected:
65 return "connected"
66 case MCPStateError:
67 return "error"
68 default:
69 return "unknown"
70 }
71}
72
73// MCPEventType represents the type of MCP event
74type MCPEventType string
75
76const (
77 MCPEventStateChanged MCPEventType = "state_changed"
78)
79
80func (t MCPEventType) MarshalText() ([]byte, error) {
81 return []byte(t), nil
82}
83
84func (t *MCPEventType) UnmarshalText(data []byte) error {
85 *t = MCPEventType(data)
86 return nil
87}
88
89// MCPEvent represents an event in the MCP system
90type MCPEvent struct {
91 Type MCPEventType `json:"type"`
92 Name string `json:"name"`
93 State MCPState `json:"state"`
94 Error error `json:"error,omitempty"`
95 ToolCount int `json:"tool_count,omitempty"`
96}
97
98// MCPClientInfo holds information about an MCP client's state
99type MCPClientInfo struct {
100 Name string
101 State MCPState
102 Error error
103 Client *client.Client
104 ToolCount int
105 ConnectedAt time.Time
106}
107
108var (
109 mcpToolsOnce sync.Once
110 mcpTools []tools.BaseTool
111 mcpClients = csync.NewMap[string, *client.Client]()
112 mcpStates = csync.NewMap[string, MCPClientInfo]()
113 mcpBroker = pubsub.NewBroker[MCPEvent]()
114)
115
116type McpTool struct {
117 mcpName string
118 tool mcp.Tool
119 permissions permission.Service
120 workingDir string
121}
122
123func (b *McpTool) Name() string {
124 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
125}
126
127func (b *McpTool) Info() tools.ToolInfo {
128 required := b.tool.InputSchema.Required
129 if required == nil {
130 required = make([]string, 0)
131 }
132 parameters := b.tool.InputSchema.Properties
133 if parameters == nil {
134 parameters = make(map[string]any)
135 }
136 return tools.ToolInfo{
137 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
138 Description: b.tool.Description,
139 Parameters: parameters,
140 Required: required,
141 }
142}
143
144func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
145 var args map[string]any
146 if err := json.Unmarshal([]byte(input), &args); err != nil {
147 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
148 }
149
150 c, err := getOrRenewClient(ctx, name)
151 if err != nil {
152 return tools.NewTextErrorResponse(err.Error()), nil
153 }
154 result, err := c.CallTool(ctx, mcp.CallToolRequest{
155 Params: mcp.CallToolParams{
156 Name: toolName,
157 Arguments: args,
158 },
159 })
160 if err != nil {
161 return tools.NewTextErrorResponse(err.Error()), nil
162 }
163
164 output := make([]string, 0, len(result.Content))
165 for _, v := range result.Content {
166 if v, ok := v.(mcp.TextContent); ok {
167 output = append(output, v.Text)
168 } else {
169 output = append(output, fmt.Sprintf("%v", v))
170 }
171 }
172 return tools.NewTextResponse(strings.Join(output, "\n")), nil
173}
174
175func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
176 c, ok := mcpClients.Get(name)
177 if !ok {
178 return nil, fmt.Errorf("mcp '%s' not available", name)
179 }
180
181 m := config.Get().MCP[name]
182 state, _ := mcpStates.Get(name)
183
184 timeout := mcpTimeout(m)
185 pingCtx, cancel := context.WithTimeout(ctx, timeout)
186 defer cancel()
187 err := c.Ping(pingCtx)
188 if err == nil {
189 return c, nil
190 }
191 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
192
193 c, err = createAndInitializeClient(ctx, name, m)
194 if err != nil {
195 return nil, err
196 }
197
198 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
199 mcpClients.Set(name, c)
200 return c, nil
201}
202
203func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
204 sessionID, messageID := tools.GetContextValues(ctx)
205 if sessionID == "" || messageID == "" {
206 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
207 }
208 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
209 p := b.permissions.Request(
210 permission.CreatePermissionRequest{
211 SessionID: sessionID,
212 ToolCallID: params.ID,
213 Path: b.workingDir,
214 ToolName: b.Info().Name,
215 Action: "execute",
216 Description: permissionDescription,
217 Params: params.Input,
218 },
219 )
220 if !p {
221 return tools.ToolResponse{}, permission.ErrorPermissionDenied
222 }
223
224 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
225}
226
227func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
228 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
229 if err != nil {
230 slog.Error("error listing tools", "error", err)
231 updateMCPState(name, MCPStateError, err, nil, 0)
232 c.Close()
233 mcpClients.Del(name)
234 return nil
235 }
236 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
237 for _, tool := range result.Tools {
238 mcpTools = append(mcpTools, &McpTool{
239 mcpName: name,
240 tool: tool,
241 permissions: permissions,
242 workingDir: workingDir,
243 })
244 }
245 return mcpTools
246}
247
248// SubscribeMCPEvents returns a channel for MCP events
249func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
250 return mcpBroker.Subscribe(ctx)
251}
252
253// GetMCPStates returns the current state of all MCP clients
254func GetMCPStates() map[string]MCPClientInfo {
255 return maps.Collect(mcpStates.Seq2())
256}
257
258// GetMCPState returns the state of a specific MCP client
259func GetMCPState(name string) (MCPClientInfo, bool) {
260 return mcpStates.Get(name)
261}
262
263// updateMCPState updates the state of an MCP client and publishes an event
264func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
265 info := MCPClientInfo{
266 Name: name,
267 State: state,
268 Error: err,
269 Client: client,
270 ToolCount: toolCount,
271 }
272 if state == MCPStateConnected {
273 info.ConnectedAt = time.Now()
274 }
275 mcpStates.Set(name, info)
276
277 // Publish state change event
278 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
279 Type: MCPEventStateChanged,
280 Name: name,
281 State: state,
282 Error: err,
283 ToolCount: toolCount,
284 })
285}
286
287// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
288func CloseMCPClients() error {
289 var errs []error
290 for c := range mcpClients.Seq() {
291 if err := c.Close(); err != nil {
292 errs = append(errs, err)
293 }
294 }
295 mcpBroker.Shutdown()
296 return errors.Join(errs...)
297}
298
299var mcpInitRequest = mcp.InitializeRequest{
300 Params: mcp.InitializeParams{
301 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
302 ClientInfo: mcp.Implementation{
303 Name: "Crush",
304 Version: version.Version,
305 },
306 },
307}
308
309func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
310 var wg sync.WaitGroup
311 result := csync.NewSlice[tools.BaseTool]()
312
313 // Initialize states for all configured MCPs
314 for name, m := range cfg.MCP {
315 if m.Disabled {
316 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
317 slog.Debug("skipping disabled mcp", "name", name)
318 continue
319 }
320
321 // Set initial starting state
322 updateMCPState(name, MCPStateStarting, nil, nil, 0)
323
324 wg.Add(1)
325 go func(name string, m config.MCPConfig) {
326 defer func() {
327 wg.Done()
328 if r := recover(); r != nil {
329 var err error
330 switch v := r.(type) {
331 case error:
332 err = v
333 case string:
334 err = fmt.Errorf("panic: %s", v)
335 default:
336 err = fmt.Errorf("panic: %v", v)
337 }
338 updateMCPState(name, MCPStateError, err, nil, 0)
339 slog.Error("panic in mcp client initialization", "error", err, "name", name)
340 }
341 }()
342
343 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
344 defer cancel()
345 c, err := createAndInitializeClient(ctx, name, m)
346 if err != nil {
347 return
348 }
349 mcpClients.Set(name, c)
350
351 tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
352 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
353 result.Append(tools...)
354 }(name, m)
355 }
356 wg.Wait()
357 return slices.Collect(result.Seq())
358}
359
360func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
361 c, err := createMcpClient(name, m)
362 if err != nil {
363 updateMCPState(name, MCPStateError, err, nil, 0)
364 slog.Error("error creating mcp client", "error", err, "name", name)
365 return nil, err
366 }
367
368 timeout := mcpTimeout(m)
369 initCtx, cancel := context.WithTimeout(ctx, timeout)
370 defer cancel()
371
372 // Only call Start() for non-stdio clients, as stdio clients auto-start
373 if m.Type != config.MCPStdio {
374 if err := c.Start(initCtx); err != nil {
375 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
376 slog.Error("error starting mcp client", "error", err, "name", name)
377 _ = c.Close()
378 return nil, err
379 }
380 }
381 if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil {
382 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
383 slog.Error("error initializing mcp client", "error", err, "name", name)
384 _ = c.Close()
385 return nil, err
386 }
387
388 slog.Info("Initialized mcp client", "name", name)
389 return c, nil
390}
391
392func maybeTimeoutErr(err error, timeout time.Duration) error {
393 if errors.Is(err, context.DeadlineExceeded) {
394 return fmt.Errorf("timed out after %s", timeout)
395 }
396 return err
397}
398
399func createMcpClient(name string, m config.MCPConfig) (*client.Client, error) {
400 switch m.Type {
401 case config.MCPStdio:
402 if strings.TrimSpace(m.Command) == "" {
403 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
404 }
405 return client.NewStdioMCPClientWithOptions(
406 home.Long(m.Command),
407 m.ResolvedEnv(),
408 m.Args,
409 transport.WithCommandLogger(mcpLogger{name: name}),
410 )
411 case config.MCPHttp:
412 if strings.TrimSpace(m.URL) == "" {
413 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
414 }
415 return client.NewStreamableHttpClient(
416 m.URL,
417 transport.WithHTTPHeaders(m.ResolvedHeaders()),
418 transport.WithHTTPLogger(mcpLogger{name: name}),
419 )
420 case config.MCPSse:
421 if strings.TrimSpace(m.URL) == "" {
422 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
423 }
424 return client.NewSSEMCPClient(
425 m.URL,
426 client.WithHeaders(m.ResolvedHeaders()),
427 transport.WithSSELogger(mcpLogger{name: name}),
428 )
429 default:
430 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
431 }
432}
433
434// for MCP's clients.
435type mcpLogger struct{ name string }
436
437func (l mcpLogger) Errorf(format string, v ...any) {
438 slog.Error(fmt.Sprintf(format, v...), "name", l.name)
439}
440
441func (l mcpLogger) Infof(format string, v ...any) {
442 slog.Info(fmt.Sprintf(format, v...), "name", l.name)
443}
444
445func mcpTimeout(m config.MCPConfig) time.Duration {
446 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
447}