1package tools
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/home"
19 "github.com/charmbracelet/crush/internal/permission"
20 "github.com/charmbracelet/crush/internal/pubsub"
21 "github.com/charmbracelet/crush/internal/version"
22 "github.com/charmbracelet/fantasy/ai"
23 "github.com/mark3labs/mcp-go/client"
24 "github.com/mark3labs/mcp-go/client/transport"
25 "github.com/mark3labs/mcp-go/mcp"
26)
27
28// MCPState represents the current state of an MCP client
29type MCPState int
30
31const (
32 MCPStateDisabled MCPState = iota
33 MCPStateStarting
34 MCPStateConnected
35 MCPStateError
36)
37
38func (s MCPState) String() string {
39 switch s {
40 case MCPStateDisabled:
41 return "disabled"
42 case MCPStateStarting:
43 return "starting"
44 case MCPStateConnected:
45 return "connected"
46 case MCPStateError:
47 return "error"
48 default:
49 return "unknown"
50 }
51}
52
53// MCPEventType represents the type of MCP event
54type MCPEventType string
55
56const (
57 MCPEventStateChanged MCPEventType = "state_changed"
58)
59
60// MCPEvent represents an event in the MCP system
61type MCPEvent struct {
62 Type MCPEventType
63 Name string
64 State MCPState
65 Error error
66 ToolCount int
67}
68
69// MCPClientInfo holds information about an MCP client's state
70type MCPClientInfo struct {
71 Name string
72 State MCPState
73 Error error
74 Client *client.Client
75 ToolCount int
76 ConnectedAt time.Time
77}
78
79var (
80 mcpToolsOnce sync.Once
81 mcpTools []ai.AgentTool
82 mcpClients = csync.NewMap[string, *client.Client]()
83 mcpStates = csync.NewMap[string, MCPClientInfo]()
84 mcpBroker = pubsub.NewBroker[MCPEvent]()
85)
86
87type McpTool struct {
88 mcpName string
89 tool mcp.Tool
90 permissions permission.Service
91 workingDir string
92 providerOptions ai.ProviderOptions
93}
94
95func (m *McpTool) SetProviderOptions(opts ai.ProviderOptions) {
96 m.providerOptions = opts
97}
98
99func (m *McpTool) ProviderOptions() ai.ProviderOptions {
100 return m.providerOptions
101}
102
103func (b *McpTool) Name() string {
104 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
105}
106
107func (b *McpTool) Info() ai.ToolInfo {
108 required := b.tool.InputSchema.Required
109 if required == nil {
110 required = make([]string, 0)
111 }
112 parameters := b.tool.InputSchema.Properties
113 if parameters == nil {
114 parameters = make(map[string]any)
115 }
116 return ai.ToolInfo{
117 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
118 Description: b.tool.Description,
119 Parameters: parameters,
120 Required: required,
121 }
122}
123
124func runTool(ctx context.Context, name, toolName string, input string) (ai.ToolResponse, error) {
125 var args map[string]any
126 if err := json.Unmarshal([]byte(input), &args); err != nil {
127 return ai.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
128 }
129
130 c, err := getOrRenewClient(ctx, name)
131 if err != nil {
132 return ai.NewTextErrorResponse(err.Error()), nil
133 }
134 result, err := c.CallTool(ctx, mcp.CallToolRequest{
135 Params: mcp.CallToolParams{
136 Name: toolName,
137 Arguments: args,
138 },
139 })
140 if err != nil {
141 return ai.NewTextErrorResponse(err.Error()), nil
142 }
143
144 output := make([]string, 0, len(result.Content))
145 for _, v := range result.Content {
146 if v, ok := v.(mcp.TextContent); ok {
147 output = append(output, v.Text)
148 } else {
149 output = append(output, fmt.Sprintf("%v", v))
150 }
151 }
152 return ai.NewTextResponse(strings.Join(output, "\n")), nil
153}
154
155func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
156 c, ok := mcpClients.Get(name)
157 if !ok {
158 return nil, fmt.Errorf("mcp '%s' not available", name)
159 }
160
161 cfg := config.Get()
162 m := cfg.MCP[name]
163 state, _ := mcpStates.Get(name)
164
165 timeout := mcpTimeout(m)
166 pingCtx, cancel := context.WithTimeout(ctx, timeout)
167 defer cancel()
168 err := c.Ping(pingCtx)
169 if err == nil {
170 return c, nil
171 }
172 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
173
174 c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
175 if err != nil {
176 return nil, err
177 }
178
179 updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
180 mcpClients.Set(name, c)
181 return c, nil
182}
183
184func (b *McpTool) Run(ctx context.Context, params ai.ToolCall) (ai.ToolResponse, error) {
185 sessionID := GetSessionFromContext(ctx)
186 if sessionID == "" {
187 return ai.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
188 }
189 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
190 p := b.permissions.Request(
191 permission.CreatePermissionRequest{
192 SessionID: sessionID,
193 ToolCallID: params.ID,
194 Path: b.workingDir,
195 ToolName: b.Info().Name,
196 Action: "execute",
197 Description: permissionDescription,
198 Params: params.Input,
199 },
200 )
201 if !p {
202 return ai.ToolResponse{}, permission.ErrorPermissionDenied
203 }
204
205 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
206}
207
208func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []ai.AgentTool {
209 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
210 if err != nil {
211 slog.Error("error listing tools", "error", err)
212 updateMCPState(name, MCPStateError, err, nil, 0)
213 c.Close()
214 mcpClients.Del(name)
215 return nil
216 }
217 mcpTools := make([]ai.AgentTool, 0, len(result.Tools))
218 for _, tool := range result.Tools {
219 mcpTools = append(mcpTools, &McpTool{
220 mcpName: name,
221 tool: tool,
222 permissions: permissions,
223 workingDir: workingDir,
224 })
225 }
226 return mcpTools
227}
228
229// SubscribeMCPEvents returns a channel for MCP events
230func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
231 return mcpBroker.Subscribe(ctx)
232}
233
234// GetMCPStates returns the current state of all MCP clients
235func GetMCPStates() map[string]MCPClientInfo {
236 return maps.Collect(mcpStates.Seq2())
237}
238
239// GetMCPState returns the state of a specific MCP client
240func GetMCPState(name string) (MCPClientInfo, bool) {
241 return mcpStates.Get(name)
242}
243
244// updateMCPState updates the state of an MCP client and publishes an event
245func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
246 info := MCPClientInfo{
247 Name: name,
248 State: state,
249 Error: err,
250 Client: client,
251 ToolCount: toolCount,
252 }
253 if state == MCPStateConnected {
254 info.ConnectedAt = time.Now()
255 }
256 mcpStates.Set(name, info)
257
258 // Publish state change event
259 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
260 Type: MCPEventStateChanged,
261 Name: name,
262 State: state,
263 Error: err,
264 ToolCount: toolCount,
265 })
266}
267
268// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
269func CloseMCPClients() error {
270 var errs []error
271 for name, c := range mcpClients.Seq2() {
272 if err := c.Close(); err != nil {
273 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
274 }
275 }
276 mcpBroker.Shutdown()
277 return errors.Join(errs...)
278}
279
280var mcpInitRequest = mcp.InitializeRequest{
281 Params: mcp.InitializeParams{
282 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
283 ClientInfo: mcp.Implementation{
284 Name: "Crush",
285 Version: version.Version,
286 },
287 },
288}
289
290func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []ai.AgentTool {
291 var wg sync.WaitGroup
292 result := csync.NewSlice[ai.AgentTool]()
293
294 // Initialize states for all configured MCPs
295 for name, m := range cfg.MCP {
296 if m.Disabled {
297 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
298 slog.Debug("skipping disabled mcp", "name", name)
299 continue
300 }
301
302 // Set initial starting state
303 updateMCPState(name, MCPStateStarting, nil, nil, 0)
304
305 wg.Add(1)
306 go func(name string, m config.MCPConfig) {
307 defer func() {
308 wg.Done()
309 if r := recover(); r != nil {
310 var err error
311 switch v := r.(type) {
312 case error:
313 err = v
314 case string:
315 err = fmt.Errorf("panic: %s", v)
316 default:
317 err = fmt.Errorf("panic: %v", v)
318 }
319 updateMCPState(name, MCPStateError, err, nil, 0)
320 slog.Error("panic in mcp client initialization", "error", err, "name", name)
321 }
322 }()
323
324 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
325 defer cancel()
326 c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
327 if err != nil {
328 return
329 }
330 mcpClients.Set(name, c)
331
332 tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
333 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
334 result.Append(tools...)
335 }(name, m)
336 }
337 wg.Wait()
338 return slices.Collect(result.Seq())
339}
340
341func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
342 c, err := createMcpClient(name, m, resolver)
343 if err != nil {
344 updateMCPState(name, MCPStateError, err, nil, 0)
345 slog.Error("error creating mcp client", "error", err, "name", name)
346 return nil, err
347 }
348
349 timeout := mcpTimeout(m)
350 initCtx, cancel := context.WithTimeout(ctx, timeout)
351 defer cancel()
352
353 // Only call Start() for non-stdio clients, as stdio clients auto-start
354 if m.Type != config.MCPStdio {
355 if err := c.Start(initCtx); err != nil {
356 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
357 slog.Error("error starting mcp client", "error", err, "name", name)
358 _ = c.Close()
359 return nil, err
360 }
361 }
362 if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil {
363 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
364 slog.Error("error initializing mcp client", "error", err, "name", name)
365 _ = c.Close()
366 return nil, err
367 }
368
369 slog.Info("Initialized mcp client", "name", name)
370 return c, nil
371}
372
373func maybeTimeoutErr(err error, timeout time.Duration) error {
374 if errors.Is(err, context.DeadlineExceeded) {
375 return fmt.Errorf("timed out after %s", timeout)
376 }
377 return err
378}
379
380func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
381 switch m.Type {
382 case config.MCPStdio:
383 command, err := resolver.ResolveValue(m.Command)
384 if err != nil {
385 return nil, fmt.Errorf("invalid mcp command: %w", err)
386 }
387 if strings.TrimSpace(command) == "" {
388 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
389 }
390 return client.NewStdioMCPClientWithOptions(
391 home.Long(command),
392 m.ResolvedEnv(),
393 m.Args,
394 transport.WithCommandLogger(mcpLogger{name: name}),
395 )
396 case config.MCPHttp:
397 if strings.TrimSpace(m.URL) == "" {
398 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
399 }
400 return client.NewStreamableHttpClient(
401 m.URL,
402 transport.WithHTTPHeaders(m.ResolvedHeaders()),
403 transport.WithHTTPLogger(mcpLogger{name: name}),
404 )
405 case config.MCPSse:
406 if strings.TrimSpace(m.URL) == "" {
407 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
408 }
409 return client.NewSSEMCPClient(
410 m.URL,
411 client.WithHeaders(m.ResolvedHeaders()),
412 transport.WithSSELogger(mcpLogger{name: name}),
413 )
414 default:
415 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
416 }
417}
418
419// for MCP's clients.
420type mcpLogger struct{ name string }
421
422func (l mcpLogger) Errorf(format string, v ...any) {
423 slog.Error(fmt.Sprintf(format, v...), "name", l.name)
424}
425
426func (l mcpLogger) Infof(format string, v ...any) {
427 slog.Info(fmt.Sprintf(format, v...), "name", l.name)
428}
429
430func mcpTimeout(m config.MCPConfig) time.Duration {
431 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
432}