1package tools
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "net/http"
13 "os/exec"
14 "slices"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/home"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/charmbracelet/fantasy/ai"
26 "github.com/modelcontextprotocol/go-sdk/mcp"
27)
28
29// MCPState represents the current state of an MCP client
30type MCPState int
31
32const (
33 MCPStateDisabled MCPState = iota
34 MCPStateStarting
35 MCPStateConnected
36 MCPStateError
37)
38
39func (s MCPState) String() string {
40 switch s {
41 case MCPStateDisabled:
42 return "disabled"
43 case MCPStateStarting:
44 return "starting"
45 case MCPStateConnected:
46 return "connected"
47 case MCPStateError:
48 return "error"
49 default:
50 return "unknown"
51 }
52}
53
54// MCPEventType represents the type of MCP event
55type MCPEventType string
56
57const (
58 MCPEventStateChanged MCPEventType = "state_changed"
59 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
60)
61
62// MCPEvent represents an event in the MCP system
63type MCPEvent struct {
64 Type MCPEventType
65 Name string
66 State MCPState
67 Error error
68 ToolCount int
69}
70
71// MCPClientInfo holds information about an MCP client's state
72type MCPClientInfo struct {
73 Name string
74 State MCPState
75 Error error
76 Client *mcp.ClientSession
77 ToolCount int
78 ConnectedAt time.Time
79}
80
81var (
82 mcpToolsOnce sync.Once
83 mcpTools = csync.NewMap[string, *McpTool]()
84 mcpClient2Tools = csync.NewMap[string, []*McpTool]()
85 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
86 mcpStates = csync.NewMap[string, MCPClientInfo]()
87 mcpBroker = pubsub.NewBroker[MCPEvent]()
88)
89
90type McpTool struct {
91 mcpName string
92 tool *mcp.Tool
93 permissions permission.Service
94 workingDir string
95 providerOptions ai.ProviderOptions
96}
97
98func (m *McpTool) SetProviderOptions(opts ai.ProviderOptions) {
99 m.providerOptions = opts
100}
101
102func (m *McpTool) ProviderOptions() ai.ProviderOptions {
103 return m.providerOptions
104}
105
106func (m *McpTool) Name() string {
107 return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
108}
109
110func (m *McpTool) MCP() string {
111 return m.mcpName
112}
113
114func (m *McpTool) MCPToolName() string {
115 return m.tool.Name
116}
117
118func (b *McpTool) Info() ai.ToolInfo {
119 input := b.tool.InputSchema.(map[string]any)
120 required, _ := input["required"].([]string)
121 if required == nil {
122 required = make([]string, 0)
123 }
124 parameters, _ := input["properties"].(map[string]any)
125 if parameters == nil {
126 parameters = make(map[string]any)
127 }
128 return ai.ToolInfo{
129 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
130 Description: b.tool.Description,
131 Parameters: parameters,
132 Required: required,
133 }
134}
135
136func runTool(ctx context.Context, name, toolName string, input string) (ai.ToolResponse, error) {
137 var args map[string]any
138 if err := json.Unmarshal([]byte(input), &args); err != nil {
139 return ai.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
140 }
141
142 c, err := getOrRenewClient(ctx, name)
143 if err != nil {
144 return ai.NewTextErrorResponse(err.Error()), nil
145 }
146 result, err := c.CallTool(ctx, &mcp.CallToolParams{
147 Name: toolName,
148 Arguments: args,
149 })
150 if err != nil {
151 return ai.NewTextErrorResponse(err.Error()), nil
152 }
153
154 output := make([]string, 0, len(result.Content))
155 for _, v := range result.Content {
156 if vv, ok := v.(*mcp.TextContent); ok {
157 output = append(output, vv.Text)
158 } else {
159 output = append(output, fmt.Sprintf("%v", v))
160 }
161 }
162 return ai.NewTextResponse(strings.Join(output, "\n")), nil
163}
164
165func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
166 sess, ok := mcpClients.Get(name)
167 if !ok {
168 return nil, fmt.Errorf("mcp '%s' not available", name)
169 }
170
171 cfg := config.Get()
172 m := cfg.MCP[name]
173 state, _ := mcpStates.Get(name)
174
175 timeout := mcpTimeout(m)
176 pingCtx, cancel := context.WithTimeout(ctx, timeout)
177 defer cancel()
178 err := sess.Ping(pingCtx, nil)
179 if err == nil {
180 return sess, nil
181 }
182 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
183
184 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
185 if err != nil {
186 return nil, err
187 }
188
189 updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
190 mcpClients.Set(name, sess)
191 return sess, nil
192}
193
194func (m *McpTool) Run(ctx context.Context, params ai.ToolCall) (ai.ToolResponse, error) {
195 sessionID := GetSessionFromContext(ctx)
196 if sessionID == "" {
197 return ai.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
198 }
199 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
200 p := m.permissions.Request(
201 permission.CreatePermissionRequest{
202 SessionID: sessionID,
203 ToolCallID: params.ID,
204 Path: m.workingDir,
205 ToolName: m.Info().Name,
206 Action: "execute",
207 Description: permissionDescription,
208 Params: params.Input,
209 },
210 )
211 if !p {
212 return ai.ToolResponse{}, permission.ErrorPermissionDenied
213 }
214
215 return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
216}
217
218func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*McpTool, error) {
219 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
220 if err != nil {
221 return nil, err
222 }
223 mcpTools := make([]*McpTool, 0, len(result.Tools))
224 for _, tool := range result.Tools {
225 mcpTools = append(mcpTools, &McpTool{
226 mcpName: name,
227 tool: tool,
228 permissions: permissions,
229 workingDir: workingDir,
230 })
231 }
232 return mcpTools, nil
233}
234
235// SubscribeMCPEvents returns a channel for MCP events
236func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
237 return mcpBroker.Subscribe(ctx)
238}
239
240// GetMCPStates returns the current state of all MCP clients
241func GetMCPStates() map[string]MCPClientInfo {
242 return maps.Collect(mcpStates.Seq2())
243}
244
245// GetMCPState returns the state of a specific MCP client
246func GetMCPState(name string) (MCPClientInfo, bool) {
247 return mcpStates.Get(name)
248}
249
250// updateMCPState updates the state of an MCP client and publishes an event
251func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
252 info := MCPClientInfo{
253 Name: name,
254 State: state,
255 Error: err,
256 Client: client,
257 ToolCount: toolCount,
258 }
259 switch state {
260 case MCPStateConnected:
261 info.ConnectedAt = time.Now()
262 case MCPStateError:
263 updateMcpTools(name, nil)
264 mcpClients.Del(name)
265 }
266 mcpStates.Set(name, info)
267
268 // Publish state change event
269 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
270 Type: MCPEventStateChanged,
271 Name: name,
272 State: state,
273 Error: err,
274 ToolCount: toolCount,
275 })
276}
277
278// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
279func CloseMCPClients() error {
280 var errs []error
281 for name, c := range mcpClients.Seq2() {
282 if err := c.Close(); err != nil &&
283 !errors.Is(err, io.EOF) &&
284 !errors.Is(err, context.Canceled) &&
285 err.Error() != "signal: killed" {
286 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
287 }
288 }
289 mcpBroker.Shutdown()
290 return errors.Join(errs...)
291}
292
293func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []*McpTool {
294 mcpToolsOnce.Do(func() {
295 var wg sync.WaitGroup
296 // Initialize states for all configured MCPs
297 for name, m := range cfg.MCP {
298 if m.Disabled {
299 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
300 slog.Debug("skipping disabled mcp", "name", name)
301 continue
302 }
303
304 // Set initial starting state
305 updateMCPState(name, MCPStateStarting, nil, nil, 0)
306
307 wg.Add(1)
308 go func(name string, m config.MCPConfig) {
309 defer func() {
310 wg.Done()
311 if r := recover(); r != nil {
312 var err error
313 switch v := r.(type) {
314 case error:
315 err = v
316 case string:
317 err = fmt.Errorf("panic: %s", v)
318 default:
319 err = fmt.Errorf("panic: %v", v)
320 }
321 updateMCPState(name, MCPStateError, err, nil, 0)
322 slog.Error("panic in mcp client initialization", "error", err, "name", name)
323 }
324 }()
325 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
326 if err != nil {
327 return
328 }
329
330 mcpClients.Set(name, c)
331
332 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
333 if err != nil {
334 slog.Error("error listing tools", "error", err)
335 updateMCPState(name, MCPStateError, err, nil, 0)
336 c.Close()
337 return
338 }
339
340 updateMcpTools(name, tools)
341 mcpClients.Set(name, c)
342 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
343 }(name, m)
344 }
345 wg.Wait()
346 })
347 return slices.Collect(mcpTools.Seq())
348}
349
350// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
351func updateMcpTools(mcpName string, tools []*McpTool) {
352 if len(tools) == 0 {
353 mcpClient2Tools.Del(mcpName)
354 } else {
355 mcpClient2Tools.Set(mcpName, tools)
356 }
357 for _, tools := range mcpClient2Tools.Seq2() {
358 for _, t := range tools {
359 mcpTools.Set(t.Name(), t)
360 }
361 }
362}
363
364func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
365 timeout := mcpTimeout(m)
366 mcpCtx, cancel := context.WithCancel(ctx)
367 cancelTimer := time.AfterFunc(timeout, cancel)
368
369 transport, err := createMCPTransport(mcpCtx, m, resolver)
370 if err != nil {
371 updateMCPState(name, MCPStateError, err, nil, 0)
372 slog.Error("error creating mcp client", "error", err, "name", name)
373 return nil, err
374 }
375
376 client := mcp.NewClient(
377 &mcp.Implementation{
378 Name: "crush",
379 Version: version.Version,
380 Title: "Crush",
381 },
382 &mcp.ClientOptions{
383 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
384 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
385 Type: MCPEventToolsListChanged,
386 Name: name,
387 })
388 },
389 KeepAlive: time.Minute * 10,
390 },
391 )
392
393 session, err := client.Connect(mcpCtx, transport, nil)
394 if err != nil {
395 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
396 slog.Error("error starting mcp client", "error", err, "name", name)
397 cancel()
398 return nil, err
399 }
400
401 cancelTimer.Stop()
402 slog.Info("Initialized mcp client", "name", name)
403 return session, nil
404}
405
406func maybeTimeoutErr(err error, timeout time.Duration) error {
407 if errors.Is(err, context.Canceled) {
408 return fmt.Errorf("timed out after %s", timeout)
409 }
410 return err
411}
412
413func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
414 switch m.Type {
415 case config.MCPStdio:
416 command, err := resolver.ResolveValue(m.Command)
417 if err != nil {
418 return nil, fmt.Errorf("invalid mcp command: %w", err)
419 }
420 if strings.TrimSpace(command) == "" {
421 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
422 }
423 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
424 cmd.Env = m.ResolvedEnv()
425 return &mcp.CommandTransport{
426 Command: cmd,
427 }, nil
428 case config.MCPHttp:
429 if strings.TrimSpace(m.URL) == "" {
430 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
431 }
432 client := &http.Client{
433 Transport: &headerRoundTripper{
434 headers: m.ResolvedHeaders(),
435 },
436 }
437 return &mcp.StreamableClientTransport{
438 Endpoint: m.URL,
439 HTTPClient: client,
440 }, nil
441 case config.MCPSSE:
442 if strings.TrimSpace(m.URL) == "" {
443 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
444 }
445 client := &http.Client{
446 Transport: &headerRoundTripper{
447 headers: m.ResolvedHeaders(),
448 },
449 }
450 return &mcp.SSEClientTransport{
451 Endpoint: m.URL,
452 HTTPClient: client,
453 }, nil
454 default:
455 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
456 }
457}
458
459type headerRoundTripper struct {
460 headers map[string]string
461}
462
463func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
464 for k, v := range rt.headers {
465 req.Header.Set(k, v)
466 }
467 return http.DefaultTransport.RoundTrip(req)
468}
469
470func mcpTimeout(m config.MCPConfig) time.Duration {
471 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
472}