1package mcp
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "iter"
8 "strings"
9
10 "github.com/charmbracelet/crush/internal/csync"
11 "github.com/modelcontextprotocol/go-sdk/mcp"
12)
13
14type Tool = mcp.Tool
15
16var (
17 allTools = csync.NewMap[string, *Tool]()
18 client2Tools = csync.NewMap[string, []*Tool]()
19)
20
21// GetTools returns all available MCP tools.
22func GetTools() iter.Seq2[string, *Tool] {
23 return allTools.Seq2()
24}
25
26// RunTool runs an MCP tool with the given input parameters.
27func RunTool(ctx context.Context, name, toolName string, input string) (string, error) {
28 var args map[string]any
29 if err := json.Unmarshal([]byte(input), &args); err != nil {
30 return "", fmt.Errorf("error parsing parameters: %s", err)
31 }
32
33 c, err := getOrRenewClient(ctx, name)
34 if err != nil {
35 return "", err
36 }
37 result, err := c.CallTool(ctx, &mcp.CallToolParams{
38 Name: toolName,
39 Arguments: args,
40 })
41 if err != nil {
42 return "", err
43 }
44
45 output := make([]string, 0, len(result.Content))
46 for _, v := range result.Content {
47 if vv, ok := v.(*mcp.TextContent); ok {
48 output = append(output, vv.Text)
49 } else {
50 output = append(output, fmt.Sprintf("%v", v))
51 }
52 }
53 return strings.Join(output, "\n"), nil
54}
55
56func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
57 if session.InitializeResult().Capabilities.Tools == nil {
58 return nil, nil
59 }
60 result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
61 if err != nil {
62 return nil, err
63 }
64 return result.Tools, nil
65}
66
67// updateTools updates the global mcpTools and mcpClient2Tools maps
68func updateTools(mcpName string, tools []*Tool) {
69 if len(tools) == 0 {
70 client2Tools.Del(mcpName)
71 } else {
72 client2Tools.Set(mcpName, tools)
73 }
74 for name, tools := range client2Tools.Seq2() {
75 for _, t := range tools {
76 allTools.Set(name, t)
77 }
78 }
79}