1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "sync"
9
10 "github.com/charmbracelet/crush/internal/config"
11 "github.com/charmbracelet/crush/internal/llm/tools"
12
13 "github.com/charmbracelet/crush/internal/permission"
14 "github.com/charmbracelet/crush/internal/version"
15
16 "github.com/mark3labs/mcp-go/client"
17 "github.com/mark3labs/mcp-go/client/transport"
18 "github.com/mark3labs/mcp-go/mcp"
19)
20
21type mcpTool struct {
22 mcpName string
23 tool mcp.Tool
24 mcpConfig config.MCPConfig
25 permissions permission.Service
26 workingDir string
27}
28
29type MCPClient interface {
30 Initialize(
31 ctx context.Context,
32 request mcp.InitializeRequest,
33 ) (*mcp.InitializeResult, error)
34 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
35 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
36 Close() error
37}
38
39func (b *mcpTool) Name() string {
40 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
41}
42
43func (b *mcpTool) Info() tools.ToolInfo {
44 required := b.tool.InputSchema.Required
45 if required == nil {
46 required = make([]string, 0)
47 }
48 return tools.ToolInfo{
49 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
50 Description: b.tool.Description,
51 Parameters: b.tool.InputSchema.Properties,
52 Required: required,
53 }
54}
55
56func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
57 defer c.Close()
58 initRequest := mcp.InitializeRequest{}
59 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
60 initRequest.Params.ClientInfo = mcp.Implementation{
61 Name: "Crush",
62 Version: version.Version,
63 }
64
65 _, err := c.Initialize(ctx, initRequest)
66 if err != nil {
67 return tools.NewTextErrorResponse(err.Error()), nil
68 }
69
70 toolRequest := mcp.CallToolRequest{}
71 toolRequest.Params.Name = toolName
72 var args map[string]any
73 if err = json.Unmarshal([]byte(input), &args); err != nil {
74 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
75 }
76 toolRequest.Params.Arguments = args
77 result, err := c.CallTool(ctx, toolRequest)
78 if err != nil {
79 return tools.NewTextErrorResponse(err.Error()), nil
80 }
81
82 output := ""
83 for _, v := range result.Content {
84 if v, ok := v.(mcp.TextContent); ok {
85 output = v.Text
86 } else {
87 output = fmt.Sprintf("%v", v)
88 }
89 }
90
91 return tools.NewTextResponse(output), nil
92}
93
94func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
95 sessionID, messageID := tools.GetContextValues(ctx)
96 if sessionID == "" || messageID == "" {
97 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
98 }
99 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
100 p := b.permissions.Request(
101 permission.CreatePermissionRequest{
102 SessionID: sessionID,
103 ToolCallID: params.ID,
104 Path: b.workingDir,
105 ToolName: b.Info().Name,
106 Action: "execute",
107 Description: permissionDescription,
108 Params: params.Input,
109 },
110 )
111 if !p {
112 return tools.ToolResponse{}, permission.ErrorPermissionDenied
113 }
114
115 switch b.mcpConfig.Type {
116 case config.MCPStdio:
117 c, err := client.NewStdioMCPClient(
118 b.mcpConfig.Command,
119 b.mcpConfig.ResolvedEnv(),
120 b.mcpConfig.Args...,
121 )
122 if err != nil {
123 return tools.NewTextErrorResponse(err.Error()), nil
124 }
125 return runTool(ctx, c, b.tool.Name, params.Input)
126 case config.MCPHttp:
127 c, err := client.NewStreamableHttpClient(
128 b.mcpConfig.URL,
129 transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()),
130 )
131 if err != nil {
132 return tools.NewTextErrorResponse(err.Error()), nil
133 }
134 return runTool(ctx, c, b.tool.Name, params.Input)
135 case config.MCPSse:
136 c, err := client.NewSSEMCPClient(
137 b.mcpConfig.URL,
138 client.WithHeaders(b.mcpConfig.ResolvedHeaders()),
139 )
140 if err != nil {
141 return tools.NewTextErrorResponse(err.Error()), nil
142 }
143 return runTool(ctx, c, b.tool.Name, params.Input)
144 }
145
146 return tools.NewTextErrorResponse("invalid mcp type"), nil
147}
148
149func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
150 return &mcpTool{
151 mcpName: name,
152 tool: tool,
153 mcpConfig: mcpConfig,
154 permissions: permissions,
155 workingDir: workingDir,
156 }
157}
158
159func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
160 var stdioTools []tools.BaseTool
161 initRequest := mcp.InitializeRequest{}
162 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
163 initRequest.Params.ClientInfo = mcp.Implementation{
164 Name: "Crush",
165 Version: version.Version,
166 }
167
168 _, err := c.Initialize(ctx, initRequest)
169 if err != nil {
170 slog.Error("error initializing mcp client", "error", err)
171 return stdioTools
172 }
173 toolsRequest := mcp.ListToolsRequest{}
174 tools, err := c.ListTools(ctx, toolsRequest)
175 if err != nil {
176 slog.Error("error listing tools", "error", err)
177 return stdioTools
178 }
179 for _, t := range tools.Tools {
180 stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir))
181 }
182 defer c.Close()
183 return stdioTools
184}
185
186var (
187 mcpToolsOnce sync.Once
188 mcpTools []tools.BaseTool
189)
190
191func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
192 mcpToolsOnce.Do(func() {
193 mcpTools = doGetMCPTools(ctx, permissions, cfg)
194 })
195 return mcpTools
196}
197
198func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
199 var mu sync.Mutex
200 var wg sync.WaitGroup
201 var result []tools.BaseTool
202 for name, m := range cfg.MCP {
203 if m.Disabled {
204 slog.Debug("skipping disabled mcp", "name", name)
205 continue
206 }
207 wg.Add(1)
208 go func(name string, m config.MCPConfig) {
209 defer wg.Done()
210 switch m.Type {
211 case config.MCPStdio:
212 c, err := client.NewStdioMCPClient(
213 m.Command,
214 m.ResolvedEnv(),
215 m.Args...,
216 )
217 if err != nil {
218 slog.Error("error creating mcp client", "error", err)
219 return
220 }
221
222 mu.Lock()
223 result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
224 mu.Unlock()
225 case config.MCPHttp:
226 c, err := client.NewStreamableHttpClient(
227 m.URL,
228 transport.WithHTTPHeaders(m.ResolvedHeaders()),
229 )
230 if err != nil {
231 slog.Error("error creating mcp client", "error", err)
232 return
233 }
234 mu.Lock()
235 result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
236 mu.Unlock()
237 case config.MCPSse:
238 c, err := client.NewSSEMCPClient(
239 m.URL,
240 client.WithHeaders(m.ResolvedHeaders()),
241 )
242 if err != nil {
243 slog.Error("error creating mcp client", "error", err)
244 return
245 }
246 mu.Lock()
247 result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
248 mu.Unlock()
249 }
250 }(name, m)
251 }
252 wg.Wait()
253 return result
254}