1package mcp
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "iter"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/csync"
12 "github.com/modelcontextprotocol/go-sdk/mcp"
13)
14
15type Tool = mcp.Tool
16
17var (
18 allTools = csync.NewMap[string, *Tool]()
19 clientTools = csync.NewMap[string, []*Tool]()
20)
21
22// Tools returns all available MCP tools.
23func Tools() iter.Seq2[string, *Tool] {
24 return allTools.Seq2()
25}
26
27// RunTool runs an MCP tool with the given input parameters.
28func RunTool(ctx context.Context, name, toolName string, input string) (string, error) {
29 var args map[string]any
30 if err := json.Unmarshal([]byte(input), &args); err != nil {
31 return "", fmt.Errorf("error parsing parameters: %s", err)
32 }
33
34 c, err := getOrRenewClient(ctx, name)
35 if err != nil {
36 return "", err
37 }
38 result, err := c.CallTool(ctx, &mcp.CallToolParams{
39 Name: toolName,
40 Arguments: args,
41 })
42 if err != nil {
43 return "", err
44 }
45
46 output := make([]string, 0, len(result.Content))
47 for _, v := range result.Content {
48 if vv, ok := v.(*mcp.TextContent); ok {
49 output = append(output, vv.Text)
50 } else {
51 output = append(output, fmt.Sprintf("%v", v))
52 }
53 }
54 return strings.Join(output, "\n"), nil
55}
56
57// RefreshTools gets the updated list of tools from the MCP and updates the
58// global state.
59func RefreshTools(ctx context.Context, name string) {
60 session, ok := sessions.Get(name)
61 if !ok {
62 slog.Warn("refresh tools: no session", "name", name)
63 return
64 }
65
66 tools, err := getTools(ctx, session)
67 if err != nil {
68 updateState(name, StateError, err, nil, Counts{})
69 return
70 }
71
72 updateTools(name, tools)
73
74 prev, _ := states.Get(name)
75 updateState(name, StateConnected, nil, session, Counts{
76 Tools: len(tools),
77 Prompts: prev.Counts.Prompts,
78 })
79}
80
81func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
82 if session.InitializeResult().Capabilities.Tools == nil {
83 return nil, nil
84 }
85 result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
86 if err != nil {
87 return nil, err
88 }
89 return result.Tools, nil
90}
91
92// updateTools updates the global mcpTools and mcpClient2Tools maps
93func updateTools(name string, tools []*Tool) {
94 if len(tools) == 0 {
95 clientTools.Del(name)
96 } else {
97 clientTools.Set(name, tools)
98 }
99 for name, tools := range clientTools.Seq2() {
100 for _, t := range tools {
101 allTools.Set(name, t)
102 }
103 }
104}