fix(mcp): centrally filter disabled tools (#1622)

Amolith created

Change summary

internal/agent/coordinator.go     |  6 ------
internal/agent/tools/mcp/init.go  |  4 ++--
internal/agent/tools/mcp/tools.go | 29 +++++++++++++++++++++++++----
3 files changed, 27 insertions(+), 12 deletions(-)

Detailed changes

internal/agent/coordinator.go 🔗

@@ -407,12 +407,6 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
 	}
 
 	for _, tool := range tools.GetMCPTools(c.permissions, c.cfg.WorkingDir()) {
-		// Check MCP-specific disabled tools.
-		if mcpCfg, ok := c.cfg.MCP[tool.MCP()]; ok {
-			if slices.Contains(mcpCfg.DisabledTools, tool.MCPToolName()) {
-				continue
-			}
-		}
 		if agent.AllowedMCP == nil {
 			// No MCP restrictions
 			filteredTools = append(filteredTools, tool)

internal/agent/tools/mcp/init.go 🔗

@@ -188,12 +188,12 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config
 				return
 			}
 
-			updateTools(name, tools)
+			toolCount := updateTools(name, tools)
 			updatePrompts(name, prompts)
 			sessions.Set(name, session)
 
 			updateState(name, StateConnected, nil, session, Counts{
-				Tools:   len(tools),
+				Tools:   toolCount,
 				Prompts: len(prompts),
 			})
 		}(name, m)

internal/agent/tools/mcp/tools.go 🔗

@@ -6,8 +6,10 @@ import (
 	"fmt"
 	"iter"
 	"log/slog"
+	"slices"
 	"strings"
 
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/modelcontextprotocol/go-sdk/mcp"
 )
@@ -119,10 +121,10 @@ func RefreshTools(ctx context.Context, name string) {
 		return
 	}
 
-	updateTools(name, tools)
+	toolCount := updateTools(name, tools)
 
 	prev, _ := states.Get(name)
-	prev.Counts.Tools = len(tools)
+	prev.Counts.Tools = toolCount
 	updateState(name, StateConnected, nil, session, prev.Counts)
 }
 
@@ -137,10 +139,29 @@ func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error)
 	return result.Tools, nil
 }
 
-func updateTools(name string, tools []*Tool) {
+func updateTools(name string, tools []*Tool) int {
+	tools = filterDisabledTools(name, tools)
 	if len(tools) == 0 {
 		allTools.Del(name)
-		return
+		return 0
 	}
 	allTools.Set(name, tools)
+	return len(tools)
+}
+
+// filterDisabledTools removes tools that are disabled via config.
+func filterDisabledTools(mcpName string, tools []*Tool) []*Tool {
+	cfg := config.Get()
+	mcpCfg, ok := cfg.MCP[mcpName]
+	if !ok || len(mcpCfg.DisabledTools) == 0 {
+		return tools
+	}
+
+	filtered := make([]*Tool, 0, len(tools))
+	for _, tool := range tools {
+		if !slices.Contains(mcpCfg.DisabledTools, tool.Name) {
+			filtered = append(filtered, tool)
+		}
+	}
+	return filtered
 }