diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go index 4da41779abf53cf825c701abab03bd4c84aa8298..b110ab56f9db5ad28ddd53072d7e5bf4fcfda5bb 100644 --- a/internal/agent/tools/mcp/tools.go +++ b/internal/agent/tools/mcp/tools.go @@ -142,7 +142,10 @@ func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) { } func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int { - tools = filterDisabledTools(cfg, name, tools) + mcpCfg, ok := cfg.Config().MCP[name] + if ok { + tools = filterTools(mcpCfg, tools) + } if len(tools) == 0 { allTools.Del(name) return 0 @@ -151,20 +154,30 @@ func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int { return len(tools) } -// filterDisabledTools removes tools that are disabled via config. -func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool) []*Tool { - mcpCfg, ok := cfg.Config().MCP[mcpName] - if !ok || len(mcpCfg.DisabledTools) == 0 { - return tools +// filterTools filters tools based on enabled_tools (allow list) and +// disabled_tools (deny list) from the MCP config. +func filterTools(mcpCfg config.MCPConfig, tools []*Tool) []*Tool { + if len(mcpCfg.EnabledTools) > 0 { + filtered := make([]*Tool, 0, len(mcpCfg.EnabledTools)) + for _, tool := range tools { + if slices.Contains(mcpCfg.EnabledTools, tool.Name) { + filtered = append(filtered, tool) + } + } + tools = filtered } - filtered := make([]*Tool, 0, len(tools)) - for _, tool := range tools { - if !slices.Contains(mcpCfg.DisabledTools, tool.Name) { - filtered = append(filtered, tool) + if len(mcpCfg.DisabledTools) > 0 { + filtered := make([]*Tool, 0, len(tools)) + for _, tool := range tools { + if !slices.Contains(mcpCfg.DisabledTools, tool.Name) { + filtered = append(filtered, tool) + } } + tools = filtered } - return filtered + + return tools } // ensureRawBytes normalizes MCP media data into raw binary bytes. diff --git a/internal/agent/tools/mcp/tools_test.go b/internal/agent/tools/mcp/tools_test.go index 935e17be42be5e45a592a5aed909aa4a2bfb3d48..77892325e94eecf3aacd72332027e3457696a7bd 100644 --- a/internal/agent/tools/mcp/tools_test.go +++ b/internal/agent/tools/mcp/tools_test.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "testing" + "github.com/charmbracelet/crush/internal/config" "github.com/stretchr/testify/require" ) @@ -67,3 +68,50 @@ func TestEnsureRawBytes(t *testing.T) { }) } } + +func TestFilterTools(t *testing.T) { + t.Parallel() + + tools := []*Tool{ + {Name: "tool_a"}, + {Name: "tool_b"}, + {Name: "tool_c"}, + } + + t.Run("no filters returns all tools", func(t *testing.T) { + t.Parallel() + result := filterTools(config.MCPConfig{}, tools) + require.Len(t, result, 3) + }) + + t.Run("disabled tools filters deny list", func(t *testing.T) { + t.Parallel() + result := filterTools(config.MCPConfig{DisabledTools: []string{"tool_a"}}, tools) + require.Len(t, result, 2) + require.Equal(t, "tool_b", result[0].Name) + require.Equal(t, "tool_c", result[1].Name) + }) + + t.Run("enabled tools acts as allow list", func(t *testing.T) { + t.Parallel() + result := filterTools(config.MCPConfig{EnabledTools: []string{"tool_b"}}, tools) + require.Len(t, result, 1) + require.Equal(t, "tool_b", result[0].Name) + }) + + t.Run("enabled and disabled both apply", func(t *testing.T) { + t.Parallel() + result := filterTools(config.MCPConfig{ + EnabledTools: []string{"tool_a", "tool_b"}, + DisabledTools: []string{"tool_b"}, + }, tools) + require.Len(t, result, 1) + require.Equal(t, "tool_a", result[0].Name) + }) + + t.Run("enabled with non-existent tool returns empty", func(t *testing.T) { + t.Parallel() + result := filterTools(config.MCPConfig{EnabledTools: []string{"non_existent"}}, tools) + require.Len(t, result, 0) + }) +} diff --git a/internal/config/config.go b/internal/config/config.go index db6125017282f6e401c76002894c191e23e14d7c..6620120fa07404a3858bf1502e2c8109bf5510a3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -188,6 +188,7 @@ type MCPConfig struct { URL string `json:"url,omitempty" jsonschema:"description=URL for HTTP or SSE MCP servers,format=uri,example=http://localhost:3000/mcp"` Disabled bool `json:"disabled,omitempty" jsonschema:"description=Whether this MCP server is disabled,default=false"` DisabledTools []string `json:"disabled_tools,omitempty" jsonschema:"description=List of tools from this MCP server to disable,example=get-library-doc"` + EnabledTools []string `json:"enabled_tools,omitempty" jsonschema:"description=Allow list of tools from this MCP server,example=get-library-doc"` Timeout int `json:"timeout,omitempty" jsonschema:"description=Timeout in seconds for MCP server connections,default=15,example=30,example=60,example=120"` // Headers are HTTP headers for HTTP/SSE MCP servers. Values run diff --git a/schema.json b/schema.json index 751a5f529f8cb2773286d1dbd98b99b548c6503b..6a0576faf8260081f376a885e771c49d7119caac 100644 --- a/schema.json +++ b/schema.json @@ -264,6 +264,16 @@ "type": "array", "description": "List of tools from this MCP server to disable" }, + "enabled_tools": { + "items": { + "type": "string", + "examples": [ + "get-library-doc" + ] + }, + "type": "array", + "description": "Allow list of tools from this MCP server" + }, "timeout": { "type": "integer", "description": "Timeout in seconds for MCP server connections",