1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "slices"
9 "sync"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/llm/tools"
14
15 "github.com/charmbracelet/crush/internal/permission"
16 "github.com/charmbracelet/crush/internal/version"
17
18 "github.com/mark3labs/mcp-go/client"
19 "github.com/mark3labs/mcp-go/client/transport"
20 "github.com/mark3labs/mcp-go/mcp"
21)
22
23type mcpTool struct {
24 mcpName string
25 tool mcp.Tool
26 mcpConfig config.MCPConfig
27 permissions permission.Service
28 workingDir string
29}
30
31type MCPClient interface {
32 Initialize(
33 ctx context.Context,
34 request mcp.InitializeRequest,
35 ) (*mcp.InitializeResult, error)
36 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
37 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
38 Close() error
39}
40
41func (b *mcpTool) Name() string {
42 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
43}
44
45func (b *mcpTool) Info() tools.ToolInfo {
46 required := b.tool.InputSchema.Required
47 if required == nil {
48 required = make([]string, 0)
49 }
50 return tools.ToolInfo{
51 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
52 Description: b.tool.Description,
53 Parameters: b.tool.InputSchema.Properties,
54 Required: required,
55 }
56}
57
58func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
59 defer c.Close()
60 initRequest := mcp.InitializeRequest{}
61 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
62 initRequest.Params.ClientInfo = mcp.Implementation{
63 Name: "Crush",
64 Version: version.Version,
65 }
66
67 _, err := c.Initialize(ctx, initRequest)
68 if err != nil {
69 return tools.NewTextErrorResponse(err.Error()), nil
70 }
71
72 toolRequest := mcp.CallToolRequest{}
73 toolRequest.Params.Name = toolName
74 var args map[string]any
75 if err = json.Unmarshal([]byte(input), &args); err != nil {
76 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
77 }
78 toolRequest.Params.Arguments = args
79 result, err := c.CallTool(ctx, toolRequest)
80 if err != nil {
81 return tools.NewTextErrorResponse(err.Error()), nil
82 }
83
84 output := ""
85 for _, v := range result.Content {
86 if v, ok := v.(mcp.TextContent); ok {
87 output = v.Text
88 } else {
89 output = fmt.Sprintf("%v", v)
90 }
91 }
92
93 return tools.NewTextResponse(output), nil
94}
95
96func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
97 sessionID, messageID := tools.GetContextValues(ctx)
98 if sessionID == "" || messageID == "" {
99 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
100 }
101 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
102 p := b.permissions.Request(
103 permission.CreatePermissionRequest{
104 SessionID: sessionID,
105 Path: b.workingDir,
106 ToolName: b.Info().Name,
107 Action: "execute",
108 Description: permissionDescription,
109 Params: params.Input,
110 },
111 )
112 if !p {
113 return tools.ToolResponse{}, permission.ErrorPermissionDenied
114 }
115
116 switch b.mcpConfig.Type {
117 case config.MCPStdio:
118 c, err := client.NewStdioMCPClient(
119 b.mcpConfig.Command,
120 b.mcpConfig.ResolvedEnv(),
121 b.mcpConfig.Args...,
122 )
123 if err != nil {
124 return tools.NewTextErrorResponse(err.Error()), nil
125 }
126 return runTool(ctx, c, b.tool.Name, params.Input)
127 case config.MCPHttp:
128 c, err := client.NewStreamableHttpClient(
129 b.mcpConfig.URL,
130 transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()),
131 )
132 if err != nil {
133 return tools.NewTextErrorResponse(err.Error()), nil
134 }
135 return runTool(ctx, c, b.tool.Name, params.Input)
136 case config.MCPSse:
137 c, err := client.NewSSEMCPClient(
138 b.mcpConfig.URL,
139 client.WithHeaders(b.mcpConfig.ResolvedHeaders()),
140 )
141 if err != nil {
142 return tools.NewTextErrorResponse(err.Error()), nil
143 }
144 return runTool(ctx, c, b.tool.Name, params.Input)
145 }
146
147 return tools.NewTextErrorResponse("invalid mcp type"), nil
148}
149
150func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
151 return &mcpTool{
152 mcpName: name,
153 tool: tool,
154 mcpConfig: mcpConfig,
155 permissions: permissions,
156 workingDir: workingDir,
157 }
158}
159
160func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
161 var stdioTools []tools.BaseTool
162 initRequest := mcp.InitializeRequest{}
163 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
164 initRequest.Params.ClientInfo = mcp.Implementation{
165 Name: "Crush",
166 Version: version.Version,
167 }
168
169 _, err := c.Initialize(ctx, initRequest)
170 if err != nil {
171 slog.Error("error initializing mcp client", "error", err)
172 return stdioTools
173 }
174 toolsRequest := mcp.ListToolsRequest{}
175 tools, err := c.ListTools(ctx, toolsRequest)
176 if err != nil {
177 slog.Error("error listing tools", "error", err)
178 return stdioTools
179 }
180 for _, t := range tools.Tools {
181 stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir))
182 }
183 defer c.Close()
184 return stdioTools
185}
186
187var (
188 mcpToolsOnce sync.Once
189 mcpTools []tools.BaseTool
190)
191
192func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
193 mcpToolsOnce.Do(func() {
194 mcpTools = doGetMCPTools(ctx, permissions, cfg)
195 })
196 return mcpTools
197}
198
199func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
200 var wg sync.WaitGroup
201 result := csync.NewSlice[tools.BaseTool]()
202 for name, m := range cfg.MCP {
203 if m.Disabled {
204 slog.Debug("skipping disabled mcp", "name", name)
205 continue
206 }
207 wg.Add(1)
208 go func(name string, m config.MCPConfig) {
209 defer wg.Done()
210 switch m.Type {
211 case config.MCPStdio:
212 c, err := client.NewStdioMCPClient(
213 m.Command,
214 m.ResolvedEnv(),
215 m.Args...,
216 )
217 if err != nil {
218 slog.Error("error creating mcp client", "error", err)
219 return
220 }
221
222 result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
223 case config.MCPHttp:
224 c, err := client.NewStreamableHttpClient(
225 m.URL,
226 transport.WithHTTPHeaders(m.ResolvedHeaders()),
227 )
228 if err != nil {
229 slog.Error("error creating mcp client", "error", err)
230 return
231 }
232 result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
233 case config.MCPSse:
234 c, err := client.NewSSEMCPClient(
235 m.URL,
236 client.WithHeaders(m.ResolvedHeaders()),
237 )
238 if err != nil {
239 slog.Error("error creating mcp client", "error", err)
240 return
241 }
242 result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
243 }
244 }(name, m)
245 }
246 wg.Wait()
247 return slices.Collect(result.Seq())
248}