Detailed changes
@@ -142,7 +142,10 @@ func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) {
}
func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int {
- tools = filterDisabledTools(cfg, name, tools)
+ mcpCfg, ok := cfg.Config().MCP[name]
+ if ok {
+ tools = filterTools(mcpCfg, tools)
+ }
if len(tools) == 0 {
allTools.Del(name)
return 0
@@ -151,20 +154,30 @@ func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int {
return len(tools)
}
-// filterDisabledTools removes tools that are disabled via config.
-func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool) []*Tool {
- mcpCfg, ok := cfg.Config().MCP[mcpName]
- if !ok || len(mcpCfg.DisabledTools) == 0 {
- return tools
+// filterTools filters tools based on enabled_tools (allow list) and
+// disabled_tools (deny list) from the MCP config.
+func filterTools(mcpCfg config.MCPConfig, tools []*Tool) []*Tool {
+ if len(mcpCfg.EnabledTools) > 0 {
+ filtered := make([]*Tool, 0, len(mcpCfg.EnabledTools))
+ for _, tool := range tools {
+ if slices.Contains(mcpCfg.EnabledTools, tool.Name) {
+ filtered = append(filtered, tool)
+ }
+ }
+ tools = filtered
}
- filtered := make([]*Tool, 0, len(tools))
- for _, tool := range tools {
- if !slices.Contains(mcpCfg.DisabledTools, tool.Name) {
- filtered = append(filtered, tool)
+ if len(mcpCfg.DisabledTools) > 0 {
+ filtered := make([]*Tool, 0, len(tools))
+ for _, tool := range tools {
+ if !slices.Contains(mcpCfg.DisabledTools, tool.Name) {
+ filtered = append(filtered, tool)
+ }
}
+ tools = filtered
}
- return filtered
+
+ return tools
}
// ensureRawBytes normalizes MCP media data into raw binary bytes.
@@ -5,6 +5,7 @@ import (
"encoding/base64"
"testing"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/stretchr/testify/require"
)
@@ -67,3 +68,50 @@ func TestEnsureRawBytes(t *testing.T) {
})
}
}
+
+func TestFilterTools(t *testing.T) {
+ t.Parallel()
+
+ tools := []*Tool{
+ {Name: "tool_a"},
+ {Name: "tool_b"},
+ {Name: "tool_c"},
+ }
+
+ t.Run("no filters returns all tools", func(t *testing.T) {
+ t.Parallel()
+ result := filterTools(config.MCPConfig{}, tools)
+ require.Len(t, result, 3)
+ })
+
+ t.Run("disabled tools filters deny list", func(t *testing.T) {
+ t.Parallel()
+ result := filterTools(config.MCPConfig{DisabledTools: []string{"tool_a"}}, tools)
+ require.Len(t, result, 2)
+ require.Equal(t, "tool_b", result[0].Name)
+ require.Equal(t, "tool_c", result[1].Name)
+ })
+
+ t.Run("enabled tools acts as allow list", func(t *testing.T) {
+ t.Parallel()
+ result := filterTools(config.MCPConfig{EnabledTools: []string{"tool_b"}}, tools)
+ require.Len(t, result, 1)
+ require.Equal(t, "tool_b", result[0].Name)
+ })
+
+ t.Run("enabled and disabled both apply", func(t *testing.T) {
+ t.Parallel()
+ result := filterTools(config.MCPConfig{
+ EnabledTools: []string{"tool_a", "tool_b"},
+ DisabledTools: []string{"tool_b"},
+ }, tools)
+ require.Len(t, result, 1)
+ require.Equal(t, "tool_a", result[0].Name)
+ })
+
+ t.Run("enabled with non-existent tool returns empty", func(t *testing.T) {
+ t.Parallel()
+ result := filterTools(config.MCPConfig{EnabledTools: []string{"non_existent"}}, tools)
+ require.Len(t, result, 0)
+ })
+}
@@ -188,6 +188,7 @@ type MCPConfig struct {
URL string `json:"url,omitempty" jsonschema:"description=URL for HTTP or SSE MCP servers,format=uri,example=http://localhost:3000/mcp"`
Disabled bool `json:"disabled,omitempty" jsonschema:"description=Whether this MCP server is disabled,default=false"`
DisabledTools []string `json:"disabled_tools,omitempty" jsonschema:"description=List of tools from this MCP server to disable,example=get-library-doc"`
+ EnabledTools []string `json:"enabled_tools,omitempty" jsonschema:"description=Allow list of tools from this MCP server,example=get-library-doc"`
Timeout int `json:"timeout,omitempty" jsonschema:"description=Timeout in seconds for MCP server connections,default=15,example=30,example=60,example=120"`
// Headers are HTTP headers for HTTP/SSE MCP servers. Values run
@@ -264,6 +264,16 @@
"type": "array",
"description": "List of tools from this MCP server to disable"
},
+ "enabled_tools": {
+ "items": {
+ "type": "string",
+ "examples": [
+ "get-library-doc"
+ ]
+ },
+ "type": "array",
+ "description": "Allow list of tools from this MCP server"
+ },
"timeout": {
"type": "integer",
"description": "Timeout in seconds for MCP server connections",