1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "net/http"
13 "os/exec"
14 "strings"
15 "sync"
16 "time"
17
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/home"
21 "github.com/charmbracelet/crush/internal/llm/tools"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/modelcontextprotocol/go-sdk/mcp"
26)
27
28// MCPState represents the current state of an MCP client
29type MCPState int
30
31const (
32 MCPStateDisabled MCPState = iota
33 MCPStateStarting
34 MCPStateConnected
35 MCPStateError
36)
37
38func (s MCPState) String() string {
39 switch s {
40 case MCPStateDisabled:
41 return "disabled"
42 case MCPStateStarting:
43 return "starting"
44 case MCPStateConnected:
45 return "connected"
46 case MCPStateError:
47 return "error"
48 default:
49 return "unknown"
50 }
51}
52
53// MCPEventType represents the type of MCP event
54type MCPEventType string
55
56const (
57 MCPEventStateChanged MCPEventType = "state_changed"
58 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
59)
60
61// MCPEvent represents an event in the MCP system
62type MCPEvent struct {
63 Type MCPEventType
64 Name string
65 State MCPState
66 Error error
67 ToolCount int
68}
69
70// MCPClientInfo holds information about an MCP client's state
71type MCPClientInfo struct {
72 Name string
73 State MCPState
74 Error error
75 Client *mcp.ClientSession
76 ToolCount int
77 ConnectedAt time.Time
78}
79
80var (
81 mcpToolsOnce sync.Once
82 mcpTools = csync.NewMap[string, tools.BaseTool]()
83 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
84 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
85 mcpStates = csync.NewMap[string, MCPClientInfo]()
86 mcpBroker = pubsub.NewBroker[MCPEvent]()
87)
88
89type McpTool struct {
90 mcpName string
91 tool *mcp.Tool
92 permissions permission.Service
93 workingDir string
94}
95
96func (b *McpTool) Name() string {
97 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
98}
99
100func (b *McpTool) Info() tools.ToolInfo {
101 input := b.tool.InputSchema.(map[string]any)
102 required, _ := input["required"].([]string)
103 parameters, _ := input["properties"].(map[string]any)
104 return tools.ToolInfo{
105 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
106 Description: b.tool.Description,
107 Parameters: parameters,
108 Required: required,
109 }
110}
111
112func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
113 var args map[string]any
114 if err := json.Unmarshal([]byte(input), &args); err != nil {
115 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
116 }
117
118 c, err := getOrRenewClient(ctx, name)
119 if err != nil {
120 return tools.NewTextErrorResponse(err.Error()), nil
121 }
122 result, err := c.CallTool(ctx, &mcp.CallToolParams{
123 Name: toolName,
124 Arguments: args,
125 })
126 if err != nil {
127 return tools.NewTextErrorResponse(err.Error()), nil
128 }
129
130 output := make([]string, 0, len(result.Content))
131 for _, v := range result.Content {
132 if vv, ok := v.(*mcp.TextContent); ok {
133 output = append(output, vv.Text)
134 } else {
135 output = append(output, fmt.Sprintf("%v", v))
136 }
137 }
138 return tools.NewTextResponse(strings.Join(output, "\n")), nil
139}
140
141func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
142 sess, ok := mcpClients.Get(name)
143 if !ok {
144 return nil, fmt.Errorf("mcp '%s' not available", name)
145 }
146
147 cfg := config.Get()
148 m := cfg.MCP[name]
149 state, _ := mcpStates.Get(name)
150
151 timeout := mcpTimeout(m)
152 pingCtx, cancel := context.WithTimeout(ctx, timeout)
153 defer cancel()
154 err := sess.Ping(pingCtx, nil)
155 if err == nil {
156 return sess, nil
157 }
158 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
159
160 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
161 if err != nil {
162 return nil, err
163 }
164
165 updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
166 mcpClients.Set(name, sess)
167 return sess, nil
168}
169
170func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
171 sessionID, messageID := tools.GetContextValues(ctx)
172 if sessionID == "" || messageID == "" {
173 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
174 }
175 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
176 p := b.permissions.Request(
177 permission.CreatePermissionRequest{
178 SessionID: sessionID,
179 ToolCallID: params.ID,
180 Path: b.workingDir,
181 ToolName: b.Info().Name,
182 Action: "execute",
183 Description: permissionDescription,
184 Params: params.Input,
185 },
186 )
187 if !p {
188 return tools.ToolResponse{}, permission.ErrorPermissionDenied
189 }
190
191 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
192}
193
194func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
195 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
196 if err != nil {
197 return nil, err
198 }
199 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
200 for _, tool := range result.Tools {
201 mcpTools = append(mcpTools, &McpTool{
202 mcpName: name,
203 tool: tool,
204 permissions: permissions,
205 workingDir: workingDir,
206 })
207 }
208 return mcpTools, nil
209}
210
211// SubscribeMCPEvents returns a channel for MCP events
212func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
213 return mcpBroker.Subscribe(ctx)
214}
215
216// GetMCPStates returns the current state of all MCP clients
217func GetMCPStates() map[string]MCPClientInfo {
218 return maps.Collect(mcpStates.Seq2())
219}
220
221// GetMCPState returns the state of a specific MCP client
222func GetMCPState(name string) (MCPClientInfo, bool) {
223 return mcpStates.Get(name)
224}
225
226// updateMCPState updates the state of an MCP client and publishes an event
227func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
228 info := MCPClientInfo{
229 Name: name,
230 State: state,
231 Error: err,
232 Client: client,
233 ToolCount: toolCount,
234 }
235 switch state {
236 case MCPStateConnected:
237 info.ConnectedAt = time.Now()
238 case MCPStateError:
239 updateMcpTools(name, nil)
240 mcpClients.Del(name)
241 }
242 mcpStates.Set(name, info)
243
244 // Publish state change event
245 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
246 Type: MCPEventStateChanged,
247 Name: name,
248 State: state,
249 Error: err,
250 ToolCount: toolCount,
251 })
252}
253
254// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
255func CloseMCPClients() error {
256 var errs []error
257 for name, c := range mcpClients.Seq2() {
258 if err := c.Close(); err != nil &&
259 !errors.Is(err, io.EOF) &&
260 !errors.Is(err, context.Canceled) &&
261 err.Error() != "signal: killed" {
262 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
263 }
264 }
265 mcpBroker.Shutdown()
266 return errors.Join(errs...)
267}
268
269func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
270 var wg sync.WaitGroup
271 // Initialize states for all configured MCPs
272 for name, m := range cfg.MCP {
273 if m.Disabled {
274 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
275 slog.Debug("skipping disabled mcp", "name", name)
276 continue
277 }
278
279 // Set initial starting state
280 updateMCPState(name, MCPStateStarting, nil, nil, 0)
281
282 wg.Add(1)
283 go func(name string, m config.MCPConfig) {
284 defer func() {
285 wg.Done()
286 if r := recover(); r != nil {
287 var err error
288 switch v := r.(type) {
289 case error:
290 err = v
291 case string:
292 err = fmt.Errorf("panic: %s", v)
293 default:
294 err = fmt.Errorf("panic: %v", v)
295 }
296 updateMCPState(name, MCPStateError, err, nil, 0)
297 slog.Error("panic in mcp client initialization", "error", err, "name", name)
298 }
299 }()
300
301 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
302 defer cancel()
303
304 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
305 if err != nil {
306 return
307 }
308
309 mcpClients.Set(name, c)
310
311 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
312 if err != nil {
313 slog.Error("error listing tools", "error", err)
314 updateMCPState(name, MCPStateError, err, nil, 0)
315 c.Close()
316 return
317 }
318
319 updateMcpTools(name, tools)
320 mcpClients.Set(name, c)
321 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
322 }(name, m)
323 }
324 wg.Wait()
325}
326
327// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
328func updateMcpTools(mcpName string, tools []tools.BaseTool) {
329 if len(tools) == 0 {
330 mcpClient2Tools.Del(mcpName)
331 } else {
332 mcpClient2Tools.Set(mcpName, tools)
333 }
334 for _, tools := range mcpClient2Tools.Seq2() {
335 for _, t := range tools {
336 mcpTools.Set(t.Name(), t)
337 }
338 }
339}
340
341func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
342 timeout := mcpTimeout(m)
343 mcpCtx, cancel := context.WithCancel(ctx)
344 cancelTimer := time.AfterFunc(timeout, cancel)
345
346 transport, err := createMCPTransport(mcpCtx, m, resolver)
347 if err != nil {
348 updateMCPState(name, MCPStateError, err, nil, 0)
349 slog.Error("error creating mcp client", "error", err, "name", name)
350 return nil, err
351 }
352
353 client := mcp.NewClient(
354 &mcp.Implementation{
355 Name: "crush",
356 Version: version.Version,
357 Title: "Crush",
358 },
359 &mcp.ClientOptions{
360 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
361 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
362 Type: MCPEventToolsListChanged,
363 Name: name,
364 })
365 },
366 KeepAlive: time.Minute * 10,
367 },
368 )
369
370 session, err := client.Connect(mcpCtx, transport, nil)
371 if err != nil {
372 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
373 slog.Error("error starting mcp client", "error", err, "name", name)
374 cancel()
375 return nil, err
376 }
377
378 cancelTimer.Stop()
379 slog.Info("Initialized mcp client", "name", name)
380 return session, nil
381}
382
383func maybeTimeoutErr(err error, timeout time.Duration) error {
384 if errors.Is(err, context.Canceled) {
385 return fmt.Errorf("timed out after %s", timeout)
386 }
387 return err
388}
389
390func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
391 switch m.Type {
392 case config.MCPStdio:
393 command, err := resolver.ResolveValue(m.Command)
394 if err != nil {
395 return nil, fmt.Errorf("invalid mcp command: %w", err)
396 }
397 if strings.TrimSpace(command) == "" {
398 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
399 }
400 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
401 cmd.Env = m.ResolvedEnv()
402 return &mcp.CommandTransport{
403 Command: cmd,
404 }, nil
405 case config.MCPHttp:
406 if strings.TrimSpace(m.URL) == "" {
407 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
408 }
409 client := &http.Client{
410 Transport: &headerRoundTripper{
411 headers: m.ResolvedHeaders(),
412 },
413 }
414 return &mcp.StreamableClientTransport{
415 Endpoint: m.URL,
416 HTTPClient: client,
417 }, nil
418 case config.MCPSSE:
419 if strings.TrimSpace(m.URL) == "" {
420 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
421 }
422 client := &http.Client{
423 Transport: &headerRoundTripper{
424 headers: m.ResolvedHeaders(),
425 },
426 }
427 return &mcp.SSEClientTransport{
428 Endpoint: m.URL,
429 HTTPClient: client,
430 }, nil
431 default:
432 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
433 }
434}
435
436type headerRoundTripper struct {
437 headers map[string]string
438}
439
440func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
441 for k, v := range rt.headers {
442 req.Header.Set(k, v)
443 }
444 return http.DefaultTransport.RoundTrip(req)
445}
446
447func mcpTimeout(m config.MCPConfig) time.Duration {
448 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
449}