1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/permission"
20 "github.com/charmbracelet/crush/internal/pubsub"
21 "github.com/charmbracelet/crush/internal/version"
22 "github.com/mark3labs/mcp-go/client"
23 "github.com/mark3labs/mcp-go/client/transport"
24 "github.com/mark3labs/mcp-go/mcp"
25)
26
27// MCPState represents the current state of an MCP client
28type MCPState int
29
30const (
31 MCPStateDisabled MCPState = iota
32 MCPStateStarting
33 MCPStateConnected
34 MCPStateError
35)
36
37func (s MCPState) String() string {
38 switch s {
39 case MCPStateDisabled:
40 return "disabled"
41 case MCPStateStarting:
42 return "starting"
43 case MCPStateConnected:
44 return "connected"
45 case MCPStateError:
46 return "error"
47 default:
48 return "unknown"
49 }
50}
51
52// MCPEventType represents the type of MCP event
53type MCPEventType string
54
55const (
56 MCPEventStateChanged MCPEventType = "state_changed"
57)
58
59// MCPEvent represents an event in the MCP system
60type MCPEvent struct {
61 Type MCPEventType
62 Name string
63 State MCPState
64 Error error
65 ToolCount int
66}
67
68// MCPClientInfo holds information about an MCP client's state
69type MCPClientInfo struct {
70 Name string
71 State MCPState
72 Error error
73 Client *client.Client
74 ToolCount int
75 ConnectedAt time.Time
76}
77
78var (
79 mcpToolsOnce sync.Once
80 mcpTools []tools.BaseTool
81 mcpClients = csync.NewMap[string, *client.Client]()
82 mcpStates = csync.NewMap[string, MCPClientInfo]()
83 mcpBroker = pubsub.NewBroker[MCPEvent]()
84)
85
86type McpTool struct {
87 mcpName string
88 tool mcp.Tool
89 permissions permission.Service
90 workingDir string
91}
92
93func (b *McpTool) Name() string {
94 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
95}
96
97func (b *McpTool) Info() tools.ToolInfo {
98 required := b.tool.InputSchema.Required
99 if required == nil {
100 required = make([]string, 0)
101 }
102 parameters := b.tool.InputSchema.Properties
103 if parameters == nil {
104 parameters = make(map[string]any)
105 }
106 return tools.ToolInfo{
107 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
108 Description: b.tool.Description,
109 Parameters: parameters,
110 Required: required,
111 }
112}
113
114func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
115 var args map[string]any
116 if err := json.Unmarshal([]byte(input), &args); err != nil {
117 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
118 }
119
120 c, err := getOrRenewClient(ctx, name)
121 if err != nil {
122 return tools.NewTextErrorResponse(err.Error()), nil
123 }
124 result, err := c.CallTool(ctx, mcp.CallToolRequest{
125 Params: mcp.CallToolParams{
126 Name: toolName,
127 Arguments: args,
128 },
129 })
130 if err != nil {
131 return tools.NewTextErrorResponse(err.Error()), nil
132 }
133
134 output := make([]string, 0, len(result.Content))
135 for _, v := range result.Content {
136 if v, ok := v.(mcp.TextContent); ok {
137 output = append(output, v.Text)
138 } else {
139 output = append(output, fmt.Sprintf("%v", v))
140 }
141 }
142 return tools.NewTextResponse(strings.Join(output, "\n")), nil
143}
144
145func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
146 c, ok := mcpClients.Get(name)
147 if !ok {
148 return nil, fmt.Errorf("mcp '%s' not available", name)
149 }
150
151 m := config.Get().MCP[name]
152 state, _ := mcpStates.Get(name)
153
154 pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
155 defer cancel()
156 err := c.Ping(pingCtx)
157 if err == nil {
158 return c, nil
159 }
160 updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
161
162 c, err = createAndInitializeClient(ctx, name, m)
163 if err != nil {
164 return nil, err
165 }
166
167 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
168 mcpClients.Set(name, c)
169 return c, nil
170}
171
172func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
173 sessionID, messageID := tools.GetContextValues(ctx)
174 if sessionID == "" || messageID == "" {
175 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
176 }
177 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
178 p := b.permissions.Request(
179 permission.CreatePermissionRequest{
180 SessionID: sessionID,
181 ToolCallID: params.ID,
182 Path: b.workingDir,
183 ToolName: b.Info().Name,
184 Action: "execute",
185 Description: permissionDescription,
186 Params: params.Input,
187 },
188 )
189 if !p {
190 return tools.ToolResponse{}, permission.ErrorPermissionDenied
191 }
192
193 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
194}
195
196func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
197 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
198 if err != nil {
199 slog.Error("error listing tools", "error", err)
200 updateMCPState(name, MCPStateError, err, nil, 0)
201 c.Close()
202 mcpClients.Del(name)
203 return nil
204 }
205 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
206 for _, tool := range result.Tools {
207 mcpTools = append(mcpTools, &McpTool{
208 mcpName: name,
209 tool: tool,
210 permissions: permissions,
211 workingDir: workingDir,
212 })
213 }
214 return mcpTools
215}
216
217// SubscribeMCPEvents returns a channel for MCP events
218func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
219 return mcpBroker.Subscribe(ctx)
220}
221
222// GetMCPStates returns the current state of all MCP clients
223func GetMCPStates() map[string]MCPClientInfo {
224 return maps.Collect(mcpStates.Seq2())
225}
226
227// GetMCPState returns the state of a specific MCP client
228func GetMCPState(name string) (MCPClientInfo, bool) {
229 return mcpStates.Get(name)
230}
231
232// updateMCPState updates the state of an MCP client and publishes an event
233func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
234 info := MCPClientInfo{
235 Name: name,
236 State: state,
237 Error: err,
238 Client: client,
239 ToolCount: toolCount,
240 }
241 if state == MCPStateConnected {
242 info.ConnectedAt = time.Now()
243 }
244 mcpStates.Set(name, info)
245
246 // Publish state change event
247 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
248 Type: MCPEventStateChanged,
249 Name: name,
250 State: state,
251 Error: err,
252 ToolCount: toolCount,
253 })
254}
255
256// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
257func CloseMCPClients() error {
258 var errs []error
259 for c := range mcpClients.Seq() {
260 if err := c.Close(); err != nil {
261 errs = append(errs, err)
262 }
263 }
264 mcpBroker.Shutdown()
265 return errors.Join(errs...)
266}
267
268var mcpInitRequest = mcp.InitializeRequest{
269 Params: mcp.InitializeParams{
270 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
271 ClientInfo: mcp.Implementation{
272 Name: "Crush",
273 Version: version.Version,
274 },
275 },
276}
277
278func doGetMCPTools(permissions permission.Service, cfg *config.Config) []tools.BaseTool {
279 var wg sync.WaitGroup
280 result := csync.NewSlice[tools.BaseTool]()
281
282 // Initialize states for all configured MCPs
283 for name, m := range cfg.MCP {
284 if m.Disabled {
285 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
286 slog.Debug("skipping disabled mcp", "name", name)
287 continue
288 }
289
290 // Set initial starting state
291 updateMCPState(name, MCPStateStarting, nil, nil, 0)
292
293 wg.Add(1)
294 go func(name string, m config.MCPConfig) {
295 defer func() {
296 wg.Done()
297 if r := recover(); r != nil {
298 var err error
299 switch v := r.(type) {
300 case error:
301 err = v
302 case string:
303 err = fmt.Errorf("panic: %s", v)
304 default:
305 err = fmt.Errorf("panic: %v", v)
306 }
307 updateMCPState(name, MCPStateError, err, nil, 0)
308 slog.Error("panic in mcp client initialization", "error", err, "name", name)
309 }
310 }()
311
312 ctx, cancel := context.WithTimeout(context.Background(), mcpTimeout(m))
313 defer cancel()
314 c, err := createAndInitializeClient(ctx, name, m)
315 if err != nil {
316 return
317 }
318 mcpClients.Set(name, c)
319
320 tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
321 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
322 result.Append(tools...)
323 }(name, m)
324 }
325 wg.Wait()
326 return slices.Collect(result.Seq())
327}
328
329func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
330 c, err := createMcpClient(m)
331 if err != nil {
332 updateMCPState(name, MCPStateError, err, nil, 0)
333 slog.Error("error creating mcp client", "error", err, "name", name)
334 return nil, err
335 }
336 // Only call Start() for non-stdio clients, as stdio clients auto-start
337 if m.Type != config.MCPStdio {
338 if err := c.Start(ctx); err != nil {
339 updateMCPState(name, MCPStateError, err, nil, 0)
340 slog.Error("error starting mcp client", "error", err, "name", name)
341 _ = c.Close()
342 return nil, err
343 }
344 }
345 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
346 updateMCPState(name, MCPStateError, err, nil, 0)
347 slog.Error("error initializing mcp client", "error", err, "name", name)
348 _ = c.Close()
349 return nil, err
350 }
351
352 slog.Info("Initialized mcp client", "name", name)
353 return c, nil
354}
355
356func createMcpClient(m config.MCPConfig) (*client.Client, error) {
357 switch m.Type {
358 case config.MCPStdio:
359 if strings.TrimSpace(m.Command) == "" {
360 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
361 }
362 return client.NewStdioMCPClientWithOptions(
363 m.Command,
364 m.ResolvedEnv(),
365 m.Args,
366 transport.WithCommandLogger(mcpLogger{}),
367 )
368 case config.MCPHttp:
369 if strings.TrimSpace(m.URL) == "" {
370 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
371 }
372 return client.NewStreamableHttpClient(
373 m.URL,
374 transport.WithHTTPHeaders(m.ResolvedHeaders()),
375 transport.WithHTTPLogger(mcpLogger{}),
376 )
377 case config.MCPSse:
378 if strings.TrimSpace(m.URL) == "" {
379 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
380 }
381 return client.NewSSEMCPClient(
382 m.URL,
383 client.WithHeaders(m.ResolvedHeaders()),
384 transport.WithSSELogger(mcpLogger{}),
385 )
386 default:
387 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
388 }
389}
390
391// for MCP's clients.
392type mcpLogger struct{}
393
394func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
395func (l mcpLogger) Infof(format string, v ...any) { slog.Info(fmt.Sprintf(format, v...)) }
396
397func mcpTimeout(m config.MCPConfig) time.Duration {
398 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
399}