diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0a10568a39315f6c4077385b8ca83f6b3e52691c..6d7a9a32b3829da02021be80e6e41e28888efd83 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "runtime" "strings" "time" @@ -41,9 +42,74 @@ const ( ) var bannedCommands = []string{ - "alias", "curl", "curlie", "wget", "axel", "aria2c", - "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", + // Network/Download tools + "alias", + "aria2c", + "axel", + "chrome", + "curl", + "curlie", + "firefox", + "http-prompt", + "httpie", + "links", + "lynx", + "nc", + "safari", + "telnet", + "w3m", + "wget", + "xh", + + // System administration + "doas", + "su", + "sudo", + + // Package managers + "apk", + "apt", + "apt-cache", + "apt-get", + "dnf", + "dpkg", + "emerge", + "home-manager", + "makepkg", + "opkg", + "pacman", + "paru", + "pkg", + "pkg_add", + "pkg_delete", + "portage", + "rpm", + "yay", + "yum", + "zypper", + + // System modification + "at", + "batch", + "chkconfig", + "crontab", + "fdisk", + "mkfs", + "mount", + "parted", + "service", + "systemctl", + "umount", + + // Network configuration + "firewall-cmd", + "ifconfig", + "ip", + "iptables", + "netstat", + "pfctl", + "route", + "ufw", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -244,7 +310,42 @@ Important: - Never update git config`, bannedCommandsStr, MaxOutputLength) } +func blockFuncs() []shell.BlockFunc { + return []shell.BlockFunc{ + shell.CommandsBlocker(bannedCommands), + shell.ArgumentsBlocker([][]string{ + // System package managers + {"apk", "add"}, + {"apt", "install"}, + {"apt-get", "install"}, + {"dnf", "install"}, + {"emerge"}, + {"pacman", "-S"}, + {"pkg", "install"}, + {"yum", "install"}, + {"zypper", "install"}, + + // Language-specific package managers + {"brew", "install"}, + {"cargo", "install"}, + {"gem", "install"}, + {"go", "install"}, + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + {"pip", "install", "--user"}, + {"pip3", "install", "--user"}, + {"pnpm", "add", "-g"}, + {"pnpm", "add", "--global"}, + {"yarn", "global", "add"}, + }), + } +} + func NewBashTool(permission permission.Service, workingDir string) BaseTool { + // Set up command blocking on the persistent shell + persistentShell := shell.GetPersistentShell(workingDir) + persistentShell.SetBlockFuncs(blockFuncs()) + return &bashTool{ permissions: permission, workingDir: workingDir, @@ -289,13 +390,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } - baseCmd := strings.Fields(params.Command)[0] - for _, banned := range bannedCommands { - if strings.EqualFold(baseCmd, banned) { - return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil - } - } - isSafeReadOnly := false cmdLower := strings.ToLower(params.Command) @@ -349,7 +443,20 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) stdout = truncateOutput(stdout) stderr = truncateOutput(stderr) + slog.Info("Bash command executed", + "command", params.Command, + "stdout", stdout, + "stderr", stderr, + "exit_code", exitCode, + "interrupted", interrupted, + "err", err, + ) + errorMessage := stderr + if errorMessage == "" && err != nil { + errorMessage = err.Error() + } + if interrupted { if errorMessage != "" { errorMessage += "\n" diff --git a/internal/shell/command_block_test.go b/internal/shell/command_block_test.go new file mode 100644 index 0000000000000000000000000000000000000000..fd7c46bcd98e54f44abbe982e834f3cbb04cbfa4 --- /dev/null +++ b/internal/shell/command_block_test.go @@ -0,0 +1,123 @@ +package shell + +import ( + "context" + "os" + "strings" + "testing" +) + +func TestCommandBlocking(t *testing.T) { + tests := []struct { + name string + blockFuncs []BlockFunc + command string + shouldBlock bool + }{ + { + name: "block simple command", + blockFuncs: []BlockFunc{ + func(args []string) bool { + return len(args) > 0 && args[0] == "curl" + }, + }, + command: "curl https://example.com", + shouldBlock: true, + }, + { + name: "allow non-blocked command", + blockFuncs: []BlockFunc{ + func(args []string) bool { + return len(args) > 0 && args[0] == "curl" + }, + }, + command: "echo hello", + shouldBlock: false, + }, + { + name: "block subcommand", + blockFuncs: []BlockFunc{ + func(args []string) bool { + return len(args) >= 2 && args[0] == "brew" && args[1] == "install" + }, + }, + command: "brew install wget", + shouldBlock: true, + }, + { + name: "allow different subcommand", + blockFuncs: []BlockFunc{ + func(args []string) bool { + return len(args) >= 2 && args[0] == "brew" && args[1] == "install" + }, + }, + command: "brew list", + shouldBlock: false, + }, + { + name: "block npm global install with -g", + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install -g typescript", + shouldBlock: true, + }, + { + name: "block npm global install with --global", + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install --global typescript", + shouldBlock: true, + }, + { + name: "allow npm local install", + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install typescript", + shouldBlock: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a temporary directory for each test + tmpDir, err := os.MkdirTemp("", "shell-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + shell := NewShell(&Options{ + WorkingDir: tmpDir, + BlockFuncs: tt.blockFuncs, + }) + + _, _, err = shell.Exec(context.Background(), tt.command) + + if tt.shouldBlock { + if err == nil { + t.Errorf("Expected command to be blocked, but it was allowed") + } else if !strings.Contains(err.Error(), "not allowed for security reasons") { + t.Errorf("Expected security error, got: %v", err) + } + } else { + // For non-blocked commands, we might get other errors (like command not found) + // but we shouldn't get the security error + if err != nil && strings.Contains(err.Error(), "not allowed for security reasons") { + t.Errorf("Command was unexpectedly blocked: %v", err) + } + } + }) + } +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 0467c9072c5111e4b4ea9a5439519e4edf76af46..b655c5dbecf5b69c7ad102c53108733515138771 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -44,12 +44,16 @@ type noopLogger struct{} func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {} +// BlockFunc is a function that determines if a command should be blocked +type BlockFunc func(args []string) bool + // Shell provides cross-platform shell execution with optional state persistence type Shell struct { - env []string - cwd string - mu sync.Mutex - logger Logger + env []string + cwd string + mu sync.Mutex + logger Logger + blockFuncs []BlockFunc } // Options for creating a new shell @@ -57,6 +61,7 @@ type Options struct { WorkingDir string Env []string Logger Logger + BlockFuncs []BlockFunc } // NewShell creates a new shell instance with the given options @@ -81,9 +86,10 @@ func NewShell(opts *Options) *Shell { } return &Shell{ - cwd: cwd, - env: env, - logger: logger, + cwd: cwd, + env: env, + logger: logger, + blockFuncs: opts.BlockFuncs, } } @@ -152,6 +158,13 @@ func (s *Shell) SetEnv(key, value string) { s.env = append(s.env, keyPrefix+value) } +// SetBlockFuncs sets the command block functions for the shell +func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) { + s.mu.Lock() + defer s.mu.Unlock() + s.blockFuncs = blockFuncs +} + // Windows-specific commands that should use native shell var windowsNativeCommands = map[string]bool{ "dir": true, @@ -203,6 +216,60 @@ func (s *Shell) determineShellType(command string) ShellType { return ShellTypePOSIX } +// CommandsBlocker creates a BlockFunc that blocks exact command matches +func CommandsBlocker(bannedCommands []string) BlockFunc { + bannedSet := make(map[string]bool) + for _, cmd := range bannedCommands { + bannedSet[cmd] = true + } + + return func(args []string) bool { + if len(args) == 0 { + return false + } + return bannedSet[args[0]] + } +} + +// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands +func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc { + return func(args []string) bool { + for _, blocked := range blockedSubCommands { + if len(args) >= len(blocked) { + match := true + for i, part := range blocked { + if args[i] != part { + match = false + break + } + } + if match { + return true + } + } + } + return false + } +} + +func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(ctx context.Context, args []string) error { + if len(args) == 0 { + return next(ctx, args) + } + + for _, blockFunc := range s.blockFuncs { + if blockFunc(args) { + return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " ")) + } + } + + return next(ctx, args) + } + } +} + // execWindows executes commands using native Windows shells (cmd.exe or PowerShell) func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) { var cmd *exec.Cmd @@ -291,6 +358,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, interp.Interactive(false), interp.Env(expand.ListEnviron(s.env...)), interp.Dir(s.cwd), + interp.ExecHandlers(s.blockHandler()), ) if err != nil { return "", "", fmt.Errorf("could not run command: %w", err)