1package tools
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "net/http"
13 "os"
14 "os/exec"
15 "slices"
16 "strings"
17 "sync"
18 "time"
19
20 "charm.land/fantasy"
21 "github.com/charmbracelet/crush/internal/config"
22 "github.com/charmbracelet/crush/internal/csync"
23 "github.com/charmbracelet/crush/internal/home"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/pubsub"
26 "github.com/charmbracelet/crush/internal/version"
27 "github.com/modelcontextprotocol/go-sdk/mcp"
28)
29
30// MCPState represents the current state of an MCP client
31type MCPState int
32
33const (
34 MCPStateDisabled MCPState = iota
35 MCPStateStarting
36 MCPStateConnected
37 MCPStateError
38)
39
40func (s MCPState) String() string {
41 switch s {
42 case MCPStateDisabled:
43 return "disabled"
44 case MCPStateStarting:
45 return "starting"
46 case MCPStateConnected:
47 return "connected"
48 case MCPStateError:
49 return "error"
50 default:
51 return "unknown"
52 }
53}
54
55// MCPEventType represents the type of MCP event
56type MCPEventType string
57
58const (
59 MCPEventStateChanged MCPEventType = "state_changed"
60 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
61)
62
63// MCPEvent represents an event in the MCP system
64type MCPEvent struct {
65 Type MCPEventType
66 Name string
67 State MCPState
68 Error error
69 ToolCount int
70}
71
72// MCPClientInfo holds information about an MCP client's state
73type MCPClientInfo struct {
74 Name string
75 State MCPState
76 Error error
77 Client *mcp.ClientSession
78 ToolCount int
79 ConnectedAt time.Time
80}
81
82var (
83 mcpToolsOnce sync.Once
84 mcpTools = csync.NewMap[string, *McpTool]()
85 mcpClient2Tools = csync.NewMap[string, []*McpTool]()
86 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
87 mcpStates = csync.NewMap[string, MCPClientInfo]()
88 mcpBroker = pubsub.NewBroker[MCPEvent]()
89)
90
91type McpTool struct {
92 mcpName string
93 tool *mcp.Tool
94 permissions permission.Service
95 workingDir string
96 providerOptions fantasy.ProviderOptions
97}
98
99func (m *McpTool) SetProviderOptions(opts fantasy.ProviderOptions) {
100 m.providerOptions = opts
101}
102
103func (m *McpTool) ProviderOptions() fantasy.ProviderOptions {
104 return m.providerOptions
105}
106
107func (m *McpTool) Name() string {
108 return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
109}
110
111func (m *McpTool) MCP() string {
112 return m.mcpName
113}
114
115func (m *McpTool) MCPToolName() string {
116 return m.tool.Name
117}
118
119func (b *McpTool) Info() fantasy.ToolInfo {
120 parameters := make(map[string]any)
121 required := make([]string, 0)
122
123 if input, ok := b.tool.InputSchema.(map[string]any); ok {
124 if props, ok := input["properties"].(map[string]any); ok {
125 parameters = props
126 }
127 if req, ok := input["required"].([]any); ok {
128 // Convert []any -> []string when elements are strings
129 for _, v := range req {
130 if s, ok := v.(string); ok {
131 required = append(required, s)
132 }
133 }
134 } else if reqStr, ok := input["required"].([]string); ok {
135 // Handle case where it's already []string
136 required = reqStr
137 }
138 }
139
140 return fantasy.ToolInfo{
141 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
142 Description: b.tool.Description,
143 Parameters: parameters,
144 Required: required,
145 }
146}
147
148func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) {
149 var args map[string]any
150 if err := json.Unmarshal([]byte(input), &args); err != nil {
151 return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
152 }
153
154 c, err := getOrRenewClient(ctx, name)
155 if err != nil {
156 return fantasy.NewTextErrorResponse(err.Error()), nil
157 }
158 result, err := c.CallTool(ctx, &mcp.CallToolParams{
159 Name: toolName,
160 Arguments: args,
161 })
162 if err != nil {
163 return fantasy.NewTextErrorResponse(err.Error()), nil
164 }
165
166 output := make([]string, 0, len(result.Content))
167 for _, v := range result.Content {
168 if vv, ok := v.(*mcp.TextContent); ok {
169 output = append(output, vv.Text)
170 } else {
171 output = append(output, fmt.Sprintf("%v", v))
172 }
173 }
174 return fantasy.NewTextResponse(strings.Join(output, "\n")), nil
175}
176
177func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
178 sess, ok := mcpClients.Get(name)
179 if !ok {
180 return nil, fmt.Errorf("mcp '%s' not available", name)
181 }
182
183 cfg := config.Get()
184 m := cfg.MCP[name]
185 state, _ := mcpStates.Get(name)
186
187 timeout := mcpTimeout(m)
188 pingCtx, cancel := context.WithTimeout(ctx, timeout)
189 defer cancel()
190 err := sess.Ping(pingCtx, nil)
191 if err == nil {
192 return sess, nil
193 }
194 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
195
196 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
197 if err != nil {
198 return nil, err
199 }
200
201 updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
202 mcpClients.Set(name, sess)
203 return sess, nil
204}
205
206func (m *McpTool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) {
207 sessionID := GetSessionFromContext(ctx)
208 if sessionID == "" {
209 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
210 }
211 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
212 p := m.permissions.Request(
213 permission.CreatePermissionRequest{
214 SessionID: sessionID,
215 ToolCallID: params.ID,
216 Path: m.workingDir,
217 ToolName: m.Info().Name,
218 Action: "execute",
219 Description: permissionDescription,
220 Params: params.Input,
221 },
222 )
223 if !p {
224 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
225 }
226
227 return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
228}
229
230func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*McpTool, error) {
231 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
232 if err != nil {
233 return nil, err
234 }
235 mcpTools := make([]*McpTool, 0, len(result.Tools))
236 for _, tool := range result.Tools {
237 mcpTools = append(mcpTools, &McpTool{
238 mcpName: name,
239 tool: tool,
240 permissions: permissions,
241 workingDir: workingDir,
242 })
243 }
244 return mcpTools, nil
245}
246
247// SubscribeMCPEvents returns a channel for MCP events
248func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
249 return mcpBroker.Subscribe(ctx)
250}
251
252// GetMCPStates returns the current state of all MCP clients
253func GetMCPStates() map[string]MCPClientInfo {
254 return maps.Collect(mcpStates.Seq2())
255}
256
257// GetMCPState returns the state of a specific MCP client
258func GetMCPState(name string) (MCPClientInfo, bool) {
259 return mcpStates.Get(name)
260}
261
262// updateMCPState updates the state of an MCP client and publishes an event
263func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
264 info := MCPClientInfo{
265 Name: name,
266 State: state,
267 Error: err,
268 Client: client,
269 ToolCount: toolCount,
270 }
271 switch state {
272 case MCPStateConnected:
273 info.ConnectedAt = time.Now()
274 case MCPStateError:
275 updateMcpTools(name, nil)
276 mcpClients.Del(name)
277 }
278 mcpStates.Set(name, info)
279
280 // Publish state change event
281 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
282 Type: MCPEventStateChanged,
283 Name: name,
284 State: state,
285 Error: err,
286 ToolCount: toolCount,
287 })
288}
289
290// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
291func CloseMCPClients() error {
292 var errs []error
293 for name, c := range mcpClients.Seq2() {
294 if err := c.Close(); err != nil &&
295 !errors.Is(err, io.EOF) &&
296 !errors.Is(err, context.Canceled) &&
297 err.Error() != "signal: killed" {
298 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
299 }
300 }
301 mcpBroker.Shutdown()
302 return errors.Join(errs...)
303}
304
305func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []*McpTool {
306 mcpToolsOnce.Do(func() {
307 var wg sync.WaitGroup
308 // Initialize states for all configured MCPs
309 for name, m := range cfg.MCP {
310 if m.Disabled {
311 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
312 slog.Debug("skipping disabled mcp", "name", name)
313 continue
314 }
315
316 // Set initial starting state
317 updateMCPState(name, MCPStateStarting, nil, nil, 0)
318
319 wg.Add(1)
320 go func(name string, m config.MCPConfig) {
321 defer func() {
322 wg.Done()
323 if r := recover(); r != nil {
324 var err error
325 switch v := r.(type) {
326 case error:
327 err = v
328 case string:
329 err = fmt.Errorf("panic: %s", v)
330 default:
331 err = fmt.Errorf("panic: %v", v)
332 }
333 updateMCPState(name, MCPStateError, err, nil, 0)
334 slog.Error("panic in mcp client initialization", "error", err, "name", name)
335 }
336 }()
337
338 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
339 defer cancel()
340
341 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
342 if err != nil {
343 return
344 }
345
346 mcpClients.Set(name, c)
347
348 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
349 if err != nil {
350 slog.Error("error listing tools", "error", err)
351 updateMCPState(name, MCPStateError, err, nil, 0)
352 c.Close()
353 return
354 }
355
356 updateMcpTools(name, tools)
357 mcpClients.Set(name, c)
358 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
359 }(name, m)
360 }
361 wg.Wait()
362 })
363 return slices.Collect(mcpTools.Seq())
364}
365
366// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
367func updateMcpTools(mcpName string, tools []*McpTool) {
368 if len(tools) == 0 {
369 mcpClient2Tools.Del(mcpName)
370 } else {
371 mcpClient2Tools.Set(mcpName, tools)
372 }
373 for _, tools := range mcpClient2Tools.Seq2() {
374 for _, t := range tools {
375 mcpTools.Set(t.Name(), t)
376 }
377 }
378}
379
380func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
381 timeout := mcpTimeout(m)
382 mcpCtx, cancel := context.WithCancel(ctx)
383 cancelTimer := time.AfterFunc(timeout, cancel)
384
385 transport, err := createMCPTransport(mcpCtx, m, resolver)
386 if err != nil {
387 updateMCPState(name, MCPStateError, err, nil, 0)
388 slog.Error("error creating mcp client", "error", err, "name", name)
389 cancel()
390 cancelTimer.Stop()
391 return nil, err
392 }
393
394 client := mcp.NewClient(
395 &mcp.Implementation{
396 Name: "crush",
397 Version: version.Version,
398 Title: "Crush",
399 },
400 &mcp.ClientOptions{
401 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
402 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
403 Type: MCPEventToolsListChanged,
404 Name: name,
405 })
406 },
407 KeepAlive: time.Minute * 10,
408 },
409 )
410
411 session, err := client.Connect(mcpCtx, transport, nil)
412 if err != nil {
413 err = maybeStdioErr(err, transport)
414 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
415 slog.Error("error starting mcp client", "error", err, "name", name)
416 cancel()
417 cancelTimer.Stop()
418 return nil, err
419 }
420
421 cancelTimer.Stop()
422 slog.Info("Initialized mcp client", "name", name)
423 return session, nil
424}
425
426// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
427// to parse, and the cli will then close it, causing the EOF error.
428// so, if we got an EOF err, and the transport is STDIO, we try to exec it
429// again with a timeout and collect the output so we can add details to the
430// error.
431// this happens particularly when starting things with npx, e.g. if node can't
432// be found or some other error like that.
433func maybeStdioErr(err error, transport mcp.Transport) error {
434 if !errors.Is(err, io.EOF) {
435 return err
436 }
437 ct, ok := transport.(*mcp.CommandTransport)
438 if !ok {
439 return err
440 }
441 if err2 := stdioMCPCheck(ct.Command); err2 != nil {
442 err = errors.Join(err, err2)
443 }
444 return err
445}
446
447func maybeTimeoutErr(err error, timeout time.Duration) error {
448 if errors.Is(err, context.Canceled) {
449 return fmt.Errorf("timed out after %s", timeout)
450 }
451 return err
452}
453
454func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
455 switch m.Type {
456 case config.MCPStdio:
457 command, err := resolver.ResolveValue(m.Command)
458 if err != nil {
459 return nil, fmt.Errorf("invalid mcp command: %w", err)
460 }
461 if strings.TrimSpace(command) == "" {
462 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
463 }
464 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
465 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
466 return &mcp.CommandTransport{
467 Command: cmd,
468 }, nil
469 case config.MCPHttp:
470 if strings.TrimSpace(m.URL) == "" {
471 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
472 }
473 client := &http.Client{
474 Transport: &headerRoundTripper{
475 headers: m.ResolvedHeaders(),
476 },
477 }
478 return &mcp.StreamableClientTransport{
479 Endpoint: m.URL,
480 HTTPClient: client,
481 }, nil
482 case config.MCPSSE:
483 if strings.TrimSpace(m.URL) == "" {
484 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
485 }
486 client := &http.Client{
487 Transport: &headerRoundTripper{
488 headers: m.ResolvedHeaders(),
489 },
490 }
491 return &mcp.SSEClientTransport{
492 Endpoint: m.URL,
493 HTTPClient: client,
494 }, nil
495 default:
496 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
497 }
498}
499
500type headerRoundTripper struct {
501 headers map[string]string
502}
503
504func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
505 for k, v := range rt.headers {
506 req.Header.Set(k, v)
507 }
508 return http.DefaultTransport.RoundTrip(req)
509}
510
511func mcpTimeout(m config.MCPConfig) time.Duration {
512 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
513}
514
515func stdioMCPCheck(old *exec.Cmd) error {
516 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
517 defer cancel()
518 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
519 cmd.Env = old.Env
520 out, err := cmd.CombinedOutput()
521 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
522 return nil
523 }
524 return fmt.Errorf("%w: %s", err, string(out))
525}