1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "slices"
9 "sync"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/llm/tools"
14 "github.com/charmbracelet/crush/internal/version"
15
16 "github.com/charmbracelet/crush/internal/permission"
17
18 "github.com/mark3labs/mcp-go/client"
19 "github.com/mark3labs/mcp-go/client/transport"
20 "github.com/mark3labs/mcp-go/mcp"
21)
22
23var (
24 mcpToolsOnce sync.Once
25 mcpTools []tools.BaseTool
26 mcpClients = csync.NewMap[string, *client.Client]()
27)
28
29type McpTool struct {
30 mcpName string
31 tool mcp.Tool
32 permissions permission.Service
33 workingDir string
34}
35
36func (b *McpTool) Name() string {
37 return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
38}
39
40func (b *McpTool) Info() tools.ToolInfo {
41 required := b.tool.InputSchema.Required
42 if required == nil {
43 required = make([]string, 0)
44 }
45 return tools.ToolInfo{
46 Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
47 Description: b.tool.Description,
48 Parameters: b.tool.InputSchema.Properties,
49 Required: required,
50 }
51}
52
53func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
54 var args map[string]any
55 if err := json.Unmarshal([]byte(input), &args); err != nil {
56 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
57 }
58 c, ok := mcpClients.Get(name)
59 if !ok {
60 return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
61 }
62 result, err := c.CallTool(ctx, mcp.CallToolRequest{
63 Params: mcp.CallToolParams{
64 Name: toolName,
65 Arguments: args,
66 },
67 })
68 if err != nil {
69 return tools.NewTextErrorResponse(err.Error()), nil
70 }
71
72 output := ""
73 for _, v := range result.Content {
74 if v, ok := v.(mcp.TextContent); ok {
75 output = v.Text
76 } else {
77 output = fmt.Sprintf("%v", v)
78 }
79 }
80
81 return tools.NewTextResponse(output), nil
82}
83
84func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
85 sessionID, messageID := tools.GetContextValues(ctx)
86 if sessionID == "" || messageID == "" {
87 return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
88 }
89 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
90 p := b.permissions.Request(
91 permission.CreatePermissionRequest{
92 SessionID: sessionID,
93 ToolCallID: params.ID,
94 Path: b.workingDir,
95 ToolName: b.Info().Name,
96 Action: "execute",
97 Description: permissionDescription,
98 Params: params.Input,
99 },
100 )
101 if !p {
102 return tools.ToolResponse{}, permission.ErrorPermissionDenied
103 }
104
105 return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
106}
107
108func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
109 result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
110 if err != nil {
111 slog.Error("error listing tools", "error", err)
112 c.Close()
113 mcpClients.Del(name)
114 return nil
115 }
116 mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
117 for _, tool := range result.Tools {
118 mcpTools = append(mcpTools, &McpTool{
119 mcpName: name,
120 tool: tool,
121 permissions: permissions,
122 workingDir: workingDir,
123 })
124 }
125 return mcpTools
126}
127
128// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
129func CloseMCPClients() {
130 for c := range mcpClients.Seq() {
131 _ = c.Close()
132 }
133}
134
135var mcpInitRequest = mcp.InitializeRequest{
136 Params: mcp.InitializeParams{
137 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
138 ClientInfo: mcp.Implementation{
139 Name: "Crush",
140 Version: version.Version,
141 },
142 },
143}
144
145func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
146 var wg sync.WaitGroup
147 result := csync.NewSlice[tools.BaseTool]()
148 for name, m := range cfg.MCP {
149 if m.Disabled {
150 slog.Debug("skipping disabled mcp", "name", name)
151 continue
152 }
153 wg.Add(1)
154 go func(name string, m config.MCPConfig) {
155 defer wg.Done()
156 c, err := createMcpClient(m)
157 if err != nil {
158 slog.Error("error creating mcp client", "error", err, "name", name)
159 return
160 }
161 if err := c.Start(ctx); err != nil {
162 slog.Error("error starting mcp client", "error", err, "name", name)
163 _ = c.Close()
164 return
165 }
166 if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
167 slog.Error("error initializing mcp client", "error", err, "name", name)
168 _ = c.Close()
169 return
170 }
171
172 slog.Info("Initialized mcp client", "name", name)
173 mcpClients.Set(name, c)
174
175 result.Append(getTools(ctx, name, permissions, c, cfg.WorkingDir())...)
176 }(name, m)
177 }
178 wg.Wait()
179 return slices.Collect(result.Seq())
180}
181
182func createMcpClient(m config.MCPConfig) (*client.Client, error) {
183 switch m.Type {
184 case config.MCPStdio:
185 return client.NewStdioMCPClient(
186 m.Command,
187 m.ResolvedEnv(),
188 m.Args...,
189 )
190 case config.MCPHttp:
191 return client.NewStreamableHttpClient(
192 m.URL,
193 transport.WithHTTPHeaders(m.ResolvedHeaders()),
194 )
195 case config.MCPSse:
196 return client.NewSSEMCPClient(
197 m.URL,
198 client.WithHeaders(m.ResolvedHeaders()),
199 )
200 default:
201 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
202 }
203}