1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/home"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/version"
23 "github.com/mark3labs/mcp-go/client"
24 "github.com/mark3labs/mcp-go/client/transport"
25 "github.com/mark3labs/mcp-go/mcp"
26)
27
28// MCPState represents the current state of an MCP client
29type MCPState int
30
31const (
32 MCPStateDisabled MCPState = iota
33 MCPStateStarting
34 MCPStateConnected
35 MCPStateError
36)
37
38func (s MCPState) String() string {
39 switch s {
40 case MCPStateDisabled:
41 return "disabled"
42 case MCPStateStarting:
43 return "starting"
44 case MCPStateConnected:
45 return "connected"
46 case MCPStateError:
47 return "error"
48 default:
49 return "unknown"
50 }
51}
52
53// MCPEventType represents the type of MCP event
54type MCPEventType string
55
56const (
57 MCPEventStateChanged MCPEventType = "state_changed"
58)
59
60// MCPEvent represents an event in the MCP system
61type MCPEvent struct {
62 Type MCPEventType
63 Name string
64 State MCPState
65 Error error
66 ToolCount int
67}
68
69// MCPClientInfo holds information about an MCP client's state
70type MCPClientInfo struct {
71 Name string
72 State MCPState
73 Error error
74 Client *client.Client
75 ToolCount int
76 ConnectedAt time.Time
77}
78
79var (
80 mcpToolsOnce sync.Once
81 mcpTools []tools.BaseTool
82 mcpClients = csync.NewMap[string, *client.Client]()
83 mcpStates = csync.NewMap[string, MCPClientInfo]()
84 mcpBroker = pubsub.NewBroker[MCPEvent]()
85)
86
87type McpTool struct {
88 mcpName string
89 tool mcp.Tool
90 permissions permission.Service
91 workingDir string
92}
93
94func (b *McpTool) Name() string {
95 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
96}
97
98func (b *McpTool) Info() tools.ToolInfo {
99 required := b.tool.InputSchema.Required
100 if required == nil {
101 required = make([]string, 0)
102 }
103 parameters := b.tool.InputSchema.Properties
104 if parameters == nil {
105 parameters = make(map[string]any)
106 }
107 return tools.ToolInfo{
108 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
109 Description: b.tool.Description,
110 Parameters: parameters,
111 Required: required,
112 }
113}
114
115func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
116 var args map[string]any
117 if err := json.Unmarshal([]byte(input), &args); err != nil {
118 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
119 }
120
121 c, err := getOrRenewClient(ctx, name)
122 if err != nil {
123 return tools.NewTextErrorResponse(err.Error()), nil
124 }
125 result, err := c.CallTool(ctx, mcp.CallToolRequest{
126 Params: mcp.CallToolParams{
127 Name: toolName,
128 Arguments: args,
129 },
130 })
131 if err != nil {
132 return tools.NewTextErrorResponse(err.Error()), nil
133 }
134
135 output := make([]string, 0, len(result.Content))
136 for _, v := range result.Content {
137 if v, ok := v.(mcp.TextContent); ok {
138 output = append(output, v.Text)
139 } else {
140 output = append(output, fmt.Sprintf("%v", v))
141 }
142 }
143 return tools.NewTextResponse(strings.Join(output, "\n")), nil
144}
145
146func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
147 c, ok := mcpClients.Get(name)
148 if !ok {
149 return nil, fmt.Errorf("mcp '%s' not available", name)
150 }
151
152 cfg := config.Get()
153 m := cfg.MCP[name]
154 state, _ := mcpStates.Get(name)
155
156 timeout := mcpTimeout(m)
157 pingCtx, cancel := context.WithTimeout(ctx, timeout)
158 defer cancel()
159 err := c.Ping(pingCtx)
160 if err == nil {
161 return c, nil
162 }
163 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
164
165 c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
166 if err != nil {
167 return nil, err
168 }
169
170 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
171 mcpClients.Set(name, c)
172 return c, nil
173}
174
175func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
176 sessionID, messageID := tools.GetContextValues(ctx)
177 if sessionID == "" || messageID == "" {
178 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
179 }
180 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
181 p := b.permissions.Request(
182 permission.CreatePermissionRequest{
183 SessionID: sessionID,
184 ToolCallID: params.ID,
185 Path: b.workingDir,
186 ToolName: b.Info().Name,
187 Action: "execute",
188 Description: permissionDescription,
189 Params: params.Input,
190 },
191 )
192 if !p {
193 return tools.ToolResponse{}, permission.ErrorPermissionDenied
194 }
195
196 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
197}
198
199func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]tools.BaseTool, error) {
200 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
201 if err != nil {
202 return nil, err
203 }
204 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
205 for _, tool := range result.Tools {
206 mcpTools = append(mcpTools, &McpTool{
207 mcpName: name,
208 tool: tool,
209 permissions: permissions,
210 workingDir: workingDir,
211 })
212 }
213 return mcpTools, nil
214}
215
216// SubscribeMCPEvents returns a channel for MCP events
217func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
218 return mcpBroker.Subscribe(ctx)
219}
220
221// GetMCPStates returns the current state of all MCP clients
222func GetMCPStates() map[string]MCPClientInfo {
223 return maps.Collect(mcpStates.Seq2())
224}
225
226// GetMCPState returns the state of a specific MCP client
227func GetMCPState(name string) (MCPClientInfo, bool) {
228 return mcpStates.Get(name)
229}
230
231// updateMCPState updates the state of an MCP client and publishes an event
232func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
233 info := MCPClientInfo{
234 Name: name,
235 State: state,
236 Error: err,
237 Client: client,
238 ToolCount: toolCount,
239 }
240 if state == MCPStateConnected {
241 info.ConnectedAt = time.Now()
242 }
243 mcpStates.Set(name, info)
244
245 // Publish state change event
246 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
247 Type: MCPEventStateChanged,
248 Name: name,
249 State: state,
250 Error: err,
251 ToolCount: toolCount,
252 })
253}
254
255// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
256func CloseMCPClients() error {
257 var errs []error
258 for name, c := range mcpClients.Seq2() {
259 if err := c.Close(); err != nil {
260 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
261 }
262 }
263 mcpBroker.Shutdown()
264 return errors.Join(errs...)
265}
266
267var mcpInitRequest = mcp.InitializeRequest{
268 Params: mcp.InitializeParams{
269 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
270 ClientInfo: mcp.Implementation{
271 Name: "Crush",
272 Version: version.Version,
273 },
274 },
275}
276
277func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
278 var wg sync.WaitGroup
279 result := csync.NewSlice[tools.BaseTool]()
280
281 // Initialize states for all configured MCPs
282 for name, m := range cfg.MCP {
283 if m.Disabled {
284 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
285 slog.Debug("skipping disabled mcp", "name", name)
286 continue
287 }
288
289 // Set initial starting state
290 updateMCPState(name, MCPStateStarting, nil, nil, 0)
291
292 wg.Add(1)
293 go func(name string, m config.MCPConfig) {
294 defer func() {
295 wg.Done()
296 if r := recover(); r != nil {
297 var err error
298 switch v := r.(type) {
299 case error:
300 err = v
301 case string:
302 err = fmt.Errorf("panic: %s", v)
303 default:
304 err = fmt.Errorf("panic: %v", v)
305 }
306 updateMCPState(name, MCPStateError, err, nil, 0)
307 slog.Error("panic in mcp client initialization", "error", err, "name", name)
308 }
309 }()
310
311 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
312 defer cancel()
313
314 c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
315 if err != nil {
316 return
317 }
318
319 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
320 if err != nil {
321 slog.Error("error listing tools", "error", err)
322 updateMCPState(name, MCPStateError, err, nil, 0)
323 c.Close()
324 return
325 }
326
327 mcpClients.Set(name, c)
328 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
329 result.Append(tools...)
330 }(name, m)
331 }
332 wg.Wait()
333 return slices.Collect(result.Seq())
334}
335
336func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
337 c, err := createMcpClient(name, m, resolver)
338 if err != nil {
339 updateMCPState(name, MCPStateError, err, nil, 0)
340 slog.Error("error creating mcp client", "error", err, "name", name)
341 return nil, err
342 }
343
344 // XXX: ideally we should be able to use context.WithTimeout here, but,
345 // the SSE MCP client will start failing once that context is canceled.
346 timeout := mcpTimeout(m)
347 mcpCtx, cancel := context.WithCancel(ctx)
348 cancelTimer := time.AfterFunc(timeout, cancel)
349 if err := c.Start(mcpCtx); err != nil {
350 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
351 slog.Error("error starting mcp client", "error", err, "name", name)
352 _ = c.Close()
353 cancel()
354 return nil, err
355 }
356 if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil {
357 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
358 slog.Error("error initializing mcp client", "error", err, "name", name)
359 _ = c.Close()
360 cancel()
361 return nil, err
362 }
363 cancelTimer.Stop()
364 slog.Info("Initialized mcp client", "name", name)
365 return c, nil
366}
367
368func maybeTimeoutErr(err error, timeout time.Duration) error {
369 if errors.Is(err, context.Canceled) {
370 return fmt.Errorf("timed out after %s", timeout)
371 }
372 return err
373}
374
375func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
376 switch m.Type {
377 case config.MCPStdio:
378 command, err := resolver.ResolveValue(m.Command)
379 if err != nil {
380 return nil, fmt.Errorf("invalid mcp command: %w", err)
381 }
382 if strings.TrimSpace(command) == "" {
383 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
384 }
385 return client.NewStdioMCPClientWithOptions(
386 home.Long(command),
387 m.ResolvedEnv(),
388 m.Args,
389 transport.WithCommandLogger(mcpLogger{name: name}),
390 )
391 case config.MCPHttp:
392 if strings.TrimSpace(m.URL) == "" {
393 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
394 }
395 return client.NewStreamableHttpClient(
396 m.URL,
397 transport.WithHTTPHeaders(m.ResolvedHeaders()),
398 transport.WithHTTPLogger(mcpLogger{name: name}),
399 )
400 case config.MCPSse:
401 if strings.TrimSpace(m.URL) == "" {
402 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
403 }
404 return client.NewSSEMCPClient(
405 m.URL,
406 client.WithHeaders(m.ResolvedHeaders()),
407 transport.WithSSELogger(mcpLogger{name: name}),
408 )
409 default:
410 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
411 }
412}
413
414// for MCP's clients.
415type mcpLogger struct{ name string }
416
417func (l mcpLogger) Errorf(format string, v ...any) {
418 slog.Error(fmt.Sprintf(format, v...), "name", l.name)
419}
420
421func (l mcpLogger) Infof(format string, v ...any) {
422 slog.Info(fmt.Sprintf(format, v...), "name", l.name)
423}
424
425func mcpTimeout(m config.MCPConfig) time.Duration {
426 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
427}