1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "maps"
10 "slices"
11 "strings"
12 "sync"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/llm/tools"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/version"
21 "github.com/mark3labs/mcp-go/client"
22 "github.com/mark3labs/mcp-go/client/transport"
23 "github.com/mark3labs/mcp-go/mcp"
24)
25
26// MCPState represents the current state of an MCP client
27type MCPState int
28
29const (
30 MCPStateDisabled MCPState = iota
31 MCPStateStarting
32 MCPStateConnected
33 MCPStateError
34)
35
36func (s MCPState) String() string {
37 switch s {
38 case MCPStateDisabled:
39 return "disabled"
40 case MCPStateStarting:
41 return "starting"
42 case MCPStateConnected:
43 return "connected"
44 case MCPStateError:
45 return "error"
46 default:
47 return "unknown"
48 }
49}
50
51// MCPEventType represents the type of MCP event
52type MCPEventType string
53
54const (
55 MCPEventStateChanged MCPEventType = "state_changed"
56)
57
58// MCPEvent represents an event in the MCP system
59type MCPEvent struct {
60 Type MCPEventType
61 Name string
62 State MCPState
63 Error error
64 ToolCount int
65}
66
67// MCPClientInfo holds information about an MCP client's state
68type MCPClientInfo struct {
69 Name string
70 State MCPState
71 Error error
72 Client *client.Client
73 ToolCount int
74 ConnectedAt time.Time
75}
76
77var (
78 mcpToolsOnce sync.Once
79 mcpTools []tools.BaseTool
80 mcpClients = csync.NewMap[string, *client.Client]()
81 mcpStates = csync.NewMap[string, MCPClientInfo]()
82 mcpBroker = pubsub.NewBroker[MCPEvent]()
83)
84
85type McpTool struct {
86 mcpName string
87 tool mcp.Tool
88 permissions permission.Service
89 workingDir string
90}
91
92func (b *McpTool) Name() string {
93 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
94}
95
96func (b *McpTool) Info() tools.ToolInfo {
97 required := b.tool.InputSchema.Required
98 if required == nil {
99 required = make([]string, 0)
100 }
101 parameters := b.tool.InputSchema.Properties
102 if parameters == nil {
103 parameters = make(map[string]any)
104 }
105 return tools.ToolInfo{
106 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
107 Description: b.tool.Description,
108 Parameters: parameters,
109 Required: required,
110 }
111}
112
113func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
114 var args map[string]any
115 if err := json.Unmarshal([]byte(input), &args); err != nil {
116 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
117 }
118
119 c, err := getOrRenewClient(ctx, name)
120 if err != nil {
121 return tools.NewTextErrorResponse(err.Error()), nil
122 }
123 result, err := c.CallTool(ctx, mcp.CallToolRequest{
124 Params: mcp.CallToolParams{
125 Name: toolName,
126 Arguments: args,
127 },
128 })
129 if err != nil {
130 return tools.NewTextErrorResponse(err.Error()), nil
131 }
132
133 output := make([]string, 0, len(result.Content))
134 for _, v := range result.Content {
135 if v, ok := v.(mcp.TextContent); ok {
136 output = append(output, v.Text)
137 } else {
138 output = append(output, fmt.Sprintf("%v", v))
139 }
140 }
141 return tools.NewTextResponse(strings.Join(output, "\n")), nil
142}
143
144func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
145 c, ok := mcpClients.Get(name)
146 if !ok {
147 return nil, fmt.Errorf("mcp '%s' not available", name)
148 }
149
150 m := config.Get().MCP[name]
151 state, _ := mcpStates.Get(name)
152
153 pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
154 defer cancel()
155 err := c.Ping(pingCtx)
156 if err == nil {
157 return c, nil
158 }
159 updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
160
161 c, err = createAndInitializeClient(ctx, name, m)
162 if err != nil {
163 return nil, err
164 }
165
166 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
167 mcpClients.Set(name, c)
168 return c, nil
169}
170
171func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
172 sessionID, messageID := tools.GetContextValues(ctx)
173 if sessionID == "" || messageID == "" {
174 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
175 }
176 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
177 p := b.permissions.Request(
178 permission.CreatePermissionRequest{
179 SessionID: sessionID,
180 ToolCallID: params.ID,
181 Path: b.workingDir,
182 ToolName: b.Info().Name,
183 Action: "execute",
184 Description: permissionDescription,
185 Params: params.Input,
186 },
187 )
188 if !p {
189 return tools.ToolResponse{}, permission.ErrorPermissionDenied
190 }
191
192 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
193}
194
195func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
196 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
197 if err != nil {
198 slog.Error("error listing tools", "error", err)
199 updateMCPState(name, MCPStateError, err, nil, 0)
200 c.Close()
201 mcpClients.Del(name)
202 return nil
203 }
204 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
205 for _, tool := range result.Tools {
206 mcpTools = append(mcpTools, &McpTool{
207 mcpName: name,
208 tool: tool,
209 permissions: permissions,
210 workingDir: workingDir,
211 })
212 }
213 return mcpTools
214}
215
216// SubscribeMCPEvents returns a channel for MCP events
217func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
218 return mcpBroker.Subscribe(ctx)
219}
220
221// GetMCPStates returns the current state of all MCP clients
222func GetMCPStates() map[string]MCPClientInfo {
223 return maps.Collect(mcpStates.Seq2())
224}
225
226// GetMCPState returns the state of a specific MCP client
227func GetMCPState(name string) (MCPClientInfo, bool) {
228 return mcpStates.Get(name)
229}
230
231// updateMCPState updates the state of an MCP client and publishes an event
232func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
233 info := MCPClientInfo{
234 Name: name,
235 State: state,
236 Error: err,
237 Client: client,
238 ToolCount: toolCount,
239 }
240 if state == MCPStateConnected {
241 info.ConnectedAt = time.Now()
242 }
243 mcpStates.Set(name, info)
244
245 // Publish state change event
246 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
247 Type: MCPEventStateChanged,
248 Name: name,
249 State: state,
250 Error: err,
251 ToolCount: toolCount,
252 })
253}
254
255// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
256func CloseMCPClients() {
257 for c := range mcpClients.Seq() {
258 _ = c.Close()
259 }
260 mcpBroker.Shutdown()
261}
262
263var mcpInitRequest = mcp.InitializeRequest{
264 Params: mcp.InitializeParams{
265 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
266 ClientInfo: mcp.Implementation{
267 Name: "Crush",
268 Version: version.Version,
269 },
270 },
271}
272
273func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
274 var wg sync.WaitGroup
275 result := csync.NewSlice[tools.BaseTool]()
276
277 // Initialize states for all configured MCPs
278 for name, m := range cfg.MCP {
279 if m.Disabled {
280 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
281 slog.Debug("skipping disabled mcp", "name", name)
282 continue
283 }
284
285 // Set initial starting state
286 updateMCPState(name, MCPStateStarting, nil, nil, 0)
287
288 wg.Add(1)
289 go func(name string, m config.MCPConfig) {
290 defer func() {
291 wg.Done()
292 if r := recover(); r != nil {
293 var err error
294 switch v := r.(type) {
295 case error:
296 err = v
297 case string:
298 err = fmt.Errorf("panic: %s", v)
299 default:
300 err = fmt.Errorf("panic: %v", v)
301 }
302 updateMCPState(name, MCPStateError, err, nil, 0)
303 slog.Error("panic in mcp client initialization", "error", err, "name", name)
304 }
305 }()
306
307 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
308 defer cancel()
309 c, err := createAndInitializeClient(ctx, name, m)
310 if err != nil {
311 return
312 }
313 mcpClients.Set(name, c)
314
315 tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
316 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
317 result.Append(tools...)
318 }(name, m)
319 }
320 wg.Wait()
321 return slices.Collect(result.Seq())
322}
323
324func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
325 c, err := createMcpClient(m)
326 if err != nil {
327 updateMCPState(name, MCPStateError, err, nil, 0)
328 slog.Error("error creating mcp client", "error", err, "name", name)
329 return nil, err
330 }
331 // Only call Start() for non-stdio clients, as stdio clients auto-start
332 if m.Type != config.MCPStdio {
333 if err := c.Start(ctx); err != nil {
334 updateMCPState(name, MCPStateError, err, nil, 0)
335 slog.Error("error starting mcp client", "error", err, "name", name)
336 _ = c.Close()
337 return nil, err
338 }
339 }
340 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
341 updateMCPState(name, MCPStateError, err, nil, 0)
342 slog.Error("error initializing mcp client", "error", err, "name", name)
343 _ = c.Close()
344 return nil, err
345 }
346
347 slog.Info("Initialized mcp client", "name", name)
348 return c, nil
349}
350
351func createMcpClient(m config.MCPConfig) (*client.Client, error) {
352 switch m.Type {
353 case config.MCPStdio:
354 if strings.TrimSpace(m.Command) == "" {
355 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
356 }
357 return client.NewStdioMCPClientWithOptions(
358 m.Command,
359 m.ResolvedEnv(),
360 m.Args,
361 transport.WithCommandLogger(mcpLogger{}),
362 )
363 case config.MCPHttp:
364 if strings.TrimSpace(m.URL) == "" {
365 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
366 }
367 return client.NewStreamableHttpClient(
368 m.URL,
369 transport.WithHTTPHeaders(m.ResolvedHeaders()),
370 transport.WithHTTPLogger(mcpLogger{}),
371 )
372 case config.MCPSse:
373 if strings.TrimSpace(m.URL) == "" {
374 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
375 }
376 return client.NewSSEMCPClient(
377 m.URL,
378 client.WithHeaders(m.ResolvedHeaders()),
379 transport.WithSSELogger(mcpLogger{}),
380 )
381 default:
382 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
383 }
384}
385
386// for MCP's clients.
387type mcpLogger struct{}
388
389func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
390func (l mcpLogger) Infof(format string, v ...any) { slog.Info(fmt.Sprintf(format, v...)) }
391
392func mcpTimeout(m config.MCPConfig) time.Duration {
393 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
394}