1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/permission"
17 "github.com/charmbracelet/crush/internal/pubsub"
18 "github.com/charmbracelet/crush/internal/version"
19 "github.com/mark3labs/mcp-go/client"
20 "github.com/mark3labs/mcp-go/client/transport"
21 "github.com/mark3labs/mcp-go/mcp"
22)
23
24// MCPState represents the current state of an MCP client
25type MCPState int
26
27const (
28 MCPStateDisabled MCPState = iota
29 MCPStateStarting
30 MCPStateConnected
31 MCPStateError
32)
33
34func (s MCPState) String() string {
35 switch s {
36 case MCPStateDisabled:
37 return "disabled"
38 case MCPStateStarting:
39 return "starting"
40 case MCPStateConnected:
41 return "connected"
42 case MCPStateError:
43 return "error"
44 default:
45 return "unknown"
46 }
47}
48
49// MCPEventType represents the type of MCP event
50type MCPEventType string
51
52const (
53 MCPEventStateChanged MCPEventType = "state_changed"
54)
55
56// MCPEvent represents an event in the MCP system
57type MCPEvent struct {
58 Type MCPEventType
59 Name string
60 State MCPState
61 Error error
62 ToolCount int
63}
64
65// MCPClientInfo holds information about an MCP client's state
66type MCPClientInfo struct {
67 Name string
68 State MCPState
69 Error error
70 Client *client.Client
71 ToolCount int
72 ConnectedAt time.Time
73}
74
75var (
76 mcpToolsOnce sync.Once
77 mcpTools []tools.BaseTool
78 mcpClients = csync.NewMap[string, *client.Client]()
79 mcpStates = csync.NewMap[string, MCPClientInfo]()
80 mcpBroker = pubsub.NewBroker[MCPEvent]()
81)
82
83type McpTool struct {
84 mcpName string
85 tool mcp.Tool
86 permissions permission.Service
87 workingDir string
88}
89
90func (b *McpTool) Name() string {
91 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
92}
93
94func (b *McpTool) Info() tools.ToolInfo {
95 required := b.tool.InputSchema.Required
96 if required == nil {
97 required = make([]string, 0)
98 }
99 return tools.ToolInfo{
100 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
101 Description: b.tool.Description,
102 Parameters: b.tool.InputSchema.Properties,
103 Required: required,
104 }
105}
106
107func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
108 var args map[string]any
109 if err := json.Unmarshal([]byte(input), &args); err != nil {
110 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
111 }
112 c, ok := mcpClients.Get(name)
113 if !ok {
114 return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
115 }
116 result, err := c.CallTool(ctx, mcp.CallToolRequest{
117 Params: mcp.CallToolParams{
118 Name: toolName,
119 Arguments: args,
120 },
121 })
122 if err != nil {
123 return tools.NewTextErrorResponse(err.Error()), nil
124 }
125
126 output := make([]string, 0, len(result.Content))
127 for _, v := range result.Content {
128 if v, ok := v.(mcp.TextContent); ok {
129 output = append(output, v.Text)
130 } else {
131 output = append(output, fmt.Sprintf("%v", v))
132 }
133 }
134 return tools.NewTextResponse(strings.Join(output, "\n")), nil
135}
136
137func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
138 sessionID, messageID := tools.GetContextValues(ctx)
139 if sessionID == "" || messageID == "" {
140 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
141 }
142 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
143 p := b.permissions.Request(
144 permission.CreatePermissionRequest{
145 SessionID: sessionID,
146 ToolCallID: params.ID,
147 Path: b.workingDir,
148 ToolName: b.Info().Name,
149 Action: "execute",
150 Description: permissionDescription,
151 Params: params.Input,
152 },
153 )
154 if !p {
155 return tools.ToolResponse{}, permission.ErrorPermissionDenied
156 }
157
158 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
159}
160
161func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
162 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
163 if err != nil {
164 slog.Error("error listing tools", "error", err)
165 updateMCPState(name, MCPStateError, err, nil, 0)
166 c.Close()
167 mcpClients.Del(name)
168 return nil
169 }
170 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
171 for _, tool := range result.Tools {
172 mcpTools = append(mcpTools, &McpTool{
173 mcpName: name,
174 tool: tool,
175 permissions: permissions,
176 workingDir: workingDir,
177 })
178 }
179 return mcpTools
180}
181
182// SubscribeMCPEvents returns a channel for MCP events
183func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
184 return mcpBroker.Subscribe(ctx)
185}
186
187// GetMCPStates returns the current state of all MCP clients
188func GetMCPStates() map[string]MCPClientInfo {
189 states := make(map[string]MCPClientInfo)
190 for name, info := range mcpStates.Seq2() {
191 states[name] = info
192 }
193 return states
194}
195
196// GetMCPState returns the state of a specific MCP client
197func GetMCPState(name string) (MCPClientInfo, bool) {
198 return mcpStates.Get(name)
199}
200
201// updateMCPState updates the state of an MCP client and publishes an event
202func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
203 info := MCPClientInfo{
204 Name: name,
205 State: state,
206 Error: err,
207 Client: client,
208 ToolCount: toolCount,
209 }
210 if state == MCPStateConnected {
211 info.ConnectedAt = time.Now()
212 }
213 mcpStates.Set(name, info)
214
215 // Publish state change event
216 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
217 Type: MCPEventStateChanged,
218 Name: name,
219 State: state,
220 Error: err,
221 ToolCount: toolCount,
222 })
223}
224
225// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
226func CloseMCPClients() {
227 for c := range mcpClients.Seq() {
228 _ = c.Close()
229 }
230 mcpBroker.Shutdown()
231}
232
233var mcpInitRequest = mcp.InitializeRequest{
234 Params: mcp.InitializeParams{
235 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
236 ClientInfo: mcp.Implementation{
237 Name: "Crush",
238 Version: version.Version,
239 },
240 },
241}
242
243func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
244 var wg sync.WaitGroup
245 result := csync.NewSlice[tools.BaseTool]()
246
247 // Initialize states for all configured MCPs
248 for name, m := range cfg.MCP {
249 if m.Disabled {
250 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
251 slog.Debug("skipping disabled mcp", "name", name)
252 continue
253 }
254
255 // Set initial starting state
256 updateMCPState(name, MCPStateStarting, nil, nil, 0)
257
258 wg.Add(1)
259 go func(name string, m config.MCPConfig) {
260 defer func() {
261 wg.Done()
262 if r := recover(); r != nil {
263 var err error
264 switch v := r.(type) {
265 case error:
266 err = v
267 case string:
268 err = fmt.Errorf("panic: %s", v)
269 default:
270 err = fmt.Errorf("panic: %v", v)
271 }
272 updateMCPState(name, MCPStateError, err, nil, 0)
273 slog.Error("panic in mcp client initialization", "error", err, "name", name)
274 }
275 }()
276
277 ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
278 defer cancel()
279 c, err := createMcpClient(m)
280 if err != nil {
281 updateMCPState(name, MCPStateError, err, nil, 0)
282 slog.Error("error creating mcp client", "error", err, "name", name)
283 return
284 }
285 // Only call Start() for non-stdio clients, as stdio clients auto-start
286 if m.Type != config.MCPStdio {
287 if err := c.Start(ctx); err != nil {
288 updateMCPState(name, MCPStateError, err, nil, 0)
289 slog.Error("error starting mcp client", "error", err, "name", name)
290 _ = c.Close()
291 return
292 }
293 }
294 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
295 updateMCPState(name, MCPStateError, err, nil, 0)
296 slog.Error("error initializing mcp client", "error", err, "name", name)
297 _ = c.Close()
298 return
299 }
300
301 slog.Info("Initialized mcp client", "name", name)
302 mcpClients.Set(name, c)
303
304 tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
305 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
306 result.Append(tools...)
307 }(name, m)
308 }
309 wg.Wait()
310 return slices.Collect(result.Seq())
311}
312
313func createMcpClient(m config.MCPConfig) (*client.Client, error) {
314 switch m.Type {
315 case config.MCPStdio:
316 return client.NewStdioMCPClientWithOptions(
317 m.Command,
318 m.ResolvedEnv(),
319 m.Args,
320 transport.WithCommandLogger(mcpLogger{}),
321 )
322 case config.MCPHttp:
323 return client.NewStreamableHttpClient(
324 m.URL,
325 transport.WithHTTPHeaders(m.ResolvedHeaders()),
326 transport.WithHTTPLogger(mcpLogger{}),
327 )
328 case config.MCPSse:
329 return client.NewSSEMCPClient(
330 m.URL,
331 client.WithHeaders(m.ResolvedHeaders()),
332 transport.WithSSELogger(mcpLogger{}),
333 )
334 default:
335 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
336 }
337}
338
339// for MCP's clients.
340type mcpLogger struct{}
341
342func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
343func (l mcpLogger) Infof(format string, v ...any) { slog.Info(fmt.Sprintf(format, v...)) }