1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/llm/tools"
11
12 "github.com/charmbracelet/crush/internal/permission"
13 "github.com/charmbracelet/crush/internal/version"
14
15 "github.com/mark3labs/mcp-go/client"
16 "github.com/mark3labs/mcp-go/client/transport"
17 "github.com/mark3labs/mcp-go/mcp"
18)
19
20type mcpTool struct {
21 mcpName string
22 tool mcp.Tool
23 mcpConfig config.MCPConfig
24 permissions permission.Service
25}
26
27type MCPClient interface {
28 Initialize(
29 ctx context.Context,
30 request mcp.InitializeRequest,
31 ) (*mcp.InitializeResult, error)
32 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
33 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
34 Close() error
35}
36
37func (b *mcpTool) Name() string {
38 return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
39}
40
41func (b *mcpTool) Info() tools.ToolInfo {
42 required := b.tool.InputSchema.Required
43 if required == nil {
44 required = make([]string, 0)
45 }
46 return tools.ToolInfo{
47 Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
48 Description: b.tool.Description,
49 Parameters: b.tool.InputSchema.Properties,
50 Required: required,
51 }
52}
53
54func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
55 defer c.Close()
56 initRequest := mcp.InitializeRequest{}
57 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
58 initRequest.Params.ClientInfo = mcp.Implementation{
59 Name: "Crush",
60 Version: version.Version,
61 }
62
63 _, err := c.Initialize(ctx, initRequest)
64 if err != nil {
65 return tools.NewTextErrorResponse(err.Error()), nil
66 }
67
68 toolRequest := mcp.CallToolRequest{}
69 toolRequest.Params.Name = toolName
70 var args map[string]any
71 if err = json.Unmarshal([]byte(input), &args); err != nil {
72 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
73 }
74 toolRequest.Params.Arguments = args
75 result, err := c.CallTool(ctx, toolRequest)
76 if err != nil {
77 return tools.NewTextErrorResponse(err.Error()), nil
78 }
79
80 output := ""
81 for _, v := range result.Content {
82 if v, ok := v.(mcp.TextContent); ok {
83 output = v.Text
84 } else {
85 output = fmt.Sprintf("%v", v)
86 }
87 }
88
89 return tools.NewTextResponse(output), nil
90}
91
92func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
93 sessionID, messageID := tools.GetContextValues(ctx)
94 if sessionID == "" || messageID == "" {
95 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
96 }
97 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
98 p := b.permissions.Request(
99 permission.CreatePermissionRequest{
100 SessionID: sessionID,
101 Path: config.Get().WorkingDir(),
102 ToolName: b.Info().Name,
103 Action: "execute",
104 Description: permissionDescription,
105 Params: params.Input,
106 },
107 )
108 if !p {
109 return tools.NewTextErrorResponse("permission denied"), nil
110 }
111
112 switch b.mcpConfig.Type {
113 case config.MCPStdio:
114 c, err := client.NewStdioMCPClient(
115 b.mcpConfig.Command,
116 b.mcpConfig.Env,
117 b.mcpConfig.Args...,
118 )
119 if err != nil {
120 return tools.NewTextErrorResponse(err.Error()), nil
121 }
122 return runTool(ctx, c, b.tool.Name, params.Input)
123 case config.MCPHttp:
124 c, err := client.NewStreamableHttpClient(
125 b.mcpConfig.URL,
126 transport.WithHTTPHeaders(b.mcpConfig.Headers),
127 )
128 if err != nil {
129 return tools.NewTextErrorResponse(err.Error()), nil
130 }
131 return runTool(ctx, c, b.tool.Name, params.Input)
132 case config.MCPSse:
133 c, err := client.NewSSEMCPClient(
134 b.mcpConfig.URL,
135 client.WithHeaders(b.mcpConfig.Headers),
136 )
137 if err != nil {
138 return tools.NewTextErrorResponse(err.Error()), nil
139 }
140 return runTool(ctx, c, b.tool.Name, params.Input)
141 }
142
143 return tools.NewTextErrorResponse("invalid mcp type"), nil
144}
145
146func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig) tools.BaseTool {
147 return &mcpTool{
148 mcpName: name,
149 tool: tool,
150 mcpConfig: mcpConfig,
151 permissions: permissions,
152 }
153}
154
155var mcpTools []tools.BaseTool
156
157func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool {
158 var stdioTools []tools.BaseTool
159 initRequest := mcp.InitializeRequest{}
160 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
161 initRequest.Params.ClientInfo = mcp.Implementation{
162 Name: "Crush",
163 Version: version.Version,
164 }
165
166 _, err := c.Initialize(ctx, initRequest)
167 if err != nil {
168 slog.Error("error initializing mcp client", "error", err)
169 return stdioTools
170 }
171 toolsRequest := mcp.ListToolsRequest{}
172 tools, err := c.ListTools(ctx, toolsRequest)
173 if err != nil {
174 slog.Error("error listing tools", "error", err)
175 return stdioTools
176 }
177 for _, t := range tools.Tools {
178 stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
179 }
180 defer c.Close()
181 return stdioTools
182}
183
184func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
185 if len(mcpTools) > 0 {
186 return mcpTools
187 }
188 for name, m := range config.Get().MCP {
189 switch m.Type {
190 case config.MCPStdio:
191 c, err := client.NewStdioMCPClient(
192 m.Command,
193 m.Env,
194 m.Args...,
195 )
196 if err != nil {
197 slog.Error("error creating mcp client", "error", err)
198 continue
199 }
200
201 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
202 case config.MCPHttp:
203 c, err := client.NewStreamableHttpClient(
204 m.URL,
205 transport.WithHTTPHeaders(m.Headers),
206 )
207 if err != nil {
208 slog.Error("error creating mcp client", "error", err)
209 continue
210 }
211 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
212 case config.MCPSse:
213 c, err := client.NewSSEMCPClient(
214 m.URL,
215 client.WithHeaders(m.Headers),
216 )
217 if err != nil {
218 slog.Error("error creating mcp client", "error", err)
219 continue
220 }
221 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
222 }
223 }
224
225 return mcpTools
226}