tools.go

  1package mcp
  2
  3import (
  4	"context"
  5	"encoding/base64"
  6	"encoding/json"
  7	"fmt"
  8	"iter"
  9	"log/slog"
 10	"slices"
 11	"strings"
 12
 13	"git.secluded.site/crush/internal/config"
 14	"git.secluded.site/crush/internal/csync"
 15	"github.com/modelcontextprotocol/go-sdk/mcp"
 16)
 17
 18type Tool = mcp.Tool
 19
 20// ToolResult represents the result of running an MCP tool.
 21type ToolResult struct {
 22	Type      string
 23	Content   string
 24	Data      []byte
 25	MediaType string
 26}
 27
 28var allTools = csync.NewMap[string, []*Tool]()
 29
 30// Tools returns all available MCP tools.
 31func Tools() iter.Seq2[string, []*Tool] {
 32	return allTools.Seq2()
 33}
 34
 35// RunTool runs an MCP tool with the given input parameters.
 36func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string, input string) (ToolResult, error) {
 37	var args map[string]any
 38	if err := json.Unmarshal([]byte(input), &args); err != nil {
 39		return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err)
 40	}
 41
 42	c, err := getOrRenewClient(ctx, cfg, name)
 43	if err != nil {
 44		return ToolResult{}, err
 45	}
 46	result, err := c.CallTool(ctx, &mcp.CallToolParams{
 47		Name:      toolName,
 48		Arguments: args,
 49	})
 50	if err != nil {
 51		return ToolResult{}, err
 52	}
 53
 54	if len(result.Content) == 0 {
 55		return ToolResult{Type: "text", Content: ""}, nil
 56	}
 57
 58	var textParts []string
 59	var imageData []byte
 60	var imageMimeType string
 61	var audioData []byte
 62	var audioMimeType string
 63
 64	for _, v := range result.Content {
 65		switch content := v.(type) {
 66		case *mcp.TextContent:
 67			textParts = append(textParts, content.Text)
 68		case *mcp.ImageContent:
 69			if imageData == nil {
 70				imageData = content.Data
 71				imageMimeType = content.MIMEType
 72			}
 73		case *mcp.AudioContent:
 74			if audioData == nil {
 75				audioData = content.Data
 76				audioMimeType = content.MIMEType
 77			}
 78		default:
 79			textParts = append(textParts, fmt.Sprintf("%v", v))
 80		}
 81	}
 82
 83	textContent := strings.Join(textParts, "\n")
 84
 85	// We need to make sure the data is base64
 86	// when using something like docker + playwright the data was not returned correctly.
 87	if imageData != nil {
 88		return ToolResult{
 89			Type:      "image",
 90			Content:   textContent,
 91			Data:      ensureRawBytes(imageData),
 92			MediaType: imageMimeType,
 93		}, nil
 94	}
 95
 96	if audioData != nil {
 97		return ToolResult{
 98			Type:      "media",
 99			Content:   textContent,
100			Data:      ensureRawBytes(audioData),
101			MediaType: audioMimeType,
102		}, nil
103	}
104
105	return ToolResult{
106		Type:    "text",
107		Content: textContent,
108	}, nil
109}
110
111// RefreshTools gets the updated list of tools from the MCP and updates the
112// global state.
113func RefreshTools(ctx context.Context, cfg *config.ConfigStore, name string) {
114	session, ok := sessions.Get(name)
115	if !ok {
116		slog.Warn("Refresh tools: no session", "name", name)
117		return
118	}
119
120	tools, err := getTools(ctx, session)
121	if err != nil {
122		updateState(name, StateError, err, nil, Counts{})
123		return
124	}
125
126	toolCount := updateTools(cfg, name, tools)
127
128	prev, _ := states.Get(name)
129	prev.Counts.Tools = toolCount
130	updateState(name, StateConnected, nil, session, prev.Counts)
131}
132
133func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) {
134	// Always call ListTools to get the actual available tools.
135	// The InitializeResult Capabilities.Tools field may be an empty object {},
136	// which is valid per MCP spec, but we still need to call ListTools to discover tools.
137	result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
138	if err != nil {
139		return nil, err
140	}
141	return result.Tools, nil
142}
143
144func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int {
145	mcpCfg, ok := cfg.Config().MCP[name]
146	if ok {
147		tools = filterTools(mcpCfg, tools)
148	}
149	if len(tools) == 0 {
150		allTools.Del(name)
151		return 0
152	}
153	allTools.Set(name, tools)
154	return len(tools)
155}
156
157// filterTools filters tools based on enabled_tools (allow list) and
158// disabled_tools (deny list) from the MCP config.
159func filterTools(mcpCfg config.MCPConfig, tools []*Tool) []*Tool {
160	if len(mcpCfg.EnabledTools) > 0 {
161		filtered := make([]*Tool, 0, len(mcpCfg.EnabledTools))
162		for _, tool := range tools {
163			if slices.Contains(mcpCfg.EnabledTools, tool.Name) {
164				filtered = append(filtered, tool)
165			}
166		}
167		tools = filtered
168	}
169
170	if len(mcpCfg.DisabledTools) > 0 {
171		filtered := make([]*Tool, 0, len(tools))
172		for _, tool := range tools {
173			if !slices.Contains(mcpCfg.DisabledTools, tool.Name) {
174				filtered = append(filtered, tool)
175			}
176		}
177		tools = filtered
178	}
179
180	return tools
181}
182
183// ensureRawBytes normalizes MCP media data into raw binary bytes.
184//
185// The MCP Go SDK's json.Unmarshal normally base64-decodes
186// ImageContent.Data into raw bytes automatically. However, some MCP
187// transports (notably Docker over stdio) can deliver data in
188// unexpected formats. This function handles both cases:
189//
190//   - If data looks like a valid base64 string (ASCII-only, decodable)
191//     it is decoded and the raw bytes are returned.
192//   - If data is already raw binary (contains bytes > 127) it is
193//     returned as-is.
194func ensureRawBytes(data []byte) []byte {
195	if len(data) == 0 {
196		return data
197	}
198
199	normalized := normalizeBase64Input(data)
200	if decoded, ok := decodeBase64(normalized); ok {
201		return decoded
202	}
203
204	// Already raw binary — return unchanged.
205	return data
206}
207
208func normalizeBase64Input(data []byte) []byte {
209	normalized := strings.Join(strings.Fields(string(data)), "")
210	return []byte(normalized)
211}
212
213func decodeBase64(data []byte) ([]byte, bool) {
214	if len(data) == 0 {
215		return data, true
216	}
217
218	for _, b := range data {
219		if b > 127 {
220			return nil, false
221		}
222	}
223
224	s := string(data)
225	decoded, err := base64.StdEncoding.DecodeString(s)
226	if err == nil {
227		return decoded, true
228	}
229	decoded, err = base64.RawStdEncoding.DecodeString(s)
230	if err == nil {
231		return decoded, true
232	}
233	return nil, false
234}