tools.go

  1package mcp
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"iter"
  8	"strings"
  9
 10	"charm.land/fantasy"
 11	"github.com/charmbracelet/crush/internal/agent/tools"
 12	"github.com/charmbracelet/crush/internal/csync"
 13	"github.com/charmbracelet/crush/internal/permission"
 14	"github.com/modelcontextprotocol/go-sdk/mcp"
 15)
 16
 17var (
 18	allTools     = csync.NewMap[string, *Tool]()
 19	client2Tools = csync.NewMap[string, []*Tool]()
 20)
 21
 22// GetMCPTools returns all available MCP tools.
 23func GetMCPTools() iter.Seq[*Tool] {
 24	return allTools.Seq()
 25}
 26
 27type Tool struct {
 28	mcpName         string
 29	tool            *mcp.Tool
 30	permissions     permission.Service
 31	workingDir      string
 32	providerOptions fantasy.ProviderOptions
 33}
 34
 35func (m *Tool) SetProviderOptions(opts fantasy.ProviderOptions) {
 36	m.providerOptions = opts
 37}
 38
 39func (m *Tool) ProviderOptions() fantasy.ProviderOptions {
 40	return m.providerOptions
 41}
 42
 43func (m *Tool) Name() string {
 44	return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
 45}
 46
 47func (m *Tool) MCP() string {
 48	return m.mcpName
 49}
 50
 51func (m *Tool) MCPToolName() string {
 52	return m.tool.Name
 53}
 54
 55func (m *Tool) Info() fantasy.ToolInfo {
 56	parameters := make(map[string]any)
 57	required := make([]string, 0)
 58
 59	if input, ok := m.tool.InputSchema.(map[string]any); ok {
 60		if props, ok := input["properties"].(map[string]any); ok {
 61			parameters = props
 62		}
 63		if req, ok := input["required"].([]any); ok {
 64			// Convert []any -> []string when elements are strings
 65			for _, v := range req {
 66				if s, ok := v.(string); ok {
 67					required = append(required, s)
 68				}
 69			}
 70		} else if reqStr, ok := input["required"].([]string); ok {
 71			// Handle case where it's already []string
 72			required = reqStr
 73		}
 74	}
 75
 76	return fantasy.ToolInfo{
 77		Name:        fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name),
 78		Description: m.tool.Description,
 79		Parameters:  parameters,
 80		Required:    required,
 81	}
 82}
 83
 84func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) {
 85	sessionID := tools.GetSessionFromContext(ctx)
 86	if sessionID == "" {
 87		return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
 88	}
 89	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
 90	p := m.permissions.Request(
 91		permission.CreatePermissionRequest{
 92			SessionID:   sessionID,
 93			ToolCallID:  params.ID,
 94			Path:        m.workingDir,
 95			ToolName:    m.Info().Name,
 96			Action:      "execute",
 97			Description: permissionDescription,
 98			Params:      params.Input,
 99		},
100	)
101	if !p {
102		return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
103	}
104
105	return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
106}
107
108func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) {
109	var args map[string]any
110	if err := json.Unmarshal([]byte(input), &args); err != nil {
111		return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
112	}
113
114	c, err := getOrRenewClient(ctx, name)
115	if err != nil {
116		return fantasy.NewTextErrorResponse(err.Error()), nil
117	}
118	result, err := c.CallTool(ctx, &mcp.CallToolParams{
119		Name:      toolName,
120		Arguments: args,
121	})
122	if err != nil {
123		return fantasy.NewTextErrorResponse(err.Error()), nil
124	}
125
126	output := make([]string, 0, len(result.Content))
127	for _, v := range result.Content {
128		if vv, ok := v.(*mcp.TextContent); ok {
129			output = append(output, vv.Text)
130		} else {
131			output = append(output, fmt.Sprintf("%v", v))
132		}
133	}
134	return fantasy.NewTextResponse(strings.Join(output, "\n")), nil
135}
136
137func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*Tool, error) {
138	if c.InitializeResult().Capabilities.Tools == nil {
139		return nil, nil
140	}
141	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
142	if err != nil {
143		return nil, err
144	}
145	mcpTools := make([]*Tool, 0, len(result.Tools))
146	for _, tool := range result.Tools {
147		mcpTools = append(mcpTools, &Tool{
148			mcpName:     name,
149			tool:        tool,
150			permissions: permissions,
151			workingDir:  workingDir,
152		})
153	}
154	return mcpTools, nil
155}
156
157// updateTools updates the global mcpTools and mcpClient2Tools maps
158func updateTools(mcpName string, tools []*Tool) {
159	if len(tools) == 0 {
160		client2Tools.Del(mcpName)
161	} else {
162		client2Tools.Set(mcpName, tools)
163	}
164	for _, tools := range client2Tools.Seq2() {
165		for _, t := range tools {
166			allTools.Set(t.Name(), t)
167		}
168	}
169}