1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "slices"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/csync"
14 "github.com/charmbracelet/crush/internal/llm/tools"
15 "github.com/charmbracelet/crush/internal/permission"
16 "github.com/charmbracelet/crush/internal/pubsub"
17 "github.com/charmbracelet/crush/internal/version"
18 "github.com/mark3labs/mcp-go/client"
19 "github.com/mark3labs/mcp-go/client/transport"
20 "github.com/mark3labs/mcp-go/mcp"
21)
22
23// MCPState represents the current state of an MCP client
24type MCPState int
25
26const (
27 MCPStateDisabled MCPState = iota
28 MCPStateStarting
29 MCPStateConnected
30 MCPStateError
31)
32
33func (s MCPState) String() string {
34 switch s {
35 case MCPStateDisabled:
36 return "disabled"
37 case MCPStateStarting:
38 return "starting"
39 case MCPStateConnected:
40 return "connected"
41 case MCPStateError:
42 return "error"
43 default:
44 return "unknown"
45 }
46}
47
48// MCPEventType represents the type of MCP event
49type MCPEventType string
50
51const (
52 MCPEventStateChanged MCPEventType = "state_changed"
53)
54
55// MCPEvent represents an event in the MCP system
56type MCPEvent struct {
57 Type MCPEventType
58 Name string
59 State MCPState
60 Error error
61 ToolCount int
62}
63
64// MCPClientInfo holds information about an MCP client's state
65type MCPClientInfo struct {
66 Name string
67 State MCPState
68 Error error
69 Client *client.Client
70 ToolCount int
71 ConnectedAt time.Time
72}
73
74var (
75 mcpToolsOnce sync.Once
76 mcpTools []tools.BaseTool
77 mcpClients = csync.NewMap[string, *client.Client]()
78 mcpStates = csync.NewMap[string, MCPClientInfo]()
79 mcpBroker = pubsub.NewBroker[MCPEvent]()
80)
81
82type McpTool struct {
83 mcpName string
84 tool mcp.Tool
85 permissions permission.Service
86 workingDir string
87}
88
89func (b *McpTool) Name() string {
90 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
91}
92
93func (b *McpTool) Info() tools.ToolInfo {
94 required := b.tool.InputSchema.Required
95 if required == nil {
96 required = make([]string, 0)
97 }
98 return tools.ToolInfo{
99 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
100 Description: b.tool.Description,
101 Parameters: b.tool.InputSchema.Properties,
102 Required: required,
103 }
104}
105
106func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
107 var args map[string]any
108 if err := json.Unmarshal([]byte(input), &args); err != nil {
109 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
110 }
111 c, ok := mcpClients.Get(name)
112 if !ok {
113 return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
114 }
115 result, err := c.CallTool(ctx, mcp.CallToolRequest{
116 Params: mcp.CallToolParams{
117 Name: toolName,
118 Arguments: args,
119 },
120 })
121 if err != nil {
122 return tools.NewTextErrorResponse(err.Error()), nil
123 }
124
125 output := ""
126 for _, v := range result.Content {
127 if v, ok := v.(mcp.TextContent); ok {
128 output = v.Text
129 } else {
130 output = fmt.Sprintf("%v", v)
131 }
132 }
133
134 return tools.NewTextResponse(output), nil
135}
136
137func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
138 sessionID, messageID := tools.GetContextValues(ctx)
139 if sessionID == "" || messageID == "" {
140 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
141 }
142 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
143 p := b.permissions.Request(
144 permission.CreatePermissionRequest{
145 SessionID: sessionID,
146 ToolCallID: params.ID,
147 Path: b.workingDir,
148 ToolName: b.Info().Name,
149 Action: "execute",
150 Description: permissionDescription,
151 Params: params.Input,
152 },
153 )
154 if !p {
155 return tools.ToolResponse{}, permission.ErrorPermissionDenied
156 }
157
158 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
159}
160
161func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
162 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
163 if err != nil {
164 slog.Error("error listing tools", "error", err)
165 updateMCPState(name, MCPStateError, err, nil, 0)
166 c.Close()
167 mcpClients.Del(name)
168 return nil
169 }
170 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
171 for _, tool := range result.Tools {
172 mcpTools = append(mcpTools, &McpTool{
173 mcpName: name,
174 tool: tool,
175 permissions: permissions,
176 workingDir: workingDir,
177 })
178 }
179 return mcpTools
180}
181
182// SubscribeMCPEvents returns a channel for MCP events
183func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
184 return mcpBroker.Subscribe(ctx)
185}
186
187// GetMCPStates returns the current state of all MCP clients
188func GetMCPStates() map[string]MCPClientInfo {
189 states := make(map[string]MCPClientInfo)
190 for name, info := range mcpStates.Seq2() {
191 states[name] = info
192 }
193 return states
194}
195
196// GetMCPState returns the state of a specific MCP client
197func GetMCPState(name string) (MCPClientInfo, bool) {
198 return mcpStates.Get(name)
199}
200
201// updateMCPState updates the state of an MCP client and publishes an event
202func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
203 info := MCPClientInfo{
204 Name: name,
205 State: state,
206 Error: err,
207 Client: client,
208 ToolCount: toolCount,
209 }
210 if state == MCPStateConnected {
211 info.ConnectedAt = time.Now()
212 }
213 mcpStates.Set(name, info)
214
215 // Publish state change event
216 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
217 Type: MCPEventStateChanged,
218 Name: name,
219 State: state,
220 Error: err,
221 ToolCount: toolCount,
222 })
223}
224
225// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
226func CloseMCPClients() {
227 for c := range mcpClients.Seq() {
228 _ = c.Close()
229 }
230 mcpBroker.Shutdown()
231}
232
233var mcpInitRequest = mcp.InitializeRequest{
234 Params: mcp.InitializeParams{
235 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
236 ClientInfo: mcp.Implementation{
237 Name: "Crush",
238 Version: version.Version,
239 },
240 },
241}
242
243func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
244 var wg sync.WaitGroup
245 result := csync.NewSlice[tools.BaseTool]()
246
247 // Initialize states for all configured MCPs
248 for name, m := range cfg.MCP {
249 if m.Disabled {
250 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
251 slog.Debug("skipping disabled mcp", "name", name)
252 continue
253 }
254
255 // Set initial starting state
256 updateMCPState(name, MCPStateStarting, nil, nil, 0)
257
258 wg.Add(1)
259 go func(name string, m config.MCPConfig) {
260 defer func() {
261 wg.Done()
262 if r := recover(); r != nil {
263 var err error
264 switch v := r.(type) {
265 case error:
266 err = v
267 case string:
268 err = fmt.Errorf("panic: %s", v)
269 default:
270 err = fmt.Errorf("panic: %v", v)
271 }
272 updateMCPState(name, MCPStateError, err, nil, 0)
273 slog.Error("panic in mcp client initialization", "error", err, "name", name)
274 }
275 }()
276
277 c, err := createMcpClient(m)
278 if err != nil {
279 updateMCPState(name, MCPStateError, err, nil, 0)
280 slog.Error("error creating mcp client", "error", err, "name", name)
281 return
282 }
283 if err := c.Start(ctx); err != nil {
284 updateMCPState(name, MCPStateError, err, nil, 0)
285 slog.Error("error starting mcp client", "error", err, "name", name)
286 _ = c.Close()
287 return
288 }
289 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
290 updateMCPState(name, MCPStateError, err, nil, 0)
291 slog.Error("error initializing mcp client", "error", err, "name", name)
292 _ = c.Close()
293 return
294 }
295
296 slog.Info("Initialized mcp client", "name", name)
297 mcpClients.Set(name, c)
298
299 tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
300 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
301 result.Append(tools...)
302 }(name, m)
303 }
304 wg.Wait()
305 return slices.Collect(result.Seq())
306}
307
308func createMcpClient(m config.MCPConfig) (*client.Client, error) {
309 switch m.Type {
310 case config.MCPStdio:
311 return client.NewStdioMCPClient(
312 m.Command,
313 m.ResolvedEnv(),
314 m.Args...,
315 )
316 case config.MCPHttp:
317 return client.NewStreamableHttpClient(
318 m.URL,
319 transport.WithHTTPHeaders(m.ResolvedHeaders()),
320 transport.WithLogger(mcpHTTPLogger{}),
321 )
322 case config.MCPSse:
323 return client.NewSSEMCPClient(
324 m.URL,
325 client.WithHeaders(m.ResolvedHeaders()),
326 )
327 default:
328 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
329 }
330}
331
332// for MCP's HTTP client.
333type mcpHTTPLogger struct{}
334
335func (l mcpHTTPLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
336func (l mcpHTTPLogger) Infof(format string, v ...any) { slog.Info(fmt.Sprintf(format, v...)) }