tools.go

 1package mcp
 2
 3import (
 4	"context"
 5	"encoding/json"
 6	"fmt"
 7	"iter"
 8	"strings"
 9
10	"github.com/charmbracelet/crush/internal/csync"
11	"github.com/modelcontextprotocol/go-sdk/mcp"
12)
13
14type Tool = mcp.Tool
15
16var (
17	allTools     = csync.NewMap[string, *Tool]()
18	client2Tools = csync.NewMap[string, []*Tool]()
19)
20
21// GetTools returns all available MCP tools.
22func GetTools() iter.Seq2[string, *Tool] {
23	return allTools.Seq2()
24}
25
26// RunTool runs an MCP tool with the given input parameters.
27func RunTool(ctx context.Context, name, toolName string, input string) (string, error) {
28	var args map[string]any
29	if err := json.Unmarshal([]byte(input), &args); err != nil {
30		return "", fmt.Errorf("error parsing parameters: %s", err)
31	}
32
33	c, err := getOrRenewClient(ctx, name)
34	if err != nil {
35		return "", err
36	}
37	result, err := c.CallTool(ctx, &mcp.CallToolParams{
38		Name:      toolName,
39		Arguments: args,
40	})
41	if err != nil {
42		return "", err
43	}
44
45	output := make([]string, 0, len(result.Content))
46	for _, v := range result.Content {
47		if vv, ok := v.(*mcp.TextContent); ok {
48			output = append(output, vv.Text)
49		} else {
50			output = append(output, fmt.Sprintf("%v", v))
51		}
52	}
53	return strings.Join(output, "\n"), nil
54}
55
56func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
57	if session.InitializeResult().Capabilities.Tools == nil {
58		return nil, nil
59	}
60	result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
61	if err != nil {
62		return nil, err
63	}
64	return result.Tools, nil
65}
66
67// updateTools updates the global mcpTools and mcpClient2Tools maps
68func updateTools(mcpName string, tools []*Tool) {
69	if len(tools) == 0 {
70		client2Tools.Del(mcpName)
71	} else {
72		client2Tools.Set(mcpName, tools)
73	}
74	for name, tools := range client2Tools.Seq2() {
75		for _, t := range tools {
76			allTools.Set(name, t)
77		}
78	}
79}