feat: include default sub commands along with user defined w block funcs

tauraamui created

Change summary

internal/agent/tools/bash.go | 72 +++++++++++++++++++++----------------
internal/config/config.go    |  9 ++--
2 files changed, 46 insertions(+), 35 deletions(-)

Detailed changes

internal/agent/tools/bash.go 🔗

@@ -154,41 +154,51 @@ func bashDescription(attribution *config.Attribution, modelName string) string {
 	return out.String()
 }
 
-func blockFuncs(bannedCommands []string) []shell.BlockFunc {
-	return []shell.BlockFunc{
-		shell.CommandsBlocker(bannedCommands),
-
-		// System package managers
-		shell.ArgumentsBlocker("apk", []string{"add"}, nil),
-		shell.ArgumentsBlocker("apt", []string{"install"}, nil),
-		shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
-		shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
-		shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
-		shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
-		shell.ArgumentsBlocker("yum", []string{"install"}, nil),
-		shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
-
-		// Language-specific package managers
-		shell.ArgumentsBlocker("brew", []string{"install"}, nil),
-		shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
-		shell.ArgumentsBlocker("gem", []string{"install"}, nil),
-		shell.ArgumentsBlocker("go", []string{"install"}, nil),
-		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
-		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
-		shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
-		shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
-		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
-		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
-		shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
-
-		// `go test -exec` can run arbitrary commands
-		shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
+var defaultBannedSubCommands = []shell.BlockFunc{
+	// System package managers
+	shell.ArgumentsBlocker("apk", []string{"add"}, nil),
+	shell.ArgumentsBlocker("apt", []string{"install"}, nil),
+	shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
+	shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
+	shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
+	shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
+	shell.ArgumentsBlocker("yum", []string{"install"}, nil),
+	shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
+
+	// Language-specific package managers
+	shell.ArgumentsBlocker("brew", []string{"install"}, nil),
+	shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
+	shell.ArgumentsBlocker("gem", []string{"install"}, nil),
+	shell.ArgumentsBlocker("go", []string{"install"}, nil),
+	shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
+	shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
+	shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
+	shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
+	shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
+	shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
+	shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
+
+	// `go test -exec` can run arbitrary commands
+	shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
+}
+
+func blockFuncs(bannedCommands []string, bannedSubCommands []config.BannedToolArgsAndOrParams, includeSubCommandDefaults bool) []shell.BlockFunc {
+	blockFuncs := []shell.BlockFunc{}
+	blockFuncs = append(blockFuncs, shell.CommandsBlocker(bannedCommands))
+
+	for _, bannedSubCmd := range bannedSubCommands {
+		blockFuncs = append(blockFuncs, shell.ArgumentsBlocker(bannedSubCmd.Command, bannedSubCmd.Args, bannedSubCmd.Flags))
+	}
+
+	if includeSubCommandDefaults {
+		blockFuncs = append(blockFuncs, defaultBannedSubCommands...)
 	}
+	return blockFuncs
 }
 
 func resolveBannedCommandsList(cfg config.ToolBash) []string {
 	bannedCommands := cfg.BannedCommands
-	if !cfg.DisableDefaults {
+	if !cfg.DisableDefaultCommands {
 		if len(bannedCommands) == 0 {
 			return defaultBannedCommands
 		}
@@ -198,7 +208,7 @@ func resolveBannedCommandsList(cfg config.ToolBash) []string {
 }
 
 func resolveBlockFuncs(cfg config.ToolBash) []shell.BlockFunc {
-	return blockFuncs(resolveBannedCommandsList(cfg))
+	return blockFuncs(resolveBannedCommandsList(cfg), cfg.BannedSubCommands, cfg.DisableDefaultSubCommands)
 }
 
 func NewBashTool(

internal/config/config.go 🔗

@@ -323,15 +323,16 @@ type ToolLs struct {
 }
 
 type ToolBash struct {
-	DisableDefaults   bool                        `json:"disable_banned_defaults,omitempty"`
-	BannedCommands    []string                    `json:"banned_commands,omitempty"`
-	BannedSubCommands []BannedToolArgsAndOrParams `json:"banned_sub_commands"`
+	DisableDefaultCommands    bool                        `json:"disable_banned_defaults,omitempty"`
+	BannedCommands            []string                    `json:"banned_commands,omitempty"`
+	DisableDefaultSubCommands bool                        `json:"disable_banned_sub_command_defaults,omitempty"`
+	BannedSubCommands         []BannedToolArgsAndOrParams `json:"banned_sub_commands,omitempty"`
 }
 
 type BannedToolArgsAndOrParams struct {
 	Command string   `json:"command"`
 	Args    []string `json:"args"`
-	Params  []string `json:"params"`
+	Flags   []string `json:"flags"`
 }
 
 func (t ToolLs) Limits() (depth, items int) {