refactor(mcp): some more decoupling

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/agent/coordinator.go     |   3 
internal/agent/tools/mcp-tools.go | 110 +++++++++++++++++++++++++++++
internal/agent/tools/mcp/init.go  |   4 
internal/agent/tools/mcp/tools.go | 124 ++++----------------------------
4 files changed, 131 insertions(+), 110 deletions(-)

Detailed changes

internal/agent/coordinator.go 🔗

@@ -18,7 +18,6 @@ import (
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
 	"github.com/charmbracelet/crush/internal/agent/prompt"
 	"github.com/charmbracelet/crush/internal/agent/tools"
-	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/history"
@@ -345,7 +344,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
 		}
 	}
 
-	for tool := range mcp.GetMCPTools() {
+	for _, tool := range tools.GetMCPTools(c.permissions, c.cfg.WorkingDir()) {
 		if agent.AllowedMCP == nil {
 			// No MCP restrictions
 			filteredTools = append(filteredTools, tool)

internal/agent/tools/mcp-tools.go 🔗

@@ -0,0 +1,110 @@
+package tools
+
+import (
+	"context"
+	"fmt"
+
+	"charm.land/fantasy"
+	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
+	"github.com/charmbracelet/crush/internal/permission"
+)
+
+// GetMCPTools gets all the currently available MCP tools.
+func GetMCPTools(permissions permission.Service, wd string) []*Tool {
+	var result []*Tool
+	for name, tool := range mcp.GetMCPTools() {
+		result = append(result, &Tool{
+			mcpName:     name,
+			tool:        tool,
+			permissions: permissions,
+			workingDir:  wd,
+		})
+	}
+	return result
+}
+
+// Tool is a tool from a MCP.
+type Tool struct {
+	mcpName         string
+	tool            *mcp.Tool
+	permissions     permission.Service
+	workingDir      string
+	providerOptions fantasy.ProviderOptions
+}
+
+func (m *Tool) SetProviderOptions(opts fantasy.ProviderOptions) {
+	m.providerOptions = opts
+}
+
+func (m *Tool) ProviderOptions() fantasy.ProviderOptions {
+	return m.providerOptions
+}
+
+func (m *Tool) Name() string {
+	return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
+}
+
+func (m *Tool) MCP() string {
+	return m.mcpName
+}
+
+func (m *Tool) MCPToolName() string {
+	return m.tool.Name
+}
+
+func (m *Tool) Info() fantasy.ToolInfo {
+	parameters := make(map[string]any)
+	required := make([]string, 0)
+
+	if input, ok := m.tool.InputSchema.(map[string]any); ok {
+		if props, ok := input["properties"].(map[string]any); ok {
+			parameters = props
+		}
+		if req, ok := input["required"].([]any); ok {
+			// Convert []any -> []string when elements are strings
+			for _, v := range req {
+				if s, ok := v.(string); ok {
+					required = append(required, s)
+				}
+			}
+		} else if reqStr, ok := input["required"].([]string); ok {
+			// Handle case where it's already []string
+			required = reqStr
+		}
+	}
+
+	return fantasy.ToolInfo{
+		Name:        fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name),
+		Description: m.tool.Description,
+		Parameters:  parameters,
+		Required:    required,
+	}
+}
+
+func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) {
+	sessionID := GetSessionFromContext(ctx)
+	if sessionID == "" {
+		return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
+	}
+	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
+	p := m.permissions.Request(
+		permission.CreatePermissionRequest{
+			SessionID:   sessionID,
+			ToolCallID:  params.ID,
+			Path:        m.workingDir,
+			ToolName:    m.Info().Name,
+			Action:      "execute",
+			Description: permissionDescription,
+			Params:      params.Input,
+		},
+	)
+	if !p {
+		return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
+	}
+
+	content, err := mcp.RunTool(ctx, m.mcpName, m.tool.Name, params.Input)
+	if err != nil {
+		return fantasy.NewTextErrorResponse(err.Error()), nil
+	}
+	return fantasy.NewTextResponse(content), nil
+}

internal/agent/tools/mcp/init.go 🔗

@@ -1,3 +1,5 @@
+// Package mcp provides functionality for managing Model Context Protocol (MCP)
+// clients within the Crush application.
 package mcp
 
 import (
@@ -160,7 +162,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config
 				return
 			}
 
-			tools, err := getTools(ctx, name, permissions, session, cfg.WorkingDir())
+			tools, err := getTools(ctx, session)
 			if err != nil {
 				slog.Error("error listing tools", "error", err)
 				updateState(name, StateError, err, nil, Counts{})

internal/agent/tools/mcp/tools.go 🔗

@@ -7,120 +7,39 @@ import (
 	"iter"
 	"strings"
 
-	"charm.land/fantasy"
-	"github.com/charmbracelet/crush/internal/agent/tools"
 	"github.com/charmbracelet/crush/internal/csync"
-	"github.com/charmbracelet/crush/internal/permission"
 	"github.com/modelcontextprotocol/go-sdk/mcp"
 )
 
+type Tool = mcp.Tool
+
 var (
 	allTools     = csync.NewMap[string, *Tool]()
 	client2Tools = csync.NewMap[string, []*Tool]()
 )
 
-// GetMCPTools returns all available MCP tools.
-func GetMCPTools() iter.Seq[*Tool] {
-	return allTools.Seq()
-}
-
-type Tool struct {
-	mcpName         string
-	tool            *mcp.Tool
-	permissions     permission.Service
-	workingDir      string
-	providerOptions fantasy.ProviderOptions
-}
-
-func (m *Tool) SetProviderOptions(opts fantasy.ProviderOptions) {
-	m.providerOptions = opts
-}
-
-func (m *Tool) ProviderOptions() fantasy.ProviderOptions {
-	return m.providerOptions
+// GetTools returns all available MCP tools.
+func GetTools() iter.Seq2[string, *Tool] {
+	return allTools.Seq2()
 }
 
-func (m *Tool) Name() string {
-	return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
-}
-
-func (m *Tool) MCP() string {
-	return m.mcpName
-}
-
-func (m *Tool) MCPToolName() string {
-	return m.tool.Name
-}
-
-func (m *Tool) Info() fantasy.ToolInfo {
-	parameters := make(map[string]any)
-	required := make([]string, 0)
-
-	if input, ok := m.tool.InputSchema.(map[string]any); ok {
-		if props, ok := input["properties"].(map[string]any); ok {
-			parameters = props
-		}
-		if req, ok := input["required"].([]any); ok {
-			// Convert []any -> []string when elements are strings
-			for _, v := range req {
-				if s, ok := v.(string); ok {
-					required = append(required, s)
-				}
-			}
-		} else if reqStr, ok := input["required"].([]string); ok {
-			// Handle case where it's already []string
-			required = reqStr
-		}
-	}
-
-	return fantasy.ToolInfo{
-		Name:        fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name),
-		Description: m.tool.Description,
-		Parameters:  parameters,
-		Required:    required,
-	}
-}
-
-func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) {
-	sessionID := tools.GetSessionFromContext(ctx)
-	if sessionID == "" {
-		return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
-	}
-	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
-	p := m.permissions.Request(
-		permission.CreatePermissionRequest{
-			SessionID:   sessionID,
-			ToolCallID:  params.ID,
-			Path:        m.workingDir,
-			ToolName:    m.Info().Name,
-			Action:      "execute",
-			Description: permissionDescription,
-			Params:      params.Input,
-		},
-	)
-	if !p {
-		return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
-	}
-
-	return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
-}
-
-func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) {
+// RunTool runs an MCP tool with the given input parameters.
+func RunTool(ctx context.Context, name, toolName string, input string) (string, error) {
 	var args map[string]any
 	if err := json.Unmarshal([]byte(input), &args); err != nil {
-		return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
+		return "", fmt.Errorf("error parsing parameters: %s", err)
 	}
 
 	c, err := getOrRenewClient(ctx, name)
 	if err != nil {
-		return fantasy.NewTextErrorResponse(err.Error()), nil
+		return "", err
 	}
 	result, err := c.CallTool(ctx, &mcp.CallToolParams{
 		Name:      toolName,
 		Arguments: args,
 	})
 	if err != nil {
-		return fantasy.NewTextErrorResponse(err.Error()), nil
+		return "", err
 	}
 
 	output := make([]string, 0, len(result.Content))
@@ -131,27 +50,18 @@ func runTool(ctx context.Context, name, toolName string, input string) (fantasy.
 			output = append(output, fmt.Sprintf("%v", v))
 		}
 	}
-	return fantasy.NewTextResponse(strings.Join(output, "\n")), nil
+	return strings.Join(output, "\n"), nil
 }
 
-func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*Tool, error) {
-	if c.InitializeResult().Capabilities.Tools == nil {
+func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
+	if session.InitializeResult().Capabilities.Tools == nil {
 		return nil, nil
 	}
-	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
+	result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
 	if err != nil {
 		return nil, err
 	}
-	mcpTools := make([]*Tool, 0, len(result.Tools))
-	for _, tool := range result.Tools {
-		mcpTools = append(mcpTools, &Tool{
-			mcpName:     name,
-			tool:        tool,
-			permissions: permissions,
-			workingDir:  workingDir,
-		})
-	}
-	return mcpTools, nil
+	return result.Tools, nil
 }
 
 // updateTools updates the global mcpTools and mcpClient2Tools maps
@@ -161,9 +71,9 @@ func updateTools(mcpName string, tools []*Tool) {
 	} else {
 		client2Tools.Set(mcpName, tools)
 	}
-	for _, tools := range client2Tools.Seq2() {
+	for name, tools := range client2Tools.Seq2() {
 		for _, t := range tools {
-			allTools.Set(t.Name(), t)
+			allTools.Set(name, t)
 		}
 	}
 }