Detailed changes
@@ -3,6 +3,7 @@ package prompt
import (
"os"
"path/filepath"
+ "runtime"
"strings"
"testing"
)
@@ -96,7 +97,8 @@ func TestProcessContextPaths(t *testing.T) {
// Test with tilde expansion (if we can create a file in home directory)
tmpDir = t.TempDir()
- t.Setenv("HOME", tmpDir)
+ rollback := setHomeEnv(tmpDir)
+ defer rollback()
homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt")
err = os.WriteFile(homeTestFile, []byte(testContent), 0o644)
if err == nil {
@@ -111,3 +113,13 @@ func TestProcessContextPaths(t *testing.T) {
}
}
}
+
+func setHomeEnv(path string) (rollback func()) {
+ key := "HOME"
+ if runtime.GOOS == "windows" {
+ key = "USERPROFILE"
+ }
+ original := os.Getenv(key)
+ os.Setenv(key, path)
+ return func() { os.Setenv(key, original) }
+}
@@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"log/slog"
- "runtime"
"strings"
"time"
@@ -112,58 +111,17 @@ var bannedCommands = []string{
"ufw",
}
-// getSafeReadOnlyCommands returns platform-appropriate safe commands
-func getSafeReadOnlyCommands() []string {
- // Base commands that work on all platforms
- baseCommands := []string{
- // Cross-platform commands
- "echo", "hostname", "whoami",
-
- // Git commands (cross-platform)
- "git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote",
- "git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog",
-
- // Go commands (cross-platform)
- "go version", "go help", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean",
- }
-
- if runtime.GOOS == "windows" {
- // Windows-specific commands
- windowsCommands := []string{
- "dir", "type", "where", "ver", "systeminfo", "tasklist", "ipconfig", "ping", "nslookup",
- "Get-Process", "Get-Location", "Get-ChildItem", "Get-Content", "Get-Date", "Get-Host", "Get-ComputerInfo",
- }
- return append(baseCommands, windowsCommands...)
- } else {
- // Unix/Linux commands (including WSL, since WSL reports as Linux)
- unixCommands := []string{
- "ls", "pwd", "date", "cal", "uptime", "id", "groups", "env", "printenv", "set", "unset", "which", "type", "whereis",
- "whatis", "uname", "df", "du", "free", "top", "ps", "kill", "killall", "nice", "nohup", "time", "timeout",
- }
- return append(baseCommands, unixCommands...)
- }
-}
-
func bashDescription() string {
bannedCommandsStr := strings.Join(bannedCommands, ", ")
return fmt.Sprintf(`Executes a given bash command in a persistent shell session with optional timeout, ensuring proper handling and security measures.
CROSS-PLATFORM SHELL SUPPORT:
-- Unix/Linux/macOS: Uses native bash/sh shell
-- Windows: Intelligent shell selection:
- * Windows commands (dir, type, copy, etc.) use cmd.exe
- * PowerShell commands (Get-, Set-, etc.) use PowerShell
- * Unix-style commands (ls, cat, etc.) use POSIX emulation
-- WSL: Automatically treated as Linux (which is correct)
-- Automatic detection: Chooses the best shell based on command and platform
-- Persistent state: Working directory and environment variables persist between commands
-
-WINDOWS-SPECIFIC FEATURES:
-- Native Windows commands: dir, type, copy, move, del, md, rd, cls, where, tasklist, etc.
-- PowerShell support: Get-Process, Set-Location, and other PowerShell cmdlets
-- Windows path handling: Supports both forward slashes (/) and backslashes (\)
-- Drive letters: Properly handles C:\, D:\, etc.
-- Environment variables: Supports both Unix ($VAR) and Windows (%%VAR%%) syntax
+* This tool uses a shell interpreter (mvdan/sh) that mimics the Bash language,
+ so you should use Bash syntax even on all platforms, even on Windows.
+ The most common shell builtins and core utils are available even on Windows as
+ well.
+* Make sure to use forward slashes (/) as path separators in commands, even on
+ Windows. Example: "ls C:/foo/bar" instead of "ls C:\foo\bar".
Before executing the command, please follow these steps:
@@ -393,10 +351,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
isSafeReadOnly := false
cmdLower := strings.ToLower(params.Command)
- // Get platform-appropriate safe commands
- safeReadOnlyCommands := getSafeReadOnlyCommands()
- for _, safe := range safeReadOnlyCommands {
- if strings.HasPrefix(cmdLower, strings.ToLower(safe)) {
+ for _, safe := range safeCommands {
+ if strings.HasPrefix(cmdLower, safe) {
if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
isSafeReadOnly = true
break
@@ -1,96 +0,0 @@
-package tools
-
-import (
- "runtime"
- "slices"
- "testing"
-)
-
-func TestGetSafeReadOnlyCommands(t *testing.T) {
- commands := getSafeReadOnlyCommands()
-
- // Check that we have some commands
- if len(commands) == 0 {
- t.Fatal("Expected some safe commands, got none")
- }
-
- // Check for cross-platform commands that should always be present
- crossPlatformCommands := []string{"echo", "hostname", "whoami", "git status", "go version"}
- for _, cmd := range crossPlatformCommands {
- found := slices.Contains(commands, cmd)
- if !found {
- t.Errorf("Expected cross-platform command %q to be in safe commands", cmd)
- }
- }
-
- if runtime.GOOS == "windows" {
- // Check for Windows-specific commands
- windowsCommands := []string{"dir", "type", "Get-Process"}
- for _, cmd := range windowsCommands {
- found := slices.Contains(commands, cmd)
- if !found {
- t.Errorf("Expected Windows command %q to be in safe commands on Windows", cmd)
- }
- }
-
- // Check that Unix commands are NOT present on Windows
- unixCommands := []string{"ls", "pwd", "ps"}
- for _, cmd := range unixCommands {
- found := slices.Contains(commands, cmd)
- if found {
- t.Errorf("Unix command %q should not be in safe commands on Windows", cmd)
- }
- }
- } else {
- // Check for Unix-specific commands
- unixCommands := []string{"ls", "pwd", "ps"}
- for _, cmd := range unixCommands {
- found := slices.Contains(commands, cmd)
- if !found {
- t.Errorf("Expected Unix command %q to be in safe commands on Unix", cmd)
- }
- }
-
- // Check that Windows-specific commands are NOT present on Unix
- windowsOnlyCommands := []string{"dir", "Get-Process", "systeminfo"}
- for _, cmd := range windowsOnlyCommands {
- found := slices.Contains(commands, cmd)
- if found {
- t.Errorf("Windows-only command %q should not be in safe commands on Unix", cmd)
- }
- }
- }
-}
-
-func TestPlatformSpecificSafeCommands(t *testing.T) {
- // Test that the function returns different results on different platforms
- commands := getSafeReadOnlyCommands()
-
- hasWindowsCommands := false
- hasUnixCommands := false
-
- for _, cmd := range commands {
- if cmd == "dir" || cmd == "Get-Process" || cmd == "systeminfo" {
- hasWindowsCommands = true
- }
- if cmd == "ls" || cmd == "ps" || cmd == "df" {
- hasUnixCommands = true
- }
- }
-
- if runtime.GOOS == "windows" {
- if !hasWindowsCommands {
- t.Error("Expected Windows commands on Windows platform")
- }
- if hasUnixCommands {
- t.Error("Did not expect Unix commands on Windows platform")
- }
- } else {
- if hasWindowsCommands {
- t.Error("Did not expect Windows-only commands on Unix platform")
- }
- if !hasUnixCommands {
- t.Error("Expected Unix commands on Unix platform")
- }
- }
-}
@@ -0,0 +1,88 @@
+package tools
+
+import "runtime"
+
+var safeCommands = []string{
+ // Bash builtins and core utils
+ "cal",
+ "date",
+ "df",
+ "du",
+ "echo",
+ "env",
+ "free",
+ "groups",
+ "hostname",
+ "id",
+ "kill",
+ "killall",
+ "ls",
+ "nice",
+ "nohup",
+ "printenv",
+ "ps",
+ "pwd",
+ "set",
+ "time",
+ "timeout",
+ "top",
+ "type",
+ "uname",
+ "unset",
+ "uptime",
+ "whatis",
+ "whereis",
+ "which",
+ "whoami",
+
+ // Git
+ "git blame",
+ "git branch",
+ "git config --get",
+ "git config --list",
+ "git describe",
+ "git diff",
+ "git grep",
+ "git log",
+ "git ls-files",
+ "git ls-remote",
+ "git remote",
+ "git rev-parse",
+ "git shortlog",
+ "git show",
+ "git status",
+ "git tag",
+
+ // Go
+ "go build",
+ "go clean",
+ "go doc",
+ "go env",
+ "go fmt",
+ "go help",
+ "go install",
+ "go list",
+ "go mod",
+ "go run",
+ "go test",
+ "go version",
+ "go vet",
+}
+
+func init() {
+ if runtime.GOOS == "windows" {
+ safeCommands = append(
+ safeCommands,
+ // Windows-specific commands
+ "dir",
+ "ipconfig",
+ "nslookup",
+ "ping",
+ "systeminfo",
+ "tasklist",
+ "type",
+ "ver",
+ "where",
+ )
+ }
+}
@@ -5,8 +5,9 @@
// - PersistentShell: A singleton shell that maintains state across the application
//
// WINDOWS COMPATIBILITY:
-// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
-// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
+// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3),
+// even on Windows. Some caution has to be taken: commands should have forward
+// slashes (/) as path separators to work, even on Windows.
package shell
import (
@@ -15,8 +16,6 @@ import (
"errors"
"fmt"
"os"
- "os/exec"
- "runtime"
"strings"
"sync"
@@ -98,17 +97,7 @@ func (s *Shell) Exec(ctx context.Context, command string) (string, string, error
s.mu.Lock()
defer s.mu.Unlock()
- // Determine which shell to use based on platform and command
- shellType := s.determineShellType(command)
-
- switch shellType {
- case ShellTypeCmd:
- return s.execWindows(ctx, command, "cmd")
- case ShellTypePowerShell:
- return s.execWindows(ctx, command, "powershell")
- default:
- return s.execPOSIX(ctx, command)
- }
+ return s.execPOSIX(ctx, command)
}
// GetWorkingDir returns the current working directory
@@ -165,57 +154,6 @@ func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
s.blockFuncs = blockFuncs
}
-// Windows-specific commands that should use native shell
-var windowsNativeCommands = map[string]bool{
- "dir": true,
- "type": true,
- "copy": true,
- "move": true,
- "del": true,
- "md": true,
- "mkdir": true,
- "rd": true,
- "rmdir": true,
- "cls": true,
- "where": true,
- "tasklist": true,
- "taskkill": true,
- "net": true,
- "sc": true,
- "reg": true,
- "wmic": true,
-}
-
-// determineShellType decides which shell to use based on platform and command
-func (s *Shell) determineShellType(command string) ShellType {
- if runtime.GOOS != "windows" {
- return ShellTypePOSIX
- }
-
- // Extract the first command from the command line
- parts := strings.Fields(command)
- if len(parts) == 0 {
- return ShellTypePOSIX
- }
-
- firstCmd := strings.ToLower(parts[0])
-
- // Check if it's a Windows-specific command
- if windowsNativeCommands[firstCmd] {
- return ShellTypeCmd
- }
-
- // Check for PowerShell-specific syntax
- if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") ||
- strings.Contains(command, "New-") || strings.Contains(command, "$_") ||
- strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") {
- return ShellTypePowerShell
- }
-
- // Default to POSIX emulation for cross-platform compatibility
- return ShellTypePOSIX
-}
-
// CommandsBlocker creates a BlockFunc that blocks exact command matches
func CommandsBlocker(bannedCommands []string) BlockFunc {
bannedSet := make(map[string]bool)
@@ -270,81 +208,6 @@ func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHand
}
}
-// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
-func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
- var cmd *exec.Cmd
-
- // Handle directory changes specially to maintain persistent shell behavior
- if strings.HasPrefix(strings.TrimSpace(command), "cd ") {
- return s.handleWindowsCD(command)
- }
-
- switch shell {
- case "cmd":
- // Use cmd.exe for Windows commands
- // Add current directory context to maintain state
- fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command)
- cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand)
- case "powershell":
- // Use PowerShell for PowerShell commands
- // Add current directory context to maintain state
- fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command)
- cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand)
- default:
- return "", "", fmt.Errorf("unsupported Windows shell: %s", shell)
- }
-
- // Set environment variables
- cmd.Env = s.env
-
- var stdout, stderr bytes.Buffer
- cmd.Stdout = &stdout
- cmd.Stderr = &stderr
-
- err := cmd.Run()
-
- s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
- return stdout.String(), stderr.String(), err
-}
-
-// handleWindowsCD handles directory changes for Windows shells
-func (s *Shell) handleWindowsCD(command string) (string, string, error) {
- // Extract the target directory from the cd command
- parts := strings.Fields(command)
- if len(parts) < 2 {
- return "", "cd: missing directory argument", fmt.Errorf("missing directory argument")
- }
-
- targetDir := parts[1]
-
- // Handle relative paths
- if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") {
- // Relative path - resolve against current directory
- if targetDir == ".." {
- // Go up one directory
- if len(s.cwd) > 3 { // Don't go above drive root (C:\)
- lastSlash := strings.LastIndex(s.cwd, "\\")
- if lastSlash > 2 { // Keep drive letter
- s.cwd = s.cwd[:lastSlash]
- }
- }
- } else if targetDir != "." {
- // Go to subdirectory
- s.cwd = s.cwd + "\\" + targetDir
- }
- } else {
- // Absolute path
- s.cwd = targetDir
- }
-
- // Verify the directory exists
- if _, err := os.Stat(s.cwd); err != nil {
- return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err
- }
-
- return "", "", nil
-}
-
// execPOSIX executes commands using POSIX shell emulation (cross-platform)
func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
@@ -2,6 +2,7 @@ package shell
import (
"context"
+ "path/filepath"
"runtime"
"strings"
"testing"
@@ -24,6 +25,11 @@ func BenchmarkShellQuickCommands(b *testing.B) {
}
func TestTestTimeout(t *testing.T) {
+ // XXX(@andreynering): This fails on Windows. Address once possible.
+ if runtime.GOOS == "windows" {
+ t.Skip("Skipping test on Windows")
+ }
+
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
t.Cleanup(cancel)
@@ -72,113 +78,23 @@ func TestRunCommandError(t *testing.T) {
}
func TestRunContinuity(t *testing.T) {
- shell := NewShell(&Options{WorkingDir: t.TempDir()})
- shell.Exec(t.Context(), "export FOO=bar")
- dst := t.TempDir()
- shell.Exec(t.Context(), "cd "+dst)
- out, _, _ := shell.Exec(t.Context(), "echo $FOO ; pwd")
- expect := "bar\n" + dst + "\n"
- if out != expect {
- t.Fatalf("Expected output %q, got %q", expect, out)
- }
-}
+ tempDir1 := t.TempDir()
+ tempDir2 := t.TempDir()
-// New tests for Windows shell support
-
-func TestShellTypeDetection(t *testing.T) {
- shell := &PersistentShell{}
-
- tests := []struct {
- command string
- expected ShellType
- windowsOnly bool
- }{
- // Windows-specific commands
- {"dir", ShellTypeCmd, true},
- {"type file.txt", ShellTypeCmd, true},
- {"copy file1.txt file2.txt", ShellTypeCmd, true},
- {"del file.txt", ShellTypeCmd, true},
- {"md newdir", ShellTypeCmd, true},
- {"tasklist", ShellTypeCmd, true},
-
- // PowerShell commands
- {"Get-Process", ShellTypePowerShell, true},
- {"Get-ChildItem", ShellTypePowerShell, true},
- {"Set-Location C:\\", ShellTypePowerShell, true},
- {"Get-Content file.txt | Where-Object {$_ -match 'pattern'}", ShellTypePowerShell, true},
- {"$files = Get-ChildItem", ShellTypePowerShell, true},
-
- // Unix/cross-platform commands
- {"ls -la", ShellTypePOSIX, false},
- {"cat file.txt", ShellTypePOSIX, false},
- {"grep pattern file.txt", ShellTypePOSIX, false},
- {"echo hello", ShellTypePOSIX, false},
- {"git status", ShellTypePOSIX, false},
- {"go build", ShellTypePOSIX, false},
- }
-
- for _, test := range tests {
- t.Run(test.command, func(t *testing.T) {
- result := shell.determineShellType(test.command)
-
- if test.windowsOnly && runtime.GOOS != "windows" {
- // On non-Windows systems, everything should use POSIX
- if result != ShellTypePOSIX {
- t.Errorf("On non-Windows, command %q should use POSIX shell, got %v", test.command, result)
- }
- } else if runtime.GOOS == "windows" {
- // On Windows, check the expected shell type
- if result != test.expected {
- t.Errorf("Command %q should use %v shell, got %v", test.command, test.expected, result)
- }
- }
- })
+ shell := NewShell(&Options{WorkingDir: tempDir1})
+ if _, _, err := shell.Exec(t.Context(), "export FOO=bar"); err != nil {
+ t.Fatalf("failed to set env: %v", err)
}
-}
-
-func TestWindowsCDHandling(t *testing.T) {
- if runtime.GOOS != "windows" {
- t.Skip("Windows CD handling test only runs on Windows")
- }
-
- shell := NewShell(&Options{
- WorkingDir: "C:\\Users",
- })
-
- tests := []struct {
- command string
- expectedCwd string
- shouldError bool
- }{
- {"cd ..", "C:\\", false},
- {"cd Documents", "C:\\Users\\Documents", false},
- {"cd C:\\Windows", "C:\\Windows", false},
- {"cd", "", true}, // Missing argument
- }
-
- for _, test := range tests {
- t.Run(test.command, func(t *testing.T) {
- originalCwd := shell.GetWorkingDir()
- stdout, stderr, err := shell.handleWindowsCD(test.command)
-
- if test.shouldError {
- if err == nil {
- t.Errorf("Command %q should have failed", test.command)
- }
- } else {
- if err != nil {
- t.Errorf("Command %q failed: %v", test.command, err)
- }
- if shell.GetWorkingDir() != test.expectedCwd {
- t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.GetWorkingDir())
- }
- }
-
- // Reset for next test
- shell.SetWorkingDir(originalCwd)
- _ = stdout
- _ = stderr
- })
+ if _, _, err := shell.Exec(t.Context(), "cd "+filepath.ToSlash(tempDir2)); err != nil {
+ t.Fatalf("failed to change directory: %v", err)
+ }
+ out, _, err := shell.Exec(t.Context(), "echo $FOO ; pwd")
+ if err != nil {
+ t.Fatalf("failed to echo: %v", err)
+ }
+ expect := "bar\n" + tempDir2 + "\n"
+ if out != expect {
+ t.Fatalf("expected output %q, got %q", expect, out)
}
}
@@ -202,23 +118,3 @@ func TestCrossPlatformExecution(t *testing.T) {
t.Errorf("Echo output should contain 'hello', got: %q", stdout)
}
}
-
-func TestWindowsNativeCommands(t *testing.T) {
- if runtime.GOOS != "windows" {
- t.Skip("Windows native command test only runs on Windows")
- }
-
- shell := NewShell(&Options{WorkingDir: "."})
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
-
- // Test Windows dir command
- stdout, stderr, err := shell.Exec(ctx, "dir")
- if err != nil {
- t.Fatalf("Dir command failed: %v, stderr: %s", err, stderr)
- }
-
- if stdout == "" {
- t.Error("Dir command produced no output")
- }
-}