1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7
8 "github.com/charmbracelet/crush/internal/config"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/logging"
11 "github.com/charmbracelet/crush/internal/permission"
12 "github.com/charmbracelet/crush/internal/version"
13
14 "github.com/mark3labs/mcp-go/client"
15 "github.com/mark3labs/mcp-go/client/transport"
16 "github.com/mark3labs/mcp-go/mcp"
17)
18
19type mcpTool struct {
20 mcpName string
21 tool mcp.Tool
22 mcpConfig config.MCPConfig
23 permissions permission.Service
24}
25
26type MCPClient interface {
27 Initialize(
28 ctx context.Context,
29 request mcp.InitializeRequest,
30 ) (*mcp.InitializeResult, error)
31 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
32 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
33 Close() error
34}
35
36func (b *mcpTool) Name() string {
37 return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
38}
39
40func (b *mcpTool) Info() tools.ToolInfo {
41 required := b.tool.InputSchema.Required
42 if required == nil {
43 required = make([]string, 0)
44 }
45 return tools.ToolInfo{
46 Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
47 Description: b.tool.Description,
48 Parameters: b.tool.InputSchema.Properties,
49 Required: required,
50 }
51}
52
53func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
54 defer c.Close()
55 initRequest := mcp.InitializeRequest{}
56 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
57 initRequest.Params.ClientInfo = mcp.Implementation{
58 Name: "Crush",
59 Version: version.Version,
60 }
61
62 _, err := c.Initialize(ctx, initRequest)
63 if err != nil {
64 return tools.NewTextErrorResponse(err.Error()), nil
65 }
66
67 toolRequest := mcp.CallToolRequest{}
68 toolRequest.Params.Name = toolName
69 var args map[string]any
70 if err = json.Unmarshal([]byte(input), &args); err != nil {
71 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
72 }
73 toolRequest.Params.Arguments = args
74 result, err := c.CallTool(ctx, toolRequest)
75 if err != nil {
76 return tools.NewTextErrorResponse(err.Error()), nil
77 }
78
79 output := ""
80 for _, v := range result.Content {
81 if v, ok := v.(mcp.TextContent); ok {
82 output = v.Text
83 } else {
84 output = fmt.Sprintf("%v", v)
85 }
86 }
87
88 return tools.NewTextResponse(output), nil
89}
90
91func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
92 sessionID, messageID := tools.GetContextValues(ctx)
93 if sessionID == "" || messageID == "" {
94 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
95 }
96 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
97 p := b.permissions.Request(
98 permission.CreatePermissionRequest{
99 SessionID: sessionID,
100 Path: config.Get().WorkingDir(),
101 ToolName: b.Info().Name,
102 Action: "execute",
103 Description: permissionDescription,
104 Params: params.Input,
105 },
106 )
107 if !p {
108 return tools.NewTextErrorResponse("permission denied"), nil
109 }
110
111 switch b.mcpConfig.Type {
112 case config.MCPStdio:
113 c, err := client.NewStdioMCPClient(
114 b.mcpConfig.Command,
115 b.mcpConfig.Env,
116 b.mcpConfig.Args...,
117 )
118 if err != nil {
119 return tools.NewTextErrorResponse(err.Error()), nil
120 }
121 return runTool(ctx, c, b.tool.Name, params.Input)
122 case config.MCPHttp:
123 c, err := client.NewStreamableHttpClient(
124 b.mcpConfig.URL,
125 transport.WithHTTPHeaders(b.mcpConfig.Headers),
126 )
127 if err != nil {
128 return tools.NewTextErrorResponse(err.Error()), nil
129 }
130 return runTool(ctx, c, b.tool.Name, params.Input)
131 case config.MCPSse:
132 c, err := client.NewSSEMCPClient(
133 b.mcpConfig.URL,
134 client.WithHeaders(b.mcpConfig.Headers),
135 )
136 if err != nil {
137 return tools.NewTextErrorResponse(err.Error()), nil
138 }
139 return runTool(ctx, c, b.tool.Name, params.Input)
140 }
141
142 return tools.NewTextErrorResponse("invalid mcp type"), nil
143}
144
145func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig) tools.BaseTool {
146 return &mcpTool{
147 mcpName: name,
148 tool: tool,
149 mcpConfig: mcpConfig,
150 permissions: permissions,
151 }
152}
153
154var mcpTools []tools.BaseTool
155
156func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool {
157 var stdioTools []tools.BaseTool
158 initRequest := mcp.InitializeRequest{}
159 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
160 initRequest.Params.ClientInfo = mcp.Implementation{
161 Name: "Crush",
162 Version: version.Version,
163 }
164
165 _, err := c.Initialize(ctx, initRequest)
166 if err != nil {
167 logging.Error("error initializing mcp client", "error", err)
168 return stdioTools
169 }
170 toolsRequest := mcp.ListToolsRequest{}
171 tools, err := c.ListTools(ctx, toolsRequest)
172 if err != nil {
173 logging.Error("error listing tools", "error", err)
174 return stdioTools
175 }
176 for _, t := range tools.Tools {
177 stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
178 }
179 defer c.Close()
180 return stdioTools
181}
182
183func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
184 if len(mcpTools) > 0 {
185 return mcpTools
186 }
187 for name, m := range config.Get().MCP {
188 switch m.Type {
189 case config.MCPStdio:
190 c, err := client.NewStdioMCPClient(
191 m.Command,
192 m.Env,
193 m.Args...,
194 )
195 if err != nil {
196 logging.Error("error creating mcp client", "error", err)
197 continue
198 }
199
200 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
201 case config.MCPHttp:
202 c, err := client.NewStreamableHttpClient(
203 m.URL,
204 transport.WithHTTPHeaders(m.Headers),
205 )
206 if err != nil {
207 logging.Error("error creating mcp client", "error", err)
208 continue
209 }
210 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
211 case config.MCPSse:
212 c, err := client.NewSSEMCPClient(
213 m.URL,
214 client.WithHeaders(m.Headers),
215 )
216 if err != nil {
217 logging.Error("error creating mcp client", "error", err)
218 continue
219 }
220 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
221 }
222 }
223
224 return mcpTools
225}