diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go index ce7fa0fb35cfdf021b886a96a828202001588a7f..5ece16038addeaf152a421735c2696e323e83a90 100644 --- a/internal/llm/prompt/prompt_test.go +++ b/internal/llm/prompt/prompt_test.go @@ -3,6 +3,7 @@ package prompt import ( "os" "path/filepath" + "runtime" "strings" "testing" ) @@ -96,7 +97,8 @@ func TestProcessContextPaths(t *testing.T) { // Test with tilde expansion (if we can create a file in home directory) tmpDir = t.TempDir() - t.Setenv("HOME", tmpDir) + rollback := setHomeEnv(tmpDir) + defer rollback() homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt") err = os.WriteFile(homeTestFile, []byte(testContent), 0o644) if err == nil { @@ -111,3 +113,13 @@ func TestProcessContextPaths(t *testing.T) { } } } + +func setHomeEnv(path string) (rollback func()) { + key := "HOME" + if runtime.GOOS == "windows" { + key = "USERPROFILE" + } + original := os.Getenv(key) + os.Setenv(key, path) + return func() { os.Setenv(key, original) } +} diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 6d7a9a32b3829da02021be80e6e41e28888efd83..bee06caf110d5c4e47ae9bfabbf7909942e8346b 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "log/slog" - "runtime" "strings" "time" @@ -112,58 +111,17 @@ var bannedCommands = []string{ "ufw", } -// getSafeReadOnlyCommands returns platform-appropriate safe commands -func getSafeReadOnlyCommands() []string { - // Base commands that work on all platforms - baseCommands := []string{ - // Cross-platform commands - "echo", "hostname", "whoami", - - // Git commands (cross-platform) - "git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote", - "git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog", - - // Go commands (cross-platform) - "go version", "go help", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean", - } - - if runtime.GOOS == "windows" { - // Windows-specific commands - windowsCommands := []string{ - "dir", "type", "where", "ver", "systeminfo", "tasklist", "ipconfig", "ping", "nslookup", - "Get-Process", "Get-Location", "Get-ChildItem", "Get-Content", "Get-Date", "Get-Host", "Get-ComputerInfo", - } - return append(baseCommands, windowsCommands...) - } else { - // Unix/Linux commands (including WSL, since WSL reports as Linux) - unixCommands := []string{ - "ls", "pwd", "date", "cal", "uptime", "id", "groups", "env", "printenv", "set", "unset", "which", "type", "whereis", - "whatis", "uname", "df", "du", "free", "top", "ps", "kill", "killall", "nice", "nohup", "time", "timeout", - } - return append(baseCommands, unixCommands...) - } -} - func bashDescription() string { bannedCommandsStr := strings.Join(bannedCommands, ", ") return fmt.Sprintf(`Executes a given bash command in a persistent shell session with optional timeout, ensuring proper handling and security measures. CROSS-PLATFORM SHELL SUPPORT: -- Unix/Linux/macOS: Uses native bash/sh shell -- Windows: Intelligent shell selection: - * Windows commands (dir, type, copy, etc.) use cmd.exe - * PowerShell commands (Get-, Set-, etc.) use PowerShell - * Unix-style commands (ls, cat, etc.) use POSIX emulation -- WSL: Automatically treated as Linux (which is correct) -- Automatic detection: Chooses the best shell based on command and platform -- Persistent state: Working directory and environment variables persist between commands - -WINDOWS-SPECIFIC FEATURES: -- Native Windows commands: dir, type, copy, move, del, md, rd, cls, where, tasklist, etc. -- PowerShell support: Get-Process, Set-Location, and other PowerShell cmdlets -- Windows path handling: Supports both forward slashes (/) and backslashes (\) -- Drive letters: Properly handles C:\, D:\, etc. -- Environment variables: Supports both Unix ($VAR) and Windows (%%VAR%%) syntax +* This tool uses a shell interpreter (mvdan/sh) that mimics the Bash language, + so you should use Bash syntax even on all platforms, even on Windows. + The most common shell builtins and core utils are available even on Windows as + well. +* Make sure to use forward slashes (/) as path separators in commands, even on + Windows. Example: "ls C:/foo/bar" instead of "ls C:\foo\bar". Before executing the command, please follow these steps: @@ -393,10 +351,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) isSafeReadOnly := false cmdLower := strings.ToLower(params.Command) - // Get platform-appropriate safe commands - safeReadOnlyCommands := getSafeReadOnlyCommands() - for _, safe := range safeReadOnlyCommands { - if strings.HasPrefix(cmdLower, strings.ToLower(safe)) { + for _, safe := range safeCommands { + if strings.HasPrefix(cmdLower, safe) { if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' { isSafeReadOnly = true break diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go deleted file mode 100644 index a810002749408af2bb89cb958b5999dc2da3bcb3..0000000000000000000000000000000000000000 --- a/internal/llm/tools/bash_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package tools - -import ( - "runtime" - "slices" - "testing" -) - -func TestGetSafeReadOnlyCommands(t *testing.T) { - commands := getSafeReadOnlyCommands() - - // Check that we have some commands - if len(commands) == 0 { - t.Fatal("Expected some safe commands, got none") - } - - // Check for cross-platform commands that should always be present - crossPlatformCommands := []string{"echo", "hostname", "whoami", "git status", "go version"} - for _, cmd := range crossPlatformCommands { - found := slices.Contains(commands, cmd) - if !found { - t.Errorf("Expected cross-platform command %q to be in safe commands", cmd) - } - } - - if runtime.GOOS == "windows" { - // Check for Windows-specific commands - windowsCommands := []string{"dir", "type", "Get-Process"} - for _, cmd := range windowsCommands { - found := slices.Contains(commands, cmd) - if !found { - t.Errorf("Expected Windows command %q to be in safe commands on Windows", cmd) - } - } - - // Check that Unix commands are NOT present on Windows - unixCommands := []string{"ls", "pwd", "ps"} - for _, cmd := range unixCommands { - found := slices.Contains(commands, cmd) - if found { - t.Errorf("Unix command %q should not be in safe commands on Windows", cmd) - } - } - } else { - // Check for Unix-specific commands - unixCommands := []string{"ls", "pwd", "ps"} - for _, cmd := range unixCommands { - found := slices.Contains(commands, cmd) - if !found { - t.Errorf("Expected Unix command %q to be in safe commands on Unix", cmd) - } - } - - // Check that Windows-specific commands are NOT present on Unix - windowsOnlyCommands := []string{"dir", "Get-Process", "systeminfo"} - for _, cmd := range windowsOnlyCommands { - found := slices.Contains(commands, cmd) - if found { - t.Errorf("Windows-only command %q should not be in safe commands on Unix", cmd) - } - } - } -} - -func TestPlatformSpecificSafeCommands(t *testing.T) { - // Test that the function returns different results on different platforms - commands := getSafeReadOnlyCommands() - - hasWindowsCommands := false - hasUnixCommands := false - - for _, cmd := range commands { - if cmd == "dir" || cmd == "Get-Process" || cmd == "systeminfo" { - hasWindowsCommands = true - } - if cmd == "ls" || cmd == "ps" || cmd == "df" { - hasUnixCommands = true - } - } - - if runtime.GOOS == "windows" { - if !hasWindowsCommands { - t.Error("Expected Windows commands on Windows platform") - } - if hasUnixCommands { - t.Error("Did not expect Unix commands on Windows platform") - } - } else { - if hasWindowsCommands { - t.Error("Did not expect Windows-only commands on Unix platform") - } - if !hasUnixCommands { - t.Error("Expected Unix commands on Unix platform") - } - } -} diff --git a/internal/llm/tools/safe.go b/internal/llm/tools/safe.go new file mode 100644 index 0000000000000000000000000000000000000000..2ea5b4437447fc89f7cc2dff8671802f5a3f1cf0 --- /dev/null +++ b/internal/llm/tools/safe.go @@ -0,0 +1,88 @@ +package tools + +import "runtime" + +var safeCommands = []string{ + // Bash builtins and core utils + "cal", + "date", + "df", + "du", + "echo", + "env", + "free", + "groups", + "hostname", + "id", + "kill", + "killall", + "ls", + "nice", + "nohup", + "printenv", + "ps", + "pwd", + "set", + "time", + "timeout", + "top", + "type", + "uname", + "unset", + "uptime", + "whatis", + "whereis", + "which", + "whoami", + + // Git + "git blame", + "git branch", + "git config --get", + "git config --list", + "git describe", + "git diff", + "git grep", + "git log", + "git ls-files", + "git ls-remote", + "git remote", + "git rev-parse", + "git shortlog", + "git show", + "git status", + "git tag", + + // Go + "go build", + "go clean", + "go doc", + "go env", + "go fmt", + "go help", + "go install", + "go list", + "go mod", + "go run", + "go test", + "go version", + "go vet", +} + +func init() { + if runtime.GOOS == "windows" { + safeCommands = append( + safeCommands, + // Windows-specific commands + "dir", + "ipconfig", + "nslookup", + "ping", + "systeminfo", + "tasklist", + "type", + "ver", + "where", + ) + } +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index b655c5dbecf5b69c7ad102c53108733515138771..c86d3c6531d2be3c2d91389d514d235d32c8fd65 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -5,8 +5,9 @@ // - PersistentShell: A singleton shell that maintains state across the application // // WINDOWS COMPATIBILITY: -// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and -// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility. +// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3), +// even on Windows. Some caution has to be taken: commands should have forward +// slashes (/) as path separators to work, even on Windows. package shell import ( @@ -15,8 +16,6 @@ import ( "errors" "fmt" "os" - "os/exec" - "runtime" "strings" "sync" @@ -98,17 +97,7 @@ func (s *Shell) Exec(ctx context.Context, command string) (string, string, error s.mu.Lock() defer s.mu.Unlock() - // Determine which shell to use based on platform and command - shellType := s.determineShellType(command) - - switch shellType { - case ShellTypeCmd: - return s.execWindows(ctx, command, "cmd") - case ShellTypePowerShell: - return s.execWindows(ctx, command, "powershell") - default: - return s.execPOSIX(ctx, command) - } + return s.execPOSIX(ctx, command) } // GetWorkingDir returns the current working directory @@ -165,57 +154,6 @@ func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) { s.blockFuncs = blockFuncs } -// Windows-specific commands that should use native shell -var windowsNativeCommands = map[string]bool{ - "dir": true, - "type": true, - "copy": true, - "move": true, - "del": true, - "md": true, - "mkdir": true, - "rd": true, - "rmdir": true, - "cls": true, - "where": true, - "tasklist": true, - "taskkill": true, - "net": true, - "sc": true, - "reg": true, - "wmic": true, -} - -// determineShellType decides which shell to use based on platform and command -func (s *Shell) determineShellType(command string) ShellType { - if runtime.GOOS != "windows" { - return ShellTypePOSIX - } - - // Extract the first command from the command line - parts := strings.Fields(command) - if len(parts) == 0 { - return ShellTypePOSIX - } - - firstCmd := strings.ToLower(parts[0]) - - // Check if it's a Windows-specific command - if windowsNativeCommands[firstCmd] { - return ShellTypeCmd - } - - // Check for PowerShell-specific syntax - if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") || - strings.Contains(command, "New-") || strings.Contains(command, "$_") || - strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") { - return ShellTypePowerShell - } - - // Default to POSIX emulation for cross-platform compatibility - return ShellTypePOSIX -} - // CommandsBlocker creates a BlockFunc that blocks exact command matches func CommandsBlocker(bannedCommands []string) BlockFunc { bannedSet := make(map[string]bool) @@ -270,81 +208,6 @@ func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHand } } -// execWindows executes commands using native Windows shells (cmd.exe or PowerShell) -func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) { - var cmd *exec.Cmd - - // Handle directory changes specially to maintain persistent shell behavior - if strings.HasPrefix(strings.TrimSpace(command), "cd ") { - return s.handleWindowsCD(command) - } - - switch shell { - case "cmd": - // Use cmd.exe for Windows commands - // Add current directory context to maintain state - fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command) - cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand) - case "powershell": - // Use PowerShell for PowerShell commands - // Add current directory context to maintain state - fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command) - cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand) - default: - return "", "", fmt.Errorf("unsupported Windows shell: %s", shell) - } - - // Set environment variables - cmd.Env = s.env - - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - err := cmd.Run() - - s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err) - return stdout.String(), stderr.String(), err -} - -// handleWindowsCD handles directory changes for Windows shells -func (s *Shell) handleWindowsCD(command string) (string, string, error) { - // Extract the target directory from the cd command - parts := strings.Fields(command) - if len(parts) < 2 { - return "", "cd: missing directory argument", fmt.Errorf("missing directory argument") - } - - targetDir := parts[1] - - // Handle relative paths - if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") { - // Relative path - resolve against current directory - if targetDir == ".." { - // Go up one directory - if len(s.cwd) > 3 { // Don't go above drive root (C:\) - lastSlash := strings.LastIndex(s.cwd, "\\") - if lastSlash > 2 { // Keep drive letter - s.cwd = s.cwd[:lastSlash] - } - } - } else if targetDir != "." { - // Go to subdirectory - s.cwd = s.cwd + "\\" + targetDir - } - } else { - // Absolute path - s.cwd = targetDir - } - - // Verify the directory exists - if _, err := os.Stat(s.cwd); err != nil { - return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err - } - - return "", "", nil -} - // execPOSIX executes commands using POSIX shell emulation (cross-platform) func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) { line, err := syntax.NewParser().Parse(strings.NewReader(command), "") diff --git a/internal/shell/shell_test.go b/internal/shell/shell_test.go index 417743caef7fa386f8c23d418682ab6a364e8e3e..66586b7f41c92486f7a8977d8ab34909de187c28 100644 --- a/internal/shell/shell_test.go +++ b/internal/shell/shell_test.go @@ -2,6 +2,7 @@ package shell import ( "context" + "path/filepath" "runtime" "strings" "testing" @@ -24,6 +25,11 @@ func BenchmarkShellQuickCommands(b *testing.B) { } func TestTestTimeout(t *testing.T) { + // XXX(@andreynering): This fails on Windows. Address once possible. + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows") + } + ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond) t.Cleanup(cancel) @@ -72,113 +78,23 @@ func TestRunCommandError(t *testing.T) { } func TestRunContinuity(t *testing.T) { - shell := NewShell(&Options{WorkingDir: t.TempDir()}) - shell.Exec(t.Context(), "export FOO=bar") - dst := t.TempDir() - shell.Exec(t.Context(), "cd "+dst) - out, _, _ := shell.Exec(t.Context(), "echo $FOO ; pwd") - expect := "bar\n" + dst + "\n" - if out != expect { - t.Fatalf("Expected output %q, got %q", expect, out) - } -} + tempDir1 := t.TempDir() + tempDir2 := t.TempDir() -// New tests for Windows shell support - -func TestShellTypeDetection(t *testing.T) { - shell := &PersistentShell{} - - tests := []struct { - command string - expected ShellType - windowsOnly bool - }{ - // Windows-specific commands - {"dir", ShellTypeCmd, true}, - {"type file.txt", ShellTypeCmd, true}, - {"copy file1.txt file2.txt", ShellTypeCmd, true}, - {"del file.txt", ShellTypeCmd, true}, - {"md newdir", ShellTypeCmd, true}, - {"tasklist", ShellTypeCmd, true}, - - // PowerShell commands - {"Get-Process", ShellTypePowerShell, true}, - {"Get-ChildItem", ShellTypePowerShell, true}, - {"Set-Location C:\\", ShellTypePowerShell, true}, - {"Get-Content file.txt | Where-Object {$_ -match 'pattern'}", ShellTypePowerShell, true}, - {"$files = Get-ChildItem", ShellTypePowerShell, true}, - - // Unix/cross-platform commands - {"ls -la", ShellTypePOSIX, false}, - {"cat file.txt", ShellTypePOSIX, false}, - {"grep pattern file.txt", ShellTypePOSIX, false}, - {"echo hello", ShellTypePOSIX, false}, - {"git status", ShellTypePOSIX, false}, - {"go build", ShellTypePOSIX, false}, - } - - for _, test := range tests { - t.Run(test.command, func(t *testing.T) { - result := shell.determineShellType(test.command) - - if test.windowsOnly && runtime.GOOS != "windows" { - // On non-Windows systems, everything should use POSIX - if result != ShellTypePOSIX { - t.Errorf("On non-Windows, command %q should use POSIX shell, got %v", test.command, result) - } - } else if runtime.GOOS == "windows" { - // On Windows, check the expected shell type - if result != test.expected { - t.Errorf("Command %q should use %v shell, got %v", test.command, test.expected, result) - } - } - }) + shell := NewShell(&Options{WorkingDir: tempDir1}) + if _, _, err := shell.Exec(t.Context(), "export FOO=bar"); err != nil { + t.Fatalf("failed to set env: %v", err) } -} - -func TestWindowsCDHandling(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Windows CD handling test only runs on Windows") - } - - shell := NewShell(&Options{ - WorkingDir: "C:\\Users", - }) - - tests := []struct { - command string - expectedCwd string - shouldError bool - }{ - {"cd ..", "C:\\", false}, - {"cd Documents", "C:\\Users\\Documents", false}, - {"cd C:\\Windows", "C:\\Windows", false}, - {"cd", "", true}, // Missing argument - } - - for _, test := range tests { - t.Run(test.command, func(t *testing.T) { - originalCwd := shell.GetWorkingDir() - stdout, stderr, err := shell.handleWindowsCD(test.command) - - if test.shouldError { - if err == nil { - t.Errorf("Command %q should have failed", test.command) - } - } else { - if err != nil { - t.Errorf("Command %q failed: %v", test.command, err) - } - if shell.GetWorkingDir() != test.expectedCwd { - t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.GetWorkingDir()) - } - } - - // Reset for next test - shell.SetWorkingDir(originalCwd) - _ = stdout - _ = stderr - }) + if _, _, err := shell.Exec(t.Context(), "cd "+filepath.ToSlash(tempDir2)); err != nil { + t.Fatalf("failed to change directory: %v", err) + } + out, _, err := shell.Exec(t.Context(), "echo $FOO ; pwd") + if err != nil { + t.Fatalf("failed to echo: %v", err) + } + expect := "bar\n" + tempDir2 + "\n" + if out != expect { + t.Fatalf("expected output %q, got %q", expect, out) } } @@ -202,23 +118,3 @@ func TestCrossPlatformExecution(t *testing.T) { t.Errorf("Echo output should contain 'hello', got: %q", stdout) } } - -func TestWindowsNativeCommands(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Windows native command test only runs on Windows") - } - - shell := NewShell(&Options{WorkingDir: "."}) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // Test Windows dir command - stdout, stderr, err := shell.Exec(ctx, "dir") - if err != nil { - t.Fatalf("Dir command failed: %v, stderr: %s", err, stderr) - } - - if stdout == "" { - t.Error("Dir command produced no output") - } -}