1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "slices"
9 "sync"
10
11 "github.com/charmbracelet/crush/internal/csync"
12 "github.com/charmbracelet/crush/internal/llm/tools"
13 "github.com/charmbracelet/crush/internal/resolver"
14 "github.com/charmbracelet/crush/internal/version"
15
16 "github.com/charmbracelet/crush/internal/permission"
17
18 "github.com/mark3labs/mcp-go/client"
19 "github.com/mark3labs/mcp-go/client/transport"
20 "github.com/mark3labs/mcp-go/mcp"
21)
22
23type MCPType string
24
25const (
26 MCPStdio MCPType = "stdio"
27 MCPSse MCPType = "sse"
28 MCPHttp MCPType = "http"
29)
30
31type MCPConfig struct {
32 Command string `json:"command,omitempty" `
33 Env map[string]string `json:"env,omitempty"`
34 Args []string `json:"args,omitempty"`
35 Type MCPType `json:"type"`
36 URL string `json:"url,omitempty"`
37 Disabled bool `json:"disabled,omitempty"`
38
39 Headers map[string]string `json:"headers,omitempty"`
40}
41
42type mcpTool struct {
43 mcpName string
44 tool mcp.Tool
45 mcpConfig MCPConfig
46 permissions permission.Service
47 workingDir string
48}
49
50type MCPClient interface {
51 Initialize(
52 ctx context.Context,
53 request mcp.InitializeRequest,
54 ) (*mcp.InitializeResult, error)
55 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
56 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
57 Close() error
58}
59
60func (b *mcpTool) Name() string {
61 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
62}
63
64func (b *mcpTool) Info() tools.ToolInfo {
65 required := b.tool.InputSchema.Required
66 if required == nil {
67 required = make([]string, 0)
68 }
69 return tools.ToolInfo{
70 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
71 Description: b.tool.Description,
72 Parameters: b.tool.InputSchema.Properties,
73 Required: required,
74 }
75}
76
77func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
78 defer c.Close()
79 initRequest := mcp.InitializeRequest{}
80 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
81 initRequest.Params.ClientInfo = mcp.Implementation{
82 Name: "crush",
83 Version: version.Version,
84 }
85
86 _, err := c.Initialize(ctx, initRequest)
87 if err != nil {
88 return tools.NewTextErrorResponse(err.Error()), nil
89 }
90
91 toolRequest := mcp.CallToolRequest{}
92 toolRequest.Params.Name = toolName
93 var args map[string]any
94 if err = json.Unmarshal([]byte(input), &args); err != nil {
95 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
96 }
97 toolRequest.Params.Arguments = args
98 result, err := c.CallTool(ctx, toolRequest)
99 if err != nil {
100 return tools.NewTextErrorResponse(err.Error()), nil
101 }
102
103 output := ""
104 for _, v := range result.Content {
105 if v, ok := v.(mcp.TextContent); ok {
106 output = v.Text
107 } else {
108 output = fmt.Sprintf("%v", v)
109 }
110 }
111
112 return tools.NewTextResponse(output), nil
113}
114
115func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
116 sessionID, messageID := tools.GetContextValues(ctx)
117 if sessionID == "" || messageID == "" {
118 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
119 }
120 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
121 p := b.permissions.Request(
122 permission.CreatePermissionRequest{
123 SessionID: sessionID,
124 ToolCallID: params.ID,
125 Path: b.workingDir,
126 ToolName: b.Info().Name,
127 Action: "execute",
128 Description: permissionDescription,
129 Params: params.Input,
130 },
131 )
132 if !p {
133 return tools.ToolResponse{}, permission.ErrorPermissionDenied
134 }
135
136 switch b.mcpConfig.Type {
137 case MCPStdio:
138 c, err := client.NewStdioMCPClient(
139 b.mcpConfig.Command,
140 b.mcpConfig.ResolvedEnv(),
141 b.mcpConfig.Args...,
142 )
143 if err != nil {
144 return tools.NewTextErrorResponse(err.Error()), nil
145 }
146 return runTool(ctx, c, b.tool.Name, params.Input)
147 case MCPHttp:
148 c, err := client.NewStreamableHttpClient(
149 b.mcpConfig.URL,
150 transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()),
151 )
152 if err != nil {
153 return tools.NewTextErrorResponse(err.Error()), nil
154 }
155 return runTool(ctx, c, b.tool.Name, params.Input)
156 case MCPSse:
157 c, err := client.NewSSEMCPClient(
158 b.mcpConfig.URL,
159 client.WithHeaders(b.mcpConfig.ResolvedHeaders()),
160 )
161 if err != nil {
162 return tools.NewTextErrorResponse(err.Error()), nil
163 }
164 return runTool(ctx, c, b.tool.Name, params.Input)
165 }
166
167 return tools.NewTextErrorResponse("invalid mcp type"), nil
168}
169
170func NewMcpTool(name, cwd string, tool mcp.Tool, permissions permission.Service, mcpConfig MCPConfig) tools.BaseTool {
171 return &mcpTool{
172 mcpName: name,
173 tool: tool,
174 mcpConfig: mcpConfig,
175 permissions: permissions,
176 workingDir: cwd,
177 }
178}
179
180func getTools(ctx context.Context, cwd string, name string, m MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool {
181 var stdioTools []tools.BaseTool
182 initRequest := mcp.InitializeRequest{}
183 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
184 initRequest.Params.ClientInfo = mcp.Implementation{
185 Name: "dreamlover",
186 }
187
188 _, err := c.Initialize(ctx, initRequest)
189 if err != nil {
190 slog.Error("error initializing mcp client", "error", err)
191 return stdioTools
192 }
193 toolsRequest := mcp.ListToolsRequest{}
194 tools, err := c.ListTools(ctx, toolsRequest)
195 if err != nil {
196 slog.Error("error listing tools", "error", err)
197 return stdioTools
198 }
199 for _, t := range tools.Tools {
200 stdioTools = append(stdioTools, NewMcpTool(name, cwd, t, permissions, m))
201 }
202 defer c.Close()
203 return stdioTools
204}
205
206var (
207 mcpToolsOnce sync.Once
208 mcpTools []tools.BaseTool
209)
210
211func GetMCPTools(ctx context.Context, cwd string, mcps map[string]MCPConfig, permissions permission.Service) []tools.BaseTool {
212 mcpToolsOnce.Do(func() {
213 mcpTools = doGetMCPTools(ctx, cwd, mcps, permissions)
214 })
215 return mcpTools
216}
217
218func doGetMCPTools(ctx context.Context, cwd string, mcps map[string]MCPConfig, permissions permission.Service) []tools.BaseTool {
219 var wg sync.WaitGroup
220 result := csync.NewSlice[tools.BaseTool]()
221 for name, m := range mcps {
222 if m.Disabled {
223 slog.Debug("skipping disabled mcp", "name", name)
224 continue
225 }
226 wg.Add(1)
227 go func(name string, m MCPConfig) {
228 defer wg.Done()
229 switch m.Type {
230 case MCPStdio:
231 c, err := client.NewStdioMCPClient(
232 m.Command,
233 m.ResolvedEnv(),
234 m.Args...,
235 )
236 if err != nil {
237 slog.Error("error creating mcp client", "error", err)
238 return
239 }
240
241 result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
242 case MCPHttp:
243 c, err := client.NewStreamableHttpClient(
244 m.URL,
245 transport.WithHTTPHeaders(m.ResolvedHeaders()),
246 )
247 if err != nil {
248 slog.Error("error creating mcp client", "error", err)
249 return
250 }
251 result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
252 case MCPSse:
253 c, err := client.NewSSEMCPClient(
254 m.URL,
255 client.WithHeaders(m.ResolvedHeaders()),
256 )
257 if err != nil {
258 slog.Error("error creating mcp client", "error", err)
259 return
260 }
261 result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
262 }
263 }(name, m)
264 }
265 wg.Wait()
266 return slices.Collect(result.Seq())
267}
268
269func (m MCPConfig) ResolvedEnv() []string {
270 resolver := resolver.New()
271 for e, v := range m.Env {
272 var err error
273 m.Env[e], err = resolver.ResolveValue(v)
274 if err != nil {
275 slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
276 continue
277 }
278 }
279
280 env := make([]string, 0, len(m.Env))
281 for k, v := range m.Env {
282 env = append(env, fmt.Sprintf("%s=%s", k, v))
283 }
284 return env
285}
286
287func (m MCPConfig) ResolvedHeaders() map[string]string {
288 resolver := resolver.New()
289 for e, v := range m.Headers {
290 var err error
291 m.Headers[e], err = resolver.ResolveValue(v)
292 if err != nil {
293 slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
294 continue
295 }
296 }
297 return m.Headers
298}