tools.go

  1package mcp
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"iter"
  8	"log/slog"
  9	"slices"
 10	"strings"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/csync"
 14	"github.com/modelcontextprotocol/go-sdk/mcp"
 15)
 16
 17type Tool = mcp.Tool
 18
 19// ToolResult represents the result of running an MCP tool.
 20type ToolResult struct {
 21	Type      string
 22	Content   string
 23	Data      []byte
 24	MediaType string
 25}
 26
 27var allTools = csync.NewMap[string, []*Tool]()
 28
 29// Tools returns all available MCP tools.
 30func Tools() iter.Seq2[string, []*Tool] {
 31	return allTools.Seq2()
 32}
 33
 34// RunTool runs an MCP tool with the given input parameters.
 35func RunTool(ctx context.Context, name, toolName string, input string) (ToolResult, error) {
 36	var args map[string]any
 37	if err := json.Unmarshal([]byte(input), &args); err != nil {
 38		return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err)
 39	}
 40
 41	c, err := getOrRenewClient(ctx, name)
 42	if err != nil {
 43		return ToolResult{}, err
 44	}
 45	result, err := c.CallTool(ctx, &mcp.CallToolParams{
 46		Name:      toolName,
 47		Arguments: args,
 48	})
 49	if err != nil {
 50		return ToolResult{}, err
 51	}
 52
 53	if len(result.Content) == 0 {
 54		return ToolResult{Type: "text", Content: ""}, nil
 55	}
 56
 57	var textParts []string
 58	var imageData []byte
 59	var imageMimeType string
 60	var audioData []byte
 61	var audioMimeType string
 62
 63	for _, v := range result.Content {
 64		switch content := v.(type) {
 65		case *mcp.TextContent:
 66			textParts = append(textParts, content.Text)
 67		case *mcp.ImageContent:
 68			if imageData == nil {
 69				imageData = content.Data
 70				imageMimeType = content.MIMEType
 71			}
 72		case *mcp.AudioContent:
 73			if audioData == nil {
 74				audioData = content.Data
 75				audioMimeType = content.MIMEType
 76			}
 77		default:
 78			textParts = append(textParts, fmt.Sprintf("%v", v))
 79		}
 80	}
 81
 82	textContent := strings.Join(textParts, "\n")
 83
 84	// MCP SDK returns Data as already base64-encoded, so we use it directly.
 85	if imageData != nil {
 86		return ToolResult{
 87			Type:      "image",
 88			Content:   textContent,
 89			Data:      imageData,
 90			MediaType: imageMimeType,
 91		}, nil
 92	}
 93
 94	if audioData != nil {
 95		return ToolResult{
 96			Type:      "media",
 97			Content:   textContent,
 98			Data:      audioData,
 99			MediaType: audioMimeType,
100		}, nil
101	}
102
103	return ToolResult{
104		Type:    "text",
105		Content: textContent,
106	}, nil
107}
108
109// RefreshTools gets the updated list of tools from the MCP and updates the
110// global state.
111func RefreshTools(ctx context.Context, name string) {
112	session, ok := sessions.Get(name)
113	if !ok {
114		slog.Warn("refresh tools: no session", "name", name)
115		return
116	}
117
118	tools, err := getTools(ctx, session)
119	if err != nil {
120		updateState(name, StateError, err, nil, Counts{})
121		return
122	}
123
124	toolCount := updateTools(name, tools)
125
126	prev, _ := states.Get(name)
127	prev.Counts.Tools = toolCount
128	updateState(name, StateConnected, nil, session, prev.Counts)
129}
130
131func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
132	// Always call ListTools to get the actual available tools.
133	// The InitializeResult Capabilities.Tools field may be an empty object {},
134	// which is valid per MCP spec, but we still need to call ListTools to discover tools.
135	result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
136	if err != nil {
137		return nil, err
138	}
139	return result.Tools, nil
140}
141
142func updateTools(name string, tools []*Tool) int {
143	tools = filterDisabledTools(name, tools)
144	if len(tools) == 0 {
145		allTools.Del(name)
146		return 0
147	}
148	allTools.Set(name, tools)
149	return len(tools)
150}
151
152// filterDisabledTools removes tools that are disabled via config.
153func filterDisabledTools(mcpName string, tools []*Tool) []*Tool {
154	cfg := config.Get()
155	mcpCfg, ok := cfg.MCP[mcpName]
156	if !ok || len(mcpCfg.DisabledTools) == 0 {
157		return tools
158	}
159
160	filtered := make([]*Tool, 0, len(tools))
161	for _, tool := range tools {
162		if !slices.Contains(mcpCfg.DisabledTools, tool.Name) {
163			filtered = append(filtered, tool)
164		}
165	}
166	return filtered
167}