1package mcp
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "iter"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/csync"
12 "github.com/modelcontextprotocol/go-sdk/mcp"
13)
14
15type Tool = mcp.Tool
16
17// ToolResult represents the result of running an MCP tool.
18type ToolResult struct {
19 Type string
20 Content string
21 Data []byte
22 MediaType string
23}
24
25var allTools = csync.NewMap[string, []*Tool]()
26
27// Tools returns all available MCP tools.
28func Tools() iter.Seq2[string, []*Tool] {
29 return allTools.Seq2()
30}
31
32// RunTool runs an MCP tool with the given input parameters.
33func RunTool(ctx context.Context, name, toolName string, input string) (ToolResult, error) {
34 var args map[string]any
35 if err := json.Unmarshal([]byte(input), &args); err != nil {
36 return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err)
37 }
38
39 c, err := getOrRenewClient(ctx, name)
40 if err != nil {
41 return ToolResult{}, err
42 }
43 result, err := c.CallTool(ctx, &mcp.CallToolParams{
44 Name: toolName,
45 Arguments: args,
46 })
47 if err != nil {
48 return ToolResult{}, err
49 }
50
51 if len(result.Content) == 0 {
52 return ToolResult{Type: "text", Content: ""}, nil
53 }
54
55 var textParts []string
56 var imageData []byte
57 var imageMimeType string
58 var audioData []byte
59 var audioMimeType string
60
61 for _, v := range result.Content {
62 switch content := v.(type) {
63 case *mcp.TextContent:
64 textParts = append(textParts, content.Text)
65 case *mcp.ImageContent:
66 if imageData == nil {
67 imageData = content.Data
68 imageMimeType = content.MIMEType
69 }
70 case *mcp.AudioContent:
71 if audioData == nil {
72 audioData = content.Data
73 audioMimeType = content.MIMEType
74 }
75 default:
76 textParts = append(textParts, fmt.Sprintf("%v", v))
77 }
78 }
79
80 textContent := strings.Join(textParts, "\n")
81
82 // MCP SDK returns Data as already base64-encoded, so we use it directly.
83 if imageData != nil {
84 return ToolResult{
85 Type: "image",
86 Content: textContent,
87 Data: imageData,
88 MediaType: imageMimeType,
89 }, nil
90 }
91
92 if audioData != nil {
93 return ToolResult{
94 Type: "media",
95 Content: textContent,
96 Data: audioData,
97 MediaType: audioMimeType,
98 }, nil
99 }
100
101 return ToolResult{
102 Type: "text",
103 Content: textContent,
104 }, nil
105}
106
107// RefreshTools gets the updated list of tools from the MCP and updates the
108// global state.
109func RefreshTools(ctx context.Context, name string) {
110 session, ok := sessions.Get(name)
111 if !ok {
112 slog.Warn("refresh tools: no session", "name", name)
113 return
114 }
115
116 tools, err := getTools(ctx, session)
117 if err != nil {
118 updateState(name, StateError, err, nil, Counts{})
119 return
120 }
121
122 updateTools(name, tools)
123
124 prev, _ := states.Get(name)
125 prev.Counts.Tools = len(tools)
126 updateState(name, StateConnected, nil, session, prev.Counts)
127}
128
129func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
130 // Always call ListTools to get the actual available tools.
131 // The InitializeResult Capabilities.Tools field may be an empty object {},
132 // which is valid per MCP spec, but we still need to call ListTools to discover tools.
133 result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
134 if err != nil {
135 return nil, err
136 }
137 return result.Tools, nil
138}
139
140func updateTools(name string, tools []*Tool) {
141 if len(tools) == 0 {
142 allTools.Del(name)
143 return
144 }
145 allTools.Set(name, tools)
146}