@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "log/slog"
"runtime"
"strings"
"time"
@@ -41,9 +42,74 @@ const (
)
var bannedCommands = []string{
- "alias", "curl", "curlie", "wget", "axel", "aria2c",
- "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
- "http-prompt", "chrome", "firefox", "safari",
+ // Network/Download tools
+ "alias",
+ "aria2c",
+ "axel",
+ "chrome",
+ "curl",
+ "curlie",
+ "firefox",
+ "http-prompt",
+ "httpie",
+ "links",
+ "lynx",
+ "nc",
+ "safari",
+ "telnet",
+ "w3m",
+ "wget",
+ "xh",
+
+ // System administration
+ "doas",
+ "su",
+ "sudo",
+
+ // Package managers
+ "apk",
+ "apt",
+ "apt-cache",
+ "apt-get",
+ "dnf",
+ "dpkg",
+ "emerge",
+ "home-manager",
+ "makepkg",
+ "opkg",
+ "pacman",
+ "paru",
+ "pkg",
+ "pkg_add",
+ "pkg_delete",
+ "portage",
+ "rpm",
+ "yay",
+ "yum",
+ "zypper",
+
+ // System modification
+ "at",
+ "batch",
+ "chkconfig",
+ "crontab",
+ "fdisk",
+ "mkfs",
+ "mount",
+ "parted",
+ "service",
+ "systemctl",
+ "umount",
+
+ // Network configuration
+ "firewall-cmd",
+ "ifconfig",
+ "ip",
+ "iptables",
+ "netstat",
+ "pfctl",
+ "route",
+ "ufw",
}
// getSafeReadOnlyCommands returns platform-appropriate safe commands
@@ -244,7 +310,42 @@ Important:
- Never update git config`, bannedCommandsStr, MaxOutputLength)
}
+func blockFuncs() []shell.BlockFunc {
+ return []shell.BlockFunc{
+ shell.CommandsBlocker(bannedCommands),
+ shell.ArgumentsBlocker([][]string{
+ // System package managers
+ {"apk", "add"},
+ {"apt", "install"},
+ {"apt-get", "install"},
+ {"dnf", "install"},
+ {"emerge"},
+ {"pacman", "-S"},
+ {"pkg", "install"},
+ {"yum", "install"},
+ {"zypper", "install"},
+
+ // Language-specific package managers
+ {"brew", "install"},
+ {"cargo", "install"},
+ {"gem", "install"},
+ {"go", "install"},
+ {"npm", "install", "-g"},
+ {"npm", "install", "--global"},
+ {"pip", "install", "--user"},
+ {"pip3", "install", "--user"},
+ {"pnpm", "add", "-g"},
+ {"pnpm", "add", "--global"},
+ {"yarn", "global", "add"},
+ }),
+ }
+}
+
func NewBashTool(permission permission.Service, workingDir string) BaseTool {
+ // Set up command blocking on the persistent shell
+ persistentShell := shell.GetPersistentShell(workingDir)
+ persistentShell.SetBlockFuncs(blockFuncs())
+
return &bashTool{
permissions: permission,
workingDir: workingDir,
@@ -289,13 +390,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
return NewTextErrorResponse("missing command"), nil
}
- baseCmd := strings.Fields(params.Command)[0]
- for _, banned := range bannedCommands {
- if strings.EqualFold(baseCmd, banned) {
- return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
- }
- }
-
isSafeReadOnly := false
cmdLower := strings.ToLower(params.Command)
@@ -349,7 +443,20 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
stdout = truncateOutput(stdout)
stderr = truncateOutput(stderr)
+ slog.Info("Bash command executed",
+ "command", params.Command,
+ "stdout", stdout,
+ "stderr", stderr,
+ "exit_code", exitCode,
+ "interrupted", interrupted,
+ "err", err,
+ )
+
errorMessage := stderr
+ if errorMessage == "" && err != nil {
+ errorMessage = err.Error()
+ }
+
if interrupted {
if errorMessage != "" {
errorMessage += "\n"
@@ -0,0 +1,123 @@
+package shell
+
+import (
+ "context"
+ "os"
+ "strings"
+ "testing"
+)
+
+func TestCommandBlocking(t *testing.T) {
+ tests := []struct {
+ name string
+ blockFuncs []BlockFunc
+ command string
+ shouldBlock bool
+ }{
+ {
+ name: "block simple command",
+ blockFuncs: []BlockFunc{
+ func(args []string) bool {
+ return len(args) > 0 && args[0] == "curl"
+ },
+ },
+ command: "curl https://example.com",
+ shouldBlock: true,
+ },
+ {
+ name: "allow non-blocked command",
+ blockFuncs: []BlockFunc{
+ func(args []string) bool {
+ return len(args) > 0 && args[0] == "curl"
+ },
+ },
+ command: "echo hello",
+ shouldBlock: false,
+ },
+ {
+ name: "block subcommand",
+ blockFuncs: []BlockFunc{
+ func(args []string) bool {
+ return len(args) >= 2 && args[0] == "brew" && args[1] == "install"
+ },
+ },
+ command: "brew install wget",
+ shouldBlock: true,
+ },
+ {
+ name: "allow different subcommand",
+ blockFuncs: []BlockFunc{
+ func(args []string) bool {
+ return len(args) >= 2 && args[0] == "brew" && args[1] == "install"
+ },
+ },
+ command: "brew list",
+ shouldBlock: false,
+ },
+ {
+ name: "block npm global install with -g",
+ blockFuncs: []BlockFunc{
+ ArgumentsBlocker([][]string{
+ {"npm", "install", "-g"},
+ {"npm", "install", "--global"},
+ }),
+ },
+ command: "npm install -g typescript",
+ shouldBlock: true,
+ },
+ {
+ name: "block npm global install with --global",
+ blockFuncs: []BlockFunc{
+ ArgumentsBlocker([][]string{
+ {"npm", "install", "-g"},
+ {"npm", "install", "--global"},
+ }),
+ },
+ command: "npm install --global typescript",
+ shouldBlock: true,
+ },
+ {
+ name: "allow npm local install",
+ blockFuncs: []BlockFunc{
+ ArgumentsBlocker([][]string{
+ {"npm", "install", "-g"},
+ {"npm", "install", "--global"},
+ }),
+ },
+ command: "npm install typescript",
+ shouldBlock: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create a temporary directory for each test
+ tmpDir, err := os.MkdirTemp("", "shell-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ shell := NewShell(&Options{
+ WorkingDir: tmpDir,
+ BlockFuncs: tt.blockFuncs,
+ })
+
+ _, _, err = shell.Exec(context.Background(), tt.command)
+
+ if tt.shouldBlock {
+ if err == nil {
+ t.Errorf("Expected command to be blocked, but it was allowed")
+ } else if !strings.Contains(err.Error(), "not allowed for security reasons") {
+ t.Errorf("Expected security error, got: %v", err)
+ }
+ } else {
+ // For non-blocked commands, we might get other errors (like command not found)
+ // but we shouldn't get the security error
+ if err != nil && strings.Contains(err.Error(), "not allowed for security reasons") {
+ t.Errorf("Command was unexpectedly blocked: %v", err)
+ }
+ }
+ })
+ }
+}
@@ -44,12 +44,16 @@ type noopLogger struct{}
func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
+// BlockFunc is a function that determines if a command should be blocked
+type BlockFunc func(args []string) bool
+
// Shell provides cross-platform shell execution with optional state persistence
type Shell struct {
- env []string
- cwd string
- mu sync.Mutex
- logger Logger
+ env []string
+ cwd string
+ mu sync.Mutex
+ logger Logger
+ blockFuncs []BlockFunc
}
// Options for creating a new shell
@@ -57,6 +61,7 @@ type Options struct {
WorkingDir string
Env []string
Logger Logger
+ BlockFuncs []BlockFunc
}
// NewShell creates a new shell instance with the given options
@@ -81,9 +86,10 @@ func NewShell(opts *Options) *Shell {
}
return &Shell{
- cwd: cwd,
- env: env,
- logger: logger,
+ cwd: cwd,
+ env: env,
+ logger: logger,
+ blockFuncs: opts.BlockFuncs,
}
}
@@ -152,6 +158,13 @@ func (s *Shell) SetEnv(key, value string) {
s.env = append(s.env, keyPrefix+value)
}
+// SetBlockFuncs sets the command block functions for the shell
+func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.blockFuncs = blockFuncs
+}
+
// Windows-specific commands that should use native shell
var windowsNativeCommands = map[string]bool{
"dir": true,
@@ -203,6 +216,60 @@ func (s *Shell) determineShellType(command string) ShellType {
return ShellTypePOSIX
}
+// CommandsBlocker creates a BlockFunc that blocks exact command matches
+func CommandsBlocker(bannedCommands []string) BlockFunc {
+ bannedSet := make(map[string]bool)
+ for _, cmd := range bannedCommands {
+ bannedSet[cmd] = true
+ }
+
+ return func(args []string) bool {
+ if len(args) == 0 {
+ return false
+ }
+ return bannedSet[args[0]]
+ }
+}
+
+// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
+func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
+ return func(args []string) bool {
+ for _, blocked := range blockedSubCommands {
+ if len(args) >= len(blocked) {
+ match := true
+ for i, part := range blocked {
+ if args[i] != part {
+ match = false
+ break
+ }
+ }
+ if match {
+ return true
+ }
+ }
+ }
+ return false
+ }
+}
+
+func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+ return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+ return func(ctx context.Context, args []string) error {
+ if len(args) == 0 {
+ return next(ctx, args)
+ }
+
+ for _, blockFunc := range s.blockFuncs {
+ if blockFunc(args) {
+ return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
+ }
+ }
+
+ return next(ctx, args)
+ }
+ }
+}
+
// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
var cmd *exec.Cmd
@@ -291,6 +358,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string,
interp.Interactive(false),
interp.Env(expand.ListEnviron(s.env...)),
interp.Dir(s.cwd),
+ interp.ExecHandlers(s.blockHandler()),
)
if err != nil {
return "", "", fmt.Errorf("could not run command: %w", err)