1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "slices"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/csync"
14 "github.com/charmbracelet/crush/internal/llm/tools"
15 "github.com/charmbracelet/crush/internal/pubsub"
16 "github.com/charmbracelet/crush/internal/version"
17
18 "github.com/charmbracelet/crush/internal/permission"
19
20 "github.com/mark3labs/mcp-go/client"
21 "github.com/mark3labs/mcp-go/client/transport"
22 "github.com/mark3labs/mcp-go/mcp"
23)
24
25// MCPState represents the current state of an MCP client
26type MCPState int
27
28const (
29 MCPStateDisabled MCPState = iota
30 MCPStateStarting
31 MCPStateConnected
32 MCPStateError
33)
34
35func (s MCPState) String() string {
36 switch s {
37 case MCPStateDisabled:
38 return "disabled"
39 case MCPStateStarting:
40 return "starting"
41 case MCPStateConnected:
42 return "connected"
43 case MCPStateError:
44 return "error"
45 default:
46 return "unknown"
47 }
48}
49
50// MCPEventType represents the type of MCP event
51type MCPEventType string
52
53const (
54 MCPEventStateChanged MCPEventType = "state_changed"
55)
56
57// MCPEvent represents an event in the MCP system
58type MCPEvent struct {
59 Type MCPEventType
60 Name string
61 State MCPState
62 Error error
63 ToolCount int
64}
65
66// MCPClientInfo holds information about an MCP client's state
67type MCPClientInfo struct {
68 Name string
69 State MCPState
70 Error error
71 Client *client.Client
72 ToolCount int
73 ConnectedAt time.Time
74}
75
76var (
77 mcpToolsOnce sync.Once
78 mcpTools []tools.BaseTool
79 mcpClients = csync.NewMap[string, *client.Client]()
80 mcpStates = csync.NewMap[string, MCPClientInfo]()
81 mcpBroker = pubsub.NewBroker[MCPEvent]()
82)
83
84type McpTool struct {
85 mcpName string
86 tool mcp.Tool
87 permissions permission.Service
88 workingDir string
89}
90
91func (b *McpTool) Name() string {
92 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
93}
94
95func (b *McpTool) Info() tools.ToolInfo {
96 required := b.tool.InputSchema.Required
97 if required == nil {
98 required = make([]string, 0)
99 }
100 return tools.ToolInfo{
101 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
102 Description: b.tool.Description,
103 Parameters: b.tool.InputSchema.Properties,
104 Required: required,
105 }
106}
107
108func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
109 var args map[string]any
110 if err := json.Unmarshal([]byte(input), &args); err != nil {
111 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
112 }
113 c, ok := mcpClients.Get(name)
114 if !ok {
115 return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
116 }
117 result, err := c.CallTool(ctx, mcp.CallToolRequest{
118 Params: mcp.CallToolParams{
119 Name: toolName,
120 Arguments: args,
121 },
122 })
123 if err != nil {
124 return tools.NewTextErrorResponse(err.Error()), nil
125 }
126
127 output := ""
128 for _, v := range result.Content {
129 if v, ok := v.(mcp.TextContent); ok {
130 output = v.Text
131 } else {
132 output = fmt.Sprintf("%v", v)
133 }
134 }
135
136 return tools.NewTextResponse(output), nil
137}
138
139func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
140 sessionID, messageID := tools.GetContextValues(ctx)
141 if sessionID == "" || messageID == "" {
142 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
143 }
144 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
145 p := b.permissions.Request(
146 permission.CreatePermissionRequest{
147 SessionID: sessionID,
148 ToolCallID: params.ID,
149 Path: b.workingDir,
150 ToolName: b.Info().Name,
151 Action: "execute",
152 Description: permissionDescription,
153 Params: params.Input,
154 },
155 )
156 if !p {
157 return tools.ToolResponse{}, permission.ErrorPermissionDenied
158 }
159
160 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
161}
162
163func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
164 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
165 if err != nil {
166 slog.Error("error listing tools", "error", err)
167 updateMCPState(name, MCPStateError, err, nil, 0)
168 c.Close()
169 mcpClients.Del(name)
170 return nil
171 }
172 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
173 for _, tool := range result.Tools {
174 mcpTools = append(mcpTools, &McpTool{
175 mcpName: name,
176 tool: tool,
177 permissions: permissions,
178 workingDir: workingDir,
179 })
180 }
181 return mcpTools
182}
183
184// SubscribeMCPEvents returns a channel for MCP events
185func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
186 return mcpBroker.Subscribe(ctx)
187}
188
189// GetMCPStates returns the current state of all MCP clients
190func GetMCPStates() map[string]MCPClientInfo {
191 states := make(map[string]MCPClientInfo)
192 for name, info := range mcpStates.Seq2() {
193 states[name] = info
194 }
195 return states
196}
197
198// GetMCPState returns the state of a specific MCP client
199func GetMCPState(name string) (MCPClientInfo, bool) {
200 return mcpStates.Get(name)
201}
202
203// updateMCPState updates the state of an MCP client and publishes an event
204func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
205 info := MCPClientInfo{
206 Name: name,
207 State: state,
208 Error: err,
209 Client: client,
210 ToolCount: toolCount,
211 }
212 if state == MCPStateConnected {
213 info.ConnectedAt = time.Now()
214 }
215 mcpStates.Set(name, info)
216
217 // Publish state change event
218 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
219 Type: MCPEventStateChanged,
220 Name: name,
221 State: state,
222 Error: err,
223 ToolCount: toolCount,
224 })
225}
226
227// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
228func CloseMCPClients() {
229 for c := range mcpClients.Seq() {
230 _ = c.Close()
231 }
232 mcpBroker.Shutdown()
233}
234
235var mcpInitRequest = mcp.InitializeRequest{
236 Params: mcp.InitializeParams{
237 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
238 ClientInfo: mcp.Implementation{
239 Name: "Crush",
240 Version: version.Version,
241 },
242 },
243}
244
245func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
246 var wg sync.WaitGroup
247 result := csync.NewSlice[tools.BaseTool]()
248
249 // Initialize states for all configured MCPs
250 for name, m := range cfg.MCP {
251 if m.Disabled {
252 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
253 slog.Debug("skipping disabled mcp", "name", name)
254 continue
255 }
256
257 // Set initial starting state
258 updateMCPState(name, MCPStateStarting, nil, nil, 0)
259
260 wg.Add(1)
261 go func(name string, m config.MCPConfig) {
262 defer func() {
263 wg.Done()
264 if r := recover(); r != nil {
265 var err error
266 switch v := r.(type) {
267 case error:
268 err = v
269 case string:
270 err = fmt.Errorf("panic: %s", v)
271 default:
272 err = fmt.Errorf("panic: %v", v)
273 }
274 updateMCPState(name, MCPStateError, err, nil, 0)
275 slog.Error("panic in mcp client initialization", "error", err, "name", name)
276 }
277 }()
278
279 c, err := createMcpClient(m)
280 if err != nil {
281 updateMCPState(name, MCPStateError, err, nil, 0)
282 slog.Error("error creating mcp client", "error", err, "name", name)
283 return
284 }
285 if err := c.Start(ctx); err != nil {
286 updateMCPState(name, MCPStateError, err, nil, 0)
287 slog.Error("error starting mcp client", "error", err, "name", name)
288 _ = c.Close()
289 return
290 }
291 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
292 updateMCPState(name, MCPStateError, err, nil, 0)
293 slog.Error("error initializing mcp client", "error", err, "name", name)
294 _ = c.Close()
295 return
296 }
297
298 slog.Info("Initialized mcp client", "name", name)
299 mcpClients.Set(name, c)
300
301 tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
302 toolCount := len(tools)
303 updateMCPState(name, MCPStateConnected, nil, c, toolCount)
304 result.Append(tools...)
305 }(name, m)
306 }
307 wg.Wait()
308 return slices.Collect(result.Seq())
309}
310
311func createMcpClient(m config.MCPConfig) (*client.Client, error) {
312 switch m.Type {
313 case config.MCPStdio:
314 return client.NewStdioMCPClient(
315 m.Command,
316 m.ResolvedEnv(),
317 m.Args...,
318 )
319 case config.MCPHttp:
320 return client.NewStreamableHttpClient(
321 m.URL,
322 transport.WithHTTPHeaders(m.ResolvedHeaders()),
323 )
324 case config.MCPSse:
325 return client.NewSSEMCPClient(
326 m.URL,
327 client.WithHeaders(m.ResolvedHeaders()),
328 )
329 default:
330 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
331 }
332}