1package mcp
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "fmt"
8 "iter"
9 "log/slog"
10 "slices"
11 "strings"
12
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/modelcontextprotocol/go-sdk/mcp"
16)
17
18type Tool = mcp.Tool
19
20// ToolResult represents the result of running an MCP tool.
21type ToolResult struct {
22 Type string
23 Content string
24 Data []byte
25 MediaType string
26}
27
28var allTools = csync.NewMap[string, []*Tool]()
29
30// Tools returns all available MCP tools.
31func Tools() iter.Seq2[string, []*Tool] {
32 return allTools.Seq2()
33}
34
35// RunTool runs an MCP tool with the given input parameters.
36func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string, input string) (ToolResult, error) {
37 var args map[string]any
38 if err := json.Unmarshal([]byte(input), &args); err != nil {
39 return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err)
40 }
41
42 c, err := getOrRenewClient(ctx, cfg, name)
43 if err != nil {
44 return ToolResult{}, err
45 }
46 result, err := c.CallTool(ctx, &mcp.CallToolParams{
47 Name: toolName,
48 Arguments: args,
49 })
50 if err != nil {
51 return ToolResult{}, err
52 }
53
54 if len(result.Content) == 0 {
55 return ToolResult{Type: "text", Content: ""}, nil
56 }
57
58 var textParts []string
59 var imageData []byte
60 var imageMimeType string
61 var audioData []byte
62 var audioMimeType string
63
64 for _, v := range result.Content {
65 switch content := v.(type) {
66 case *mcp.TextContent:
67 textParts = append(textParts, content.Text)
68 case *mcp.ImageContent:
69 if imageData == nil {
70 imageData = content.Data
71 imageMimeType = content.MIMEType
72 }
73 case *mcp.AudioContent:
74 if audioData == nil {
75 audioData = content.Data
76 audioMimeType = content.MIMEType
77 }
78 default:
79 textParts = append(textParts, fmt.Sprintf("%v", v))
80 }
81 }
82
83 textContent := strings.Join(textParts, "\n")
84
85 // We need to to make sure the data is base64
86 // when using something like docker + playwright the data was not returned correctly
87 if imageData != nil {
88 return ToolResult{
89 Type: "image",
90 Content: textContent,
91 Data: ensureBase64(imageData),
92 MediaType: imageMimeType,
93 }, nil
94 }
95
96 if audioData != nil {
97 return ToolResult{
98 Type: "media",
99 Content: textContent,
100 Data: ensureBase64(audioData),
101 MediaType: audioMimeType,
102 }, nil
103 }
104
105 return ToolResult{
106 Type: "text",
107 Content: textContent,
108 }, nil
109}
110
111// RefreshTools gets the updated list of tools from the MCP and updates the
112// global state.
113func RefreshTools(ctx context.Context, cfg *config.ConfigStore, name string) {
114 session, ok := sessions.Get(name)
115 if !ok {
116 slog.Warn("Refresh tools: no session", "name", name)
117 return
118 }
119
120 tools, err := getTools(ctx, session)
121 if err != nil {
122 updateState(name, StateError, err, nil, Counts{})
123 return
124 }
125
126 toolCount := updateTools(cfg, name, tools)
127
128 prev, _ := states.Get(name)
129 prev.Counts.Tools = toolCount
130 updateState(name, StateConnected, nil, session, prev.Counts)
131}
132
133func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) {
134 // Always call ListTools to get the actual available tools.
135 // The InitializeResult Capabilities.Tools field may be an empty object {},
136 // which is valid per MCP spec, but we still need to call ListTools to discover tools.
137 result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
138 if err != nil {
139 return nil, err
140 }
141 return result.Tools, nil
142}
143
144func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int {
145 tools = filterDisabledTools(cfg, name, tools)
146 if len(tools) == 0 {
147 allTools.Del(name)
148 return 0
149 }
150 allTools.Set(name, tools)
151 return len(tools)
152}
153
154// filterDisabledTools removes tools that are disabled via config.
155func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool) []*Tool {
156 mcpCfg, ok := cfg.Config().MCP[mcpName]
157 if !ok || len(mcpCfg.DisabledTools) == 0 {
158 return tools
159 }
160
161 filtered := make([]*Tool, 0, len(tools))
162 for _, tool := range tools {
163 if !slices.Contains(mcpCfg.DisabledTools, tool.Name) {
164 filtered = append(filtered, tool)
165 }
166 }
167 return filtered
168}
169
170// ensureBase64 checks if data is valid base64 and returns it as-is if so,
171// otherwise encodes the raw binary data to base64.
172func ensureBase64(data []byte) []byte {
173 // Check if the data is already valid base64 by attempting to decode it.
174 // Valid base64 should only contain ASCII characters (A-Z, a-z, 0-9, +, /, =).
175 if isValidBase64(data) {
176 return data
177 }
178 // Data is raw binary, encode it to base64.
179 encoded := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
180 base64.StdEncoding.Encode(encoded, data)
181 return encoded
182}
183
184// isValidBase64 checks if the data appears to be valid base64-encoded content.
185func isValidBase64(data []byte) bool {
186 if len(data) == 0 {
187 return true
188 }
189 // Base64 strings should only contain ASCII characters.
190 for _, b := range data {
191 if b > 127 {
192 return false
193 }
194 }
195 // Try to decode to verify it's valid base64.
196 _, err := base64.StdEncoding.DecodeString(string(data))
197 return err == nil
198}