1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "slices"
9 "sync"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/llm/tools"
14 "github.com/charmbracelet/crush/internal/permission"
15 "github.com/charmbracelet/crush/internal/version"
16 "github.com/mark3labs/mcp-go/client"
17 "github.com/mark3labs/mcp-go/client/transport"
18 "github.com/mark3labs/mcp-go/mcp"
19)
20
21var (
22 mcpToolsOnce sync.Once
23 mcpTools []tools.BaseTool
24 mcpClients = csync.NewMap[string, *client.Client]()
25)
26
27type McpTool struct {
28 mcpName string
29 tool mcp.Tool
30 permissions permission.Service
31 workingDir string
32}
33
34func (b *McpTool) Name() string {
35 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
36}
37
38func (b *McpTool) Info() tools.ToolInfo {
39 required := b.tool.InputSchema.Required
40 if required == nil {
41 required = make([]string, 0)
42 }
43 return tools.ToolInfo{
44 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
45 Description: b.tool.Description,
46 Parameters: b.tool.InputSchema.Properties,
47 Required: required,
48 }
49}
50
51func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
52 var args map[string]any
53 if err := json.Unmarshal([]byte(input), &args); err != nil {
54 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
55 }
56 c, ok := mcpClients.Get(name)
57 if !ok {
58 return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
59 }
60 result, err := c.CallTool(ctx, mcp.CallToolRequest{
61 Params: mcp.CallToolParams{
62 Name: toolName,
63 Arguments: args,
64 },
65 })
66 if err != nil {
67 return tools.NewTextErrorResponse(err.Error()), nil
68 }
69
70 output := ""
71 for _, v := range result.Content {
72 if v, ok := v.(mcp.TextContent); ok {
73 output = v.Text
74 } else {
75 output = fmt.Sprintf("%v", v)
76 }
77 }
78
79 return tools.NewTextResponse(output), nil
80}
81
82func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
83 sessionID, messageID := tools.GetContextValues(ctx)
84 if sessionID == "" || messageID == "" {
85 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
86 }
87 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
88 p := b.permissions.Request(
89 permission.CreatePermissionRequest{
90 SessionID: sessionID,
91 ToolCallID: params.ID,
92 Path: b.workingDir,
93 ToolName: b.Info().Name,
94 Action: "execute",
95 Description: permissionDescription,
96 Params: params.Input,
97 },
98 )
99 if !p {
100 return tools.ToolResponse{}, permission.ErrorPermissionDenied
101 }
102
103 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
104}
105
106func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
107 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
108 if err != nil {
109 slog.Error("error listing tools", "error", err)
110 c.Close()
111 mcpClients.Del(name)
112 return nil
113 }
114 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
115 for _, tool := range result.Tools {
116 mcpTools = append(mcpTools, &McpTool{
117 mcpName: name,
118 tool: tool,
119 permissions: permissions,
120 workingDir: workingDir,
121 })
122 }
123 return mcpTools
124}
125
126// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
127func CloseMCPClients() {
128 for c := range mcpClients.Seq() {
129 _ = c.Close()
130 }
131}
132
133var mcpInitRequest = mcp.InitializeRequest{
134 Params: mcp.InitializeParams{
135 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
136 ClientInfo: mcp.Implementation{
137 Name: "Crush",
138 Version: version.Version,
139 },
140 },
141}
142
143func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
144 var wg sync.WaitGroup
145 result := csync.NewSlice[tools.BaseTool]()
146 for name, m := range cfg.MCP {
147 if m.Disabled {
148 slog.Debug("skipping disabled mcp", "name", name)
149 continue
150 }
151 wg.Add(1)
152 go func(name string, m config.MCPConfig) {
153 defer wg.Done()
154 c, err := createMcpClient(m)
155 if err != nil {
156 slog.Error("error creating mcp client", "error", err, "name", name)
157 return
158 }
159 if err := c.Start(ctx); err != nil {
160 slog.Error("error starting mcp client", "error", err, "name", name)
161 _ = c.Close()
162 return
163 }
164 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
165 slog.Error("error initializing mcp client", "error", err, "name", name)
166 _ = c.Close()
167 return
168 }
169
170 slog.Info("Initialized mcp client", "name", name)
171 mcpClients.Set(name, c)
172
173 result.Append(getTools(ctx, name, permissions, c, cfg.WorkingDir())...)
174 }(name, m)
175 }
176 wg.Wait()
177 return slices.Collect(result.Seq())
178}
179
180func createMcpClient(m config.MCPConfig) (*client.Client, error) {
181 switch m.Type {
182 case config.MCPStdio:
183 return client.NewStdioMCPClient(
184 m.Command,
185 m.ResolvedEnv(),
186 m.Args...,
187 )
188 case config.MCPHttp:
189 return client.NewStreamableHttpClient(
190 m.URL,
191 transport.WithHTTPHeaders(m.ResolvedHeaders()),
192 transport.WithLogger(mcpHTTPLogger{}),
193 )
194 case config.MCPSse:
195 return client.NewSSEMCPClient(
196 m.URL,
197 client.WithHeaders(m.ResolvedHeaders()),
198 )
199 default:
200 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
201 }
202}
203
204// for MCP's HTTP client.
205type mcpHTTPLogger struct{}
206
207func (l mcpHTTPLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
208func (l mcpHTTPLogger) Infof(format string, v ...any) { slog.Info(fmt.Sprintf(format, v...)) }