tools.go

  1package mcp
  2
  3import (
  4	"context"
  5	"encoding/base64"
  6	"encoding/json"
  7	"fmt"
  8	"iter"
  9	"log/slog"
 10	"slices"
 11	"strings"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/modelcontextprotocol/go-sdk/mcp"
 16)
 17
 18type Tool = mcp.Tool
 19
 20// ToolResult represents the result of running an MCP tool.
 21type ToolResult struct {
 22	Type      string
 23	Content   string
 24	Data      []byte
 25	MediaType string
 26}
 27
 28var allTools = csync.NewMap[string, []*Tool]()
 29
 30// Tools returns all available MCP tools.
 31func Tools() iter.Seq2[string, []*Tool] {
 32	return allTools.Seq2()
 33}
 34
 35// RunTool runs an MCP tool with the given input parameters.
 36func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string, input string) (ToolResult, error) {
 37	var args map[string]any
 38	if err := json.Unmarshal([]byte(input), &args); err != nil {
 39		return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err)
 40	}
 41
 42	c, err := getOrRenewClient(ctx, cfg, name)
 43	if err != nil {
 44		return ToolResult{}, err
 45	}
 46	result, err := c.CallTool(ctx, &mcp.CallToolParams{
 47		Name:      toolName,
 48		Arguments: args,
 49	})
 50	if err != nil {
 51		return ToolResult{}, err
 52	}
 53
 54	if len(result.Content) == 0 {
 55		return ToolResult{Type: "text", Content: ""}, nil
 56	}
 57
 58	var textParts []string
 59	var imageData []byte
 60	var imageMimeType string
 61	var audioData []byte
 62	var audioMimeType string
 63
 64	for _, v := range result.Content {
 65		switch content := v.(type) {
 66		case *mcp.TextContent:
 67			textParts = append(textParts, content.Text)
 68		case *mcp.ImageContent:
 69			if imageData == nil {
 70				imageData = content.Data
 71				imageMimeType = content.MIMEType
 72			}
 73		case *mcp.AudioContent:
 74			if audioData == nil {
 75				audioData = content.Data
 76				audioMimeType = content.MIMEType
 77			}
 78		default:
 79			textParts = append(textParts, fmt.Sprintf("%v", v))
 80		}
 81	}
 82
 83	textContent := strings.Join(textParts, "\n")
 84
 85	// We need to to make sure the data is base64
 86	// when using something like docker + playwright the data was not returned correctly
 87	if imageData != nil {
 88		return ToolResult{
 89			Type:      "image",
 90			Content:   textContent,
 91			Data:      ensureBase64(imageData),
 92			MediaType: imageMimeType,
 93		}, nil
 94	}
 95
 96	if audioData != nil {
 97		return ToolResult{
 98			Type:      "media",
 99			Content:   textContent,
100			Data:      ensureBase64(audioData),
101			MediaType: audioMimeType,
102		}, nil
103	}
104
105	return ToolResult{
106		Type:    "text",
107		Content: textContent,
108	}, nil
109}
110
111// RefreshTools gets the updated list of tools from the MCP and updates the
112// global state.
113func RefreshTools(ctx context.Context, cfg *config.ConfigStore, name string) {
114	session, ok := sessions.Get(name)
115	if !ok {
116		slog.Warn("Refresh tools: no session", "name", name)
117		return
118	}
119
120	tools, err := getTools(ctx, session)
121	if err != nil {
122		updateState(name, StateError, err, nil, Counts{})
123		return
124	}
125
126	toolCount := updateTools(cfg, name, tools)
127
128	prev, _ := states.Get(name)
129	prev.Counts.Tools = toolCount
130	updateState(name, StateConnected, nil, session, prev.Counts)
131}
132
133func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) {
134	// Always call ListTools to get the actual available tools.
135	// The InitializeResult Capabilities.Tools field may be an empty object {},
136	// which is valid per MCP spec, but we still need to call ListTools to discover tools.
137	result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
138	if err != nil {
139		return nil, err
140	}
141	return result.Tools, nil
142}
143
144func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int {
145	tools = filterDisabledTools(cfg, name, tools)
146	if len(tools) == 0 {
147		allTools.Del(name)
148		return 0
149	}
150	allTools.Set(name, tools)
151	return len(tools)
152}
153
154// filterDisabledTools removes tools that are disabled via config.
155func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool) []*Tool {
156	mcpCfg, ok := cfg.Config().MCP[mcpName]
157	if !ok || len(mcpCfg.DisabledTools) == 0 {
158		return tools
159	}
160
161	filtered := make([]*Tool, 0, len(tools))
162	for _, tool := range tools {
163		if !slices.Contains(mcpCfg.DisabledTools, tool.Name) {
164			filtered = append(filtered, tool)
165		}
166	}
167	return filtered
168}
169
170// ensureBase64 checks if data is valid base64 and returns it as-is if so,
171// otherwise encodes the raw binary data to base64.
172func ensureBase64(data []byte) []byte {
173	// Check if the data is already valid base64 by attempting to decode it.
174	// Valid base64 should only contain ASCII characters (A-Z, a-z, 0-9, +, /, =).
175	if isValidBase64(data) {
176		return data
177	}
178	// Data is raw binary, encode it to base64.
179	encoded := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
180	base64.StdEncoding.Encode(encoded, data)
181	return encoded
182}
183
184// isValidBase64 checks if the data appears to be valid base64-encoded content.
185func isValidBase64(data []byte) bool {
186	if len(data) == 0 {
187		return true
188	}
189	// Base64 strings should only contain ASCII characters.
190	for _, b := range data {
191		if b > 127 {
192			return false
193		}
194	}
195	// Try to decode to verify it's valid base64.
196	_, err := base64.StdEncoding.DecodeString(string(data))
197	return err == nil
198}