diff --git a/internal/config/config.go b/internal/config/config.go index c126e7ff11a1f63e9c1ace21984888e76af71479..4e42a56e361c81feca31cd95bd778d14c312cd20 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -143,6 +143,7 @@ type Options struct { DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"` DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"description=Disable automatic conversation summarization,default=false"` DataDirectory string `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd + DisabledTools []string `json:"disabled_tools" jsonschema:"description=Tools to disable"` } type MCPs map[string]MCPConfig @@ -415,7 +416,51 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { return nil } +func allToolNames() []string { + return []string{ + "bash", + "download", + "edit", + "multiedit", + "fetch", + "glob", + "grep", + "ls", + "sourcegraph", + "view", + "write", + } +} + +func resolveAllowedTools(allTools []string, disabledTools []string) []string { + if disabledTools == nil { + return allTools + } + // filter out disabled tools (exclude mode) + return filterSlice(allTools, disabledTools, false) +} + +func resolveReadOnlyTools(tools []string) []string { + readOnlyTools := []string{"glob", "grep", "ls", "sourcegraph", "view"} + // filter to only include tools that are in allowedtools (include mode) + return filterSlice(tools, readOnlyTools, true) +} + +func filterSlice(data []string, mask []string, include bool) []string { + filtered := []string{} + for _, s := range data { + // if include is true, we include items that ARE in the mask + // if include is false, we include items that are NOT in the mask + if include == slices.Contains(mask, s) { + filtered = append(filtered, s) + } + } + return filtered +} + func (c *Config) SetupAgents() { + allowedTools := resolveAllowedTools(allToolNames(), c.Options.DisabledTools) + agents := map[string]Agent{ "coder": { ID: "coder", @@ -423,7 +468,7 @@ func (c *Config) SetupAgents() { Description: "An agent that helps with executing coding tasks.", Model: SelectedModelTypeLarge, ContextPaths: c.Options.ContextPaths, - // All tools allowed + AllowedTools: allowedTools, }, "task": { ID: "task", @@ -431,13 +476,7 @@ func (c *Config) SetupAgents() { Description: "An agent that helps with searching for context and finding implementation details.", Model: SelectedModelTypeLarge, ContextPaths: c.Options.ContextPaths, - AllowedTools: []string{ - "glob", - "grep", - "ls", - "sourcegraph", - "view", - }, + AllowedTools: resolveReadOnlyTools(allowedTools), // NO MCPs or LSPs by default AllowedMCP: map[string][]string{}, AllowedLSP: []string{}, diff --git a/internal/config/load_test.go b/internal/config/load_test.go index a83ab2b94fa29ade149b968c700f22b34b4e86fd..e0ce94f3995fb64cc8f66348723a4e6c62a0ea2b 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -11,6 +11,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -453,6 +454,44 @@ func TestConfig_IsConfigured(t *testing.T) { }) } +func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) { + cfg := &Config{ + Options: &Options{ + DisabledTools: []string{}, + }, + } + + cfg.SetupAgents() + coderAgent, ok := cfg.Agents["coder"] + require.True(t, ok) + assert.Equal(t, allToolNames(), coderAgent.AllowedTools) + + taskAgent, ok := cfg.Agents["task"] + require.True(t, ok) + assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) +} + +func TestConfig_setupAgentsWithDisabledTools(t *testing.T) { + cfg := &Config{ + Options: &Options{ + DisabledTools: []string{ + "edit", + "download", + "grep", + }, + }, + } + + cfg.SetupAgents() + coderAgent, ok := cfg.Agents["coder"] + require.True(t, ok) + assert.Equal(t, []string{"bash", "multiedit", "fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools) + + taskAgent, ok := cfg.Agents["task"] + require.True(t, ok) + assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) +} + func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { knownProviders := []catwalk.Provider{ { diff --git a/internal/lsp/protocol/tsjson.go b/internal/lsp/protocol/tsjson.go index 24eb515c0482f6259a1ebdfb997e26877f1b9dde..3cf7275245a5dc532c52e03024652fceda6e713a 100644 --- a/internal/lsp/protocol/tsjson.go +++ b/internal/lsp/protocol/tsjson.go @@ -10,10 +10,11 @@ package protocol // https://github.com/microsoft/vscode-languageserver-node/blob/release/protocol/3.17.6-next.9/protocol/metaModel.json // LSP metaData.version = 3.17.0. -import "bytes" -import "encoding/json" - -import "fmt" +import ( + "bytes" + "encoding/json" + "fmt" +) // UnmarshalError indicates that a JSON value did not conform to // one of the expected cases of an LSP union type.