1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "net/http"
12 "os/exec"
13 "strings"
14 "sync"
15 "time"
16
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/csync"
19 "github.com/charmbracelet/crush/internal/home"
20 "github.com/charmbracelet/crush/internal/llm/tools"
21 "github.com/charmbracelet/crush/internal/permission"
22 "github.com/charmbracelet/crush/internal/pubsub"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/modelcontextprotocol/go-sdk/mcp"
25)
26
27// MCPState represents the current state of an MCP client
28type MCPState int
29
30const (
31 MCPStateDisabled MCPState = iota
32 MCPStateStarting
33 MCPStateConnected
34 MCPStateError
35)
36
37func (s MCPState) String() string {
38 switch s {
39 case MCPStateDisabled:
40 return "disabled"
41 case MCPStateStarting:
42 return "starting"
43 case MCPStateConnected:
44 return "connected"
45 case MCPStateError:
46 return "error"
47 default:
48 return "unknown"
49 }
50}
51
52// MCPEventType represents the type of MCP event
53type MCPEventType string
54
55const (
56 MCPEventStateChanged MCPEventType = "state_changed"
57 MCPEventToolsListChanged MCPEventType = "tools_list_changed"
58)
59
60// MCPEvent represents an event in the MCP system
61type MCPEvent struct {
62 Type MCPEventType
63 Name string
64 State MCPState
65 Error error
66 ToolCount int
67}
68
69// MCPClientInfo holds information about an MCP client's state
70type MCPClientInfo struct {
71 Name string
72 State MCPState
73 Error error
74 Client *mcp.ClientSession
75 ToolCount int
76 ConnectedAt time.Time
77}
78
79var (
80 mcpToolsOnce sync.Once
81 mcpTools = csync.NewMap[string, tools.BaseTool]()
82 mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
83 mcpClients = csync.NewMap[string, *mcp.ClientSession]()
84 mcpStates = csync.NewMap[string, MCPClientInfo]()
85 mcpBroker = pubsub.NewBroker[MCPEvent]()
86)
87
88type McpTool struct {
89 mcpName string
90 tool *mcp.Tool
91 permissions permission.Service
92 workingDir string
93}
94
95func (b *McpTool) Name() string {
96 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
97}
98
99func (b *McpTool) Info() tools.ToolInfo {
100 input := b.tool.InputSchema.(map[string]any)
101 required, _ := input["required"].([]string)
102 parameters, _ := input["properties"].(map[string]any)
103 return tools.ToolInfo{
104 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
105 Description: b.tool.Description,
106 Parameters: parameters,
107 Required: required,
108 }
109}
110
111func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
112 var args map[string]any
113 if err := json.Unmarshal([]byte(input), &args); err != nil {
114 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
115 }
116
117 c, err := getOrRenewClient(ctx, name)
118 if err != nil {
119 return tools.NewTextErrorResponse(err.Error()), nil
120 }
121 result, err := c.CallTool(ctx, &mcp.CallToolParams{
122 Name: toolName,
123 Arguments: args,
124 })
125 if err != nil {
126 return tools.NewTextErrorResponse(err.Error()), nil
127 }
128
129 output := make([]string, 0, len(result.Content))
130 for _, v := range result.Content {
131 if vv, ok := v.(*mcp.TextContent); ok {
132 output = append(output, vv.Text)
133 } else {
134 output = append(output, fmt.Sprintf("%v", v))
135 }
136 }
137 return tools.NewTextResponse(strings.Join(output, "\n")), nil
138}
139
140func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
141 sess, ok := mcpClients.Get(name)
142 if !ok {
143 return nil, fmt.Errorf("mcp '%s' not available", name)
144 }
145
146 cfg := config.Get()
147 m := cfg.MCP[name]
148 state, _ := mcpStates.Get(name)
149
150 timeout := mcpTimeout(m)
151 pingCtx, cancel := context.WithTimeout(ctx, timeout)
152 defer cancel()
153 err := sess.Ping(pingCtx, nil)
154 if err == nil {
155 return sess, nil
156 }
157 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
158
159 sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
160 if err != nil {
161 return nil, err
162 }
163
164 updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
165 mcpClients.Set(name, sess)
166 return sess, nil
167}
168
169func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
170 sessionID, messageID := tools.GetContextValues(ctx)
171 if sessionID == "" || messageID == "" {
172 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
173 }
174 permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
175 p := b.permissions.Request(
176 permission.CreatePermissionRequest{
177 SessionID: sessionID,
178 ToolCallID: params.ID,
179 Path: b.workingDir,
180 ToolName: b.Info().Name,
181 Action: "execute",
182 Description: permissionDescription,
183 Params: params.Input,
184 },
185 )
186 if !p {
187 return tools.ToolResponse{}, permission.ErrorPermissionDenied
188 }
189
190 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
191}
192
193func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
194 result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
195 if err != nil {
196 return nil, err
197 }
198 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
199 for _, tool := range result.Tools {
200 mcpTools = append(mcpTools, &McpTool{
201 mcpName: name,
202 tool: tool,
203 permissions: permissions,
204 workingDir: workingDir,
205 })
206 }
207 return mcpTools, nil
208}
209
210// SubscribeMCPEvents returns a channel for MCP events
211func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
212 return mcpBroker.Subscribe(ctx)
213}
214
215// GetMCPStates returns the current state of all MCP clients
216func GetMCPStates() map[string]MCPClientInfo {
217 return maps.Collect(mcpStates.Seq2())
218}
219
220// GetMCPState returns the state of a specific MCP client
221func GetMCPState(name string) (MCPClientInfo, bool) {
222 return mcpStates.Get(name)
223}
224
225// updateMCPState updates the state of an MCP client and publishes an event
226func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
227 info := MCPClientInfo{
228 Name: name,
229 State: state,
230 Error: err,
231 Client: client,
232 ToolCount: toolCount,
233 }
234 switch state {
235 case MCPStateConnected:
236 info.ConnectedAt = time.Now()
237 case MCPStateError:
238 updateMcpTools(name, nil)
239 mcpClients.Del(name)
240 }
241 mcpStates.Set(name, info)
242
243 // Publish state change event
244 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
245 Type: MCPEventStateChanged,
246 Name: name,
247 State: state,
248 Error: err,
249 ToolCount: toolCount,
250 })
251}
252
253// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
254func CloseMCPClients() error {
255 var errs []error
256 for name, c := range mcpClients.Seq2() {
257 if err := c.Close(); err != nil {
258 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
259 }
260 }
261 mcpBroker.Shutdown()
262 return errors.Join(errs...)
263}
264
265func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
266 var wg sync.WaitGroup
267 // Initialize states for all configured MCPs
268 for name, m := range cfg.MCP {
269 if m.Disabled {
270 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
271 slog.Debug("skipping disabled mcp", "name", name)
272 continue
273 }
274
275 // Set initial starting state
276 updateMCPState(name, MCPStateStarting, nil, nil, 0)
277
278 wg.Add(1)
279 go func(name string, m config.MCPConfig) {
280 defer func() {
281 wg.Done()
282 if r := recover(); r != nil {
283 var err error
284 switch v := r.(type) {
285 case error:
286 err = v
287 case string:
288 err = fmt.Errorf("panic: %s", v)
289 default:
290 err = fmt.Errorf("panic: %v", v)
291 }
292 updateMCPState(name, MCPStateError, err, nil, 0)
293 slog.Error("panic in mcp client initialization", "error", err, "name", name)
294 }
295 }()
296
297 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
298 defer cancel()
299
300 c, err := createMCPSession(ctx, name, m, cfg.Resolver())
301 if err != nil {
302 return
303 }
304
305 mcpClients.Set(name, c)
306
307 tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
308 if err != nil {
309 slog.Error("error listing tools", "error", err)
310 updateMCPState(name, MCPStateError, err, nil, 0)
311 c.Close()
312 return
313 }
314
315 updateMcpTools(name, tools)
316 mcpClients.Set(name, c)
317 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
318 }(name, m)
319 }
320 wg.Wait()
321}
322
323// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
324func updateMcpTools(mcpName string, tools []tools.BaseTool) {
325 if len(tools) == 0 {
326 mcpClient2Tools.Del(mcpName)
327 } else {
328 mcpClient2Tools.Set(mcpName, tools)
329 }
330 for _, tools := range mcpClient2Tools.Seq2() {
331 for _, t := range tools {
332 mcpTools.Set(t.Name(), t)
333 }
334 }
335}
336
337func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
338 transport, err := createMCPTransport(m, resolver)
339 if err != nil {
340 updateMCPState(name, MCPStateError, err, nil, 0)
341 slog.Error("error creating mcp client", "error", err, "name", name)
342 return nil, err
343 }
344
345 client := mcp.NewClient(&mcp.Implementation{
346 Name: "crush",
347 Version: version.Version,
348 Title: "Crush",
349 }, &mcp.ClientOptions{
350 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
351 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
352 Type: MCPEventToolsListChanged,
353 Name: name,
354 })
355 },
356 })
357
358 timeout := mcpTimeout(m)
359 mcpCtx, cancel := context.WithCancel(ctx)
360 cancelTimer := time.AfterFunc(timeout, cancel)
361
362 session, err := client.Connect(mcpCtx, transport, nil)
363 if err != nil {
364 updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
365 slog.Error("error starting mcp client", "error", err, "name", name)
366 _ = session.Close()
367 cancel()
368 return nil, err
369 }
370
371 cancelTimer.Stop()
372 slog.Info("Initialized mcp client", "name", name)
373 return session, nil
374}
375
376func maybeTimeoutErr(err error, timeout time.Duration) error {
377 if errors.Is(err, context.Canceled) {
378 return fmt.Errorf("timed out after %s", timeout)
379 }
380 return err
381}
382
383func createMCPTransport(m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
384 switch m.Type {
385 case config.MCPStdio:
386 command, err := resolver.ResolveValue(m.Command)
387 if err != nil {
388 return nil, fmt.Errorf("invalid mcp command: %w", err)
389 }
390 if strings.TrimSpace(command) == "" {
391 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
392 }
393 cmd := exec.Command(home.Long(command), m.Args...)
394 cmd.Env = m.ResolvedEnv()
395 return &mcp.CommandTransport{
396 Command: cmd,
397 }, nil
398 case config.MCPHttp:
399 if strings.TrimSpace(m.URL) == "" {
400 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
401 }
402 client := &http.Client{
403 Transport: &headerRoundTripper{
404 headers: m.ResolvedHeaders(),
405 },
406 }
407 return &mcp.StreamableClientTransport{
408 Endpoint: m.URL,
409 HTTPClient: client,
410 }, nil
411 case config.MCPSSE:
412 if strings.TrimSpace(m.URL) == "" {
413 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
414 }
415 client := &http.Client{
416 Transport: &headerRoundTripper{
417 headers: m.ResolvedHeaders(),
418 },
419 }
420 return &mcp.SSEClientTransport{
421 Endpoint: m.URL,
422 HTTPClient: client,
423 }, nil
424 default:
425 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
426 }
427}
428
429type headerRoundTripper struct {
430 headers map[string]string
431}
432
433func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
434 for k, v := range rt.headers {
435 req.Header.Set(k, v)
436 }
437 return http.DefaultTransport.RoundTrip(req)
438}
439
440func mcpTimeout(m config.MCPConfig) time.Duration {
441 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
442}