1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7
8 "github.com/charmbracelet/crush/internal/config"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/logging"
11 "github.com/charmbracelet/crush/internal/permission"
12 "github.com/charmbracelet/crush/internal/version"
13
14 "github.com/mark3labs/mcp-go/client"
15 "github.com/mark3labs/mcp-go/mcp"
16)
17
18type mcpTool struct {
19 mcpName string
20 tool mcp.Tool
21 mcpConfig config.MCPServer
22 permissions permission.Service
23}
24
25type MCPClient interface {
26 Initialize(
27 ctx context.Context,
28 request mcp.InitializeRequest,
29 ) (*mcp.InitializeResult, error)
30 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
31 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
32 Close() error
33}
34
35func (b *mcpTool) Name() string {
36 return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
37}
38
39func (b *mcpTool) Info() tools.ToolInfo {
40 return tools.ToolInfo{
41 Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
42 Description: b.tool.Description,
43 Parameters: b.tool.InputSchema.Properties,
44 Required: b.tool.InputSchema.Required,
45 }
46}
47
48func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
49 defer c.Close()
50 initRequest := mcp.InitializeRequest{}
51 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
52 initRequest.Params.ClientInfo = mcp.Implementation{
53 Name: "Crush",
54 Version: version.Version,
55 }
56
57 _, err := c.Initialize(ctx, initRequest)
58 if err != nil {
59 return tools.NewTextErrorResponse(err.Error()), nil
60 }
61
62 toolRequest := mcp.CallToolRequest{}
63 toolRequest.Params.Name = toolName
64 var args map[string]any
65 if err = json.Unmarshal([]byte(input), &args); err != nil {
66 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
67 }
68 toolRequest.Params.Arguments = args
69 result, err := c.CallTool(ctx, toolRequest)
70 if err != nil {
71 return tools.NewTextErrorResponse(err.Error()), nil
72 }
73
74 output := ""
75 for _, v := range result.Content {
76 if v, ok := v.(mcp.TextContent); ok {
77 output = v.Text
78 } else {
79 output = fmt.Sprintf("%v", v)
80 }
81 }
82
83 return tools.NewTextResponse(output), nil
84}
85
86func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
87 sessionID, messageID := tools.GetContextValues(ctx)
88 if sessionID == "" || messageID == "" {
89 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
90 }
91 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
92 p := b.permissions.Request(
93 permission.CreatePermissionRequest{
94 SessionID: sessionID,
95 Path: config.WorkingDirectory(),
96 ToolName: b.Info().Name,
97 Action: "execute",
98 Description: permissionDescription,
99 Params: params.Input,
100 },
101 )
102 if !p {
103 return tools.NewTextErrorResponse("permission denied"), nil
104 }
105
106 switch b.mcpConfig.Type {
107 case config.MCPStdio:
108 c, err := client.NewStdioMCPClient(
109 b.mcpConfig.Command,
110 b.mcpConfig.Env,
111 b.mcpConfig.Args...,
112 )
113 if err != nil {
114 return tools.NewTextErrorResponse(err.Error()), nil
115 }
116 return runTool(ctx, c, b.tool.Name, params.Input)
117 case config.MCPSse:
118 c, err := client.NewSSEMCPClient(
119 b.mcpConfig.URL,
120 client.WithHeaders(b.mcpConfig.Headers),
121 )
122 if err != nil {
123 return tools.NewTextErrorResponse(err.Error()), nil
124 }
125 return runTool(ctx, c, b.tool.Name, params.Input)
126 }
127
128 return tools.NewTextErrorResponse("invalid mcp type"), nil
129}
130
131func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
132 return &mcpTool{
133 mcpName: name,
134 tool: tool,
135 mcpConfig: mcpConfig,
136 permissions: permissions,
137 }
138}
139
140var mcpTools []tools.BaseTool
141
142func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
143 var stdioTools []tools.BaseTool
144 initRequest := mcp.InitializeRequest{}
145 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
146 initRequest.Params.ClientInfo = mcp.Implementation{
147 Name: "Crush",
148 Version: version.Version,
149 }
150
151 _, err := c.Initialize(ctx, initRequest)
152 if err != nil {
153 logging.Error("error initializing mcp client", "error", err)
154 return stdioTools
155 }
156 toolsRequest := mcp.ListToolsRequest{}
157 tools, err := c.ListTools(ctx, toolsRequest)
158 if err != nil {
159 logging.Error("error listing tools", "error", err)
160 return stdioTools
161 }
162 for _, t := range tools.Tools {
163 stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
164 }
165 defer c.Close()
166 return stdioTools
167}
168
169func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
170 if len(mcpTools) > 0 {
171 return mcpTools
172 }
173 for name, m := range config.Get().MCPServers {
174 switch m.Type {
175 case config.MCPStdio:
176 c, err := client.NewStdioMCPClient(
177 m.Command,
178 m.Env,
179 m.Args...,
180 )
181 if err != nil {
182 logging.Error("error creating mcp client", "error", err)
183 continue
184 }
185
186 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
187 case config.MCPSse:
188 c, err := client.NewSSEMCPClient(
189 m.URL,
190 client.WithHeaders(m.Headers),
191 )
192 if err != nil {
193 logging.Error("error creating mcp client", "error", err)
194 continue
195 }
196 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
197 }
198 }
199
200 return mcpTools
201}