tools.go

  1package mcp
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"iter"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/csync"
 12	"github.com/modelcontextprotocol/go-sdk/mcp"
 13)
 14
 15type Tool = mcp.Tool
 16
 17// ToolResult represents the result of running an MCP tool.
 18type ToolResult struct {
 19	Type      string
 20	Content   string
 21	Data      []byte
 22	MediaType string
 23}
 24
 25var allTools = csync.NewMap[string, []*Tool]()
 26
 27// Tools returns all available MCP tools.
 28func Tools() iter.Seq2[string, []*Tool] {
 29	return allTools.Seq2()
 30}
 31
 32// RunTool runs an MCP tool with the given input parameters.
 33func RunTool(ctx context.Context, name, toolName string, input string) (ToolResult, error) {
 34	var args map[string]any
 35	if err := json.Unmarshal([]byte(input), &args); err != nil {
 36		return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err)
 37	}
 38
 39	c, err := getOrRenewClient(ctx, name)
 40	if err != nil {
 41		return ToolResult{}, err
 42	}
 43	result, err := c.CallTool(ctx, &mcp.CallToolParams{
 44		Name:      toolName,
 45		Arguments: args,
 46	})
 47	if err != nil {
 48		return ToolResult{}, err
 49	}
 50
 51	if len(result.Content) == 0 {
 52		return ToolResult{Type: "text", Content: ""}, nil
 53	}
 54
 55	var textParts []string
 56	var imageData []byte
 57	var imageMimeType string
 58	var audioData []byte
 59	var audioMimeType string
 60
 61	for _, v := range result.Content {
 62		switch content := v.(type) {
 63		case *mcp.TextContent:
 64			textParts = append(textParts, content.Text)
 65		case *mcp.ImageContent:
 66			if imageData == nil {
 67				imageData = content.Data
 68				imageMimeType = content.MIMEType
 69			}
 70		case *mcp.AudioContent:
 71			if audioData == nil {
 72				audioData = content.Data
 73				audioMimeType = content.MIMEType
 74			}
 75		default:
 76			textParts = append(textParts, fmt.Sprintf("%v", v))
 77		}
 78	}
 79
 80	textContent := strings.Join(textParts, "\n")
 81
 82	// MCP SDK returns Data as already base64-encoded, so we use it directly.
 83	if imageData != nil {
 84		return ToolResult{
 85			Type:      "image",
 86			Content:   textContent,
 87			Data:      imageData,
 88			MediaType: imageMimeType,
 89		}, nil
 90	}
 91
 92	if audioData != nil {
 93		return ToolResult{
 94			Type:      "media",
 95			Content:   textContent,
 96			Data:      audioData,
 97			MediaType: audioMimeType,
 98		}, nil
 99	}
100
101	return ToolResult{
102		Type:    "text",
103		Content: textContent,
104	}, nil
105}
106
107// RefreshTools gets the updated list of tools from the MCP and updates the
108// global state.
109func RefreshTools(ctx context.Context, name string) {
110	session, ok := sessions.Get(name)
111	if !ok {
112		slog.Warn("refresh tools: no session", "name", name)
113		return
114	}
115
116	tools, err := getTools(ctx, session)
117	if err != nil {
118		updateState(name, StateError, err, nil, Counts{})
119		return
120	}
121
122	updateTools(name, tools)
123
124	prev, _ := states.Get(name)
125	prev.Counts.Tools = len(tools)
126	updateState(name, StateConnected, nil, session, prev.Counts)
127}
128
129func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
130	// Always call ListTools to get the actual available tools.
131	// The InitializeResult Capabilities.Tools field may be an empty object {},
132	// which is valid per MCP spec, but we still need to call ListTools to discover tools.
133	result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
134	if err != nil {
135		return nil, err
136	}
137	return result.Tools, nil
138}
139
140func updateTools(name string, tools []*Tool) {
141	if len(tools) == 0 {
142		allTools.Del(name)
143		return
144	}
145	allTools.Set(name, tools)
146}