1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/llm/tools"
11
12 "github.com/charmbracelet/crush/internal/permission"
13 "github.com/charmbracelet/crush/internal/version"
14
15 "github.com/mark3labs/mcp-go/client"
16 "github.com/mark3labs/mcp-go/client/transport"
17 "github.com/mark3labs/mcp-go/mcp"
18)
19
20type mcpTool struct {
21 mcpName string
22 tool mcp.Tool
23 mcpConfig config.MCPConfig
24 permissions permission.Service
25 workingDir string
26}
27
28type MCPClient interface {
29 Initialize(
30 ctx context.Context,
31 request mcp.InitializeRequest,
32 ) (*mcp.InitializeResult, error)
33 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
34 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
35 Close() error
36}
37
38func (b *mcpTool) Name() string {
39 return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
40}
41
42func (b *mcpTool) Info() tools.ToolInfo {
43 required := b.tool.InputSchema.Required
44 if required == nil {
45 required = make([]string, 0)
46 }
47 return tools.ToolInfo{
48 Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
49 Description: b.tool.Description,
50 Parameters: b.tool.InputSchema.Properties,
51 Required: required,
52 }
53}
54
55func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
56 defer c.Close()
57 initRequest := mcp.InitializeRequest{}
58 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
59 initRequest.Params.ClientInfo = mcp.Implementation{
60 Name: "Crush",
61 Version: version.Version,
62 }
63
64 _, err := c.Initialize(ctx, initRequest)
65 if err != nil {
66 return tools.NewTextErrorResponse(err.Error()), nil
67 }
68
69 toolRequest := mcp.CallToolRequest{}
70 toolRequest.Params.Name = toolName
71 var args map[string]any
72 if err = json.Unmarshal([]byte(input), &args); err != nil {
73 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
74 }
75 toolRequest.Params.Arguments = args
76 result, err := c.CallTool(ctx, toolRequest)
77 if err != nil {
78 return tools.NewTextErrorResponse(err.Error()), nil
79 }
80
81 output := ""
82 for _, v := range result.Content {
83 if v, ok := v.(mcp.TextContent); ok {
84 output = v.Text
85 } else {
86 output = fmt.Sprintf("%v", v)
87 }
88 }
89
90 return tools.NewTextResponse(output), nil
91}
92
93func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
94 sessionID, messageID := tools.GetContextValues(ctx)
95 if sessionID == "" || messageID == "" {
96 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
97 }
98 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
99 p := b.permissions.Request(
100 permission.CreatePermissionRequest{
101 SessionID: sessionID,
102 Path: b.workingDir,
103 ToolName: b.Info().Name,
104 Action: "execute",
105 Description: permissionDescription,
106 Params: params.Input,
107 },
108 )
109 if !p {
110 return tools.NewTextErrorResponse("permission denied"), nil
111 }
112
113 switch b.mcpConfig.Type {
114 case config.MCPStdio:
115 c, err := client.NewStdioMCPClient(
116 b.mcpConfig.Command,
117 b.mcpConfig.Env,
118 b.mcpConfig.Args...,
119 )
120 if err != nil {
121 return tools.NewTextErrorResponse(err.Error()), nil
122 }
123 return runTool(ctx, c, b.tool.Name, params.Input)
124 case config.MCPHttp:
125 c, err := client.NewStreamableHttpClient(
126 b.mcpConfig.URL,
127 transport.WithHTTPHeaders(b.mcpConfig.Headers),
128 )
129 if err != nil {
130 return tools.NewTextErrorResponse(err.Error()), nil
131 }
132 return runTool(ctx, c, b.tool.Name, params.Input)
133 case config.MCPSse:
134 c, err := client.NewSSEMCPClient(
135 b.mcpConfig.URL,
136 client.WithHeaders(b.mcpConfig.Headers),
137 )
138 if err != nil {
139 return tools.NewTextErrorResponse(err.Error()), nil
140 }
141 return runTool(ctx, c, b.tool.Name, params.Input)
142 }
143
144 return tools.NewTextErrorResponse("invalid mcp type"), nil
145}
146
147func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
148 return &mcpTool{
149 mcpName: name,
150 tool: tool,
151 mcpConfig: mcpConfig,
152 permissions: permissions,
153 workingDir: workingDir,
154 }
155}
156
157var mcpTools []tools.BaseTool
158
159func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
160 var stdioTools []tools.BaseTool
161 initRequest := mcp.InitializeRequest{}
162 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
163 initRequest.Params.ClientInfo = mcp.Implementation{
164 Name: "Crush",
165 Version: version.Version,
166 }
167
168 _, err := c.Initialize(ctx, initRequest)
169 if err != nil {
170 slog.Error("error initializing mcp client", "error", err)
171 return stdioTools
172 }
173 toolsRequest := mcp.ListToolsRequest{}
174 tools, err := c.ListTools(ctx, toolsRequest)
175 if err != nil {
176 slog.Error("error listing tools", "error", err)
177 return stdioTools
178 }
179 for _, t := range tools.Tools {
180 stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir))
181 }
182 defer c.Close()
183 return stdioTools
184}
185
186func GetMcpTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
187 if len(mcpTools) > 0 {
188 return mcpTools
189 }
190 for name, m := range cfg.MCP {
191 switch m.Type {
192 case config.MCPStdio:
193 c, err := client.NewStdioMCPClient(
194 m.Command,
195 m.Env,
196 m.Args...,
197 )
198 if err != nil {
199 slog.Error("error creating mcp client", "error", err)
200 continue
201 }
202
203 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
204 case config.MCPHttp:
205 c, err := client.NewStreamableHttpClient(
206 m.URL,
207 transport.WithHTTPHeaders(m.Headers),
208 )
209 if err != nil {
210 slog.Error("error creating mcp client", "error", err)
211 continue
212 }
213 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
214 case config.MCPSse:
215 c, err := client.NewSSEMCPClient(
216 m.URL,
217 client.WithHeaders(m.Headers),
218 )
219 if err != nil {
220 slog.Error("error creating mcp client", "error", err)
221 continue
222 }
223 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
224 }
225 }
226
227 return mcpTools
228}