diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index b1be1be93b4b428011bfc360e548da560e087f69..a6ec70fbbd6aec71087c250ac8635fd4ffcc7159 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -407,12 +407,6 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan } for _, tool := range tools.GetMCPTools(c.permissions, c.cfg.WorkingDir()) { - // Check MCP-specific disabled tools. - if mcpCfg, ok := c.cfg.MCP[tool.MCP()]; ok { - if slices.Contains(mcpCfg.DisabledTools, tool.MCPToolName()) { - continue - } - } if agent.AllowedMCP == nil { // No MCP restrictions filteredTools = append(filteredTools, tool) diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index 6ad77bcedbf528e9c355bb0533093455ed12bcee..be27ce3f8ae5b9b7f425e496a1726bc23eaf3aae 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -188,12 +188,12 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config return } - updateTools(name, tools) + toolCount := updateTools(name, tools) updatePrompts(name, prompts) sessions.Set(name, session) updateState(name, StateConnected, nil, session, Counts{ - Tools: len(tools), + Tools: toolCount, Prompts: len(prompts), }) }(name, m) diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go index 3a874aa8e1e6d790f8a9af2c9df83dfbbf49e942..779baa55d93bc54523bac81c5094bacee7fc68fb 100644 --- a/internal/agent/tools/mcp/tools.go +++ b/internal/agent/tools/mcp/tools.go @@ -6,8 +6,10 @@ import ( "fmt" "iter" "log/slog" + "slices" "strings" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -119,10 +121,10 @@ func RefreshTools(ctx context.Context, name string) { return } - updateTools(name, tools) + toolCount := updateTools(name, tools) prev, _ := states.Get(name) - prev.Counts.Tools = len(tools) + prev.Counts.Tools = toolCount updateState(name, StateConnected, nil, session, prev.Counts) } @@ -137,10 +139,29 @@ func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) return result.Tools, nil } -func updateTools(name string, tools []*Tool) { +func updateTools(name string, tools []*Tool) int { + tools = filterDisabledTools(name, tools) if len(tools) == 0 { allTools.Del(name) - return + return 0 } allTools.Set(name, tools) + return len(tools) +} + +// filterDisabledTools removes tools that are disabled via config. +func filterDisabledTools(mcpName string, tools []*Tool) []*Tool { + cfg := config.Get() + mcpCfg, ok := cfg.MCP[mcpName] + if !ok || len(mcpCfg.DisabledTools) == 0 { + return tools + } + + filtered := make([]*Tool, 0, len(tools)) + for _, tool := range tools { + if !slices.Contains(mcpCfg.DisabledTools, tool.Name) { + filtered = append(filtered, tool) + } + } + return filtered }