tools.go

  1package mcp
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"iter"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/csync"
 12	"github.com/modelcontextprotocol/go-sdk/mcp"
 13)
 14
 15type Tool = mcp.Tool
 16
 17var (
 18	allTools    = csync.NewMap[string, *Tool]()
 19	clientTools = csync.NewMap[string, []*Tool]()
 20)
 21
 22// Tools returns all available MCP tools.
 23func Tools() iter.Seq2[string, *Tool] {
 24	return allTools.Seq2()
 25}
 26
 27// RunTool runs an MCP tool with the given input parameters.
 28func RunTool(ctx context.Context, name, toolName string, input string) (string, error) {
 29	var args map[string]any
 30	if err := json.Unmarshal([]byte(input), &args); err != nil {
 31		return "", fmt.Errorf("error parsing parameters: %s", err)
 32	}
 33
 34	c, err := getOrRenewClient(ctx, name)
 35	if err != nil {
 36		return "", err
 37	}
 38	result, err := c.CallTool(ctx, &mcp.CallToolParams{
 39		Name:      toolName,
 40		Arguments: args,
 41	})
 42	if err != nil {
 43		return "", err
 44	}
 45
 46	output := make([]string, 0, len(result.Content))
 47	for _, v := range result.Content {
 48		if vv, ok := v.(*mcp.TextContent); ok {
 49			output = append(output, vv.Text)
 50		} else {
 51			output = append(output, fmt.Sprintf("%v", v))
 52		}
 53	}
 54	return strings.Join(output, "\n"), nil
 55}
 56
 57// RefreshTools gets the updated list of tools from the MCP and updates the
 58// global state.
 59func RefreshTools(ctx context.Context, name string) {
 60	session, ok := sessions.Get(name)
 61	if !ok {
 62		slog.Warn("refresh tools: no session", "name", name)
 63		return
 64	}
 65
 66	tools, err := getTools(ctx, session)
 67	if err != nil {
 68		updateState(name, StateError, err, nil, Counts{})
 69		return
 70	}
 71
 72	updateTools(name, tools)
 73
 74	prev, _ := states.Get(name)
 75	updateState(name, StateConnected, nil, session, Counts{
 76		Tools:   len(tools),
 77		Prompts: prev.Counts.Prompts,
 78	})
 79}
 80
 81func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
 82	if session.InitializeResult().Capabilities.Tools == nil {
 83		return nil, nil
 84	}
 85	result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
 86	if err != nil {
 87		return nil, err
 88	}
 89	return result.Tools, nil
 90}
 91
 92// updateTools updates the global mcpTools and mcpClient2Tools maps
 93func updateTools(name string, tools []*Tool) {
 94	if len(tools) == 0 {
 95		clientTools.Del(name)
 96	} else {
 97		clientTools.Set(name, tools)
 98	}
 99	for name, tools := range clientTools.Seq2() {
100		for _, t := range tools {
101			allTools.Set(name, t)
102		}
103	}
104}