diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 64d2c83a1bb68b32cd695a6ed411d25b746c12e6..c04593badfa002c71923090992ac1ebda992e6d1 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -18,7 +18,6 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/agent/prompt" "github.com/charmbracelet/crush/internal/agent/tools" - "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/history" @@ -345,7 +344,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan } } - for tool := range mcp.GetMCPTools() { + for _, tool := range tools.GetMCPTools(c.permissions, c.cfg.WorkingDir()) { if agent.AllowedMCP == nil { // No MCP restrictions filteredTools = append(filteredTools, tool) diff --git a/internal/agent/tools/mcp-tools.go b/internal/agent/tools/mcp-tools.go new file mode 100644 index 0000000000000000000000000000000000000000..22afb417ef5c4fb2b01046ec4bf3fe90826d371e --- /dev/null +++ b/internal/agent/tools/mcp-tools.go @@ -0,0 +1,110 @@ +package tools + +import ( + "context" + "fmt" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent/tools/mcp" + "github.com/charmbracelet/crush/internal/permission" +) + +// GetMCPTools gets all the currently available MCP tools. +func GetMCPTools(permissions permission.Service, wd string) []*Tool { + var result []*Tool + for name, tool := range mcp.GetMCPTools() { + result = append(result, &Tool{ + mcpName: name, + tool: tool, + permissions: permissions, + workingDir: wd, + }) + } + return result +} + +// Tool is a tool from a MCP. +type Tool struct { + mcpName string + tool *mcp.Tool + permissions permission.Service + workingDir string + providerOptions fantasy.ProviderOptions +} + +func (m *Tool) SetProviderOptions(opts fantasy.ProviderOptions) { + m.providerOptions = opts +} + +func (m *Tool) ProviderOptions() fantasy.ProviderOptions { + return m.providerOptions +} + +func (m *Tool) Name() string { + return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name) +} + +func (m *Tool) MCP() string { + return m.mcpName +} + +func (m *Tool) MCPToolName() string { + return m.tool.Name +} + +func (m *Tool) Info() fantasy.ToolInfo { + parameters := make(map[string]any) + required := make([]string, 0) + + if input, ok := m.tool.InputSchema.(map[string]any); ok { + if props, ok := input["properties"].(map[string]any); ok { + parameters = props + } + if req, ok := input["required"].([]any); ok { + // Convert []any -> []string when elements are strings + for _, v := range req { + if s, ok := v.(string); ok { + required = append(required, s) + } + } + } else if reqStr, ok := input["required"].([]string); ok { + // Handle case where it's already []string + required = reqStr + } + } + + return fantasy.ToolInfo{ + Name: fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name), + Description: m.tool.Description, + Parameters: parameters, + Required: required, + } +} + +func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) { + sessionID := GetSessionFromContext(ctx) + if sessionID == "" { + return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file") + } + permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name) + p := m.permissions.Request( + permission.CreatePermissionRequest{ + SessionID: sessionID, + ToolCallID: params.ID, + Path: m.workingDir, + ToolName: m.Info().Name, + Action: "execute", + Description: permissionDescription, + Params: params.Input, + }, + ) + if !p { + return fantasy.ToolResponse{}, permission.ErrorPermissionDenied + } + + content, err := mcp.RunTool(ctx, m.mcpName, m.tool.Name, params.Input) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + return fantasy.NewTextResponse(content), nil +} diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index 600c47517fc66d22b416c5362cb6019a81e59247..10cfdc87d3bfb78b7d723a13b0182540bd8bc50f 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -1,3 +1,5 @@ +// Package mcp provides functionality for managing Model Context Protocol (MCP) +// clients within the Crush application. package mcp import ( @@ -160,7 +162,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config return } - tools, err := getTools(ctx, name, permissions, session, cfg.WorkingDir()) + tools, err := getTools(ctx, session) if err != nil { slog.Error("error listing tools", "error", err) updateState(name, StateError, err, nil, Counts{}) diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go index 5b30514dcc31c69bbce4fee41637fa5505c8e684..a30b9de57b0e5ac35c41a666c105522121164f79 100644 --- a/internal/agent/tools/mcp/tools.go +++ b/internal/agent/tools/mcp/tools.go @@ -7,120 +7,39 @@ import ( "iter" "strings" - "charm.land/fantasy" - "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/permission" "github.com/modelcontextprotocol/go-sdk/mcp" ) +type Tool = mcp.Tool + var ( allTools = csync.NewMap[string, *Tool]() client2Tools = csync.NewMap[string, []*Tool]() ) -// GetMCPTools returns all available MCP tools. -func GetMCPTools() iter.Seq[*Tool] { - return allTools.Seq() -} - -type Tool struct { - mcpName string - tool *mcp.Tool - permissions permission.Service - workingDir string - providerOptions fantasy.ProviderOptions -} - -func (m *Tool) SetProviderOptions(opts fantasy.ProviderOptions) { - m.providerOptions = opts -} - -func (m *Tool) ProviderOptions() fantasy.ProviderOptions { - return m.providerOptions +// GetTools returns all available MCP tools. +func GetTools() iter.Seq2[string, *Tool] { + return allTools.Seq2() } -func (m *Tool) Name() string { - return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name) -} - -func (m *Tool) MCP() string { - return m.mcpName -} - -func (m *Tool) MCPToolName() string { - return m.tool.Name -} - -func (m *Tool) Info() fantasy.ToolInfo { - parameters := make(map[string]any) - required := make([]string, 0) - - if input, ok := m.tool.InputSchema.(map[string]any); ok { - if props, ok := input["properties"].(map[string]any); ok { - parameters = props - } - if req, ok := input["required"].([]any); ok { - // Convert []any -> []string when elements are strings - for _, v := range req { - if s, ok := v.(string); ok { - required = append(required, s) - } - } - } else if reqStr, ok := input["required"].([]string); ok { - // Handle case where it's already []string - required = reqStr - } - } - - return fantasy.ToolInfo{ - Name: fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name), - Description: m.tool.Description, - Parameters: parameters, - Required: required, - } -} - -func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) { - sessionID := tools.GetSessionFromContext(ctx) - if sessionID == "" { - return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file") - } - permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name) - p := m.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - ToolCallID: params.ID, - Path: m.workingDir, - ToolName: m.Info().Name, - Action: "execute", - Description: permissionDescription, - Params: params.Input, - }, - ) - if !p { - return fantasy.ToolResponse{}, permission.ErrorPermissionDenied - } - - return runTool(ctx, m.mcpName, m.tool.Name, params.Input) -} - -func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) { +// RunTool runs an MCP tool with the given input parameters. +func RunTool(ctx context.Context, name, toolName string, input string) (string, error) { var args map[string]any if err := json.Unmarshal([]byte(input), &args); err != nil { - return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil + return "", fmt.Errorf("error parsing parameters: %s", err) } c, err := getOrRenewClient(ctx, name) if err != nil { - return fantasy.NewTextErrorResponse(err.Error()), nil + return "", err } result, err := c.CallTool(ctx, &mcp.CallToolParams{ Name: toolName, Arguments: args, }) if err != nil { - return fantasy.NewTextErrorResponse(err.Error()), nil + return "", err } output := make([]string, 0, len(result.Content)) @@ -131,27 +50,18 @@ func runTool(ctx context.Context, name, toolName string, input string) (fantasy. output = append(output, fmt.Sprintf("%v", v)) } } - return fantasy.NewTextResponse(strings.Join(output, "\n")), nil + return strings.Join(output, "\n"), nil } -func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*Tool, error) { - if c.InitializeResult().Capabilities.Tools == nil { +func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) { + if session.InitializeResult().Capabilities.Tools == nil { return nil, nil } - result, err := c.ListTools(ctx, &mcp.ListToolsParams{}) + result, err := session.ListTools(ctx, &mcp.ListToolsParams{}) if err != nil { return nil, err } - mcpTools := make([]*Tool, 0, len(result.Tools)) - for _, tool := range result.Tools { - mcpTools = append(mcpTools, &Tool{ - mcpName: name, - tool: tool, - permissions: permissions, - workingDir: workingDir, - }) - } - return mcpTools, nil + return result.Tools, nil } // updateTools updates the global mcpTools and mcpClient2Tools maps @@ -161,9 +71,9 @@ func updateTools(mcpName string, tools []*Tool) { } else { client2Tools.Set(mcpName, tools) } - for _, tools := range client2Tools.Seq2() { + for name, tools := range client2Tools.Seq2() { for _, t := range tools { - allTools.Set(t.Name(), t) + allTools.Set(name, t) } } }