feat(tools): create an allow list for MCP tools (#2800)

Bruno Krugel created

Change summary

internal/agent/tools/mcp/tools.go      | 35 ++++++++++++++------
internal/agent/tools/mcp/tools_test.go | 48 ++++++++++++++++++++++++++++
internal/config/config.go              |  1 
schema.json                            | 10 +++++
4 files changed, 83 insertions(+), 11 deletions(-)

Detailed changes

internal/agent/tools/mcp/tools.go 🔗

@@ -142,7 +142,10 @@ func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) {
 }
 
 func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int {
-	tools = filterDisabledTools(cfg, name, tools)
+	mcpCfg, ok := cfg.Config().MCP[name]
+	if ok {
+		tools = filterTools(mcpCfg, tools)
+	}
 	if len(tools) == 0 {
 		allTools.Del(name)
 		return 0
@@ -151,20 +154,30 @@ func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int {
 	return len(tools)
 }
 
-// filterDisabledTools removes tools that are disabled via config.
-func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool) []*Tool {
-	mcpCfg, ok := cfg.Config().MCP[mcpName]
-	if !ok || len(mcpCfg.DisabledTools) == 0 {
-		return tools
+// filterTools filters tools based on enabled_tools (allow list) and
+// disabled_tools (deny list) from the MCP config.
+func filterTools(mcpCfg config.MCPConfig, tools []*Tool) []*Tool {
+	if len(mcpCfg.EnabledTools) > 0 {
+		filtered := make([]*Tool, 0, len(mcpCfg.EnabledTools))
+		for _, tool := range tools {
+			if slices.Contains(mcpCfg.EnabledTools, tool.Name) {
+				filtered = append(filtered, tool)
+			}
+		}
+		tools = filtered
 	}
 
-	filtered := make([]*Tool, 0, len(tools))
-	for _, tool := range tools {
-		if !slices.Contains(mcpCfg.DisabledTools, tool.Name) {
-			filtered = append(filtered, tool)
+	if len(mcpCfg.DisabledTools) > 0 {
+		filtered := make([]*Tool, 0, len(tools))
+		for _, tool := range tools {
+			if !slices.Contains(mcpCfg.DisabledTools, tool.Name) {
+				filtered = append(filtered, tool)
+			}
 		}
+		tools = filtered
 	}
-	return filtered
+
+	return tools
 }
 
 // ensureRawBytes normalizes MCP media data into raw binary bytes.

internal/agent/tools/mcp/tools_test.go 🔗

@@ -5,6 +5,7 @@ import (
 	"encoding/base64"
 	"testing"
 
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/stretchr/testify/require"
 )
 
@@ -67,3 +68,50 @@ func TestEnsureRawBytes(t *testing.T) {
 		})
 	}
 }
+
+func TestFilterTools(t *testing.T) {
+	t.Parallel()
+
+	tools := []*Tool{
+		{Name: "tool_a"},
+		{Name: "tool_b"},
+		{Name: "tool_c"},
+	}
+
+	t.Run("no filters returns all tools", func(t *testing.T) {
+		t.Parallel()
+		result := filterTools(config.MCPConfig{}, tools)
+		require.Len(t, result, 3)
+	})
+
+	t.Run("disabled tools filters deny list", func(t *testing.T) {
+		t.Parallel()
+		result := filterTools(config.MCPConfig{DisabledTools: []string{"tool_a"}}, tools)
+		require.Len(t, result, 2)
+		require.Equal(t, "tool_b", result[0].Name)
+		require.Equal(t, "tool_c", result[1].Name)
+	})
+
+	t.Run("enabled tools acts as allow list", func(t *testing.T) {
+		t.Parallel()
+		result := filterTools(config.MCPConfig{EnabledTools: []string{"tool_b"}}, tools)
+		require.Len(t, result, 1)
+		require.Equal(t, "tool_b", result[0].Name)
+	})
+
+	t.Run("enabled and disabled both apply", func(t *testing.T) {
+		t.Parallel()
+		result := filterTools(config.MCPConfig{
+			EnabledTools:  []string{"tool_a", "tool_b"},
+			DisabledTools: []string{"tool_b"},
+		}, tools)
+		require.Len(t, result, 1)
+		require.Equal(t, "tool_a", result[0].Name)
+	})
+
+	t.Run("enabled with non-existent tool returns empty", func(t *testing.T) {
+		t.Parallel()
+		result := filterTools(config.MCPConfig{EnabledTools: []string{"non_existent"}}, tools)
+		require.Len(t, result, 0)
+	})
+}

internal/config/config.go 🔗

@@ -188,6 +188,7 @@ type MCPConfig struct {
 	URL           string            `json:"url,omitempty" jsonschema:"description=URL for HTTP or SSE MCP servers,format=uri,example=http://localhost:3000/mcp"`
 	Disabled      bool              `json:"disabled,omitempty" jsonschema:"description=Whether this MCP server is disabled,default=false"`
 	DisabledTools []string          `json:"disabled_tools,omitempty" jsonschema:"description=List of tools from this MCP server to disable,example=get-library-doc"`
+	EnabledTools  []string          `json:"enabled_tools,omitempty" jsonschema:"description=Allow list of tools from this MCP server,example=get-library-doc"`
 	Timeout       int               `json:"timeout,omitempty" jsonschema:"description=Timeout in seconds for MCP server connections,default=15,example=30,example=60,example=120"`
 
 	// Headers are HTTP headers for HTTP/SSE MCP servers. Values run

schema.json 🔗

@@ -264,6 +264,16 @@
           "type": "array",
           "description": "List of tools from this MCP server to disable"
         },
+        "enabled_tools": {
+          "items": {
+            "type": "string",
+            "examples": [
+              "get-library-doc"
+            ]
+          },
+          "type": "array",
+          "description": "Allow list of tools from this MCP server"
+        },
         "timeout": {
           "type": "integer",
           "description": "Timeout in seconds for MCP server connections",