1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7
8 "github.com/charmbracelet/crush/internal/config"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/logging"
11 "github.com/charmbracelet/crush/internal/permission"
12 "github.com/charmbracelet/crush/internal/version"
13
14 "github.com/mark3labs/mcp-go/client"
15 "github.com/mark3labs/mcp-go/mcp"
16)
17
18type mcpTool struct {
19 mcpName string
20 tool mcp.Tool
21 mcpConfig config.MCP
22 permissions permission.Service
23}
24
25type MCPClient interface {
26 Initialize(
27 ctx context.Context,
28 request mcp.InitializeRequest,
29 ) (*mcp.InitializeResult, error)
30 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
31 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
32 Close() error
33}
34
35func (b *mcpTool) Name() string {
36 return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
37}
38
39func (b *mcpTool) Info() tools.ToolInfo {
40 required := b.tool.InputSchema.Required
41 if required == nil {
42 required = make([]string, 0)
43 }
44 return tools.ToolInfo{
45 Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
46 Description: b.tool.Description,
47 Parameters: b.tool.InputSchema.Properties,
48 Required: required,
49 }
50}
51
52func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
53 defer c.Close()
54 initRequest := mcp.InitializeRequest{}
55 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
56 initRequest.Params.ClientInfo = mcp.Implementation{
57 Name: "Crush",
58 Version: version.Version,
59 }
60
61 _, err := c.Initialize(ctx, initRequest)
62 if err != nil {
63 return tools.NewTextErrorResponse(err.Error()), nil
64 }
65
66 toolRequest := mcp.CallToolRequest{}
67 toolRequest.Params.Name = toolName
68 var args map[string]any
69 if err = json.Unmarshal([]byte(input), &args); err != nil {
70 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
71 }
72 toolRequest.Params.Arguments = args
73 result, err := c.CallTool(ctx, toolRequest)
74 if err != nil {
75 return tools.NewTextErrorResponse(err.Error()), nil
76 }
77
78 output := ""
79 for _, v := range result.Content {
80 if v, ok := v.(mcp.TextContent); ok {
81 output = v.Text
82 } else {
83 output = fmt.Sprintf("%v", v)
84 }
85 }
86
87 return tools.NewTextResponse(output), nil
88}
89
90func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
91 sessionID, messageID := tools.GetContextValues(ctx)
92 if sessionID == "" || messageID == "" {
93 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
94 }
95 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
96 p := b.permissions.Request(
97 permission.CreatePermissionRequest{
98 SessionID: sessionID,
99 Path: config.WorkingDirectory(),
100 ToolName: b.Info().Name,
101 Action: "execute",
102 Description: permissionDescription,
103 Params: params.Input,
104 },
105 )
106 if !p {
107 return tools.NewTextErrorResponse("permission denied"), nil
108 }
109
110 switch b.mcpConfig.Type {
111 case config.MCPStdio:
112 c, err := client.NewStdioMCPClient(
113 b.mcpConfig.Command,
114 b.mcpConfig.Env,
115 b.mcpConfig.Args...,
116 )
117 if err != nil {
118 return tools.NewTextErrorResponse(err.Error()), nil
119 }
120 return runTool(ctx, c, b.tool.Name, params.Input)
121 case config.MCPSse:
122 c, err := client.NewSSEMCPClient(
123 b.mcpConfig.URL,
124 client.WithHeaders(b.mcpConfig.Headers),
125 )
126 if err != nil {
127 return tools.NewTextErrorResponse(err.Error()), nil
128 }
129 return runTool(ctx, c, b.tool.Name, params.Input)
130 }
131
132 return tools.NewTextErrorResponse("invalid mcp type"), nil
133}
134
135func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool {
136 return &mcpTool{
137 mcpName: name,
138 tool: tool,
139 mcpConfig: mcpConfig,
140 permissions: permissions,
141 }
142}
143
144var mcpTools []tools.BaseTool
145
146func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool {
147 var stdioTools []tools.BaseTool
148 initRequest := mcp.InitializeRequest{}
149 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
150 initRequest.Params.ClientInfo = mcp.Implementation{
151 Name: "Crush",
152 Version: version.Version,
153 }
154
155 _, err := c.Initialize(ctx, initRequest)
156 if err != nil {
157 logging.Error("error initializing mcp client", "error", err)
158 return stdioTools
159 }
160 toolsRequest := mcp.ListToolsRequest{}
161 tools, err := c.ListTools(ctx, toolsRequest)
162 if err != nil {
163 logging.Error("error listing tools", "error", err)
164 return stdioTools
165 }
166 for _, t := range tools.Tools {
167 stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
168 }
169 defer c.Close()
170 return stdioTools
171}
172
173func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
174 if len(mcpTools) > 0 {
175 return mcpTools
176 }
177 for name, m := range config.Get().MCP {
178 switch m.Type {
179 case config.MCPStdio:
180 c, err := client.NewStdioMCPClient(
181 m.Command,
182 m.Env,
183 m.Args...,
184 )
185 if err != nil {
186 logging.Error("error creating mcp client", "error", err)
187 continue
188 }
189
190 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
191 case config.MCPSse:
192 c, err := client.NewSSEMCPClient(
193 m.URL,
194 client.WithHeaders(m.Headers),
195 )
196 if err != nil {
197 logging.Error("error creating mcp client", "error", err)
198 continue
199 }
200 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
201 }
202 }
203
204 return mcpTools
205}