1package mcp
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "iter"
8 "log/slog"
9 "slices"
10 "strings"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/csync"
14 "github.com/modelcontextprotocol/go-sdk/mcp"
15)
16
17type Tool = mcp.Tool
18
19// ToolResult represents the result of running an MCP tool.
20type ToolResult struct {
21 Type string
22 Content string
23 Data []byte
24 MediaType string
25}
26
27var allTools = csync.NewMap[string, []*Tool]()
28
29// Tools returns all available MCP tools.
30func Tools() iter.Seq2[string, []*Tool] {
31 return allTools.Seq2()
32}
33
34// RunTool runs an MCP tool with the given input parameters.
35func RunTool(ctx context.Context, name, toolName string, input string) (ToolResult, error) {
36 var args map[string]any
37 if err := json.Unmarshal([]byte(input), &args); err != nil {
38 return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err)
39 }
40
41 c, err := getOrRenewClient(ctx, name)
42 if err != nil {
43 return ToolResult{}, err
44 }
45 result, err := c.CallTool(ctx, &mcp.CallToolParams{
46 Name: toolName,
47 Arguments: args,
48 })
49 if err != nil {
50 return ToolResult{}, err
51 }
52
53 if len(result.Content) == 0 {
54 return ToolResult{Type: "text", Content: ""}, nil
55 }
56
57 var textParts []string
58 var imageData []byte
59 var imageMimeType string
60 var audioData []byte
61 var audioMimeType string
62
63 for _, v := range result.Content {
64 switch content := v.(type) {
65 case *mcp.TextContent:
66 textParts = append(textParts, content.Text)
67 case *mcp.ImageContent:
68 if imageData == nil {
69 imageData = content.Data
70 imageMimeType = content.MIMEType
71 }
72 case *mcp.AudioContent:
73 if audioData == nil {
74 audioData = content.Data
75 audioMimeType = content.MIMEType
76 }
77 default:
78 textParts = append(textParts, fmt.Sprintf("%v", v))
79 }
80 }
81
82 textContent := strings.Join(textParts, "\n")
83
84 // MCP SDK returns Data as already base64-encoded, so we use it directly.
85 if imageData != nil {
86 return ToolResult{
87 Type: "image",
88 Content: textContent,
89 Data: imageData,
90 MediaType: imageMimeType,
91 }, nil
92 }
93
94 if audioData != nil {
95 return ToolResult{
96 Type: "media",
97 Content: textContent,
98 Data: audioData,
99 MediaType: audioMimeType,
100 }, nil
101 }
102
103 return ToolResult{
104 Type: "text",
105 Content: textContent,
106 }, nil
107}
108
109// RefreshTools gets the updated list of tools from the MCP and updates the
110// global state.
111func RefreshTools(ctx context.Context, name string) {
112 session, ok := sessions.Get(name)
113 if !ok {
114 slog.Warn("refresh tools: no session", "name", name)
115 return
116 }
117
118 tools, err := getTools(ctx, session)
119 if err != nil {
120 updateState(name, StateError, err, nil, Counts{})
121 return
122 }
123
124 toolCount := updateTools(name, tools)
125
126 prev, _ := states.Get(name)
127 prev.Counts.Tools = toolCount
128 updateState(name, StateConnected, nil, session, prev.Counts)
129}
130
131func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
132 // Always call ListTools to get the actual available tools.
133 // The InitializeResult Capabilities.Tools field may be an empty object {},
134 // which is valid per MCP spec, but we still need to call ListTools to discover tools.
135 result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
136 if err != nil {
137 return nil, err
138 }
139 return result.Tools, nil
140}
141
142func updateTools(name string, tools []*Tool) int {
143 tools = filterDisabledTools(name, tools)
144 if len(tools) == 0 {
145 allTools.Del(name)
146 return 0
147 }
148 allTools.Set(name, tools)
149 return len(tools)
150}
151
152// filterDisabledTools removes tools that are disabled via config.
153func filterDisabledTools(mcpName string, tools []*Tool) []*Tool {
154 cfg := config.Get()
155 mcpCfg, ok := cfg.MCP[mcpName]
156 if !ok || len(mcpCfg.DisabledTools) == 0 {
157 return tools
158 }
159
160 filtered := make([]*Tool, 0, len(tools))
161 for _, tool := range tools {
162 if !slices.Contains(mcpCfg.DisabledTools, tool.Name) {
163 filtered = append(filtered, tool)
164 }
165 }
166 return filtered
167}