1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "slices"
9 "sync"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/llm/tools"
14
15 "github.com/charmbracelet/crush/internal/permission"
16 "github.com/charmbracelet/crush/internal/version"
17
18 "github.com/mark3labs/mcp-go/client"
19 "github.com/mark3labs/mcp-go/client/transport"
20 "github.com/mark3labs/mcp-go/mcp"
21)
22
23type McpTool struct {
24 mcpName string
25 tool mcp.Tool
26 client MCPClient
27 mcpConfig config.MCPConfig
28 permissions permission.Service
29 workingDir string
30}
31
32type MCPClient interface {
33 Initialize(
34 ctx context.Context,
35 request mcp.InitializeRequest,
36 ) (*mcp.InitializeResult, error)
37 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
38 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
39 Close() error
40}
41
42func (b *McpTool) Name() string {
43 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
44}
45
46func (b *McpTool) Info() tools.ToolInfo {
47 required := b.tool.InputSchema.Required
48 if required == nil {
49 required = make([]string, 0)
50 }
51 return tools.ToolInfo{
52 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
53 Description: b.tool.Description,
54 Parameters: b.tool.InputSchema.Properties,
55 Required: required,
56 }
57}
58
59func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
60 initRequest := mcp.InitializeRequest{}
61 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
62 initRequest.Params.ClientInfo = mcp.Implementation{
63 Name: "Crush",
64 Version: version.Version,
65 }
66
67 _, err := c.Initialize(ctx, initRequest)
68 if err != nil {
69 return tools.NewTextErrorResponse(err.Error()), nil
70 }
71
72 toolRequest := mcp.CallToolRequest{}
73 toolRequest.Params.Name = toolName
74 var args map[string]any
75 if err = json.Unmarshal([]byte(input), &args); err != nil {
76 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
77 }
78 toolRequest.Params.Arguments = args
79 result, err := c.CallTool(ctx, toolRequest)
80 if err != nil {
81 return tools.NewTextErrorResponse(err.Error()), nil
82 }
83
84 output := ""
85 for _, v := range result.Content {
86 if v, ok := v.(mcp.TextContent); ok {
87 output = v.Text
88 } else {
89 output = fmt.Sprintf("%v", v)
90 }
91 }
92
93 return tools.NewTextResponse(output), nil
94}
95
96func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
97 sessionID, messageID := tools.GetContextValues(ctx)
98 if sessionID == "" || messageID == "" {
99 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
100 }
101 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
102 p := b.permissions.Request(
103 permission.CreatePermissionRequest{
104 SessionID: sessionID,
105 ToolCallID: params.ID,
106 Path: b.workingDir,
107 ToolName: b.Info().Name,
108 Action: "execute",
109 Description: permissionDescription,
110 Params: params.Input,
111 },
112 )
113 if !p {
114 return tools.ToolResponse{}, permission.ErrorPermissionDenied
115 }
116
117 return runTool(ctx, b.client, b.tool.Name, params.Input)
118}
119
120func NewMcpTool(name string, c MCPClient, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
121 return &McpTool{
122 mcpName: name,
123 client: c,
124 tool: tool,
125 mcpConfig: mcpConfig,
126 permissions: permissions,
127 workingDir: workingDir,
128 }
129}
130
131func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
132 var stdioTools []tools.BaseTool
133 initRequest := mcp.InitializeRequest{}
134 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
135 initRequest.Params.ClientInfo = mcp.Implementation{
136 Name: "Crush",
137 Version: version.Version,
138 }
139
140 _, err := c.Initialize(ctx, initRequest)
141 if err != nil {
142 slog.Error("error initializing mcp client", "error", err)
143 return stdioTools
144 }
145 toolsRequest := mcp.ListToolsRequest{}
146 tools, err := c.ListTools(ctx, toolsRequest)
147 if err != nil {
148 slog.Error("error listing tools", "error", err)
149 return stdioTools
150 }
151 for _, t := range tools.Tools {
152 stdioTools = append(stdioTools, NewMcpTool(name, c, t, permissions, m, workingDir))
153 }
154 return stdioTools
155}
156
157var (
158 mcpToolsOnce sync.Once
159 mcpTools []tools.BaseTool
160)
161
162func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
163 mcpToolsOnce.Do(func() {
164 mcpTools = doGetMCPTools(ctx, permissions, cfg)
165 })
166 return mcpTools
167}
168
169func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
170 var wg sync.WaitGroup
171 result := csync.NewSlice[tools.BaseTool]()
172 for name, m := range cfg.MCP {
173 if m.Disabled {
174 slog.Debug("skipping disabled mcp", "name", name)
175 continue
176 }
177 wg.Add(1)
178 go func(name string, m config.MCPConfig) {
179 defer wg.Done()
180 switch m.Type {
181 case config.MCPStdio:
182 c, err := client.NewStdioMCPClient(
183 m.Command,
184 m.ResolvedEnv(),
185 m.Args...,
186 )
187 if err != nil {
188 slog.Error("error creating mcp client", "error", err)
189 return
190 }
191
192 result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
193 case config.MCPHttp:
194 c, err := client.NewStreamableHttpClient(
195 m.URL,
196 transport.WithHTTPHeaders(m.ResolvedHeaders()),
197 )
198 if err != nil {
199 slog.Error("error creating mcp client", "error", err)
200 return
201 }
202 result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
203 case config.MCPSse:
204 c, err := client.NewSSEMCPClient(
205 m.URL,
206 client.WithHeaders(m.ResolvedHeaders()),
207 )
208 if err != nil {
209 slog.Error("error creating mcp client", "error", err)
210 return
211 }
212 result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
213 }
214 }(name, m)
215 }
216 wg.Wait()
217 return slices.Collect(result.Seq())
218}