refactor: improve command blocking system and fix test isolation

Kujtim Hoxha created

- Simplify command blocking logic by using utility functions instead of complex closures
- Add sudo to banned commands list
- Move command blocking from bash tool to shell layer for better separation of concerns
- Add comprehensive tests for command blocking functionality
- Fix test isolation by using temporary directories to prevent npm package files from polluting source tree
- Remove redundant command validation logic from bash tool

Change summary

internal/llm/tools/bash.go           |  24 +++-
internal/shell/command_block_test.go | 123 ++++++++++++++++++++++++++++++
internal/shell/shell.go              |  81 ++++++++++++++++++-
3 files changed, 213 insertions(+), 15 deletions(-)

Detailed changes

internal/llm/tools/bash.go 🔗

@@ -43,7 +43,7 @@ const (
 var bannedCommands = []string{
 	"alias", "curl", "curlie", "wget", "axel", "aria2c",
 	"nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
-	"http-prompt", "chrome", "firefox", "safari",
+	"http-prompt", "chrome", "firefox", "safari", "sudo",
 }
 
 // getSafeReadOnlyCommands returns platform-appropriate safe commands
@@ -244,7 +244,22 @@ Important:
 - Never update git config`, bannedCommandsStr, MaxOutputLength)
 }
 
+func createCommandBlockFuncs() []shell.CommandBlockFunc {
+	return []shell.CommandBlockFunc{
+		shell.CreateSimpleCommandBlocker(bannedCommands),
+		shell.CreateSubCommandBlocker([][]string{
+			{"brew", "install"},
+			{"npm", "install", "-g"},
+			{"npm", "install", "--global"},
+		}),
+	}
+}
+
 func NewBashTool(permission permission.Service, workingDir string) BaseTool {
+	// Set up command blocking on the persistent shell
+	persistentShell := shell.GetPersistentShell(workingDir)
+	persistentShell.SetBlockFuncs(createCommandBlockFuncs())
+	
 	return &bashTool{
 		permissions: permission,
 		workingDir:  workingDir,
@@ -289,13 +304,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 		return NewTextErrorResponse("missing command"), nil
 	}
 
-	baseCmd := strings.Fields(params.Command)[0]
-	for _, banned := range bannedCommands {
-		if strings.EqualFold(baseCmd, banned) {
-			return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
-		}
-	}
-
 	isSafeReadOnly := false
 	cmdLower := strings.ToLower(params.Command)
 

internal/shell/command_block_test.go 🔗

@@ -0,0 +1,123 @@
+package shell
+
+import (
+	"context"
+	"os"
+	"strings"
+	"testing"
+)
+
+func TestCommandBlocking(t *testing.T) {
+	tests := []struct {
+		name        string
+		blockFuncs  []CommandBlockFunc
+		command     string
+		shouldBlock bool
+	}{
+		{
+			name: "block simple command",
+			blockFuncs: []CommandBlockFunc{
+				func(args []string) bool {
+					return len(args) > 0 && args[0] == "curl"
+				},
+			},
+			command:     "curl https://example.com",
+			shouldBlock: true,
+		},
+		{
+			name: "allow non-blocked command",
+			blockFuncs: []CommandBlockFunc{
+				func(args []string) bool {
+					return len(args) > 0 && args[0] == "curl"
+				},
+			},
+			command:     "echo hello",
+			shouldBlock: false,
+		},
+		{
+			name: "block subcommand",
+			blockFuncs: []CommandBlockFunc{
+				func(args []string) bool {
+					return len(args) >= 2 && args[0] == "brew" && args[1] == "install"
+				},
+			},
+			command:     "brew install wget",
+			shouldBlock: true,
+		},
+		{
+			name: "allow different subcommand",
+			blockFuncs: []CommandBlockFunc{
+				func(args []string) bool {
+					return len(args) >= 2 && args[0] == "brew" && args[1] == "install"
+				},
+			},
+			command:     "brew list",
+			shouldBlock: false,
+		},
+		{
+			name: "block npm global install with -g",
+			blockFuncs: []CommandBlockFunc{
+				CreateSubCommandBlocker([][]string{
+					{"npm", "install", "-g"},
+					{"npm", "install", "--global"},
+				}),
+			},
+			command:     "npm install -g typescript",
+			shouldBlock: true,
+		},
+		{
+			name: "block npm global install with --global",
+			blockFuncs: []CommandBlockFunc{
+				CreateSubCommandBlocker([][]string{
+					{"npm", "install", "-g"},
+					{"npm", "install", "--global"},
+				}),
+			},
+			command:     "npm install --global typescript",
+			shouldBlock: true,
+		},
+		{
+			name: "allow npm local install",
+			blockFuncs: []CommandBlockFunc{
+				CreateSubCommandBlocker([][]string{
+					{"npm", "install", "-g"},
+					{"npm", "install", "--global"},
+				}),
+			},
+			command:     "npm install typescript",
+			shouldBlock: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			// Create a temporary directory for each test
+			tmpDir, err := os.MkdirTemp("", "shell-test-*")
+			if err != nil {
+				t.Fatalf("Failed to create temp dir: %v", err)
+			}
+			defer os.RemoveAll(tmpDir)
+
+			shell := NewShell(&Options{
+				WorkingDir: tmpDir,
+				BlockFuncs: tt.blockFuncs,
+			})
+
+			_, _, err = shell.Exec(context.Background(), tt.command)
+
+			if tt.shouldBlock {
+				if err == nil {
+					t.Errorf("Expected command to be blocked, but it was allowed")
+				} else if !strings.Contains(err.Error(), "not allowed for security reasons") {
+					t.Errorf("Expected security error, got: %v", err)
+				}
+			} else {
+				// For non-blocked commands, we might get other errors (like command not found)
+				// but we shouldn't get the security error
+				if err != nil && strings.Contains(err.Error(), "not allowed for security reasons") {
+					t.Errorf("Command was unexpectedly blocked: %v", err)
+				}
+			}
+		})
+	}
+}

internal/shell/shell.go 🔗

@@ -44,12 +44,16 @@ type noopLogger struct{}
 
 func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
 
+// CommandBlockFunc is a function that determines if a command should be blocked
+type CommandBlockFunc func(args []string) bool
+
 // Shell provides cross-platform shell execution with optional state persistence
 type Shell struct {
-	env    []string
-	cwd    string
-	mu     sync.Mutex
-	logger Logger
+	env        []string
+	cwd        string
+	mu         sync.Mutex
+	logger     Logger
+	blockFuncs []CommandBlockFunc
 }
 
 // Options for creating a new shell
@@ -57,6 +61,7 @@ type Options struct {
 	WorkingDir string
 	Env        []string
 	Logger     Logger
+	BlockFuncs []CommandBlockFunc
 }
 
 // NewShell creates a new shell instance with the given options
@@ -81,9 +86,10 @@ func NewShell(opts *Options) *Shell {
 	}
 
 	return &Shell{
-		cwd:    cwd,
-		env:    env,
-		logger: logger,
+		cwd:        cwd,
+		env:        env,
+		logger:     logger,
+		blockFuncs: opts.BlockFuncs,
 	}
 }
 
@@ -152,6 +158,13 @@ func (s *Shell) SetEnv(key, value string) {
 	s.env = append(s.env, keyPrefix+value)
 }
 
+// SetBlockFuncs sets the command block functions for the shell
+func (s *Shell) SetBlockFuncs(blockFuncs []CommandBlockFunc) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	s.blockFuncs = blockFuncs
+}
+
 // Windows-specific commands that should use native shell
 var windowsNativeCommands = map[string]bool{
 	"dir":      true,
@@ -203,6 +216,59 @@ func (s *Shell) determineShellType(command string) ShellType {
 	return ShellTypePOSIX
 }
 
+// CreateSimpleCommandBlocker creates a CommandBlockFunc that blocks exact command matches
+func CreateSimpleCommandBlocker(bannedCommands []string) CommandBlockFunc {
+	bannedSet := make(map[string]bool)
+	for _, cmd := range bannedCommands {
+		bannedSet[cmd] = true
+	}
+	
+	return func(args []string) bool {
+		if len(args) == 0 {
+			return false
+		}
+		return bannedSet[args[0]]
+	}
+}
+
+// CreateSubCommandBlocker creates a CommandBlockFunc that blocks specific subcommands
+func CreateSubCommandBlocker(blockedSubCommands [][]string) CommandBlockFunc {
+	return func(args []string) bool {
+		for _, blocked := range blockedSubCommands {
+			if len(args) >= len(blocked) {
+				match := true
+				for i, part := range blocked {
+					if args[i] != part {
+						match = false
+						break
+					}
+				}
+				if match {
+					return true
+				}
+			}
+		}
+		return false
+	}
+}
+func (s *Shell) createCommandBlockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+		return func(ctx context.Context, args []string) error {
+			if len(args) == 0 {
+				return next(ctx, args)
+			}
+
+			for _, blockFunc := range s.blockFuncs {
+				if blockFunc(args) {
+					return fmt.Errorf("command '%s' is not allowed for security reasons", strings.Join(args, " "))
+				}
+			}
+
+			return next(ctx, args)
+		}
+	}
+}
+
 // execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
 func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
 	var cmd *exec.Cmd
@@ -291,6 +357,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string,
 		interp.Interactive(false),
 		interp.Env(expand.ListEnviron(s.env...)),
 		interp.Dir(s.cwd),
+		interp.ExecHandlers(s.createCommandBlockHandler()),
 	)
 	if err != nil {
 		return "", "", fmt.Errorf("could not run command: %w", err)