1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/permission"
20 "github.com/charmbracelet/crush/internal/pubsub"
21 "github.com/charmbracelet/crush/internal/version"
22 "github.com/mark3labs/mcp-go/client"
23 "github.com/mark3labs/mcp-go/client/transport"
24 "github.com/mark3labs/mcp-go/mcp"
25)
26
27// MCPState represents the current state of an MCP client
28type MCPState int
29
30const (
31 MCPStateDisabled MCPState = iota
32 MCPStateStarting
33 MCPStateConnected
34 MCPStateError
35)
36
37func (s MCPState) MarshalText() ([]byte, error) {
38 return []byte(s.String()), nil
39}
40
41func (s *MCPState) UnmarshalText(data []byte) error {
42 switch string(data) {
43 case "disabled":
44 *s = MCPStateDisabled
45 case "starting":
46 *s = MCPStateStarting
47 case "connected":
48 *s = MCPStateConnected
49 case "error":
50 *s = MCPStateError
51 default:
52 return fmt.Errorf("unknown mcp state: %s", data)
53 }
54 return nil
55}
56
57func (s MCPState) String() string {
58 switch s {
59 case MCPStateDisabled:
60 return "disabled"
61 case MCPStateStarting:
62 return "starting"
63 case MCPStateConnected:
64 return "connected"
65 case MCPStateError:
66 return "error"
67 default:
68 return "unknown"
69 }
70}
71
72// MCPEventType represents the type of MCP event
73type MCPEventType string
74
75const (
76 MCPEventStateChanged MCPEventType = "state_changed"
77)
78
79func (t MCPEventType) MarshalText() ([]byte, error) {
80 return []byte(t), nil
81}
82
83func (t *MCPEventType) UnmarshalText(data []byte) error {
84 *t = MCPEventType(data)
85 return nil
86}
87
88// MCPEvent represents an event in the MCP system
89type MCPEvent struct {
90 Type MCPEventType `json:"type"`
91 Name string `json:"name"`
92 State MCPState `json:"state"`
93 Error error `json:"error,omitempty"`
94 ToolCount int `json:"tool_count,omitempty"`
95}
96
97// MCPClientInfo holds information about an MCP client's state
98type MCPClientInfo struct {
99 Name string
100 State MCPState
101 Error error
102 Client *client.Client
103 ToolCount int
104 ConnectedAt time.Time
105}
106
107var (
108 mcpToolsOnce sync.Once
109 mcpTools []tools.BaseTool
110 mcpClients = csync.NewMap[string, *client.Client]()
111 mcpStates = csync.NewMap[string, MCPClientInfo]()
112 mcpBroker = pubsub.NewBroker[MCPEvent]()
113)
114
115type McpTool struct {
116 mcpName string
117 tool mcp.Tool
118 permissions permission.Service
119 workingDir string
120}
121
122func (b *McpTool) Name() string {
123 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
124}
125
126func (b *McpTool) Info() tools.ToolInfo {
127 required := b.tool.InputSchema.Required
128 if required == nil {
129 required = make([]string, 0)
130 }
131 parameters := b.tool.InputSchema.Properties
132 if parameters == nil {
133 parameters = make(map[string]any)
134 }
135 return tools.ToolInfo{
136 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
137 Description: b.tool.Description,
138 Parameters: parameters,
139 Required: required,
140 }
141}
142
143func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
144 var args map[string]any
145 if err := json.Unmarshal([]byte(input), &args); err != nil {
146 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
147 }
148
149 c, err := getOrRenewClient(ctx, name)
150 if err != nil {
151 return tools.NewTextErrorResponse(err.Error()), nil
152 }
153 result, err := c.CallTool(ctx, mcp.CallToolRequest{
154 Params: mcp.CallToolParams{
155 Name: toolName,
156 Arguments: args,
157 },
158 })
159 if err != nil {
160 return tools.NewTextErrorResponse(err.Error()), nil
161 }
162
163 output := make([]string, 0, len(result.Content))
164 for _, v := range result.Content {
165 if v, ok := v.(mcp.TextContent); ok {
166 output = append(output, v.Text)
167 } else {
168 output = append(output, fmt.Sprintf("%v", v))
169 }
170 }
171 return tools.NewTextResponse(strings.Join(output, "\n")), nil
172}
173
174func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
175 c, ok := mcpClients.Get(name)
176 if !ok {
177 return nil, fmt.Errorf("mcp '%s' not available", name)
178 }
179
180 m := config.Get().MCP[name]
181 state, _ := mcpStates.Get(name)
182
183 pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
184 defer cancel()
185 err := c.Ping(pingCtx)
186 if err == nil {
187 return c, nil
188 }
189 updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
190
191 c, err = createAndInitializeClient(ctx, name, m)
192 if err != nil {
193 return nil, err
194 }
195
196 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
197 mcpClients.Set(name, c)
198 return c, nil
199}
200
201func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
202 sessionID, messageID := tools.GetContextValues(ctx)
203 if sessionID == "" || messageID == "" {
204 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
205 }
206 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
207 p := b.permissions.Request(
208 permission.CreatePermissionRequest{
209 SessionID: sessionID,
210 ToolCallID: params.ID,
211 Path: b.workingDir,
212 ToolName: b.Info().Name,
213 Action: "execute",
214 Description: permissionDescription,
215 Params: params.Input,
216 },
217 )
218 if !p {
219 return tools.ToolResponse{}, permission.ErrorPermissionDenied
220 }
221
222 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
223}
224
225func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
226 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
227 if err != nil {
228 slog.Error("error listing tools", "error", err)
229 updateMCPState(name, MCPStateError, err, nil, 0)
230 c.Close()
231 mcpClients.Del(name)
232 return nil
233 }
234 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
235 for _, tool := range result.Tools {
236 mcpTools = append(mcpTools, &McpTool{
237 mcpName: name,
238 tool: tool,
239 permissions: permissions,
240 workingDir: workingDir,
241 })
242 }
243 return mcpTools
244}
245
246// SubscribeMCPEvents returns a channel for MCP events
247func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
248 return mcpBroker.Subscribe(ctx)
249}
250
251// GetMCPStates returns the current state of all MCP clients
252func GetMCPStates() map[string]MCPClientInfo {
253 return maps.Collect(mcpStates.Seq2())
254}
255
256// GetMCPState returns the state of a specific MCP client
257func GetMCPState(name string) (MCPClientInfo, bool) {
258 return mcpStates.Get(name)
259}
260
261// updateMCPState updates the state of an MCP client and publishes an event
262func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
263 info := MCPClientInfo{
264 Name: name,
265 State: state,
266 Error: err,
267 Client: client,
268 ToolCount: toolCount,
269 }
270 if state == MCPStateConnected {
271 info.ConnectedAt = time.Now()
272 }
273 mcpStates.Set(name, info)
274
275 // Publish state change event
276 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
277 Type: MCPEventStateChanged,
278 Name: name,
279 State: state,
280 Error: err,
281 ToolCount: toolCount,
282 })
283}
284
285// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
286func CloseMCPClients() error {
287 var errs []error
288 for c := range mcpClients.Seq() {
289 if err := c.Close(); err != nil {
290 errs = append(errs, err)
291 }
292 }
293 mcpBroker.Shutdown()
294 return errors.Join(errs...)
295}
296
297var mcpInitRequest = mcp.InitializeRequest{
298 Params: mcp.InitializeParams{
299 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
300 ClientInfo: mcp.Implementation{
301 Name: "Crush",
302 Version: version.Version,
303 },
304 },
305}
306
307func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
308 var wg sync.WaitGroup
309 result := csync.NewSlice[tools.BaseTool]()
310
311 // Initialize states for all configured MCPs
312 for name, m := range cfg.MCP {
313 if m.Disabled {
314 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
315 slog.Debug("skipping disabled mcp", "name", name)
316 continue
317 }
318
319 // Set initial starting state
320 updateMCPState(name, MCPStateStarting, nil, nil, 0)
321
322 wg.Add(1)
323 go func(name string, m config.MCPConfig) {
324 defer func() {
325 wg.Done()
326 if r := recover(); r != nil {
327 var err error
328 switch v := r.(type) {
329 case error:
330 err = v
331 case string:
332 err = fmt.Errorf("panic: %s", v)
333 default:
334 err = fmt.Errorf("panic: %v", v)
335 }
336 updateMCPState(name, MCPStateError, err, nil, 0)
337 slog.Error("panic in mcp client initialization", "error", err, "name", name)
338 }
339 }()
340
341 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
342 defer cancel()
343 c, err := createAndInitializeClient(ctx, name, m)
344 if err != nil {
345 return
346 }
347 mcpClients.Set(name, c)
348
349 tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
350 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
351 result.Append(tools...)
352 }(name, m)
353 }
354 wg.Wait()
355 return slices.Collect(result.Seq())
356}
357
358func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
359 c, err := createMcpClient(name, m)
360 if err != nil {
361 updateMCPState(name, MCPStateError, err, nil, 0)
362 slog.Error("error creating mcp client", "error", err, "name", name)
363 return nil, err
364 }
365 // Only call Start() for non-stdio clients, as stdio clients auto-start
366 if m.Type != config.MCPStdio {
367 if err := c.Start(ctx); err != nil {
368 updateMCPState(name, MCPStateError, err, nil, 0)
369 slog.Error("error starting mcp client", "error", err, "name", name)
370 _ = c.Close()
371 return nil, err
372 }
373 }
374 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
375 updateMCPState(name, MCPStateError, err, nil, 0)
376 slog.Error("error initializing mcp client", "error", err, "name", name)
377 _ = c.Close()
378 return nil, err
379 }
380
381 slog.Info("Initialized mcp client", "name", name)
382 return c, nil
383}
384
385func createMcpClient(name string, m config.MCPConfig) (*client.Client, error) {
386 switch m.Type {
387 case config.MCPStdio:
388 if strings.TrimSpace(m.Command) == "" {
389 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
390 }
391 return client.NewStdioMCPClientWithOptions(
392 m.Command,
393 m.ResolvedEnv(),
394 m.Args,
395 transport.WithCommandLogger(mcpLogger{name: name}),
396 )
397 case config.MCPHttp:
398 if strings.TrimSpace(m.URL) == "" {
399 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
400 }
401 return client.NewStreamableHttpClient(
402 m.URL,
403 transport.WithHTTPHeaders(m.ResolvedHeaders()),
404 transport.WithHTTPLogger(mcpLogger{name: name}),
405 )
406 case config.MCPSse:
407 if strings.TrimSpace(m.URL) == "" {
408 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
409 }
410 return client.NewSSEMCPClient(
411 m.URL,
412 client.WithHeaders(m.ResolvedHeaders()),
413 transport.WithSSELogger(mcpLogger{name: name}),
414 )
415 default:
416 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
417 }
418}
419
420// for MCP's clients.
421type mcpLogger struct{ name string }
422
423func (l mcpLogger) Errorf(format string, v ...any) {
424 slog.Error(fmt.Sprintf(format, v...), "name", l.name)
425}
426
427func (l mcpLogger) Infof(format string, v ...any) {
428 slog.Info(fmt.Sprintf(format, v...), "name", l.name)
429}
430
431func mcpTimeout(m config.MCPConfig) time.Duration {
432 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
433}