1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7
8 "github.com/charmbracelet/crush/internal/config"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/logging"
11 "github.com/charmbracelet/crush/internal/permission"
12 "github.com/charmbracelet/crush/internal/version"
13
14 "github.com/mark3labs/mcp-go/client"
15 "github.com/mark3labs/mcp-go/mcp"
16)
17
18type mcpTool struct {
19 mcpName string
20 tool mcp.Tool
21 mcpConfig config.MCPServer
22 permissions permission.Service
23}
24
25type MCPClient interface {
26 Initialize(
27 ctx context.Context,
28 request mcp.InitializeRequest,
29 ) (*mcp.InitializeResult, error)
30 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
31 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
32 Close() error
33}
34
35func (b *mcpTool) Info() tools.ToolInfo {
36 return tools.ToolInfo{
37 Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
38 Description: b.tool.Description,
39 Parameters: b.tool.InputSchema.Properties,
40 Required: b.tool.InputSchema.Required,
41 }
42}
43
44func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
45 defer c.Close()
46 initRequest := mcp.InitializeRequest{}
47 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
48 initRequest.Params.ClientInfo = mcp.Implementation{
49 Name: "Crush",
50 Version: version.Version,
51 }
52
53 _, err := c.Initialize(ctx, initRequest)
54 if err != nil {
55 return tools.NewTextErrorResponse(err.Error()), nil
56 }
57
58 toolRequest := mcp.CallToolRequest{}
59 toolRequest.Params.Name = toolName
60 var args map[string]any
61 if err = json.Unmarshal([]byte(input), &args); err != nil {
62 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
63 }
64 toolRequest.Params.Arguments = args
65 result, err := c.CallTool(ctx, toolRequest)
66 if err != nil {
67 return tools.NewTextErrorResponse(err.Error()), nil
68 }
69
70 output := ""
71 for _, v := range result.Content {
72 if v, ok := v.(mcp.TextContent); ok {
73 output = v.Text
74 } else {
75 output = fmt.Sprintf("%v", v)
76 }
77 }
78
79 return tools.NewTextResponse(output), nil
80}
81
82func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
83 sessionID, messageID := tools.GetContextValues(ctx)
84 if sessionID == "" || messageID == "" {
85 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
86 }
87 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
88 p := b.permissions.Request(
89 permission.CreatePermissionRequest{
90 SessionID: sessionID,
91 Path: config.WorkingDirectory(),
92 ToolName: b.Info().Name,
93 Action: "execute",
94 Description: permissionDescription,
95 Params: params.Input,
96 },
97 )
98 if !p {
99 return tools.NewTextErrorResponse("permission denied"), nil
100 }
101
102 switch b.mcpConfig.Type {
103 case config.MCPStdio:
104 c, err := client.NewStdioMCPClient(
105 b.mcpConfig.Command,
106 b.mcpConfig.Env,
107 b.mcpConfig.Args...,
108 )
109 if err != nil {
110 return tools.NewTextErrorResponse(err.Error()), nil
111 }
112 return runTool(ctx, c, b.tool.Name, params.Input)
113 case config.MCPSse:
114 c, err := client.NewSSEMCPClient(
115 b.mcpConfig.URL,
116 client.WithHeaders(b.mcpConfig.Headers),
117 )
118 if err != nil {
119 return tools.NewTextErrorResponse(err.Error()), nil
120 }
121 return runTool(ctx, c, b.tool.Name, params.Input)
122 }
123
124 return tools.NewTextErrorResponse("invalid mcp type"), nil
125}
126
127func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
128 return &mcpTool{
129 mcpName: name,
130 tool: tool,
131 mcpConfig: mcpConfig,
132 permissions: permissions,
133 }
134}
135
136var mcpTools []tools.BaseTool
137
138func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
139 var stdioTools []tools.BaseTool
140 initRequest := mcp.InitializeRequest{}
141 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
142 initRequest.Params.ClientInfo = mcp.Implementation{
143 Name: "Crush",
144 Version: version.Version,
145 }
146
147 _, err := c.Initialize(ctx, initRequest)
148 if err != nil {
149 logging.Error("error initializing mcp client", "error", err)
150 return stdioTools
151 }
152 toolsRequest := mcp.ListToolsRequest{}
153 tools, err := c.ListTools(ctx, toolsRequest)
154 if err != nil {
155 logging.Error("error listing tools", "error", err)
156 return stdioTools
157 }
158 for _, t := range tools.Tools {
159 stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
160 }
161 defer c.Close()
162 return stdioTools
163}
164
165func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
166 if len(mcpTools) > 0 {
167 return mcpTools
168 }
169 for name, m := range config.Get().MCPServers {
170 switch m.Type {
171 case config.MCPStdio:
172 c, err := client.NewStdioMCPClient(
173 m.Command,
174 m.Env,
175 m.Args...,
176 )
177 if err != nil {
178 logging.Error("error creating mcp client", "error", err)
179 continue
180 }
181
182 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
183 case config.MCPSse:
184 c, err := client.NewSSEMCPClient(
185 m.URL,
186 client.WithHeaders(m.Headers),
187 )
188 if err != nil {
189 logging.Error("error creating mcp client", "error", err)
190 continue
191 }
192 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
193 }
194 }
195
196 return mcpTools
197}