1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/permission"
17 "github.com/charmbracelet/crush/internal/pubsub"
18 "github.com/charmbracelet/crush/internal/version"
19 "github.com/mark3labs/mcp-go/client"
20 "github.com/mark3labs/mcp-go/client/transport"
21 "github.com/mark3labs/mcp-go/mcp"
22)
23
24// MCPState represents the current state of an MCP client
25type MCPState int
26
27const (
28 MCPStateDisabled MCPState = iota
29 MCPStateStarting
30 MCPStateConnected
31 MCPStateError
32)
33
34func (s MCPState) String() string {
35 switch s {
36 case MCPStateDisabled:
37 return "disabled"
38 case MCPStateStarting:
39 return "starting"
40 case MCPStateConnected:
41 return "connected"
42 case MCPStateError:
43 return "error"
44 default:
45 return "unknown"
46 }
47}
48
49// MCPEventType represents the type of MCP event
50type MCPEventType string
51
52const (
53 MCPEventStateChanged MCPEventType = "state_changed"
54)
55
56// MCPEvent represents an event in the MCP system
57type MCPEvent struct {
58 Type MCPEventType
59 Name string
60 State MCPState
61 Error error
62 ToolCount int
63}
64
65// MCPClientInfo holds information about an MCP client's state
66type MCPClientInfo struct {
67 Name string
68 State MCPState
69 Error error
70 Client *client.Client
71 ToolCount int
72 ConnectedAt time.Time
73}
74
75var (
76 mcpToolsOnce sync.Once
77 mcpTools []tools.BaseTool
78 mcpClients = csync.NewMap[string, *client.Client]()
79 mcpStates = csync.NewMap[string, MCPClientInfo]()
80 mcpBroker = pubsub.NewBroker[MCPEvent]()
81)
82
83type McpTool struct {
84 mcpName string
85 tool mcp.Tool
86 permissions permission.Service
87 workingDir string
88}
89
90func (b *McpTool) Name() string {
91 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
92}
93
94func (b *McpTool) Info() tools.ToolInfo {
95 required := b.tool.InputSchema.Required
96 if required == nil {
97 required = make([]string, 0)
98 }
99 return tools.ToolInfo{
100 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
101 Description: b.tool.Description,
102 Parameters: b.tool.InputSchema.Properties,
103 Required: required,
104 }
105}
106
107func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
108 var args map[string]any
109 if err := json.Unmarshal([]byte(input), &args); err != nil {
110 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
111 }
112 c, ok := mcpClients.Get(name)
113 if !ok {
114 return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
115 }
116 result, err := c.CallTool(ctx, mcp.CallToolRequest{
117 Params: mcp.CallToolParams{
118 Name: toolName,
119 Arguments: args,
120 },
121 })
122 if err != nil {
123 return tools.NewTextErrorResponse(err.Error()), nil
124 }
125
126 var output strings.Builder
127 for _, v := range result.Content {
128 if v, ok := v.(mcp.TextContent); ok {
129 output.WriteString(v.Text)
130 } else {
131 _, _ = fmt.Fprintf(&output, "%v: ", v)
132 }
133 }
134
135 return tools.NewTextResponse(output.String()), nil
136}
137
138func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
139 sessionID, messageID := tools.GetContextValues(ctx)
140 if sessionID == "" || messageID == "" {
141 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
142 }
143 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
144 p := b.permissions.Request(
145 permission.CreatePermissionRequest{
146 SessionID: sessionID,
147 ToolCallID: params.ID,
148 Path: b.workingDir,
149 ToolName: b.Info().Name,
150 Action: "execute",
151 Description: permissionDescription,
152 Params: params.Input,
153 },
154 )
155 if !p {
156 return tools.ToolResponse{}, permission.ErrorPermissionDenied
157 }
158
159 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
160}
161
162func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
163 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
164 if err != nil {
165 slog.Error("error listing tools", "error", err)
166 updateMCPState(name, MCPStateError, err, nil, 0)
167 c.Close()
168 mcpClients.Del(name)
169 return nil
170 }
171 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
172 for _, tool := range result.Tools {
173 mcpTools = append(mcpTools, &McpTool{
174 mcpName: name,
175 tool: tool,
176 permissions: permissions,
177 workingDir: workingDir,
178 })
179 }
180 return mcpTools
181}
182
183// SubscribeMCPEvents returns a channel for MCP events
184func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
185 return mcpBroker.Subscribe(ctx)
186}
187
188// GetMCPStates returns the current state of all MCP clients
189func GetMCPStates() map[string]MCPClientInfo {
190 states := make(map[string]MCPClientInfo)
191 for name, info := range mcpStates.Seq2() {
192 states[name] = info
193 }
194 return states
195}
196
197// GetMCPState returns the state of a specific MCP client
198func GetMCPState(name string) (MCPClientInfo, bool) {
199 return mcpStates.Get(name)
200}
201
202// updateMCPState updates the state of an MCP client and publishes an event
203func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
204 info := MCPClientInfo{
205 Name: name,
206 State: state,
207 Error: err,
208 Client: client,
209 ToolCount: toolCount,
210 }
211 if state == MCPStateConnected {
212 info.ConnectedAt = time.Now()
213 }
214 mcpStates.Set(name, info)
215
216 // Publish state change event
217 mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
218 Type: MCPEventStateChanged,
219 Name: name,
220 State: state,
221 Error: err,
222 ToolCount: toolCount,
223 })
224}
225
226// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
227func CloseMCPClients() {
228 for c := range mcpClients.Seq() {
229 _ = c.Close()
230 }
231 mcpBroker.Shutdown()
232}
233
234var mcpInitRequest = mcp.InitializeRequest{
235 Params: mcp.InitializeParams{
236 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
237 ClientInfo: mcp.Implementation{
238 Name: "Crush",
239 Version: version.Version,
240 },
241 },
242}
243
244func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
245 var wg sync.WaitGroup
246 result := csync.NewSlice[tools.BaseTool]()
247
248 // Initialize states for all configured MCPs
249 for name, m := range cfg.MCP {
250 if m.Disabled {
251 updateMCPState(name, MCPStateDisabled, nil, nil, 0)
252 slog.Debug("skipping disabled mcp", "name", name)
253 continue
254 }
255
256 // Set initial starting state
257 updateMCPState(name, MCPStateStarting, nil, nil, 0)
258
259 wg.Add(1)
260 go func(name string, m config.MCPConfig) {
261 defer func() {
262 wg.Done()
263 if r := recover(); r != nil {
264 var err error
265 switch v := r.(type) {
266 case error:
267 err = v
268 case string:
269 err = fmt.Errorf("panic: %s", v)
270 default:
271 err = fmt.Errorf("panic: %v", v)
272 }
273 updateMCPState(name, MCPStateError, err, nil, 0)
274 slog.Error("panic in mcp client initialization", "error", err, "name", name)
275 }
276 }()
277
278 ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
279 defer cancel()
280 c, err := createMcpClient(m)
281 if err != nil {
282 updateMCPState(name, MCPStateError, err, nil, 0)
283 slog.Error("error creating mcp client", "error", err, "name", name)
284 return
285 }
286 if err := c.Start(ctx); err != nil {
287 updateMCPState(name, MCPStateError, err, nil, 0)
288 slog.Error("error starting mcp client", "error", err, "name", name)
289 _ = c.Close()
290 return
291 }
292 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
293 updateMCPState(name, MCPStateError, err, nil, 0)
294 slog.Error("error initializing mcp client", "error", err, "name", name)
295 _ = c.Close()
296 return
297 }
298
299 slog.Info("Initialized mcp client", "name", name)
300 mcpClients.Set(name, c)
301
302 tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
303 updateMCPState(name, MCPStateConnected, nil, c, len(tools))
304 result.Append(tools...)
305 }(name, m)
306 }
307 wg.Wait()
308 return slices.Collect(result.Seq())
309}
310
311func createMcpClient(m config.MCPConfig) (*client.Client, error) {
312 switch m.Type {
313 case config.MCPStdio:
314 return client.NewStdioMCPClientWithOptions(
315 m.Command,
316 m.ResolvedEnv(),
317 m.Args,
318 transport.WithCommandLogger(mcpLogger{}),
319 )
320 case config.MCPHttp:
321 return client.NewStreamableHttpClient(
322 m.URL,
323 transport.WithHTTPHeaders(m.ResolvedHeaders()),
324 transport.WithHTTPLogger(mcpLogger{}),
325 )
326 case config.MCPSse:
327 return client.NewSSEMCPClient(
328 m.URL,
329 client.WithHeaders(m.ResolvedHeaders()),
330 transport.WithSSELogger(mcpLogger{}),
331 )
332 default:
333 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
334 }
335}
336
337// for MCP's clients.
338type mcpLogger struct{}
339
340func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
341func (l mcpLogger) Infof(format string, v ...any) { slog.Info(fmt.Sprintf(format, v...)) }