1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/home"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/proto"
22 "github.com/charmbracelet/crush/internal/pubsub"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/mark3labs/mcp-go/client"
25 "github.com/mark3labs/mcp-go/client/transport"
26 "github.com/mark3labs/mcp-go/mcp"
27)
28
29type (
30 MCPState = proto.MCPState
31 MCPEventType = proto.MCPEventType
32 MCPEvent = proto.MCPEvent
33)
34
35const (
36 MCPStateDisabled = proto.MCPStateDisabled
37 MCPStateStarting = proto.MCPStateStarting
38 MCPStateConnected = proto.MCPStateConnected
39 MCPStateError = proto.MCPStateError
40
41 MCPEventStateChanged = proto.MCPEventStateChanged
42)
43
44// MCPClientInfo holds information about an MCP client's state
45type MCPClientInfo struct {
46 Name string
47 State MCPState
48 Error error
49 Client *client.Client
50 ToolCount int
51 ConnectedAt time.Time
52}
53
54var (
55 mcpToolsOnce sync.Once
56 mcpTools []tools.BaseTool
57 mcpClients = csync.NewMap[string, *client.Client]()
58 mcpStates = csync.NewMap[string, MCPClientInfo]()
59 mcpBroker = pubsub.NewBroker[MCPEvent]()
60)
61
62type McpTool struct {
63 mcpName string
64 tool mcp.Tool
65 permissions permission.Service
66 cfg *config.Config
67}
68
69func (b *McpTool) Name() string {
70 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
71}
72
73func (b *McpTool) Info() tools.ToolInfo {
74 required := b.tool.InputSchema.Required
75 if required == nil {
76 required = make([]string, 0)
77 }
78 parameters := b.tool.InputSchema.Properties
79 if parameters == nil {
80 parameters = make(map[string]any)
81 }
82 return tools.ToolInfo{
83 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
84 Description: b.tool.Description,
85 Parameters: parameters,
86 Required: required,
87 }
88}
89
90func runTool(ctx context.Context, cfg *config.Config, name, toolName string, input string) (tools.ToolResponse, error) {
91 var args map[string]any
92 if err := json.Unmarshal([]byte(input), &args); err != nil {
93 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
94 }
95
96 c, err := getOrRenewClient(ctx, cfg, name)
97 if err != nil {
98 return tools.NewTextErrorResponse(err.Error()), nil
99 }
100 result, err := c.CallTool(ctx, mcp.CallToolRequest{
101 Params: mcp.CallToolParams{
102 Name: toolName,
103 Arguments: args,
104 },
105 })
106 if err != nil {
107 return tools.NewTextErrorResponse(err.Error()), nil
108 }
109
110 output := make([]string, 0, len(result.Content))
111 for _, v := range result.Content {
112 if v, ok := v.(mcp.TextContent); ok {
113 output = append(output, v.Text)
114 } else {
115 output = append(output, fmt.Sprintf("%v", v))
116 }
117 }
118 return tools.NewTextResponse(strings.Join(output, "\n")), nil
119}
120
121func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*client.Client, error) {
122 c, ok := mcpClients.Get(name)
123 if !ok {
124 return nil, fmt.Errorf("mcp '%s' not available", name)
125 }
126
127 m := cfg.MCP[name]
128 state, _ := mcpStates.Get(name)
129
130 timeout := mcpTimeout(m)
131 pingCtx, cancel := context.WithTimeout(ctx, timeout)
132 defer cancel()
133 err := c.Ping(pingCtx)
134 if err == nil {
135 return c, nil
136 }
137 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
138
139 c, err = createAndInitializeClient(ctx, name, m)
140 if err != nil {
141 return nil, err
142 }
143
144 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
145 mcpClients.Set(name, c)
146 return c, nil
147}
148
149func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
150 sessionID, messageID := tools.GetContextValues(ctx)
151 if sessionID == "" || messageID == "" {
152 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
153 }
154 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
155 p := b.permissions.Request(
156 permission.CreatePermissionRequest{
157 SessionID: sessionID,
158 ToolCallID: params.ID,
159 Path: b.cfg.WorkingDir(),
160 ToolName: b.Info().Name,
161 Action: "execute",
162 Description: permissionDescription,
163 Params: params.Input,
164 },
165 )
166 if !p {
167 return tools.ToolResponse{}, permission.ErrorPermissionDenied
168 }
169
170 return runTool(ctx, b.cfg, b.mcpName, b.tool.Name, params.Input)
171}
172
173func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, cfg *config.Config) []tools.BaseTool {
174 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
175 if err != nil {
176 slog.Error("error listing tools", "error", err)
177 updateMCPState(name, MCPStateError, err, nil, 0)
178 c.Close()
179 mcpClients.Del(name)
180 return nil
181 }
182 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
183 for _, tool := range result.Tools {
184 mcpTools = append(mcpTools, &McpTool{
185 mcpName: name,
186 tool: tool,
187 permissions: permissions,
188 cfg: cfg,
189 })
190 }
191 return mcpTools
192}
193
194// SubscribeMCPEvents returns a channel for MCP events
195func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
196 return mcpBroker.Subscribe(ctx)
197}
198
199// GetMCPStates returns the current state of all MCP clients
200func GetMCPStates() map[string]MCPClientInfo {
201 return maps.Collect(mcpStates.Seq2())
202}
203
204// GetMCPState returns the state of a specific MCP client
205func GetMCPState(name string) (MCPClientInfo, bool) {
206 return mcpStates.Get(name)
207}
208
209// updateMCPState updates the state of an MCP client and publishes an event
210func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
211 info := MCPClientInfo{
212 Name: name,
213 State: state,
214 Error: err,
215 Client: client,
216 ToolCount: toolCount,
217 }
218 if state == MCPStateConnected {
219 info.ConnectedAt = time.Now()
220 }
221 mcpStates.Set(name, info)
222
223 // Publish state change event
224 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
225 Type: MCPEventStateChanged,
226 Name: name,
227 State: state,
228 Error: err,
229 ToolCount: toolCount,
230 })
231}
232
233// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
234func CloseMCPClients() error {
235 var errs []error
236 for c := range mcpClients.Seq() {
237 if err := c.Close(); err != nil {
238 errs = append(errs, err)
239 }
240 }
241 mcpBroker.Shutdown()
242 return errors.Join(errs...)
243}
244
245var mcpInitRequest = mcp.InitializeRequest{
246 Params: mcp.InitializeParams{
247 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
248 ClientInfo: mcp.Implementation{
249 Name: "Crush",
250 Version: version.Version,
251 },
252 },
253}
254
255func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
256 var wg sync.WaitGroup
257 result := csync.NewSlice[tools.BaseTool]()
258
259 // Initialize states for all configured MCPs
260 for name, m := range cfg.MCP {
261 if m.Disabled {
262 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
263 slog.Debug("skipping disabled mcp", "name", name)
264 continue
265 }
266
267 // Set initial starting state
268 updateMCPState(name, MCPStateStarting, nil, nil, 0)
269
270 wg.Add(1)
271 go func(name string, m config.MCPConfig) {
272 defer func() {
273 wg.Done()
274 if r := recover(); r != nil {
275 var err error
276 switch v := r.(type) {
277 case error:
278 err = v
279 case string:
280 err = fmt.Errorf("panic: %s", v)
281 default:
282 err = fmt.Errorf("panic: %v", v)
283 }
284 updateMCPState(name, MCPStateError, err, nil, 0)
285 slog.Error("panic in mcp client initialization", "error", err, "name", name)
286 }
287 }()
288
289 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
290 defer cancel()
291 c, err := createAndInitializeClient(ctx, name, m)
292 if err != nil {
293 return
294 }
295 mcpClients.Set(name, c)
296
297 tools := getTools(ctx, name, permissions, c, cfg)
298 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
299 result.Append(tools...)
300 }(name, m)
301 }
302 wg.Wait()
303 return slices.Collect(result.Seq())
304}
305
306func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
307 c, err := createMcpClient(name, m)
308 if err != nil {
309 updateMCPState(name, MCPStateError, err, nil, 0)
310 slog.Error("error creating mcp client", "error", err, "name", name)
311 return nil, err
312 }
313
314 timeout := mcpTimeout(m)
315 initCtx, cancel := context.WithTimeout(ctx, timeout)
316 defer cancel()
317
318 // Only call Start() for non-stdio clients, as stdio clients auto-start
319 if m.Type != config.MCPStdio {
320 if err := c.Start(initCtx); err != nil {
321 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
322 slog.Error("error starting mcp client", "error", err, "name", name)
323 _ = c.Close()
324 return nil, err
325 }
326 }
327 if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil {
328 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
329 slog.Error("error initializing mcp client", "error", err, "name", name)
330 _ = c.Close()
331 return nil, err
332 }
333
334 slog.Info("Initialized mcp client", "name", name)
335 return c, nil
336}
337
338func maybeTimeoutErr(err error, timeout time.Duration) error {
339 if errors.Is(err, context.DeadlineExceeded) {
340 return fmt.Errorf("timed out after %s", timeout)
341 }
342 return err
343}
344
345func createMcpClient(name string, m config.MCPConfig) (*client.Client, error) {
346 switch m.Type {
347 case config.MCPStdio:
348 if strings.TrimSpace(m.Command) == "" {
349 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
350 }
351 return client.NewStdioMCPClientWithOptions(
352 home.Long(m.Command),
353 m.ResolvedEnv(),
354 m.Args,
355 transport.WithCommandLogger(mcpLogger{name: name}),
356 )
357 case config.MCPHttp:
358 if strings.TrimSpace(m.URL) == "" {
359 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
360 }
361 return client.NewStreamableHttpClient(
362 m.URL,
363 transport.WithHTTPHeaders(m.ResolvedHeaders()),
364 transport.WithHTTPLogger(mcpLogger{name: name}),
365 )
366 case config.MCPSse:
367 if strings.TrimSpace(m.URL) == "" {
368 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
369 }
370 return client.NewSSEMCPClient(
371 m.URL,
372 client.WithHeaders(m.ResolvedHeaders()),
373 transport.WithSSELogger(mcpLogger{name: name}),
374 )
375 default:
376 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
377 }
378}
379
380// for MCP's clients.
381type mcpLogger struct{ name string }
382
383func (l mcpLogger) Errorf(format string, v ...any) {
384 slog.Error(fmt.Sprintf(format, v...), "name", l.name)
385}
386
387func (l mcpLogger) Infof(format string, v ...any) {
388 slog.Info(fmt.Sprintf(format, v...), "name", l.name)
389}
390
391func mcpTimeout(m config.MCPConfig) time.Duration {
392 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
393}