1package mcp
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "fmt"
8 "iter"
9 "log/slog"
10 "slices"
11 "strings"
12
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/modelcontextprotocol/go-sdk/mcp"
16)
17
18type Tool = mcp.Tool
19
20// ToolResult represents the result of running an MCP tool.
21type ToolResult struct {
22 Type string
23 Content string
24 Data []byte
25 MediaType string
26}
27
28var allTools = csync.NewMap[string, []*Tool]()
29
30// Tools returns all available MCP tools.
31func Tools() iter.Seq2[string, []*Tool] {
32 return allTools.Seq2()
33}
34
35// RunTool runs an MCP tool with the given input parameters.
36func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string, input string) (ToolResult, error) {
37 var args map[string]any
38 if err := json.Unmarshal([]byte(input), &args); err != nil {
39 return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err)
40 }
41
42 c, err := getOrRenewClient(ctx, cfg, name)
43 if err != nil {
44 return ToolResult{}, err
45 }
46 result, err := c.CallTool(ctx, &mcp.CallToolParams{
47 Name: toolName,
48 Arguments: args,
49 })
50 if err != nil {
51 return ToolResult{}, err
52 }
53
54 if len(result.Content) == 0 {
55 return ToolResult{Type: "text", Content: ""}, nil
56 }
57
58 var textParts []string
59 var imageData []byte
60 var imageMimeType string
61 var audioData []byte
62 var audioMimeType string
63
64 for _, v := range result.Content {
65 switch content := v.(type) {
66 case *mcp.TextContent:
67 textParts = append(textParts, content.Text)
68 case *mcp.ImageContent:
69 if imageData == nil {
70 imageData = content.Data
71 imageMimeType = content.MIMEType
72 }
73 case *mcp.AudioContent:
74 if audioData == nil {
75 audioData = content.Data
76 audioMimeType = content.MIMEType
77 }
78 default:
79 textParts = append(textParts, fmt.Sprintf("%v", v))
80 }
81 }
82
83 textContent := strings.Join(textParts, "\n")
84
85 // We need to make sure the data is base64
86 // when using something like docker + playwright the data was not returned correctly.
87 if imageData != nil {
88 return ToolResult{
89 Type: "image",
90 Content: textContent,
91 Data: ensureRawBytes(imageData),
92 MediaType: imageMimeType,
93 }, nil
94 }
95
96 if audioData != nil {
97 return ToolResult{
98 Type: "media",
99 Content: textContent,
100 Data: ensureRawBytes(audioData),
101 MediaType: audioMimeType,
102 }, nil
103 }
104
105 return ToolResult{
106 Type: "text",
107 Content: textContent,
108 }, nil
109}
110
111// RefreshTools gets the updated list of tools from the MCP and updates the
112// global state.
113func RefreshTools(ctx context.Context, cfg *config.ConfigStore, name string) {
114 session, ok := sessions.Get(name)
115 if !ok {
116 slog.Warn("Refresh tools: no session", "name", name)
117 return
118 }
119
120 tools, err := getTools(ctx, session)
121 if err != nil {
122 updateState(name, StateError, err, nil, Counts{})
123 return
124 }
125
126 toolCount := updateTools(cfg, name, tools)
127
128 prev, _ := states.Get(name)
129 prev.Counts.Tools = toolCount
130 updateState(name, StateConnected, nil, session, prev.Counts)
131}
132
133func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) {
134 // Always call ListTools to get the actual available tools.
135 // The InitializeResult Capabilities.Tools field may be an empty object {},
136 // which is valid per MCP spec, but we still need to call ListTools to discover tools.
137 result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
138 if err != nil {
139 return nil, err
140 }
141 return result.Tools, nil
142}
143
144func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int {
145 mcpCfg, ok := cfg.Config().MCP[name]
146 if ok {
147 tools = filterTools(mcpCfg, tools)
148 }
149 if len(tools) == 0 {
150 allTools.Del(name)
151 return 0
152 }
153 allTools.Set(name, tools)
154 return len(tools)
155}
156
157// filterTools filters tools based on enabled_tools (allow list) and
158// disabled_tools (deny list) from the MCP config.
159func filterTools(mcpCfg config.MCPConfig, tools []*Tool) []*Tool {
160 if len(mcpCfg.EnabledTools) > 0 {
161 filtered := make([]*Tool, 0, len(mcpCfg.EnabledTools))
162 for _, tool := range tools {
163 if slices.Contains(mcpCfg.EnabledTools, tool.Name) {
164 filtered = append(filtered, tool)
165 }
166 }
167 tools = filtered
168 }
169
170 if len(mcpCfg.DisabledTools) > 0 {
171 filtered := make([]*Tool, 0, len(tools))
172 for _, tool := range tools {
173 if !slices.Contains(mcpCfg.DisabledTools, tool.Name) {
174 filtered = append(filtered, tool)
175 }
176 }
177 tools = filtered
178 }
179
180 return tools
181}
182
183// ensureRawBytes normalizes MCP media data into raw binary bytes.
184//
185// The MCP Go SDK's json.Unmarshal normally base64-decodes
186// ImageContent.Data into raw bytes automatically. However, some MCP
187// transports (notably Docker over stdio) can deliver data in
188// unexpected formats. This function handles both cases:
189//
190// - If data looks like a valid base64 string (ASCII-only, decodable)
191// it is decoded and the raw bytes are returned.
192// - If data is already raw binary (contains bytes > 127) it is
193// returned as-is.
194func ensureRawBytes(data []byte) []byte {
195 if len(data) == 0 {
196 return data
197 }
198
199 normalized := normalizeBase64Input(data)
200 if decoded, ok := decodeBase64(normalized); ok {
201 return decoded
202 }
203
204 // Already raw binary — return unchanged.
205 return data
206}
207
208func normalizeBase64Input(data []byte) []byte {
209 normalized := strings.Join(strings.Fields(string(data)), "")
210 return []byte(normalized)
211}
212
213func decodeBase64(data []byte) ([]byte, bool) {
214 if len(data) == 0 {
215 return data, true
216 }
217
218 for _, b := range data {
219 if b > 127 {
220 return nil, false
221 }
222 }
223
224 s := string(data)
225 decoded, err := base64.StdEncoding.DecodeString(s)
226 if err == nil {
227 return decoded, true
228 }
229 decoded, err = base64.RawStdEncoding.DecodeString(s)
230 if err == nil {
231 return decoded, true
232 }
233 return nil, false
234}