1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "maps"
10 "slices"
11 "strings"
12 "sync"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/llm/tools"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/version"
21 "github.com/mark3labs/mcp-go/client"
22 "github.com/mark3labs/mcp-go/client/transport"
23 "github.com/mark3labs/mcp-go/mcp"
24)
25
26// MCPState represents the current state of an MCP client
27type MCPState int
28
29const (
30 MCPStateDisabled MCPState = iota
31 MCPStateStarting
32 MCPStateConnected
33 MCPStateError
34)
35
36func (s MCPState) String() string {
37 switch s {
38 case MCPStateDisabled:
39 return "disabled"
40 case MCPStateStarting:
41 return "starting"
42 case MCPStateConnected:
43 return "connected"
44 case MCPStateError:
45 return "error"
46 default:
47 return "unknown"
48 }
49}
50
51// MCPEventType represents the type of MCP event
52type MCPEventType string
53
54const (
55 MCPEventStateChanged MCPEventType = "state_changed"
56)
57
58// MCPEvent represents an event in the MCP system
59type MCPEvent struct {
60 Type MCPEventType
61 Name string
62 State MCPState
63 Error error
64 ToolCount int
65}
66
67// MCPClientInfo holds information about an MCP client's state
68type MCPClientInfo struct {
69 Name string
70 State MCPState
71 Error error
72 Client *client.Client
73 ToolCount int
74 ConnectedAt time.Time
75}
76
77var (
78 mcpToolsOnce sync.Once
79 mcpTools []tools.BaseTool
80 mcpClients = csync.NewMap[string, *client.Client]()
81 mcpStates = csync.NewMap[string, MCPClientInfo]()
82 mcpBroker = pubsub.NewBroker[MCPEvent]()
83)
84
85type McpTool struct {
86 mcpName string
87 tool mcp.Tool
88 permissions permission.Service
89 workingDir string
90}
91
92func (b *McpTool) Name() string {
93 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
94}
95
96func (b *McpTool) Info() tools.ToolInfo {
97 required := b.tool.InputSchema.Required
98 if required == nil {
99 required = make([]string, 0)
100 }
101 return tools.ToolInfo{
102 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
103 Description: b.tool.Description,
104 Parameters: b.tool.InputSchema.Properties,
105 Required: required,
106 }
107}
108
109func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
110 var args map[string]any
111 if err := json.Unmarshal([]byte(input), &args); err != nil {
112 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
113 }
114
115 c, err := getOrRenewClient(ctx, name)
116 if err != nil {
117 return tools.NewTextErrorResponse(err.Error()), nil
118 }
119 result, err := c.CallTool(ctx, mcp.CallToolRequest{
120 Params: mcp.CallToolParams{
121 Name: toolName,
122 Arguments: args,
123 },
124 })
125 if err != nil {
126 return tools.NewTextErrorResponse(err.Error()), nil
127 }
128
129 output := make([]string, 0, len(result.Content))
130 for _, v := range result.Content {
131 if v, ok := v.(mcp.TextContent); ok {
132 output = append(output, v.Text)
133 } else {
134 output = append(output, fmt.Sprintf("%v", v))
135 }
136 }
137 return tools.NewTextResponse(strings.Join(output, "\n")), nil
138}
139
140func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
141 c, ok := mcpClients.Get(name)
142 if !ok {
143 return nil, fmt.Errorf("mcp '%s' not available", name)
144 }
145
146 m := config.Get().MCP[name]
147 state, _ := mcpStates.Get(name)
148
149 pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
150 defer cancel()
151 err := c.Ping(pingCtx)
152 if err == nil {
153 return c, nil
154 }
155 updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
156
157 c, err = createAndInitializeClient(ctx, name, m)
158 if err != nil {
159 return nil, err
160 }
161
162 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
163 mcpClients.Set(name, c)
164 return c, nil
165}
166
167func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
168 sessionID, messageID := tools.GetContextValues(ctx)
169 if sessionID == "" || messageID == "" {
170 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
171 }
172 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
173 p := b.permissions.Request(
174 permission.CreatePermissionRequest{
175 SessionID: sessionID,
176 ToolCallID: params.ID,
177 Path: b.workingDir,
178 ToolName: b.Info().Name,
179 Action: "execute",
180 Description: permissionDescription,
181 Params: params.Input,
182 },
183 )
184 if !p {
185 return tools.ToolResponse{}, permission.ErrorPermissionDenied
186 }
187
188 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
189}
190
191func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
192 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
193 if err != nil {
194 slog.Error("error listing tools", "error", err)
195 updateMCPState(name, MCPStateError, err, nil, 0)
196 c.Close()
197 mcpClients.Del(name)
198 return nil
199 }
200 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
201 for _, tool := range result.Tools {
202 mcpTools = append(mcpTools, &McpTool{
203 mcpName: name,
204 tool: tool,
205 permissions: permissions,
206 workingDir: workingDir,
207 })
208 }
209 return mcpTools
210}
211
212// SubscribeMCPEvents returns a channel for MCP events
213func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
214 return mcpBroker.Subscribe(ctx)
215}
216
217// GetMCPStates returns the current state of all MCP clients
218func GetMCPStates() map[string]MCPClientInfo {
219 return maps.Collect(mcpStates.Seq2())
220}
221
222// GetMCPState returns the state of a specific MCP client
223func GetMCPState(name string) (MCPClientInfo, bool) {
224 return mcpStates.Get(name)
225}
226
227// updateMCPState updates the state of an MCP client and publishes an event
228func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
229 info := MCPClientInfo{
230 Name: name,
231 State: state,
232 Error: err,
233 Client: client,
234 ToolCount: toolCount,
235 }
236 if state == MCPStateConnected {
237 info.ConnectedAt = time.Now()
238 }
239 mcpStates.Set(name, info)
240
241 // Publish state change event
242 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
243 Type: MCPEventStateChanged,
244 Name: name,
245 State: state,
246 Error: err,
247 ToolCount: toolCount,
248 })
249}
250
251// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
252func CloseMCPClients() {
253 for c := range mcpClients.Seq() {
254 _ = c.Close()
255 }
256 mcpBroker.Shutdown()
257}
258
259var mcpInitRequest = mcp.InitializeRequest{
260 Params: mcp.InitializeParams{
261 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
262 ClientInfo: mcp.Implementation{
263 Name: "Crush",
264 Version: version.Version,
265 },
266 },
267}
268
269func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
270 var wg sync.WaitGroup
271 result := csync.NewSlice[tools.BaseTool]()
272
273 // Initialize states for all configured MCPs
274 for name, m := range cfg.MCP {
275 if m.Disabled {
276 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
277 slog.Debug("skipping disabled mcp", "name", name)
278 continue
279 }
280
281 // Set initial starting state
282 updateMCPState(name, MCPStateStarting, nil, nil, 0)
283
284 wg.Add(1)
285 go func(name string, m config.MCPConfig) {
286 defer func() {
287 wg.Done()
288 if r := recover(); r != nil {
289 var err error
290 switch v := r.(type) {
291 case error:
292 err = v
293 case string:
294 err = fmt.Errorf("panic: %s", v)
295 default:
296 err = fmt.Errorf("panic: %v", v)
297 }
298 updateMCPState(name, MCPStateError, err, nil, 0)
299 slog.Error("panic in mcp client initialization", "error", err, "name", name)
300 }
301 }()
302
303 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
304 defer cancel()
305 c, err := createAndInitializeClient(ctx, name, m)
306 if err != nil {
307 return
308 }
309 mcpClients.Set(name, c)
310
311 tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
312 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
313 result.Append(tools...)
314 }(name, m)
315 }
316 wg.Wait()
317 return slices.Collect(result.Seq())
318}
319
320func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
321 c, err := createMcpClient(m)
322 if err != nil {
323 updateMCPState(name, MCPStateError, err, nil, 0)
324 slog.Error("error creating mcp client", "error", err, "name", name)
325 return nil, err
326 }
327 // Only call Start() for non-stdio clients, as stdio clients auto-start
328 if m.Type != config.MCPStdio {
329 if err := c.Start(ctx); err != nil {
330 updateMCPState(name, MCPStateError, err, nil, 0)
331 slog.Error("error starting mcp client", "error", err, "name", name)
332 _ = c.Close()
333 return nil, err
334 }
335 }
336 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
337 updateMCPState(name, MCPStateError, err, nil, 0)
338 slog.Error("error initializing mcp client", "error", err, "name", name)
339 _ = c.Close()
340 return nil, err
341 }
342
343 slog.Info("Initialized mcp client", "name", name)
344 return c, nil
345}
346
347func createMcpClient(m config.MCPConfig) (*client.Client, error) {
348 switch m.Type {
349 case config.MCPStdio:
350 return client.NewStdioMCPClientWithOptions(
351 m.Command,
352 m.ResolvedEnv(),
353 m.Args,
354 transport.WithCommandLogger(mcpLogger{}),
355 )
356 case config.MCPHttp:
357 return client.NewStreamableHttpClient(
358 m.URL,
359 transport.WithHTTPHeaders(m.ResolvedHeaders()),
360 transport.WithHTTPLogger(mcpLogger{}),
361 )
362 case config.MCPSse:
363 return client.NewSSEMCPClient(
364 m.URL,
365 client.WithHeaders(m.ResolvedHeaders()),
366 transport.WithSSELogger(mcpLogger{}),
367 )
368 default:
369 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
370 }
371}
372
373// for MCP's clients.
374type mcpLogger struct{}
375
376func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
377func (l mcpLogger) Infof(format string, v ...any) { slog.Info(fmt.Sprintf(format, v...)) }
378
379func mcpTimeout(m config.MCPConfig) time.Duration {
380 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
381}