From f99f50427a16e36701abf1e965dd1c605c42557f Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 00:21:19 +0200 Subject: [PATCH] refactor: improve command blocking system and fix test isolation - Simplify command blocking logic by using utility functions instead of complex closures - Add sudo to banned commands list - Move command blocking from bash tool to shell layer for better separation of concerns - Add comprehensive tests for command blocking functionality - Fix test isolation by using temporary directories to prevent npm package files from polluting source tree - Remove redundant command validation logic from bash tool --- internal/llm/tools/bash.go | 24 ++++-- internal/shell/command_block_test.go | 123 +++++++++++++++++++++++++++ internal/shell/shell.go | 81 ++++++++++++++++-- 3 files changed, 213 insertions(+), 15 deletions(-) create mode 100644 internal/shell/command_block_test.go diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0a10568a39315f6c4077385b8ca83f6b3e52691c..2ae1c2956c46fef6f2cc1f0a8a20114bbb8785c1 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -43,7 +43,7 @@ const ( var bannedCommands = []string{ "alias", "curl", "curlie", "wget", "axel", "aria2c", "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", + "http-prompt", "chrome", "firefox", "safari", "sudo", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -244,7 +244,22 @@ Important: - Never update git config`, bannedCommandsStr, MaxOutputLength) } +func createCommandBlockFuncs() []shell.CommandBlockFunc { + return []shell.CommandBlockFunc{ + shell.CreateSimpleCommandBlocker(bannedCommands), + shell.CreateSubCommandBlocker([][]string{ + {"brew", "install"}, + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + } +} + func NewBashTool(permission permission.Service, workingDir string) BaseTool { + // Set up command blocking on the persistent shell + persistentShell := shell.GetPersistentShell(workingDir) + persistentShell.SetBlockFuncs(createCommandBlockFuncs()) + return &bashTool{ permissions: permission, workingDir: workingDir, @@ -289,13 +304,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } - baseCmd := strings.Fields(params.Command)[0] - for _, banned := range bannedCommands { - if strings.EqualFold(baseCmd, banned) { - return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil - } - } - isSafeReadOnly := false cmdLower := strings.ToLower(params.Command) diff --git a/internal/shell/command_block_test.go b/internal/shell/command_block_test.go new file mode 100644 index 0000000000000000000000000000000000000000..85971748f882cf79fba3ea86d2682ce6ce4f252d --- /dev/null +++ b/internal/shell/command_block_test.go @@ -0,0 +1,123 @@ +package shell + +import ( + "context" + "os" + "strings" + "testing" +) + +func TestCommandBlocking(t *testing.T) { + tests := []struct { + name string + blockFuncs []CommandBlockFunc + command string + shouldBlock bool + }{ + { + name: "block simple command", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) > 0 && args[0] == "curl" + }, + }, + command: "curl https://example.com", + shouldBlock: true, + }, + { + name: "allow non-blocked command", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) > 0 && args[0] == "curl" + }, + }, + command: "echo hello", + shouldBlock: false, + }, + { + name: "block subcommand", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) >= 2 && args[0] == "brew" && args[1] == "install" + }, + }, + command: "brew install wget", + shouldBlock: true, + }, + { + name: "allow different subcommand", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) >= 2 && args[0] == "brew" && args[1] == "install" + }, + }, + command: "brew list", + shouldBlock: false, + }, + { + name: "block npm global install with -g", + blockFuncs: []CommandBlockFunc{ + CreateSubCommandBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install -g typescript", + shouldBlock: true, + }, + { + name: "block npm global install with --global", + blockFuncs: []CommandBlockFunc{ + CreateSubCommandBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install --global typescript", + shouldBlock: true, + }, + { + name: "allow npm local install", + blockFuncs: []CommandBlockFunc{ + CreateSubCommandBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install typescript", + shouldBlock: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a temporary directory for each test + tmpDir, err := os.MkdirTemp("", "shell-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + shell := NewShell(&Options{ + WorkingDir: tmpDir, + BlockFuncs: tt.blockFuncs, + }) + + _, _, err = shell.Exec(context.Background(), tt.command) + + if tt.shouldBlock { + if err == nil { + t.Errorf("Expected command to be blocked, but it was allowed") + } else if !strings.Contains(err.Error(), "not allowed for security reasons") { + t.Errorf("Expected security error, got: %v", err) + } + } else { + // For non-blocked commands, we might get other errors (like command not found) + // but we shouldn't get the security error + if err != nil && strings.Contains(err.Error(), "not allowed for security reasons") { + t.Errorf("Command was unexpectedly blocked: %v", err) + } + } + }) + } +} \ No newline at end of file diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 0467c9072c5111e4b4ea9a5439519e4edf76af46..815be0907a3fd05f996a24f84f751ce5d776b833 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -44,12 +44,16 @@ type noopLogger struct{} func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {} +// CommandBlockFunc is a function that determines if a command should be blocked +type CommandBlockFunc func(args []string) bool + // Shell provides cross-platform shell execution with optional state persistence type Shell struct { - env []string - cwd string - mu sync.Mutex - logger Logger + env []string + cwd string + mu sync.Mutex + logger Logger + blockFuncs []CommandBlockFunc } // Options for creating a new shell @@ -57,6 +61,7 @@ type Options struct { WorkingDir string Env []string Logger Logger + BlockFuncs []CommandBlockFunc } // NewShell creates a new shell instance with the given options @@ -81,9 +86,10 @@ func NewShell(opts *Options) *Shell { } return &Shell{ - cwd: cwd, - env: env, - logger: logger, + cwd: cwd, + env: env, + logger: logger, + blockFuncs: opts.BlockFuncs, } } @@ -152,6 +158,13 @@ func (s *Shell) SetEnv(key, value string) { s.env = append(s.env, keyPrefix+value) } +// SetBlockFuncs sets the command block functions for the shell +func (s *Shell) SetBlockFuncs(blockFuncs []CommandBlockFunc) { + s.mu.Lock() + defer s.mu.Unlock() + s.blockFuncs = blockFuncs +} + // Windows-specific commands that should use native shell var windowsNativeCommands = map[string]bool{ "dir": true, @@ -203,6 +216,59 @@ func (s *Shell) determineShellType(command string) ShellType { return ShellTypePOSIX } +// CreateSimpleCommandBlocker creates a CommandBlockFunc that blocks exact command matches +func CreateSimpleCommandBlocker(bannedCommands []string) CommandBlockFunc { + bannedSet := make(map[string]bool) + for _, cmd := range bannedCommands { + bannedSet[cmd] = true + } + + return func(args []string) bool { + if len(args) == 0 { + return false + } + return bannedSet[args[0]] + } +} + +// CreateSubCommandBlocker creates a CommandBlockFunc that blocks specific subcommands +func CreateSubCommandBlocker(blockedSubCommands [][]string) CommandBlockFunc { + return func(args []string) bool { + for _, blocked := range blockedSubCommands { + if len(args) >= len(blocked) { + match := true + for i, part := range blocked { + if args[i] != part { + match = false + break + } + } + if match { + return true + } + } + } + return false + } +} +func (s *Shell) createCommandBlockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(ctx context.Context, args []string) error { + if len(args) == 0 { + return next(ctx, args) + } + + for _, blockFunc := range s.blockFuncs { + if blockFunc(args) { + return fmt.Errorf("command '%s' is not allowed for security reasons", strings.Join(args, " ")) + } + } + + return next(ctx, args) + } + } +} + // execWindows executes commands using native Windows shells (cmd.exe or PowerShell) func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) { var cmd *exec.Cmd @@ -291,6 +357,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, interp.Interactive(false), interp.Env(expand.ListEnviron(s.env...)), interp.Dir(s.cwd), + interp.ExecHandlers(s.createCommandBlockHandler()), ) if err != nil { return "", "", fmt.Errorf("could not run command: %w", err)